-
Notifications
You must be signed in to change notification settings - Fork 6
/
helper.py
43 lines (40 loc) · 1.41 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from os.path import join
import numpy as np
def getBestRunAll(modeltopdir,trialnum,logfile,keyword):
best_trial = -1
best_iter = -1
best_metric = ''
for trial in range(trialnum):
modeldir = join(modeltopdir,'trial'+str(trial))
t_iter,t_metric = getBestRun(modeldir,logfile,keyword)
if best_metric == '' or t_metric > best_metric:
best_metric = t_metric
best_trial = trial
best_iter = t_iter
return (best_trial,best_iter)
def getBestRun(modeldir,logfile,keyword):
with open(join(modeldir,logfile),'r') as f:
data = [x for x in f]
pick = [i for i in range(len(data)) if 'Testing net' in data[i]]
iter_cnt = []
metric_cnt = []
for i in pick[1:]:
x = data[i].split(' ')
idx = x.index('Iteration')+1
iter_cnt.append(x[idx].split(',')[0])
flag = False
j = i
while not flag:
j = j +1
if 'Iteration' in data[j]:
print 'Can\'t find the target metric:',keyword
sys.exit(1)
if keyword in data[j]:
x = data[j].split(' ')
idx = x.index(keyword)+2
metric_cnt.append(float(x[idx]))
flag = True
if keyword == 'accuracy':
return (iter_cnt[np.argmax(metric_cnt)],max(metric_cnt))
else:
return (iter_cnt[np.argmin(metric_cnt)],-min(metric_cnt))