This repository has been archived by the owner on Jan 24, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
session.py
executable file
·281 lines (223 loc) · 7.8 KB
/
session.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
# -*- coding: utf-8 -*-
'''
Justin Chen
session.py
4.30.2018
Tool for examining session objects in the parameter server
'''
from __future__ import division
import redis, json, argparse, sys, os
from pprint import pprint
class ToolBox(object):
def __init__(self, args):
self.cache = redis.StrictRedis(host=args.host, port=args.port, db=args.database)
'''
'''
def keys(self):
pprint('keys: {}'.format([key for key in self.cache.scan_iter("*")]))
'''
'''
def get_object(self, var):
if not self.cache.exists(var):
return 'key dne: {}'.format(var)
else:
return json.loads(self.cache.get(var))
'''
Clear all logs
'''
def clear(self, log_dir):
for file in os.listdir(log_dir):
if file.endswith('.log'): os.remove(os.path.join(log_dir, file))
'''
Display session object memory size
'''
def size(self, sess):
print('{} (bytes)'.format(sys.getsizeof(sess)))
'''
Print files
Input: files (list) List of files
'''
def print_files(self, files, title='files'):
# Print incomplete hyperedges
div = '-'*(len(title)+1)
print '\n{}\n{}:\n{}'.format(div, title, div)
for i in files:
if isinstance(i,tuple):
print '{}\n⤷{}\n'.format(*i)
else:
print i
'''
Aggregate logs from peers
'''
def pull_logs(self):
ssh_key = '~/.ssh/smpl-0.pem'
with open('distributed/config/party.json', 'r') as party_conf:
party = json.load(party_conf)
for p in party:
file = 'ubuntu@{}:/home/ubuntu/smpl/logs/*.log'.format(p['host'])
os.system(' '.join(['scp', '-i', ssh_key, file, '/home/ubuntu/smpl/logs/']))
'''
Return all log files
Input: log_dir (string) Path to directory containing logs
Output: logs (list) List of all log files
'''
def get_logs(self, log_dir):
return [os.path.join(log_dir, file) for file in os.listdir(log_dir) if file.endswith('.log')]
'''
Check if log is valid
Input: log (string) Absolute path to log file
Output: (bool) True if a valid log and False otherwise
'''
def is_valid_log(self, log):
first = ''
with open(log, 'rb') as f:
first = f.readline()
return 'train_hyperedge' in first or 'api:establish_session' in first
'''
Find all files containing given term
Input: term (string) Search term
paths (list) List of file paths
case (bool) True if case-sensitive
debug (bool) Set to True to grab last function call in logs
Output: count (int) Number of files containing term
files (list) List of file paths
'''
def grep_all(self, term, paths, case=False, debug=False):
files = {'match': [], 'mismatch': []}
count = 0
for log in paths:
with open(log, 'rb') as f:
log = log.split('/')[-1]
content = f.read()
match = False
if case:
match = term.lower() in content.lower()
else:
match = term in content
if match:
count += 1
files['match'].append(log)
else:
if debug:
# Get last function call in file
f.seek(-2, os.SEEK_END)
try:
while f.read(1) != b"\n":
f.seek(-2, os.SEEK_CUR)
last = [s for s in f.readline().split(' ') if len(s) > 0]
last.pop(1)
last = ' '.join(last[:2])
files['mismatch'].append((log, last))
except IOError as e:
print 'IOError: {}'.format(log)
else:
files['mismatch'].append(log)
return count, files
'''
Check if all the sessions completed training
Input: log_dir (string) Path to log directory
pull (bool) True if should pull logs
'''
def check_logs(self, log_dir, pull=False):
if pull:
self.pull_logs()
while len(os.listdir(log_dir)) == 0:
sleep(0.5)
all_logs = []
ps_logs = []
for l in self.get_logs(log_dir):
if 'ps' not in l:
if self.is_valid_log(l): all_logs.append(l)
else: os.remove(l)
else:
ps_logs.append(l)
total = len(all_logs)
if total > 0:
complete, files = self.grep_all('hyperedge complete', all_logs, debug=True)
self.print_files(files['mismatch'], 'incomplete hyperedges')
self.print_files(files['match'], 'completed')
print('completed hyperedges: {}/{} ({}%)'.format(complete, total, 100*complete/total))
total = len(ps_logs)
if total > 0:
complete, files = self.grep_all('hypergraph complete', ps_logs)
self.print_files(files['mismatch'], 'incomplete peers')
print('completed training: {}/{} ({}%)'.format(complete, total, 100*complete/total))
self.print_files(self.get_edges(), 'variables')
else:
print('no logs. rerun experiment.')
def get_edges(self):
return ['origin edges: {}'.format(self.get_object('origin_edges')),
'current edges: {}'.format(self.get_object('curr_edges')),
'hyperedge count: {}'.ljust(20).format(self.get_object('hyperedges'))]
'''
Query session objects
'''
def query(self, sess, args):
result = sess
if args.property != None:
result = 'key:{}, value: {}'.format(args.property, sess[args.property])
elif args.properties:
result = 'properties: {}'.format(list(sess.keys()))
elif args.ignore != None:
for k in args.ignore:
if k in sess:
del sess[k]
result = sess
elif args.minimal:
for k in ['parameters', 'gradients', 'multistep']:
if k in sess:
del sess[k]
result = sess
pprint(result)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--host', type=str, default='localhost', help='Redis host (default: localhost)')
parser.add_argument('--port', type=int, default=6379, help='Redis port (default: 6379)')
parser.add_argument('--case', '-c', action='store_true', help='Set for case sensitive matching when using the --grep option')
parser.add_argument('--check', '-ch', action='store_true', help='Check that all hyperedges completed training')
parser.add_argument('--clear', action='store_true', help='Clear all logs')
parser.add_argument('--database', '-db', type=int, default=0, help='Redis db')
parser.add_argument('--edges', '-e', action='store_true', help='Display edge count')
parser.add_argument('--grep', '-g', type=str, help='Grep all files for given term')
parser.add_argument('--ignore', '-i', type=str, nargs='+', help='Ignores a particular key/value in the session object')
parser.add_argument('--keys', '-k', action='store_true', help='Get all Redis keys')
parser.add_argument('--log_dir', '-l', type=str, default=os.path.join(os.getcwd(), 'logs'), help='Log directory')
parser.add_argument('--minimal', '-m', action='store_true', help='Ignore parameters and gradients')
parser.add_argument('--sess', '-s', type=str, help='Session objection id')
parser.add_argument('--size', '-z', action='store_true', help='Get size of cache object')
parser.add_argument('--property', '-p', type=str, help='Session object property')
parser.add_argument('--properties', '-ps', action='store_true', help='Get all properties of object')
parser.add_argument('--pull', '-pl', action='store_true', help='Pull logs from all peers')
parser.add_argument('--variable', '-v', type=str, help='Retrieve state variable. If using this, do not set --sess')
args = parser.parse_args()
tb = ToolBox(args)
# Clear logs
if args.clear:
tb.clear(args.log_dir)
if args.check:
tb.check_logs(args.log_dir, pull=args.pull)
if args.edges:
tb.print_files(tb.get_edges(), 'variables')
if args.grep != None:
all_logs = tb.get_logs(args.log_dir)
total = len(all_logs)
count, files = tb.grep_all(args.grep, all_logs, case=args.case)
tb.print_files(files['mismatch'], 'mismatch')
tb.print_files(files['match'], 'matches')
print('match rate: {}/{} ({}%)'.format(count, total, 100*count/total))
# Display all available keys
if args.keys:
tb.keys()
sess = ''
if args.sess != None:
sess = tb.get_object(args.sess)
if len(sess) > 0:
# Return size of session object
if args.size:
tb.size(sess)
# Query session objects and variables
tb.query(sess, args)
elif args.variable != None:
print tb.get_object(args.variable)
if __name__ == '__main__':
main()