In [1]:
# Load some packages
import os
import glob
import json
import datetime
from copy import deepcopy

import matplotlib.pyplot as plt
import pprint
from IPython.display import clear_output
from tqdm.auto import tqdm 

import numpy as np
import random
import torch

from torch.utils.tensorboard import SummaryWriter

In [3]:
root_folder = 'checkpoint/*'

with SummaryWriter(log_dir='runs/hparam') as w:
    for folder in glob.glob(root_folder):
        if os.path.isdir(folder):
            exp_name = folder[folder.find('\\') + 1:]
            
            for log in glob.glob(os.path.join(folder, '*')):
                if log.endswith('log'):
                    log_dict = torch.load(log)
                    
                    if 'awgn' in exp_name:
                        awgn = exp_name[exp_name.find('awgn') + 4:]
                        awgn = awgn if awgn.find('_') == -1 else awgn[:awgn.find('_')]
                        awgn = float(awgn)
                    else:
                        awgn = 0.0
                        
                    if 'wd' in exp_name and 'weight_decay' not in log_dict:
                        weight_decay = exp_name[exp_name.find('wd') + 2:]
                        weight_decay = weight_decay if weight_decay.find('_') == -1 else weight_decay[:weight_decay.find('_')]
                        weight_decay = float(weight_decay)
                    else:
                        weight_decay = log_dict.get('weight_decay', 1e-4)
                        
                    if 'avg' in exp_name:
                        final_pool = 'avg'
                    else:
                        final_pool = 'max'

                    hparam_dict = {
                        'model': log_dict['model'],
                        'num_params': log_dict.get('num_params', 0),
                        'final_pool': final_pool,
                        'starting_lr': log_dict['starting_lr'],
                        'weight_decay': weight_decay,
                        'awgn': awgn,
                        'steps': len(log_dict['losses']) if len(log_dict['losses']) % 100 == 0 else len(log_dict['losses']) + 1,
                    }
                    metric_dict = {
                        'final_loss': log_dict['losses'][-1],
                        'train_accuracy': log_dict['train_acc_history'][-1],
                        'validation_accuracy': log_dict['val_acc_history'][-1],
                        'test_accuracy': max(log_dict['best_test_accuracy'], log_dict['last_test_accuracy']),
                    }
                    
#                     if '1D-ResNet' not in log_dict['model']:
#                         continue

                    w.add_hparams(hparam_dict, metric_dict)

{'model': '1D-ResNet-50', 'num_params': 26291075, 'final_pool': 'max', 'starting_lr': 0.002486974865972968, 'weight_decay': 0.01, 'awgn': 0.5, 'steps': 100000} {'final_loss': 0.020050985738635063, 'train_accuracy': 99.25, 'validation_accuracy': 60.76923076923077, 'test_accuracy': 58.782051282051285}
{'model': '1D-ResNet', 'num_params': 16729219, 'final_pool': 'max', 'starting_lr': 0.000733633660232979, 'weight_decay': 0.01, 'awgn': 0.5, 'steps': 100000} {'final_loss': 0.0006636945181526244, 'train_accuracy': 99.5625, 'validation_accuracy': 59.03846153846154, 'test_accuracy': 64.2948717948718}
{'model': 'M5', 'num_params': 8411651, 'final_pool': 'max', 'starting_lr': 0.0003704642991107269, 'weight_decay': 0.01, 'awgn': 0.5, 'steps': 100000} {'final_loss': 0.0008513953071087599, 'train_accuracy': 99.703125, 'validation_accuracy': 61.92307692307692, 'test_accuracy': 63.46153846153846}
