forked from rai-project/dlframework
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.go
69 lines (60 loc) · 1.34 KB
/
utils.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
package cmd
import (
"strings"
"time"
"github.com/cheggaaa/pb"
)
func ParseModelName(model string) (string, string) {
if model == "all" {
return "all", ""
}
splt := strings.Split(model, "_")
modelName, modelVersion := splt[0:len(splt)-1], splt[len(splt)-1]
return strings.Join(modelName, "_"), modelVersion
}
func NewProgress(prefix string, count int) *pb.ProgressBar {
// get the new original progress bar.
//bar := pb.New(count).Prefix(prefix)
// TODO: set prefix of bar
bar := pb.New(count)
//bar.Set("prefix", prefix)
// Refresh rate for progress bar is set to 100 milliseconds.
bar.SetRefreshRate(time.Millisecond * 100)
bar.Start()
return bar
}
var (
DefaultEvaulationModels = []string{
"SqueezeNet_1.0",
"SqueezeNet_1.1",
"BVLC_AlexNet_1.0",
"BVLC_Reference_CaffeNet_1.0",
"BVLC_GoogLeNet_1.0",
"ResNet101_1.0",
"ResNet101_2.0",
"WRN50_2.0",
"BVLC_Reference_RCNN_ILSVRC13_1.0",
"Inception_3.0",
"Inception_4.0",
"ResNeXt50_32x4d_1.0",
"VGG16_1.0",
"VGG19_1.0",
}
DefaultEvaluationFrameworks = []string{
"mxnet",
"cntk",
"caffe2",
"tensorflow",
"tensorrt",
"caffe",
}
FrameworkNames = map[string]string{
"tensorflow": "TensorFlow",
"tensorrt": "TensorRT",
"mxnet": "MXNet",
"caffe": "Caffe",
"caffe2": "Caffe2",
"cntk": "CNTK",
"pytorch": "PyTorch",
}
)