# Setup

In [None]:
# check GPU
!nvidia-smi

In [None]:
# pulls repo to colab
!git clone https://github.com/char-tan/mode_connectivity.git

# checkout specific branch if needed
!cd mode_connectivity; git checkout vgg_perm_spec

In [None]:
# allows you to change the .py files and have the imports updated
%load_ext autoreload
%autoreload 2

In [None]:
import sys, os
SCRIPT_DIR = os.path.dirname(os.path.abspath('.'))
sys.path.append(os.path.dirname(SCRIPT_DIR))

from mode_connectivity.training import *
from mode_connectivity.lmc import *
from mode_connectivity.training_config import *

from mode_connectivity.models.mlp import MLP
from mode_connectivity.models.vgg import VGG
from mode_connectivity.models.resnet import ResNet

from mode_connectivity.utils import weight_matching, data, training_utils, plot, utils

import torch
from torchvision import transforms
import copy
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

# VGG LMC

In [None]:
# pull the required model files
!cd mode_connectivity && git lfs pull -I model_files/vgg_wm1_a.pt 
!cd mode_connectivity && git lfs pull -I model_files/vgg_wm1_b.pt
!cd mode_connectivity && git lfs pull -I model_files/vgg_wm2_a.pt 
!cd mode_connectivity && git lfs pull -I model_files/vgg_wm2_b.pt
!cd mode_connectivity && git lfs pull -I model_files/vgg_wm4_a.pt 
!cd mode_connectivity && git lfs pull -I model_files/vgg_wm4_b.pt

In [None]:
n_points = 20
lambdas = torch.linspace(0, 1, steps=n_points)

for wm in [4]:

  model_a = VGG(width_multiplier=wm)
  model_b = VGG(width_multiplier=wm)

  (
      permuted_params,
      train_acc_naive,
      test_acc_naive,
      train_acc_perm,
      test_acc_perm,
  ) = linear_mode_connect(
      VGG, 
      {'width_multiplier': wm},
      f'mode_connectivity/model_files/vgg_wm{wm}_a.pt', 
      f'mode_connectivity/model_files/vgg_wm{wm}_b.pt', 
      'cifar10', 
      n_points=n_points, 
      verbose=2, 
      max_iter=30,
  )

  torch.save(permuted_params, f'mode_connectivity/model_files/vgg_wm{wm}_b_permuted.pt')

  print(train_acc_naive)
  print(test_acc_naive)
  print(train_acc_perm)
  print(test_acc_perm)