-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.go
81 lines (68 loc) · 1.85 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
package main
import (
"fmt"
"math/rand"
"time"
"github.com/dsprn/gograd/grad"
)
func main() {
// setting hyperparameters
alpha := 0.0005
arch := []int{16, 16, 1} // check this in the rust version
// creating model
m := grad.NewModel(2, arch)
// creating cross validation struct used to get best L2 lambda hyperparameter
xv := grad.NewXVal(
grad.GetInputs(),
grad.GetLabels(),
append([]int{2}, arch...), // prepend number of inputs to network architecture
grad.NewFloatingRange(0.0, 0.01, 0.0005),
grad.Alpha,
grad.MSE,
10,
)
l2Lambda := xv.SearchBestHyperpar()
fmt.Printf("==> L2 lambda value=%.4f\n", l2Lambda.GetData())
// seed and random int generator
s1 := rand.NewSource(time.Now().UnixNano())
r1 := rand.New(s1)
// choosing random data and label
fmt.Println("\n==> Choosing inputs and relative label from a preloaded dataset...")
dataIndex := r1.Intn(100)
fmt.Printf("==> Getting inputs at index %d and relative label\n", dataIndex)
inputs := grad.GetInputs()[dataIndex]
label := grad.GetLabels()[dataIndex]
fmt.Printf("==> Input values=%v\n", inputs)
fmt.Printf("==> Expected value=%f\n", label)
// main loop
fmt.Println("\n==> Start training the model...")
for round := 1; round < 100; round++ {
// prepping for this round
m.ZeroGrad()
// forward pass
pred := m.FeedForward(inputs)
loss := grad.MSE(pred[0], label)
// L2 regularization
reg := grad.L2(m.Params(), l2Lambda)
totLoss := loss.Add(reg)
// backward pass
totLoss.BackwardPass()
for _, el := range m.Params() {
el.Update(alpha)
}
fmt.Printf(
"pass=%d, predicted=%f, expected=%.1f, loss=%f, reg=%f, tot_loss=%f\n",
round,
pred[0].GetData(),
label,
loss.GetData(),
reg.GetData(),
totLoss.GetData(),
)
// early exit (when results are good enough)
if loss.GetData() < 0.0001 {
break
}
}
fmt.Println("==> DONE")
}