In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import torch
import numpy as np
from types import SimpleNamespace
from scripts import launch_pretraining, launch_finetuning

import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from IPython.display import clear_output

In [None]:
device = torch.device('cpu')
device

In [None]:
def get_config(data_seed, train_seed, f):
    seed = data_seed * 100 + train_seed
    exp_name = f'f_all={f}'

    return SimpleNamespace(
        data_seed = data_seed,  # seed for data generation
        data_protocol = [
            {'feature_type': f, 'ids': (r, r + 1), 'margin': 0.1, 'noise': 0.0}
            for r in range(0, 32, 2)
            # feature_type - feature distribution (only "tick" is used in the paper)
            # ids - id of features to sample
            # margin - width of the separating boundray between classes
            # noise - probability of swapping classes for datapoints (always 0.0 in the paper)
        ],  # protocol for data generation
        multiview_probs = [1.0] * 16,  # utility of each feature (1 to make all points separable,
                                       # 0 to initialize all points along the separating boundary)
        num_features = 32,             # input dimensionality
        train_samples = 512,           # train size
        test_samples = 2000,           # test size
        batch_size = 16,               # training batch size
        num_hidden = 32,               # number of hidden units in MLP
        num_layers = 3,                # number of layert in MLP
        activation = 'relu',           # activation in MLP
        last_layer_norm = 10,          # last layer norm in MLP (last layer is fixed and not trained)
        riemann_opt = False,           # whether to use spherical SGD instead of projected SGD
                                       # (not described in the paper)
        dirichlet_init = None,         # initialize singular values of weight matrices with dirichlet
                                       # distribution for low rank initialization (not described in the paper)
        pt_iters = 40000,              # pre-training iterations
        ft_iters = 20000,              # fine-tuning iterations
        ckpt_iters = 100,              # how often to checkpoint model
        log_iters = 5,                 # how often to calculate metrics
        pt_seed = train_seed,          # pre-training seed
        ft_seed = train_seed,          # fine-tuning seed
        init_point_seed = train_seed,  # model initialization seed
        savedir = f'experiments-final/{exp_name}/PT-FCN-seed={seed}',  # pre-training path
        ft_savedir = f'experiments-final/{exp_name}/FT-FCN-seed={seed}',  # fine-tuning path
        lrs = np.logspace(-4.5, -2.25, 10).tolist()[:-1] + \
              np.logspace(-2.25, -1.25, 9).tolist()[:-1] + \
              np.logspace(-1.25, 0, 6).tolist()  # learning rate range to use
    )

In [None]:
for f in ['tick']:
    for data_seed in range(1, 6):
        for train_seed in range(1, 11):
            print(f'{f}, data seed: {data_seed}, #{train_seed}')
            config = get_config(
                data_seed, train_seed, f
            )
            launch_pretraining(config, device)
            launch_finetuning(config, device, num_ft_lr=11)

        clear_output()