In [2]:
## Load numbers


import os
import numpy as np
import pandas as pd


# root = '/home/kopi/kinit/table1/'
root = '/home/kopi/kinit/overshoot/lightning_logs2/table1'


task_name_mapping = {
    "mlp_housing": "MLP-CA",
    "vae_f-mnist": "VAE-FM",
    "vae_mnist": "VAE-M",
    "2c2d_f-mnist": "2c2d-FM",
    "3c3d_cifar10": "3c3d-C10",
    "resnet18_cifar100": "ResNet-C100",
    "gpt_hf_qqp": "GPT-2",
}



def average_convergance(path):
    dfs = [pd.read_csv(os.path.join(path, seed, 'training_stats.csv')) for seed in os.listdir(path)]
    return np.mean([df['base_loss_1'] for df in dfs], axis=0)

def process_task(path):
    results = {}
    for method_name in os.listdir(path):
        method_path = os.path.join(path, method_name)
        if os.path.isdir(method_path):
            results[method_name] =  average_convergance(method_path)
    return results



tasks = {}
for task_name in os.listdir(root):
    if task_name not in task_name_mapping.keys():
        continue
    task_path = os.path.join(root, task_name)
    print("Processing", task_name)
    tasks[task_name] = process_task(task_path)

Processing resnet18_cifar100


KeyboardInterrupt: 

In [109]:
task_running_avg = {}
avg_size = 200
for task_name, task in tasks.items():
    task_avg = {}
    for method_name, values in task.items():
        task_avg[method_name] = np.array([np.mean(values[max(0, i-avg_size):i]) for i in range(avg_size, len(values))])
    task_running_avg[task_name] = task_avg

In [151]:

import matplotlib.pyplot as plt

save = []
for task_name, task in task_running_avg.items():
    try:
        print("================")
        loss_sgd_tr_1 = task['sgd_baseline'][0] - 0.95 * (task['sgd_baseline'][0] - task['sgd_baseline'][-1])
        loss_adam_tr_1 = task['adam_baseline'][0] - 0.95 * (task['adam_baseline'][0] - task['adam_baseline'][-1])
        
        sgd_steps = np.where(task['sgd_baseline'] <= loss_sgd_tr_1)[0][0]
        print("SGD", sgd_steps)
        print("Nesterov", np.where(task['nesterov'] <= loss_sgd_tr_1)[0][0])
        for c in [3, 5, 7]:
            steps = np.where(task[f'sgd_overshoot_{c}'] <= loss_sgd_tr_1)[0][0]
            print(f'Overshoot {c}:', steps, 100 * steps / sgd_steps)
            # if c == 7:
            #     save.append(100 - 100 * steps / sgd_steps)
            
            
        print("==")
        adam_steps = np.where(task['adam_baseline'] <= loss_adam_tr_1)[0][0]
        print("Adam", adam_steps)
        print("Nadam", np.where(task['nadam'] <= loss_adam_tr_1)[0][0])
        for c in [3, 5, 7]:
            steps = np.where(task[f'adam_overshoot_{c}'] <= loss_adam_tr_1)[0][0]
            print(f'Overshoot {c}:', steps, 100 * steps / adam_steps)
            if c == 7:
                save.append(100 - 100 * steps / adam_steps)
            

        # plt.plot(task['sgd_baseline'], label='sgd')
        # plt.plot(task['sgd_overshoot_7'], label='overshoot-7')
        # plt.yscale('log')
        # plt.legend()
        # break
    except:
        pass

print(np.mean(save))



# print("##############")
# print("##############")
# for task_name, task in tasks.items():
#     print("=================")
#     print(task_name)
#     for method_name, values in task.items():
#         if 'sgd' in method_name or 'nesterov' in method_name:
#             print(method_name, values.mean())
            
#     print("===")
#     for method_name, values in task.items():
#         if not('sgd' in method_name or 'nesterov' in method_name):
#             print(method_name, values.mean())


# import torch
# x = torch.Tensor([[1, 2], [2, 3]])
# print(len(x))


# LC: M,M,A


# adam: LC + M-LC + SQRT+LC

# over: LC + M-LC + SQRT+LC + LC






SGD 12492
Nesterov 12457
Overshoot 3: 11890 95.1809157861031
Overshoot 5: 11890 95.1809157861031
Overshoot 7: 11892 95.1969260326609
==
Adam 27014
Nadam 26742
Overshoot 3: 25209 93.31827941067594
Overshoot 5: 23104 85.52602354334789
Overshoot 7: 23087 85.46309321092767
SGD 13491
Nesterov 12796
Overshoot 3: 11511 85.32354903268846
Overshoot 5: 10892 80.73530501816026
Overshoot 7: 10943 81.11333481580313
==
Adam 8352
Nadam 7464
Overshoot 3: 6605 79.08285440613027
Overshoot 5: 5905 70.70162835249042
Overshoot 7: 6051 72.44971264367815
SGD 20721
Nesterov 12939
Overshoot 3: 8880 42.85507456203851
Overshoot 5: 7413 35.7753004198639
Overshoot 7: 6598 31.842092563100238
==
Adam 16297
Nadam 16075
Overshoot 3: 14745 90.47677486653986
Overshoot 5: 14208 87.18168988157329
Overshoot 7: 13770 84.49407866478494
19.197705160203082
