-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlogging.py
126 lines (104 loc) · 4.39 KB
/
logging.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import tensorflow as tf
import os
import boto3
import datetime
import json
from util.config import CONFIG
s3 = boto3.client('s3')
db = boto3.resource('dynamodb')
def get_run_name(conf):
return '_'.join([
conf['dataset']['name'],
conf['model']['name'],
conf['loss']['name'],
])
def set_tensorboard_writer(conf, experiment_name):
if not experiment_name:
run_name = get_run_name(conf)
else:
run_name = experiment_name
local_tensorboard_dir = os.path.join(CONFIG['tensorboard']['local_dir'], 'tensorboard')
if not tf.gfile.Exists(local_tensorboard_dir):
tf.gfile.MakeDirs(local_tensorboard_dir)
dt = datetime.datetime.utcnow() + datetime.timedelta(hours=9)
run_dir = '{}_{}'.format(run_name, dt.strftime('%Y%m%d%H%M%S-%f'))
if CONFIG['tensorboard'].getboolean('s3_upload'):
s3.put_object(
Bucket=CONFIG['tensorboard']['s3_bucket'],
Body='',
Key='{}/tensorboard/{}/'.format(CONFIG['tensorboard']['s3_key'], run_dir)
)
print('Starting {}'.format(run_dir))
writer = tf.contrib.summary.create_file_writer(
os.path.join(local_tensorboard_dir, run_dir),
flush_millis=10000)
return writer, run_dir
def upload_tensorboard_log_to_s3(run_name):
run_dir = os.path.join(CONFIG['tensorboard']['local_dir'], 'tensorboard', run_name)
for filename in os.listdir(run_dir):
s3.upload_file(
os.path.join(run_dir, filename),
CONFIG['tensorboard']['s3_bucket'],
'{}/tensorboard/{}/{}'.format(CONFIG['tensorboard']['s3_key'], run_name, filename))
def save_config(conf, run_name, experiment_name):
if CONFIG['tensorboard'].getboolean('s3_upload'):
upload_string_to_s3(
bucket=CONFIG['tensorboard']['s3_bucket'],
body=json.dumps(conf, indent=4),
key='{}/experiments/{}/config.json'.format(CONFIG['tensorboard']['s3_key'], run_name)
)
else:
config_dir = os.path.join(CONFIG['tensorboard']['local_dir'], 'experiments', run_name)
if not tf.gfile.Exists(config_dir):
tf.gfile.MakeDirs(config_dir)
with open(os.path.join(config_dir, 'config.json'), 'w') as f:
json.dump(conf, f, indent=4)
if CONFIG['tensorboard'].getboolean('dynamodb_upload'):
dt = datetime.datetime.utcnow() + datetime.timedelta(hours=9)
data = {
'id': run_name,
'experiment_name': experiment_name,
'config': json.dumps(conf),
'timestamp': dt.strftime('%Y-%m-%d %H:%M:%S'),
}
for key, conf_dict in conf.items():
if key in ['image', 'metrics']:
continue
for name, value in conf_dict.items():
data['{}:{}'.format(key, name)] = str(value)
table = db.Table('Experiment')
table.put_item(Item=data)
def upload_file_to_s3(file_path, bucket, key):
s3.upload_file(file_path, bucket, key)
def upload_string_to_s3(body, bucket, key):
s3.put_object(Bucket=bucket, Body=body, Key=key)
def create_checkpoint(checkpoint, run_name, s3_upload):
prefix = '{}/experiments/{}/checkpoints/ckpt'.format(
CONFIG['tensorboard']['local_dir'],
run_name)
if not tf.gfile.Exists(os.path.dirname(prefix)):
tf.gfile.MakeDirs(os.path.dirname(prefix))
checkpoint.save(file_prefix=prefix)
g = open(os.path.join(os.path.dirname(prefix), 'checkpoint_compat'), 'w')
with open(os.path.join(os.path.dirname(prefix), 'checkpoint'), 'r') as f:
for line in f:
print(line.rstrip('\n').replace(os.path.dirname(prefix) + '/', ''), file=g)
g.close()
if s3_upload:
for root, dirnames, filenames in os.walk(os.path.dirname(prefix)):
for filename in filenames:
if filename == 'checkpoint':
continue
dest_filename = filename
if filename == 'checkpoint_compat':
dest_filename = 'checkpoint'
s3.upload_file(
os.path.join(root, filename),
CONFIG['tensorboard']['s3_bucket'],
'{}/experiments/{}/checkpoints/{}'.format(
CONFIG['tensorboard']['s3_key'],
run_name,
dest_filename
)
)
os.remove(os.path.join(root, filename))