-
Notifications
You must be signed in to change notification settings - Fork 60
/
utils.py
44 lines (30 loc) · 1.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from comet_ml import config
import tensorflow as tf
exp = config.experiment
def finalize_model(model, test_examples, test_labels, experiment):
# Log text
for i in range(20):
experiment.log_text(test_examples[0], metadata={"label": test_labels[0].item()})
# log confusion matrix
preds = model.predict(test_examples)
def onehot(val):
retval = [0, 0]
tmp = (val[0] + 1) / 2
tmp = int(round(tmp))
tmp = max(min(1, tmp), 0)
retval[tmp] = 1
return retval
new_preds = [onehot(v) for v in preds]
def index_to_example(index):
text = test_examples[index]
# data = experiment.log_text(text)
return {"sample": text.decode(),
"assetId": None,
"type": "string"}
experiment.log_confusion_matrix(new_preds,
test_labels,
index_to_example_function=index_to_example,
file_name="movie-reviews")
# Log Model
model.save('models/movie-reviews-nn.h5')
experiment.log_model('movie-reviews-nn', 'models/movie-reviews-nn.h5')