-
Notifications
You must be signed in to change notification settings - Fork 8
/
helpers.go
75 lines (65 loc) · 2.37 KB
/
helpers.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
// Copyright (c) 2022, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package axon
import (
"fmt"
"cogentcore.org/core/base/mpi"
"cogentcore.org/core/core"
"github.com/emer/emergent/v2/ecmd"
)
////////////////////////////////////////////////////
// Misc
// ToggleLayersOff can be used to disable layers in a Network, for example if you are doing an ablation study.
func ToggleLayersOff(net *Network, layerNames []string, off bool) {
for _, lnm := range layerNames {
lyi := net.AxonLayerByName(lnm)
if lyi == nil {
fmt.Printf("layer not found: %s\n", lnm)
continue
}
lyi.SetOff(off)
}
}
/////////////////////////////////////////////
// Weights files
// WeightsFilename returns default current weights file name,
// using train run and epoch counters from looper
// and the RunName string identifying tag, parameters and starting run,
func WeightsFilename(net *Network, ctrString, runName string) string {
return net.Name() + "_" + runName + "_" + ctrString + ".wts.gz"
}
// SaveWeights saves network weights to filename with WeightsFilename information
// to identify the weights.
// only for 0 rank MPI if running mpi
// Returns the name of the file saved to, or empty if not saved.
func SaveWeights(net *Network, ctrString, runName string) string {
if mpi.WorldRank() > 0 {
return ""
}
fnm := WeightsFilename(net, ctrString, runName)
fmt.Printf("Saving Weights to: %s\n", fnm)
net.SaveWtsJSON(core.Filename(fnm))
return fnm
}
// SaveWeightsIfArgSet saves network weights if the "wts" arg has been set to true.
// uses WeightsFilename information to identify the weights.
// only for 0 rank MPI if running mpi
// Returns the name of the file saved to, or empty if not saved.
func SaveWeightsIfArgSet(net *Network, args *ecmd.Args, ctrString, runName string) string {
if args.Bool("wts") {
return SaveWeights(net, ctrString, runName)
}
return ""
}
// SaveWeightsIfConfigSet saves network weights if the given config
// bool value has been set to true.
// uses WeightsFilename information to identify the weights.
// only for 0 rank MPI if running mpi
// Returns the name of the file saved to, or empty if not saved.
func SaveWeightsIfConfigSet(net *Network, cfgWts bool, ctrString, runName string) string {
if cfgWts {
return SaveWeights(net, ctrString, runName)
}
return ""
}