forked from rai-project/evaluation
/
options.go
98 lines (83 loc) · 1.96 KB
/
options.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
package plotting
import (
"path/filepath"
"github.com/AlekSi/pointer"
"github.com/mitchellh/go-homedir"
)
type Options struct {
baseDir string
machineHostName string
frameworkName string
frameworkVersion string
modelName string
modelVersion string
batchSize int
useGPU *bool
ignoreReadErrors bool
}
type OptionModifier func(o *Options)
type OptionModifiers struct{}
var Option = OptionModifiers{}
func (o OptionModifiers) MachineHostName(hostName string) OptionModifier {
return func(o *Options) {
o.machineHostName = hostName
}
}
func (o OptionModifiers) BaseDir(dir string) OptionModifier {
return func(o *Options) {
o.baseDir = dir
}
}
func (o OptionModifiers) FrameworkName(s string) OptionModifier {
return func(o *Options) {
o.frameworkName = s
}
}
func (o OptionModifiers) FrameworkVersion(s string) OptionModifier {
return func(o *Options) {
o.frameworkVersion = s
}
}
func (o OptionModifiers) ModelName(s string) OptionModifier {
return func(o *Options) {
o.modelName = s
}
}
func (o OptionModifiers) ModelVersion(s string) OptionModifier {
return func(o *Options) {
o.modelVersion = s
}
}
func (o OptionModifiers) BatchSize(val int) OptionModifier {
return func(o *Options) {
o.batchSize = val
}
}
func (o OptionModifiers) UseGPU(val bool) OptionModifier {
return func(o *Options) {
o.useGPU = &val
}
}
func (o OptionModifiers) IgnoreReadErrors(val bool) OptionModifier {
return func(o *Options) {
o.ignoreReadErrors = val
}
}
func NewOptions(os ...OptionModifier) *Options {
home, _ := homedir.Dir()
opts := &Options{
baseDir: filepath.Join(home, "experiments"),
machineHostName: "ip-172-31-20-197",
frameworkName: "TensorFlow",
frameworkVersion: "1.12",
modelName: "BVLC_AlexNet_Caffe",
modelVersion: "1.0",
batchSize: 1,
useGPU: pointer.ToBool(true),
ignoreReadErrors: false,
}
for _, o := range os {
o(opts)
}
return opts
}