In [12]:
import json
import os
import collections.abc
import csv
from tabulate import tabulate

def get_metrics(res_dir):
  with open(f"{res_dir}/trainlog.csv") as f:
    reader = csv.reader(f)
    keys =next(reader)
    metrics = {k:[] for k in keys}
    
    for row in reader:
      for k,v in zip(keys, row):
        metrics[k].append(float(v))
  return metrics

def get_accs_at_epoch(m, epoch, test_acc=None):
  accs = {}
  for k in ['accuracy', 'val_accuracy', 'test_accuracy']:
    if k in m:
      accs[k] = round(m[k][epoch] * 100, 1)
  if test_acc is not None and 'test_accuracy' not in accs:
    accs['test_accuracy'] = round(test_acc * 100, 1)
  
  return accs

def get_results(week):
  results = []
  for name in os.listdir(f"{week}"):
    if name.startswith("-"):
      continue
    for seed in os.listdir(f"{week}/{name}"):
      if seed.startswith("-"):
        continue
      item = {
        "name": name,
        'week': week,
        'seed': seed[4:]
      }

      
      if os.path.exists(f"{week}/{name}/{seed}/config.json"):
        with open(f"{week}/{name}/{seed}/config.json") as f:
          config = json.load(f)

        item.update(config)
        
      
      if os.path.exists(f"{week}/{name}/{seed}/overview.json"):
        with open(f"{week}/{name}/{seed}/overview.json") as f:
          overview = json.load(f)
        
        test_accuracy = overview.get("test_accuracy", None)
        if isinstance(test_accuracy, collections.abc.Sequence):
          test_accuracy = test_accuracy[0]

        item.update({"test_accuracy": test_accuracy, 'train_time': round(overview.get("train_time", 0) / (60*60), 2)})
      
      if os.path.exists(f"{week}/{name}/{seed}/trainlog.csv"):
        train_log = get_metrics(f"{week}/{name}/{seed}")
        max_val = train_log['val_accuracy'].index(max(train_log['val_accuracy']))
        accs = get_accs_at_epoch(train_log, max_val, item.get('test_accuracy', None))
        item.update(accs)
        item.update({"best_val_epoch": max_val + 1})
      else: 
        continue

      results.append(item)
  return results

def print_table(data, keys):
  data = [[r.get(k, '') for k in keys] for r in data]
  print(tabulate(data, headers=keys))

keys = ['name', 'week', 'seed', 'accuracy', 'val_accuracy',  'test_accuracy', 'best_val_epoch', 'G_epoch', 'hidden_size', 'batch_size', 'learning_rate', 'dropout', 'train_time',]

def print_week(week):
  results = get_results(week)
  print_table(results, keys)

all_keys = ['name', 'runs', 'accuracy', 'val_accuracy',  'test_accuracy', 'best_val_epoch', 'G_epoch', 'hidden_size', 'batch_size', 'learning_rate', 'dropout', 'train_time',]
average_keys = ['accuracy', 'val_accuracy',  'test_accuracy', 'best_val_epoch']
sum_keys = ['train_time']

def print_week_average(week):
  results = get_results(week)
  names = set([r['name'] for r in results])
  aggregates = []
  for name in names:
      seeds = [r for r in results if r['name'] == name]
      item = seeds[0]
      item['runs'] = len(seeds)
      for k in average_keys:
        vals = [s[k] for s in seeds if k in s]
        if(len(vals) > 0):
          item[k] = round(sum([r for r in vals]) / len(vals),1)
      for k in sum_keys:
        vals = [s[k] for s in seeds if k in s]
        if(len(vals) > 0):
          item[k] = round(sum([r for r in vals]),1)
          
      aggregates.append(item)
  print_table(aggregates, all_keys)

In [33]:

#print("\t".join(keys))
#print_week("week47")
#print_week("week48")
#print_week("week49")
print_week("week50")

name                                 week      seed    accuracy    val_accuracy  test_accuracy      best_val_epoch    G_epoch    hidden_size    batch_size    learning_rate    dropout  train_time
-----------------------------------  ------  ------  ----------  --------------  ---------------  ----------------  ---------  -------------  ------------  ---------------  ---------  ------------
G-LIGHT-GRU128-G1-LR7-B256-REG       week50      89        82.1            80.7  84.1                          998          1            128           256            1e-07        0    7.15
G-LIGHT-GRU128-G1-LR7-B256-REG       week50     196        83.1            81.6  79.4                          998          1            128           256            1e-07        0    7.59
G-LIGHT-GRU128-G1-LR9-B32-REG-DROP   week50     196        88.1            41.9  44.3                         1000          1            128            32            1e-09        0.5  13.73
G-LIGHT-GRU128-G1-LR9-B32-REG-DROP   wee

In [29]:
print_week_average("week49")
print_week_average("week50")

name                                   runs    accuracy    val_accuracy    test_accuracy    best_val_epoch    G_epoch    hidden_size    batch_size    learning_rate    dropout    train_time
-----------------------------------  ------  ----------  --------------  ---------------  ----------------  ---------  -------------  ------------  ---------------  ---------  ------------
G-LIGHT-GRU128-G1-LR3-B32-REG-DROP        1        98.6            22.1             30.9             169            1            128            32           0.001         0.5           3.6
G-LIGHT-GRU128-G1-LR4-B256-REG            2        91.5            87               86.6             353            1            128           256           0.0001        0             7.6
G-LIGHT-GRU128-G1-LR6-B256-REG            2        84.7            82.7             85.6             744.5          1            128           256           1e-06         0             7.5
G-LIGHT-GRU128-G1-LR5-B256-REG            2        85.4