/
stats.py
537 lines (475 loc) · 21.7 KB
/
stats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
#!/usr/bin python3
""" Stats functions for the GUI """
import logging
import time
import os
import warnings
from math import ceil, sqrt
import numpy as np
import tensorflow as tf
from tensorflow.python import errors_impl as tf_errors # pylint:disable=no-name-in-module
from lib.serializer import get_serializer
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def convert_time(timestamp):
""" Convert time stamp to total hours, minutes and seconds """
hrs = int(timestamp // 3600)
if hrs < 10:
hrs = "{0:02d}".format(hrs)
mins = "{0:02d}".format((int(timestamp % 3600) // 60))
secs = "{0:02d}".format((int(timestamp % 3600) % 60))
return hrs, mins, secs
class TensorBoardLogs():
""" Parse and return data from TensorBoard logs """
def __init__(self, logs_folder):
self.folder_base = logs_folder
self.log_filenames = self.set_log_filenames()
def set_log_filenames(self):
""" Set the TensorBoard log filenames for all existing sessions """
logger.debug("Loading log filenames. base_dir: '%s'", self.folder_base)
log_filenames = dict()
for dirpath, _, filenames in os.walk(self.folder_base):
if not any(filename.startswith("events.out.tfevents") for filename in filenames):
continue
logfiles = [filename for filename in filenames
if filename.startswith("events.out.tfevents")]
# Take the last logfile, in case of previous crash
logfile = os.path.join(dirpath, sorted(logfiles)[-1])
side, session = os.path.split(dirpath)
side = os.path.split(side)[1]
session = int(session[session.rfind("_") + 1:])
log_filenames.setdefault(session, dict())[side] = logfile
logger.debug("logfiles: %s", log_filenames)
return log_filenames
def get_loss(self, side=None, session=None):
""" Read the loss from the TensorBoard logs
Specify a side or a session or leave at None for all
"""
logger.debug("Getting loss: (side: %s, session: %s)", side, session)
all_loss = dict()
for sess, sides in self.log_filenames.items():
if session is not None and sess != session:
logger.debug("Skipping session: %s", sess)
continue
loss = dict()
for sde, logfile in sides.items():
if side is not None and sde != side:
logger.debug("Skipping side: %s", sde)
continue
for event in tf.train.summary_iterator(logfile):
for summary in event.summary.value:
if "loss" not in summary.tag:
continue
tag = summary.tag.replace("batch_", "")
loss.setdefault(tag,
dict()).setdefault(sde,
list()).append(summary.simple_value)
all_loss[sess] = loss
return all_loss
def get_timestamps(self, session=None):
""" Read the timestamps from the TensorBoard logs
Specify a session or leave at None for all
NB: For all intents and purposes timestamps are the same for
both sides, so just read from one side """
logger.debug("Getting timestamps")
all_timestamps = dict()
for sess, sides in self.log_filenames.items():
if session is not None and sess != session:
logger.debug("Skipping sessions: %s", sess)
continue
try:
for logfile in sides.values():
timestamps = [event.wall_time
for event in tf.train.summary_iterator(logfile)
if event.summary.value]
logger.debug("Total timestamps for session %s: %s", sess, len(timestamps))
all_timestamps[sess] = timestamps
break # break after first file read
except tf_errors.DataLossError as err:
logger.warning("The logs for Session %s are corrupted and cannot be displayed. "
"The totals do not include this session. Original error message: "
"'%s'", sess, str(err))
return all_timestamps
class Session():
""" The Loaded or current training session """
def __init__(self, model_dir=None, model_name=None):
logger.debug("Initializing %s: (model_dir: %s, model_name: %s)",
self.__class__.__name__, model_dir, model_name)
self.serializer = get_serializer("json")
self.state = None
self.modeldir = model_dir # Set and reset by wrapper for training sessions
self.modelname = model_name # Set and reset by wrapper for training sessions
self.tb_logs = None
self.initialized = False
self.session_id = None # Set to specific session_id or current training session
self.summary = SessionsSummary(self)
logger.debug("Initialized %s", self.__class__.__name__)
@property
def batchsize(self):
""" Return the session batchsize """
return self.session["batchsize"]
@property
def config(self):
""" Return config and other information """
retval = self.state["config"].copy()
retval["training_size"] = self.state["training_size"]
retval["input_size"] = [val[0] for key, val in self.state["inputs"].items()
if key.startswith("face")][0]
return retval
@property
def full_summary(self):
""" Retun all sessions summary data"""
return self.summary.compile_stats()
@property
def iterations(self):
""" Return session iterations """
return self.session["iterations"]
@property
def logging_disabled(self):
""" Return whether logging is disabled for this session """
return self.session["no_logs"] or self.session["pingpong"]
@property
def loss(self):
""" Return loss from logs for current session """
loss_dict = self.tb_logs.get_loss(session=self.session_id)[self.session_id]
return loss_dict
@property
def loss_keys(self):
""" Return list of unique session loss keys """
if self.session_id is None:
loss_keys = self.total_loss_keys
else:
loss_keys = set(loss_key for side_keys in self.session["loss_names"].values()
for loss_key in side_keys)
return list(loss_keys)
@property
def lowest_loss(self):
""" Return the lowest average loss per save iteration seen """
return self.state["lowest_avg_loss"]
@property
def session(self):
""" Return current session dictionary """
return self.state["sessions"].get(str(self.session_id), dict())
@property
def session_ids(self):
""" Return sorted list of all existing session ids in the state file """
return sorted([int(key) for key in self.state["sessions"].keys()])
@property
def timestamps(self):
""" Return timestamps from logs for current session """
ts_dict = self.tb_logs.get_timestamps(session=self.session_id)
return ts_dict[self.session_id]
@property
def total_batchsize(self):
""" Return all session batch sizes """
return {int(sess_id): sess["batchsize"]
for sess_id, sess in self.state["sessions"].items()}
@property
def total_iterations(self):
""" Return session iterations """
return self.state["iterations"]
@property
def total_loss(self):
""" Return collated loss for all session """
loss_dict = dict()
all_loss = self.tb_logs.get_loss()
for key in sorted(int(idx) for idx in all_loss):
for loss_key, side_loss in all_loss[key].items():
for side, loss in side_loss.items():
loss_dict.setdefault(loss_key, dict()).setdefault(side, list()).extend(loss)
return loss_dict
@property
def total_loss_keys(self):
""" Return list of unique session loss keys across all sessions """
loss_keys = set(loss_key
for session in self.state["sessions"].values()
for loss_keys in session["loss_names"].values()
for loss_key in loss_keys)
return list(loss_keys)
@property
def total_timestamps(self):
""" Return timestamps from logs seperated per session for all sessions """
return self.tb_logs.get_timestamps()
def initialize_session(self, is_training=False, session_id=None):
""" Initialize the training session """
logger.debug("Initializing session: (is_training: %s, session_id: %s)",
is_training, session_id)
self.load_state_file()
self.tb_logs = TensorBoardLogs(os.path.join(self.modeldir,
"{}_logs".format(self.modelname)))
if is_training:
self.session_id = max(int(key) for key in self.state["sessions"].keys())
else:
self.session_id = session_id
self.initialized = True
logger.debug("Initialized session. Session_ID: %s", self.session_id)
def load_state_file(self):
""" Load the current state file """
state_file = os.path.join(self.modeldir, "{}_state.json".format(self.modelname))
logger.debug("Loading State: '%s'", state_file)
self.state = self.serializer.load(state_file)
logger.debug("Loaded state: %s", self.state)
def get_iterations_for_session(self, session_id):
""" Return the number of iterations for the given session id """
session = self.state["sessions"].get(str(session_id), None)
if session is None:
logger.warning("No session data found for session id: %s", session_id)
return 0
return session["iterations"]
class SessionsSummary():
""" Calculations for analysis summary stats """
def __init__(self, session):
logger.debug("Initializing %s: (session: %s)", self.__class__.__name__, session)
self.session = session
logger.debug("Initialized %s", self.__class__.__name__)
@property
def time_stats(self):
""" Return session time stats """
ts_data = self.session.tb_logs.get_timestamps()
time_stats = {sess_id: {"start_time": min(timestamps) if timestamps else 0,
"end_time": max(timestamps) if timestamps else 0,
"datapoints": len(timestamps) if timestamps else 0}
for sess_id, timestamps in ts_data.items()}
return time_stats
@property
def sessions_stats(self):
""" Return compiled stats """
compiled = list()
for sess_idx, ts_data in self.time_stats.items():
logger.debug("Compiling session ID: %s", sess_idx)
if self.session.state is None:
logger.debug("Session state dict doesn't exist. Most likely task has been "
"terminated during compilation")
return None
iterations = self.session.get_iterations_for_session(sess_idx)
elapsed = ts_data["end_time"] - ts_data["start_time"]
batchsize = self.session.total_batchsize.get(sess_idx, 0)
compiled.append({"session": sess_idx,
"start": ts_data["start_time"],
"end": ts_data["end_time"],
"elapsed": elapsed,
"rate": (batchsize * iterations) / elapsed if elapsed != 0 else 0,
"batch": batchsize,
"iterations": iterations})
compiled = sorted(compiled, key=lambda k: k["session"])
return compiled
def compile_stats(self):
""" Compile sessions stats with totals, format and return """
logger.debug("Compiling sessions summary data")
compiled_stats = self.sessions_stats
if compiled_stats is None:
return compiled_stats
logger.debug("sessions_stats: %s", compiled_stats)
total_stats = self.total_stats(compiled_stats)
compiled_stats.append(total_stats)
compiled_stats = self.format_stats(compiled_stats)
logger.debug("Final stats: %s", compiled_stats)
return compiled_stats
@staticmethod
def total_stats(sessions_stats):
""" Return total stats """
logger.debug("Compiling Totals")
elapsed = 0
examples = 0
iterations = 0
batchset = set()
total_summaries = len(sessions_stats)
for idx, summary in enumerate(sessions_stats):
if idx == 0:
starttime = summary["start"]
if idx == total_summaries - 1:
endtime = summary["end"]
elapsed += summary["elapsed"]
examples += (summary["batch"] * summary["iterations"])
batchset.add(summary["batch"])
iterations += summary["iterations"]
batch = ",".join(str(bs) for bs in batchset)
totals = {"session": "Total",
"start": starttime,
"end": endtime,
"elapsed": elapsed,
"rate": examples / elapsed if elapsed != 0 else 0,
"batch": batch,
"iterations": iterations}
logger.debug(totals)
return totals
@staticmethod
def format_stats(compiled_stats):
""" Format for display """
logger.debug("Formatting stats")
for summary in compiled_stats:
hrs, mins, secs = convert_time(summary["elapsed"])
summary["start"] = time.strftime("%x %X", time.localtime(summary["start"]))
summary["end"] = time.strftime("%x %X", time.localtime(summary["end"]))
summary["elapsed"] = "{}:{}:{}".format(hrs, mins, secs)
summary["rate"] = "{0:.1f}".format(summary["rate"])
return compiled_stats
class Calculations():
""" Class to pull raw data for given session(s) and perform calculations """
def __init__(self, session, display="loss", loss_keys=["loss"], selections=["raw"],
avg_samples=500, smooth_amount=0.90, flatten_outliers=False, is_totals=False):
logger.debug("Initializing %s: (session: %s, display: %s, loss_keys: %s, selections: %s, "
"avg_samples: %s, smooth_amount: %s, flatten_outliers: %s, is_totals: %s",
self.__class__.__name__, session, display, loss_keys, selections, avg_samples,
smooth_amount, flatten_outliers, is_totals)
warnings.simplefilter("ignore", np.RankWarning)
self.session = session
self.display = display
self.loss_keys = loss_keys
self.selections = selections
self.is_totals = is_totals
self.args = {"avg_samples": avg_samples,
"smooth_amount": smooth_amount,
"flatten_outliers": flatten_outliers}
self.iterations = 0
self.stats = None
self.refresh()
logger.debug("Initialized %s", self.__class__.__name__)
def refresh(self):
""" Refresh the stats """
logger.debug("Refreshing")
if not self.session.initialized:
logger.warning("Session data is not initialized. Not refreshing")
return None
self.iterations = 0
self.stats = self.get_raw()
self.get_calculations()
self.remove_raw()
logger.debug("Refreshed")
return self
def get_raw(self):
""" Add raw data to stats dict """
logger.debug("Getting Raw Data")
raw = dict()
iterations = set()
if self.display.lower() == "loss":
loss_dict = self.session.total_loss if self.is_totals else self.session.loss
for loss_name, side_loss in loss_dict.items():
if loss_name not in self.loss_keys:
continue
for side, loss in side_loss.items():
if self.args["flatten_outliers"]:
loss = self.flatten_outliers(loss)
iterations.add(len(loss))
raw["raw_{}_{}".format(loss_name, side)] = loss
self.iterations = 0 if not iterations else min(iterations)
if len(iterations) > 1:
# Crop all losses to the same number of items
if self.iterations == 0:
raw = {lossname: list() for lossname in raw}
else:
raw = {lossname: loss[:self.iterations] for lossname, loss in raw.items()}
else: # Rate calulation
data = self.calc_rate_total() if self.is_totals else self.calc_rate()
if self.args["flatten_outliers"]:
data = self.flatten_outliers(data)
self.iterations = len(data)
raw = {"raw_rate": data}
logger.debug("Got Raw Data")
return raw
def remove_raw(self):
""" Remove raw values from stats if not requested """
if "raw" in self.selections:
return
logger.debug("Removing Raw Data from output")
for key in list(self.stats.keys()):
if key.startswith("raw"):
del self.stats[key]
logger.debug("Removed Raw Data from output")
def calc_rate(self):
""" Calculate rate per iteration """
logger.debug("Calculating rate")
batchsize = self.session.batchsize
timestamps = self.session.timestamps
iterations = range(len(timestamps) - 1)
rate = [batchsize / (timestamps[i + 1] - timestamps[i]) for i in iterations]
logger.debug("Calculated rate: Item_count: %s", len(rate))
return rate
def calc_rate_total(self):
""" Calculate rate per iteration
NB: For totals, gaps between sessions can be large
so time difference has to be reset for each session's
rate calculation """
logger.debug("Calculating totals rate")
batchsizes = self.session.total_batchsize
total_timestamps = self.session.total_timestamps
rate = list()
for sess_id in sorted(total_timestamps.keys()):
batchsize = batchsizes[sess_id]
timestamps = total_timestamps[sess_id]
iterations = range(len(timestamps) - 1)
print("===========\n")
print(timestamps[:100])
print([batchsize / (timestamps[i + 1] - timestamps[i]) for i in iterations][:100])
rate.extend([batchsize / (timestamps[i + 1] - timestamps[i]) for i in iterations])
logger.debug("Calculated totals rate: Item_count: %s", len(rate))
return rate
@staticmethod
def flatten_outliers(data):
""" Remove the outliers from a provided list """
logger.debug("Flattening outliers")
retdata = list()
samples = len(data)
mean = (sum(data) / samples)
limit = sqrt(sum([(item - mean)**2 for item in data]) / samples)
logger.debug("samples: %s, mean: %s, limit: %s", samples, mean, limit)
for idx, item in enumerate(data):
if (mean - limit) <= item <= (mean + limit):
retdata.append(item)
else:
logger.trace("Item idx: %s, value: %s flattened to %s", idx, item, mean)
retdata.append(mean)
logger.debug("Flattened outliers")
return retdata
def get_calculations(self):
""" Perform the required calculations """
for selection in self.selections:
if selection == "raw":
continue
logger.debug("Calculating: %s", selection)
method = getattr(self, "calc_{}".format(selection))
raw_keys = [key for key in self.stats.keys() if key.startswith("raw_")]
for key in raw_keys:
selected_key = "{}_{}".format(selection, key.replace("raw_", ""))
self.stats[selected_key] = method(self.stats[key])
def calc_avg(self, data):
""" Calculate rolling average """
logger.debug("Calculating Average")
avgs = list()
presample = ceil(self.args["avg_samples"] / 2)
postsample = self.args["avg_samples"] - presample
datapoints = len(data)
if datapoints <= (self.args["avg_samples"] * 2):
logger.info("Not enough data to compile rolling average")
return avgs
for idx in range(0, datapoints):
if idx < presample or idx >= datapoints - postsample:
avgs.append(None)
continue
avg = sum(data[idx - presample:idx + postsample]) / self.args["avg_samples"]
avgs.append(avg)
logger.debug("Calculated Average")
return avgs
def calc_smoothed(self, data):
""" Smooth the data """
last = data[0] # First value in the plot (first timestep)
weight = self.args["smooth_amount"]
smoothed = list()
for point in data:
smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value
smoothed.append(smoothed_val) # Save it
last = smoothed_val # Anchor the last smoothed value
return smoothed
@staticmethod
def calc_trend(data):
""" Compile trend data """
logger.debug("Calculating Trend")
points = len(data)
if points < 10:
dummy = [None for i in range(points)]
return dummy
x_range = range(points)
fit = np.polyfit(x_range, data, 3)
poly = np.poly1d(fit)
trend = poly(x_range)
logger.debug("Calculated Trend")
return trend