-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrong model architecture in residual network? #15
Comments
Thank you. |
I tried to apply your suggestion (without further investigation), but it fails:
|
Weird it seems to work from my side 😓 Can I take a look at your modification? |
I made this patch: @@ -65,9 +65,9 @@ func (m *maebe) res(input *G.Node, filterCount int, name string) (*G.Node, batch
}
func (m *maebe) share(input *G.Node, filterCount, layer int) (*G.Node, batchNormOp, batchNormOp) {
- layer1, l1Op := m.res(input, filterCount, fmt.Sprintf("Layer1 of Shared Layer %d", layer))
+ _, l1Op := m.res(input, filterCount, fmt.Sprintf("Layer1 of Shared Layer %d", layer))
layer2, l2Op := m.res(input, filterCount, fmt.Sprintf("Layer2 of Shared Layer %d", layer))
- added := m.do(func() (*G.Node, error) { return G.Add(layer1, layer2) })
+ added := m.do(func() (*G.Node, error) { return G.Add(input, layer2) })
retVal := m.rectify(added)
return retVal, l1Op, l2Op
} and the learning function is basically: func learn() error {
conf := agogo.Config{
Name: "Tic Tac Toe",
NNConf: dual.DefaultConf(3, 3, 10),
MCTSConf: mcts.DefaultConfig(3),
UpdateThreshold: 0.52,
}
conf.NNConf.BatchSize = 100
conf.NNConf.Features = 2 // write a better encoding of the board, and increase features (and that allows you to increase K as well)
conf.NNConf.K = 3
conf.NNConf.SharedLayers = 3
conf.MCTSConf = mcts.Config{
PUCT: 1.0,
M: 3,
N: 3,
Timeout: 50 * time.Millisecond,
PassPreference: mcts.DontPreferPass,
Budget: 1000,
DumbPass: true,
RandomCount: 0,
}
outEnc := NewEncoder()
go func(h http.Handler) {
mux := http.NewServeMux()
mux.Handle("/ws", h)
mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("./htdocs"))))
log.Println("go to http://localhost:8080/static")
http.ListenAndServe(":8080", mux)
}(outEnc)
conf.Encoder = encodeBoard
conf.OutputEncoder = outEnc
g := mnk.TicTacToe()
a := agogo.New(g, conf)
reader := bufio.NewReader(os.Stdin)
fmt.Print("press ented when ready")
reader.ReadString('\n')
//a.Learn(5, 30, 200, 30) // 5 epochs, 50 episode, 100 NN iters, 100 games.
err := a.Learn(5, 50, 100, 100) // 5 epochs, 50 episode, 100 NN iters, 100 games.
if err != nil {
return err
}
err = a.Save("example.model")
if err != nil {
return err
}
return nil
} |
I see but I think it should be like this can you help give it a try:
|
I wonder is there any misconfiguration in model architecture. Specifically this function: https://github.com/gorgonia/agogo/blob/master/dualnet/ermahagerdmonards.go#L67
Because based from my understanding, from the paper (link) page 8/18 it said:
Point 6 means that the add operation should be from input to the block and each module should be in sequence. I wonder is this a correct implementation:
The text was updated successfully, but these errors were encountered: