# Setup

In [None]:
# check GPU
!nvidia-smi

In [None]:
# pulls repo to colab
!git clone https://github.com/char-tan/mode_connectivity.git

# checkout specific branch if needed
!cd mode_connectivity; git checkout resnet_perm_spec

In [None]:
# allows you to change the .py files and have the imports updated
%load_ext autoreload
%autoreload 2

In [None]:
import sys, os
SCRIPT_DIR = os.path.dirname(os.path.abspath('.'))
sys.path.append(os.path.dirname(SCRIPT_DIR))

from mode_connectivity.training import *
from mode_connectivity.lmc import *
from mode_connectivity.training_config import *

from mode_connectivity.models.mlp import MLP
from mode_connectivity.models.vgg import VGG
from mode_connectivity.models.resnet import ResNet

from mode_connectivity.utils import weight_matching, data, training_utils, plot, utils

import torch
from torchvision import transforms
import copy
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

# VGG LMC

In [None]:
# pull the required model files
!cd mode_connectivity && git lfs pull -I model_files/resnet_wm2_a.pt 
!cd mode_connectivity && git lfs pull -I model_files/resnet_wm2_b.pt

In [None]:
n_points = 20
lambdas = torch.linspace(0, 1, steps=n_points)

for wm in [2]:

  (
      permuted_params,
      train_acc_naive,
      test_acc_naive,
      train_acc_perm,
      test_acc_perm,
  ) = linear_mode_connect(
      ResNet, 
      {'width_multiplier': wm},
      f'mode_connectivity/model_files/resnet_wm{wm}_a.pt', 
      f'mode_connectivity/model_files/resnet_wm{wm}_b.pt', 
      'cifar10', 
      n_points=n_points, 
      verbose=2, 
      max_iter=30,
  )

  torch.save(permuted_params, f'mode_connectivity/model_files/resnet_wm{wm}_b_permuted.pt')

  print(train_acc_naive)
  print(test_acc_naive)
  print(train_acc_perm)
  print(test_acc_perm)

In [None]:
# First we do the basic linear interpolation plot
lambdas = torch.linspace(0, 1, steps=n_points)
fig = plot.plot_interp_metric("accuracy", lambdas, train_acc_naive, test_acc_naive, train_acc_perm, test_acc_perm)

In [None]:
# Now we generate the contour plot
a_params = torch.load("mode_connectivity/model_files/resnet_wm{wm}_a.pt") 
b_params = torch.load("mode_connectivity/model_files/resnet_wm{wm}_b.pt") 
v1, v2, v3 = (
    utils.state_dict_to_numpy_array(p) for p in [a_params, b_params, permuted_params]
)

model_a = ResNet(width_multiplier=4)
model_a.load_state_dict(a_params)

contour_plane = utils.generate_orthogonal_basis(v1, v2, v3)
train_loader, test_loader = data.get_data_loaders(
    dataset="cifar10", train_kwargs={"batch_size":512}, test_kwargs={"batch_size":512}
)
device, _ = get_device()
(
    t1s,
    t2s,
    test_acc_grid,
    test_loss_grid,
    train_acc_grid,
    train_loss_grid,
) = utils.generate_loss_landscape_contour(
    model_a, device, train_loader, test_loader, contour_plane, granularity=10
)

In [None]:
plot.plot_metric_contour(
    "accuracy",
    t1s,
    t2s,
    test_acc_grid,
    model_vectors_dict={
        "A": utils.projection(v1, contour_plane),
        "B": utils.projection(v2, contour_plane),
        "B permuted": utils.projection(v3, contour_plane),
    },
)

In [None]:
np.save('resnet_wm{wm}_test_acc_grid.npy', test_acc_grid)
np.save('resnet_wm{wm}_test_loss_grid.npy', test_loss_grid)