forked from matpalm/snli_nn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
stats.py
52 lines (44 loc) · 1.58 KB
/
stats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import json
import numpy as np
import os
import sys
import time
import util
class Stats(object):
def __init__(self, model, opts):
self.start_time = int(time.time())
self.n_egs_trained = 0
self.base_stats = {"model": model,
"run": "RUN_%s_%s" % (self.start_time, os.getpid())}
for opt in dir(opts):
if not opt.startswith("_"):
self.base_stats[opt] = getattr(opts, opt)
self.reset()
def reset(self):
self.train_costs = []
self.dev_costs = []
self.dev_accuracy = None
self.norms = None
def record_training_cost(self, cost):
self.train_costs.append(cost)
self.n_egs_trained += 1
def record_dev_cost(self, cost):
self.dev_costs.append(cost)
def set_dev_accuracy(self, dev_accuracy):
assert self.dev_accuracy is None
self.dev_accuracy = dev_accuracy
def set_param_norms(self, norms):
self.norms = norms
def flush_to_stdout(self, epoch):
stats = dict(self.base_stats)
stats.update({"dts_h": util.dts(), "epoch": epoch,
"n_egs_trained": self.n_egs_trained,
"elapsed_time": int(time.time()) - self.start_time,
"train_cost": util.mean_sd(self.train_costs),
"dev_cost": util.mean_sd(self.dev_costs),
"dev_acc": self.dev_accuracy})
if self.norms:
stats.update({"norms": self.norms})
print "STATS\t%s" % json.dumps(stats)
sys.stdout.flush()
self.reset()