In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

import torch
import math, random, sys, os, time, pickle

## 1. Execute this for CIFAR

In [None]:
models_dir = '/gstore/project/prescient/data/loukasa1/alg-stability/model_weights'
focus_on   = 'features.0.weight'
data       = 'cifar'

epochs_all, weights_all, loss_all, accuracy_all, accuracy_test_all, noise_all = [], [], [], [], [], []

for folder in os.listdir(models_dir):
        
    if 'cifar10-vgg16' not in folder: continue
    if '20.04' in folder: continue
    if '21.04' in folder: continue
    if 'wd' not in folder: continue 
        
    model_dir = os.path.join(models_dir, folder)
    models = [m for m in os.listdir(model_dir) if os.path.isfile(os.path.join(model_dir, m))]

    print(model_dir)
    if 'noise:0.25' in model_dir: noise_all.append(0.25) 
    else: noise_all.append(0) 

    epochs, weights, loss, accuracy, accuracy_test = [], [], [], [], []

    for model in models:    
        dummy = torch.load(os.path.join(model_dir, model))
        epochs.append(dummy['epoch'])
        weights.append(dummy['model_state_dict'][focus_on].detach().cpu().numpy())

        if 'loss' in dummy.keys(): 
            loss.append(dummy['loss'])
        else: 
            loss.append(np.nan)

        if 'accuracy' in dummy.keys(): 
            accuracy.append(dummy['accuracy'].detach().cpu().numpy())
        else: 
            accuracy.append(np.nan)
            
        if 'accuracy_test' in dummy.keys(): 
            accuracy_test.append(dummy['accuracy_test'].detach().cpu().numpy())
        else: 
            accuracy_test.append(np.nan)
            
    idx = np.argsort(np.array(epochs))
    weights  = np.array([weights[i] for i in idx])
    epochs   = np.array([epochs[i] for i in idx])
    loss     = np.array([loss[i] for i in idx])
    accuracy = np.array([accuracy[i] for i in idx])
    accuracy_test = np.array([accuracy_test[i] for i in idx])

    weights_all.append(weights)
    epochs_all.append(epochs)
    loss_all.append(loss)
    accuracy_all.append(accuracy)
    accuracy_test_all.append(accuracy_test)

In [None]:
# for k in dummy['model_state_dict'].keys(): print(k)

In [None]:
epoch_start, min_n_epochs = 400, 4000 
X = np.array([weights_all[i][j].reshape(-1) for i,epochs in enumerate(epochs_all) for j,epoch in enumerate(epochs) if max(epochs) > min_n_epochs and epoch >= epoch_start]).astype(np.float)
obs_train = np.array([loss_all[i][j] for i,epochs in enumerate(epochs_all) for j,epoch in enumerate(epochs) if max(epochs) > min_n_epochs and epoch >= epoch_start]).astype(np.float)
obs_test = np.array([accuracy_test_all[i][j] for i,epochs in enumerate(epochs_all) for j,epoch in enumerate(epochs) if max(epochs) > min_n_epochs and epoch >= epoch_start]).astype(np.float)
e = np.array([epochs_all[i][j] for i,epochs in enumerate(epochs_all) for j,epoch in enumerate(epochs) if max(epochs) > min_n_epochs and epoch >= epoch_start]).astype(np.float)
r = np.array([i for i,epochs in enumerate(epochs_all) for j,epoch in enumerate(epochs) if max(epochs) > min_n_epochs and epoch >= epoch_start]).astype(np.float)
factor = np.array([noise_all[i] for i,epochs in enumerate(epochs_all) for j,epoch in enumerate(epochs) if max(epochs) > min_n_epochs and epoch >= epoch_start]).astype(np.float)

X.shape, obs_train.shape, obs_test.shape, e.shape, r.shape, factor.shape
obs, obs_names = [obs_train,obs_test,e], ['log(loss)', 'test accuracy', 'epoch']

## 1. Execute this for WIKITEXT

In [None]:
models_dir = '/gstore/scratch/u/loukasa1/alg-stability/model_weights' # TODO: change this 
focus_on   = 'transformer_encoder.layers.0.linear1.weight'
data       = 'wiki'

epochs_all, weights_all, loss_all, loss_val_all, loss_test_all, prc_all = [], [], [], [], [], []

for folder in os.listdir(models_dir):

    if ('wiki' not in folder): continue 
                
    model_dir = os.path.join(models_dir, folder)
    models = [m for m in os.listdir(model_dir) if os.path.isfile(os.path.join(model_dir, m))]

    print(model_dir)

    if 'prc:0.1' in model_dir: prc_all.append(0.1)
    elif 'prc:0.01' in model_dir: prc_all.append(0.01)
    elif 'prc:0.3' in model_dir: prc_all.append(0.3)
    else: prc_all.append(1)
        
    epochs, weights, loss, loss_val, loss_test = [], [], [], [], []

    for model in models:    
        dummy = torch.load(os.path.join(model_dir, model))
        epochs.append(dummy['epoch'])
        weights.append(dummy[focus_on].detach().cpu().numpy())

        if 'loss' in dummy.keys(): 
            loss.append(dummy['loss'])
        else: 
            loss.append(np.nan)

        if 'loss_val' in dummy.keys(): 
            loss_val.append(dummy['loss_val'])
        else: 
            loss_val.append(np.nan)
            
        if 'loss_test' in dummy.keys(): 
            loss_test.append(dummy['loss_test'])
        else: 
            loss_test.append(np.nan)
            
    idx = np.argsort(np.array(epochs))
    weights   = np.array([weights[i] for i in idx])
    epochs    = np.array([epochs[i] for i in idx])
    loss      = np.array([loss[i] for i in idx])
    loss_val  = np.array([loss_val[i] for i in idx])
    loss_test = np.array([loss_test[i] for i in idx])

    weights_all.append(weights)
    epochs_all.append(epochs)
    loss_all.append(loss)
    loss_val_all.append(loss_val)
    loss_test_all.append(loss_test)

In [None]:
epoch_start, min_n_epochs = 0, 0 
X = np.array([weights_all[i][j].reshape(-1) for i,epochs in enumerate(epochs_all) for j,epoch in enumerate(epochs) if max(epochs) > min_n_epochs and epoch >= epoch_start]).astype(np.float)
obs_train = np.array([loss_all[i][j] for i,epochs in enumerate(epochs_all) for j,epoch in enumerate(epochs) if max(epochs) > min_n_epochs and epoch >= epoch_start]).astype(np.float)
obs_test  = np.array([loss_test_all[i][j] for i,epochs in enumerate(epochs_all) for j,epoch in enumerate(epochs) if max(epochs) > min_n_epochs and epoch >= epoch_start]).astype(np.float)
e         = np.array([epochs_all[i][j]*101544324*0.01*prc_all[i] for i,epochs in enumerate(epochs_all) for j,epoch in enumerate(epochs) if max(epochs) > min_n_epochs and epoch >= epoch_start]).astype(np.float)
r         = np.array([i for i,epochs in enumerate(epochs_all) for j,epoch in enumerate(epochs) if max(epochs) > min_n_epochs and epoch >= epoch_start]).astype(np.float)
factor    = np.array([prc_all[i] for i,epochs in enumerate(epochs_all) for j,epoch in enumerate(epochs) if max(epochs) > min_n_epochs and epoch >= epoch_start]).astype(np.float)

X.shape, obs_train.shape, obs_test.shape, e.shape, r.shape, factor.shape
obs, obs_names = [obs_train,obs_test,e], ['loss', 'test loss', 'log(epoch)']

## 2. PCA (execute for both cases)

In [None]:
pca = PCA(n_components=40).fit(X)
plt.plot(np.log(pca.explained_variance_))
U = pca.components_
Xpca = X @ U.T
X.shape, U.shape, Xpca.shape

## 3. Visualization based on random directions

CIFAR

In [None]:
assert data == 'cifar'

n_iters = 5

fig = plt.figure(figsize=(3*10, 9*n_iters), facecolor=(1,1,1))
for it in range(n_iters): 
    
    v1 = np.random.randn(len(weights_all[0][0].reshape(-1))); v1 = v1 @ U.T; v1 /= np.linalg.norm(v1)
    v2 = np.random.randn(len(weights_all[0][0].reshape(-1))); v2 = v2 @ U.T; v2 /= np.linalg.norm(v2)

    Xe = np.zeros((X.shape[0], 2))
    for i in range(X.shape[0]):
        Xe[i,0] = np.sum(Xpca[i,:] * v1)
        Xe[i,1] = np.sum(Xpca[i,:] * v2)

    for i,x,name in zip([1,2,3], obs, obs_names):
        ax = fig.add_subplot(n_iters,3,it*3 + i)
        for j in np.unique(r):
            mask = np.where(r==j)[0]
            if noise_all[int(j)] == 0: 
                sc = ax.plot(Xe[mask, 0], Xe[mask, 1], 'k-', alpha=0.5) 
            else: 
                sc = ax.plot(Xe[mask, 0], Xe[mask, 1], 'y-', alpha=0.5) 
        if 'log' in name: 
            sc = ax.scatter(Xe[:, 0], Xe[:, 1], c=np.log(0.0000+x), cmap=plt.cm.Spectral.reversed(), marker='o', alpha=1)
        else: 
            sc = ax.scatter(Xe[:, 0], Xe[:, 1], c=x, cmap=plt.cm.Spectral.reversed(), marker='o', alpha=1)
        
        plt.colorbar(sc);
        ax.set_title(name);

WIKITEXT

In [None]:
assert data == 'wiki'

n_iters = 5

fig = plt.figure(figsize=(3*10, 9*n_iters), facecolor=(1,1,1))
for it in range(n_iters): 
    
    v1 = np.random.randn(len(weights_all[0][0].reshape(-1))); v1 = v1 @ U.T; v1 /= np.linalg.norm(v1)
    v2 = np.random.randn(len(weights_all[0][0].reshape(-1))); v2 = v2 @ U.T; v2 /= np.linalg.norm(v2)

    Xe = np.zeros((X.shape[0], 2))
    for i in range(X.shape[0]):
        Xe[i,0] = np.sum(Xpca[i,:] * v1)
        Xe[i,1] = np.sum(Xpca[i,:] * v2)

    for i,x,name in zip([1,2,3], obs, obs_names):
        ax = fig.add_subplot(n_iters,3,it*3 + i)
        for j in np.unique(r):
            mask = np.where(r==j)[0]
            sc = ax.plot(Xe[mask, 0], Xe[mask, 1], 'k-', alpha=0.1, linewidth=2) 
#             if prc_all[int(j)] == 0.01: 
#                 sc = ax.plot(Xe[mask, 0], Xe[mask, 1], 'r-', alpha=0.31, label=' 1%', linewidth=5) 
#             elif prc_all[int(j)] == 0.1: 
#                 sc = ax.plot(Xe[mask, 0], Xe[mask, 1], 'm-', alpha=0.31, label='10%', linewidth=5) 
#             else: 
#                 sc = ax.plot(Xe[mask, 0], Xe[mask, 1], 'g-', alpha=0.31, label='30%', linewidth=5) 
        if 'log' in name: 
            sc = ax.scatter(Xe[:, 0], Xe[:, 1], c=np.log(0.0000+x), cmap=plt.cm.Spectral.reversed(), marker='o', alpha=1, s=50)
        else: 
            sc = ax.scatter(Xe[:, 0], Xe[:, 1], c=x, cmap=plt.cm.Spectral.reversed(), marker='o', alpha=1, s=50, vmin=2, vmax=14)
        
        plt.colorbar(sc);
        ax.set_title(name);
#         ax.legend()

Visualization based on t-SNE (CIFAR & WIKITEXT)

In [None]:
Xe = TSNE(n_components=2, perplexity=10).fit_transform(Xpca)
fig = plt.figure(figsize=(3*8, 6))

for i,x,name in zip([1,2,3], obs, obs_names):
    ax = fig.add_subplot(1,3,i)
    for j in np.unique(r):
        mask = np.where(r==j)[0]
#         if noise_all[int(j)] == 0: 
#             sc = ax.plot(Xe[mask, 0], Xe[mask, 1], 'k-', alpha=0.5) 
#         else: 
#             sc = ax.plot(Xe[mask, 0], Xe[mask, 1], 'y-', alpha=0.5) 
            
    if 'log' in name: 
        sc = ax.scatter(Xe[:, 0], Xe[:, 1], c=np.log(0.0000+x), cmap=plt.cm.Spectral.reversed(), marker='o', alpha=1)
    else: 
        sc = ax.scatter(Xe[:, 0], Xe[:, 1], c=x, cmap=plt.cm.Spectral.reversed(), marker='o', alpha=1)

    plt.colorbar(sc);
    ax.set_title(name);

Visualization based on principal components (CIFAR & WIKITEXT)

In [None]:
for i in np.arange(-1,len(epochs_all)): 
    
    if i == -1: 
        T = X.copy()
    else:
        epochs = epochs_all[i]
        T = np.array([weights_all[i][j].reshape(-1) for j,epoch in enumerate(epochs) if max(epochs) > min_n_epochs and epoch >= epoch_start]).astype(np.float)

    if T.shape[0] == 0: continue

    mean = np.mean(T, axis=0)
    T -= np.repeat(mean.reshape((1,-1)), T.shape[0], axis=0)
    eigvalues, eigvectors = np.linalg.eigh(T.T @ T)

    v1 = eigvectors[:,-1]
    v2 = eigvectors[:,-2]

    Xe = np.zeros((X.shape[0], 2))
    for i in range(X.shape[0]):
        Xe[i,0] = np.sum(X[i,:] * v1)
        Xe[i,1] = np.sum(X[i,:] * v2)

    fig = plt.figure(figsize=(3*8, 6))

    for i,x,name in zip([1,2,3], obs, obs_names):
        ax = fig.add_subplot(1,3,i)
        for j in np.unique(r):
            mask = np.where(r==j)[0]
            sc = ax.plot(Xe[mask, 0], Xe[mask, 1], 'k-', alpha=0.4) 
        if 'log' in name: 
            sc = ax.scatter(Xe[:, 0], Xe[:, 1], c=np.log(0.0000+x), cmap=plt.cm.Spectral.reversed(), marker='o', alpha=1)
        else: 
            sc = ax.scatter(Xe[:, 0], Xe[:, 1], c=x, cmap=plt.cm.Spectral.reversed(), marker='o', alpha=1)

        plt.colorbar(sc);
        ax.set_title(name);

How they move w.r.t. the mean

In [None]:
fig = plt.figure(figsize=(3*8, 6))

ax1 = fig.add_subplot(1,3,1)
ax2 = fig.add_subplot(1,3,2)
ax3 = fig.add_subplot(1,3,3)

for i in np.arange(len(epochs_all)): 
    
    epochs = epochs_all[i]
    T = np.array([weights_all[i][j].reshape(-1) for j,epoch in enumerate(epochs) if max(epochs) > min_n_epochs and epoch >= epoch_start]).astype(np.float)
    e = np.array([epoch for j,epoch in enumerate(epochs) if max(epochs) > min_n_epochs and epoch >= epoch_start]).astype(np.float)
    l = np.array([loss_all[i][j] for j,epoch in enumerate(epochs) if max(epochs) > min_n_epochs and epoch >= epoch_start]).astype(np.float)
    a = np.array([accuracy_all[i][j] for j,epoch in enumerate(epochs) if max(epochs) > min_n_epochs and epoch >= epoch_start]).astype(np.float)

    n_epochs = T.shape[0]
    
    if n_epochs == 0: continue

    mean = np.mean(T, axis=0)
    dist = np.zeros(n_epochs)
    for j in range(n_epochs): dist[j] = np.linalg.norm(T[j,:] - mean)

    for x,name,ax in zip([l,a,e], obs, obs_names):
        ax.plot(e, dist, 'k-', alpha=0.4) 
        if 'log(loss)' in name: 
            sc = ax.scatter(e, dist, c=np.log(0.0000+x), cmap=plt.cm.Spectral.reversed(), marker='o', alpha=1, vmin=-3, vmax=-1)
        elif 'accuracy' in name:  
            sc = ax.scatter(e, dist, c=x, cmap=plt.cm.Spectral.reversed(), marker='o', alpha=1, vmin=0.94, vmax=0.97)
        elif 'epoch' in name:  
            sc = ax.scatter(e, dist, c=x, cmap=plt.cm.Spectral.reversed(), marker='o', alpha=1, vmin=1000, vmax=5000)

        if i == len(epochs_all)-1: 
            plt.colorbar(sc);
            ax.set_title(name);

## Scratch space

In [None]:
for i,epochs in enumerate(epochs_all):
    norm = np.zeros(len(epochs)) 
    for j,epoch in enumerate(epochs): 
        norm[j] = np.linalg.norm(weights_all[i][j].reshape(-1))

    plt.plot(epochs, norm)

In [None]:
T.shape