In [None]:
%%capture
!pip install torchdata==0.4.1

In [None]:
from os import path, listdir, makedirs

import torch

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

from datasets import get_data
from models import init_model

# Arguments

In [None]:
CHECKPOINTS_DIR = '.'
PATHS_DIR = '.'
COLORS = ['C0', 'C1', 'C2']
NUM_STEPS = 60

# Function and Constant Definitions

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def get_model_params(model):
    params = []
    for w in model.parameters():
        params.append(torch.flatten(w))

    return torch.cat(params)

In [None]:
def load_checkpoint(model, dir, i, device='cpu'):
    model.load_state_dict(torch.load(path.join(dir, f'ck-{i}.pt'), map_location=device))
    return model

In [None]:
def plot_save_paths(params, inits, optimizers, paths_dir='paths'):
    """
    Plots and saves all paths using the string representation of
    the lists inits and optimizers. Assumes that all paths are of
    the same length.
    """

    if params.shape[1] != 2:
        raise Exception(f'Got dim={params.shape[1]}. Cannot visualize when dim!=2.')
    
    makedirs(paths_dir, exist_ok=True)

    num_paths = params.shape[0] // (len(inits) * len(optimizers))

    for i, init in enumerate(inits):
        for j, optimizer in enumerate(optimizers):
            current_path = params[(i * len(optimizers) + j)*num_paths : (i * len(optimizers) + j +1)*num_paths]
            plt.plot(current_path[0, 0], current_path[0, 1], marker='o', color=COLORS[j]) # Mark init weights
            plt.plot(current_path[:, 0], current_path[:, 1], color=COLORS[j], label=f'{optimizer}')

        plt.legend()
        plt.savefig(path.join(paths_dir, f'seed_{init}.pdf'))
        plt.show()

In [None]:
dataset_names = sorted(listdir(CHECKPOINTS_DIR))

In [None]:
tsne = TSNE(n_components=2, verbose=0, init='pca', learning_rate='auto', perplexity=10)

for dataset_name in dataset_names:
    dataset_dir = path.join(CHECKPOINTS_DIR, dataset_name)

    _, _ = get_data(dataset_name)
    model = init_model(dataset_name)

    seeds = sorted(listdir(dataset_dir))

    all_params = []

    for seed in seeds:
        seed_dir = path.join(dataset_dir, seed)
        optimizers = sorted(listdir(seed_dir))

        for optimizer in optimizers:
            optimizer_dir = path.join(seed_dir, optimizer)
            # Add all checkpoints in the dir
            all_params.extend(get_model_params(load_checkpoint(model, optimizer_dir, i)) for i in range(NUM_STEPS + 1))


    # Visualize all paths for the dataset
    all_params = torch.stack(all_params).detach()

    # Normalize the data
    scaler = StandardScaler()
    all_params = scaler.fit_transform(all_params)

    pca = PCA(n_components=100)

    transformed_params = pca.fit_transform(all_params)
    print(f'PCA transformation done. Paths are in the shape: {all_params.shape}')

    transformed_params = tsne.fit_transform(transformed_params)
    print(f'TSNE transformation done. Paths are in the shape: {transformed_params.shape}')

    plot_save_paths(transformed_params, seeds, optimizers, paths_dir=path.join(PATHS_DIR, dataset_name))