In [1]:
import json
import os
import collections.abc
import csv
from tabulate import tabulate

def get_metrics(res_dir):
  with open(f"{res_dir}/trainlog.csv") as f:
    reader = csv.reader(f)
    keys =next(reader)
    metrics = {k:[] for k in keys}
    
    for row in reader:
      for k,v in zip(keys, row):
        metrics[k].append(float(v))
  return metrics

def get_accs_at_epoch(m, epoch, test_acc=None):
  accs = {}
  for k in ['accuracy', 'val_accuracy', 'test_accuracy']:
    if k in m:
      accs[k] = round(m[k][epoch] * 100, 1)
  if test_acc is not None and 'test_accuracy' not in accs:
    accs['test_accuracy'] = round(test_acc * 100, 1)
  
  return accs

def get_results(week):
  results = []
  for name in os.listdir(f"{week}"):
    if name.startswith("-"):
      continue
    for seed in os.listdir(f"{week}/{name}"):
      if seed.startswith("-"):
        continue
      item = {
        "name": name,
        'week': week,
        'seed': seed[4:]
      }

      
      if os.path.exists(f"{week}/{name}/{seed}/config.json"):
        with open(f"{week}/{name}/{seed}/config.json") as f:
          config = json.load(f)

        item.update(config)
        
      
      if os.path.exists(f"{week}/{name}/{seed}/overview.json"):
        with open(f"{week}/{name}/{seed}/overview.json") as f:
          overview = json.load(f)
        
        test_accuracy = overview.get("test_accuracy", None)
        if isinstance(test_accuracy, collections.abc.Sequence):
          test_accuracy = test_accuracy[0]

        item.update({"test_accuracy": test_accuracy, 'train_time': round(overview.get("train_time", 0) / (60*60), 2)})
      
      if os.path.exists(f"{week}/{name}/{seed}/trainlog.csv"):
        train_log = get_metrics(f"{week}/{name}/{seed}")
        max_val = train_log['val_accuracy'].index(max(train_log['val_accuracy']))
        accs = get_accs_at_epoch(train_log, max_val, item.get('test_accuracy', None))
        item.update(accs)
        item.update({"best_val_epoch": max_val + 1})
      else: 
        continue

      results.append(item)
  return results

def print_table(data, keys):
  data = [[r.get(k, '') for k in keys] for r in data]
  print(tabulate(data, headers=keys))

keys = ['name', 'week', 'seed', 'accuracy', 'val_accuracy',  'test_accuracy', 'best_val_epoch', 'G_epoch', 'hidden_size', 'batch_size', 'learning_rate', 'dropout', 'train_time',]

def print_week(week):
  results = get_results(week)
  print_table(results, keys)

all_keys = ['name', 'runs', 'accuracy', 'val_accuracy',  'test_accuracy', 'best_val_epoch', 'G_epoch', 'hidden_size', 'batch_size', 'learning_rate', 'dropout', 'train_time',]
average_keys = ['accuracy', 'val_accuracy',  'test_accuracy', 'best_val_epoch']

def print_week_average(week):
  results = get_results(week)
  names = set([r['name'] for r in results])
  aggregates = []
  for name in names:
      seeds = [r for r in results if r['name'] == name]
      item = seeds[0]
      item['runs'] = len(seeds)
      for k in average_keys:
        vals = [s[k] for s in seeds if k in s]
        item[k] = round(sum([r for r in vals]) / len(vals),1)
      aggregates.append(item)
  print_table(aggregates, all_keys)

In [2]:

#print("\t".join(keys))
#print_week("week47")
#print_week("week48")
print_week("week49")

name                          week      seed    accuracy    val_accuracy  test_accuracy      best_val_epoch    G_epoch    hidden_size    batch_size    learning_rate    dropout  train_time
----------------------------  ------  ------  ----------  --------------  ---------------  ----------------  ---------  -------------  ------------  ---------------  ---------  ------------
GRU128-G1-LR8-BS32            week49      89        85              63.9  71.0                           99          1            128            32            1e-08        0    20.91
GRU128-G1-LR8-BS32            week49      23        87.9            76.5                                692          1            128            32            1e-08        0
GRU128-G1-LR8-BS32            week49     196        86.9            72.1                                840          1            128            32            1e-08        0
GRU128-G1-LR3-BS256-REG-DROP  week49      89        98.3            87.2  85.5             

In [4]:
print_week_average("week49")

name                            runs    accuracy    val_accuracy    test_accuracy    best_val_epoch    G_epoch    hidden_size    batch_size    learning_rate    dropout    train_time
----------------------------  ------  ----------  --------------  ---------------  ----------------  ---------  -------------  ------------  ---------------  ---------  ------------
GRU128-G1-LR8-BS32                 3        86.6            70.8             71               544.7          1            128            32            1e-08        0           20.91
GRU128-G1-LR3-BS256-REG-DROP       2        89.4            84.8             81.2             229.5          1            128           256            0.001        0.5         13.44
