In [None]:
import os
import glob
import json
import torch
import joblib
import numpy as np
from types import SimpleNamespace
from torch.nn import functional as F
from scripts import launch_pretraining, launch_finetuning

import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from IPython.display import clear_output
from scipy.special import xlogy
from utils import load_stats
from model import *
from dataset import *

from matplotlib.colors import LogNorm, LinearSegmentedColormap
from matplotlib.cm import ScalarMappable

In [None]:
dirs = sorted(glob.glob('experiments-final/f_all=tick/PT-*'))
configs = []
for dir_name in dirs:
    with open(os.path.join(dir_name, 'config.json')) as file:
        configs.append(SimpleNamespace(**json.load(file)))
configs = sorted(configs, key=lambda x: (x.data_seed, x.pt_seed))
lrs = configs[0].lrs

## Single feature test sets

In [None]:
config = configs[0]
X_group, y_group = [], []
num_groups = config.num_features // 2
samples_per_group = 1000

for i in range(num_groups):
    multiview_probs = [0] * num_groups
    multiview_probs[i] = 1.0

    _, X, _, y, _ = \
        generate_data(
            config.data_protocol, multiview_probs,
            config.data_seed + i, config.num_features,
            config.train_samples, samples_per_group, 'cpu'
        )

    X_group += [X]
    y_group += [y]

X_group = torch.cat(X_group, dim=0)
y_group = torch.cat(y_group, dim=0)

In [None]:
def calculate_feature_importance(flr=None, num_swa=None):
    group_accs = np.zeros((len(configs), len(lrs), 16))

    for i, config in enumerate(tqdm(configs)):
        model = init_model(
            config.num_layers, config.num_hidden, config.num_features,
            config.last_layer_norm, config.activation
        )
        X_train, X_test, y_train, y_test = torch.load(f'{config.savedir}/data.pt')[:4]
        X_train.requires_grad = True
        X_test.requires_grad = True

        for j, lr in enumerate(lrs):
            if flr is None:
                ckpt = torch.load(f'{config.savedir}/pt_lr={lr:.3e}.pt')
            else:
                ckpt = torch.load(f'{config.ft_savedir}/pt_lr={lr:.3e}-ft_lr={flr:.3e}.pt')
            model.load_state_dict(ckpt['model'])

            if num_swa is not None:
                w_swa = torch.stack(ckpt['trace']['weight'][-num_swa:], dim=0).mean(0)
                set_weights(model, w_swa)

            with torch.no_grad():
                group_preds = (model(X_group)[:, 0] > 0).to(torch.long)
            correct = group_preds == y_group

            for k in range(num_groups):
                group_accs[i, j, k] = correct[k * samples_per_group: (k + 1) * samples_per_group] \
                    .to(torch.float).mean().item()
        
    return group_accs

In [None]:
group_accs = calculate_feature_importance()
joblib.dump(group_accs, f'experiments-final/f_all=tick/pt_group_accs.pickle')

In [None]:
group_accs = calculate_feature_importance(flr=lrs[0])
joblib.dump(group_accs, f'experiments-final/f_all=tick/ft_group_accs-flr={lrs[0]:.3e}.pickle')

In [None]:
group_accs = calculate_feature_importance(flr=lrs[10])
joblib.dump(group_accs, f'experiments-final/f_all=tick/ft_group_accs-flr={lrs[10]:.3e}.pickle')

In [None]:
group_accs = calculate_feature_importance(num_swa=5)
joblib.dump(group_accs, f'experiments-final/f_all=tick/swa_group_accs.pickle')

## Angular distance and error barriers

In [None]:
alphas = np.linspace(0, 1, 11)

@torch.no_grad()
def get_error(model, X, y):
    preds = (model(X)[:, 0] > 0).to(torch.long)
    err = (preds != y).to(torch.float).mean().item()
    return 100 * err

@torch.no_grad()
def get_accuracy(model, X, y):
    preds = (model(X)[:, 0] > 0).to(torch.long)
    acc = (preds == y).to(torch.float).mean().item()
    return 100 * acc

def get_barrier(model, w1, w2, X, y, alphas):
    errors = np.zeros_like(alphas)
    for i, alpha in enumerate(alphas):
        #w = np.cos(alpha) * w1 + np.sin(alpha) * w2
        w = (1 - alpha) * w1 + alpha * w2
        set_weights(model, w)
        errors[i] = get_error(model, X, y)

    barrier = np.max(errors - (1 - alphas) * errors[0] - alphas * errors[-1])
    return barrier

def get_angle(w1, w2):
    return torch.clip(w1 @ w2, -1, 1).arccos().item()

In [None]:
from joblib import Parallel, delayed

config = configs[0]
# Set required FLRs and the number of models in SWA
flr1, flr2 = lrs[0], lrs[10]
num_swa = 5
alphas = np.linspace(0, 1, 11)
angles = np.zeros((3, len(configs), len(lrs)))
train_barriers = np.zeros((3, len(configs), len(lrs)))
train_errors = np.zeros((3, len(configs), len(lrs), len(alphas)))
test_barriers = np.zeros((3, len(configs), len(lrs)))
test_errors = np.zeros((3, len(configs), len(lrs), len(alphas)))

model = init_model(
    config.num_layers, config.num_hidden, config.num_features,
    config.last_layer_norm, config.activation
)

def process_plr(config, plr):
    ckpt = torch.load(f'{config.savedir}/pt_lr={plr:.3e}.pt')
    ckpt1 = torch.load(f'{config.ft_savedir}/pt_lr={plr:.3e}-ft_lr={flr1:.3e}.pt')
    ckpt2 = torch.load(f'{config.ft_savedir}/pt_lr={plr:.3e}-ft_lr={flr2:.3e}.pt')
    model.load_state_dict(ckpt['model'])

    w_swa = torch.stack(ckpt['trace']['weight'][-num_swa:], dim=0).mean(0)
    w_low = ckpt1['trace']['weight'][-1]
    w_high = ckpt2['trace']['weight'][-1]

    angles, train_barriers, test_barriers = np.zeros(3), np.zeros(3), np.zeros(3)
    for k, (w1, w2) in enumerate([(w_low, w_high), (w_high, w_swa), (w_low, w_swa)]):
        angles[k] = get_angle(w1, w2)

        barrier = get_barrier(model, w1, w2, X_train, y_train, alphas)
        train_barriers[k] = barrier
        barrier = get_barrier(model, w1, w2, X_test, y_test, alphas)
        test_barriers[k] = barrier
    
    return angles, train_barriers, test_barriers

for i, config in enumerate(tqdm(configs)):
    X_train, X_test, y_train, y_test = torch.load(f'{config.savedir}/data.pt')[:4]
    results = Parallel(n_jobs=8)(
        delayed(process_plr)(config, plr) for plr in lrs
    )
    for j, (a, b, c) in enumerate(results):
        angles[:, i, j] = a
        train_barriers[:, i, j] = b
        test_barriers[:, i, j] = c

In [None]:
joblib.dump((angles, train_barriers, train_errors, test_barriers, test_errors),
            f'experiments-final/f_all=tick/barriers-low_flr={flr1:.3e}-high_flr={flr2:.3e}.pickle')

In [None]:
config = configs[0]
num_swa_list = [2, 5, 10, 20, 50]
swa_accs = np.zeros((len(configs), len(lrs), len(num_swa_list)))
model = init_model(
    config.num_layers, config.num_hidden, config.num_features,
    config.last_layer_norm, config.activation
)

for i, config in enumerate(tqdm(configs)):
    X_train, X_test, y_train, y_test = torch.load(f'{config.savedir}/data.pt')[:4]

    for j, plr in enumerate(lrs):
        ckpt = torch.load(f'{config.savedir}/pt_lr={plr:.3e}.pt')
        model.load_state_dict(ckpt['model'])

        for k, num_swa in enumerate(num_swa_list):
            w_swa = torch.stack(ckpt['trace']['weight'][-num_swa:], dim=0).mean(0)
            set_weights(model, w_swa)
            swa_accs[i, j, k] = get_accuracy(model, X_test, y_test)

In [None]:
joblib.dump(swa_accs, 'experiments-final/f_all=tick/swa.pickle')