# Setup

In [1]:
# check GPU
!nvidia-smi

Fri Dec 30 21:38:04 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   45C    P0    27W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
# pulls repo to colab
!git clone https://github.com/char-tan/mode_connectivity.git

# checkout specific branch if needed
!cd mode_connectivity; git checkout main

Cloning into 'mode_connectivity'...
remote: Enumerating objects: 402, done.[K
remote: Counting objects: 100% (180/180), done.[K
remote: Compressing objects: 100% (95/95), done.[K
remote: Total 402 (delta 106), reused 125 (delta 85), pack-reused 222[K
Receiving objects: 100% (402/402), 7.12 MiB | 24.45 MiB/s, done.
Resolving deltas: 100% (225/225), done.
Already on 'main'
Your branch is up to date with 'origin/main'.


In [11]:
# allows you to change the .py files and have the imports updated
%load_ext autoreload
%autoreload 2

In [9]:
import sys, os
SCRIPT_DIR = os.path.dirname(os.path.abspath('.'))
sys.path.append(os.path.dirname(SCRIPT_DIR))

from mode_connectivity.training import *
from mode_connectivity.lmc import *
from mode_connectivity.training_config import *

from mode_connectivity.models.mlp import MLP
from mode_connectivity.models.vgg import VGG
from mode_connectivity.models.resnet import ResNet

from mode_connectivity.utils import weight_matching, data, training_utils, plot

import torch
from torchvision import transforms
import copy
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

# Training

In [5]:
training_config = MLP_MNIST_DEFAULT
training_config.epochs = 1
training_config.seed = 7

# train model a
model_a = train_model(*setup_train(training_config), verbose = 1)
torch.save(model_a.state_dict(), 'model_a.pt')

# change seed to get different model
training_config.seed = 42

# train model b
model_b = train_model(*setup_train(training_config), verbose = 1)
torch.save(model_b.state_dict(), 'model_b.pt')  

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



  return nn.functional.log_softmax(x)


Train Epoch: 1, Train Accuracy: (90%) 
Average loss: 0.1298, Accuracy: (96%)
Train Epoch: 1, Train Accuracy: (90%) 
Average loss: 0.1377, Accuracy: (96%)


# LMC

In [13]:
linear_mode_connect(MLP, 'model_a.pt','model_b.pt', "mnist", n_points = 3)


performing naive interpolation
lam = 0.0, train loss = 0.1224782431602478, test loss = 0.12978225574493407
lam = 0.5, train loss = 1.5429939069112142, test loss = 1.5321542510986328
lam = 1.0, train loss = 0.12981847489674886, test loss = 0.13772883453369142

permuting model
iteration 0 P_0: loss? 49.641780853271484
iteration 0 P_1: loss? 26.808156967163086
iteration 0 P_2: loss? 13.301342010498047
iteration 1 P_0: loss? 3.686370849609375
iteration 1 P_2: loss? 0.0
iteration 1 P_1: loss? 1.620452880859375
iteration 2 P_2: loss? 2.6010475158691406
iteration 2 P_1: loss? 0.3034782409667969
iteration 2 P_0: loss? 0.6556549072265625
iteration 3 P_0: loss? 0.0
iteration 3 P_1: loss? 0.15951919555664062
iteration 3 P_2: loss? 1.1921653747558594
iteration 4 P_1: loss? 0.14936065673828125
iteration 4 P_0: loss? 0.31650543212890625
iteration 4 P_2: loss? 0.5059909820556641
iteration 5 P_0: loss? 0.0
iteration 5 P_2: loss? 0.0
iteration 5 P_1: loss? 0.11322784423828125
iteration 6 P_0: loss? 0.

In [16]:
!cd mode_connectivity && git diff

[1mdiff --git a/lmc.py b/lmc.py[m
[1mindex dd018b1..1c6d8fe 100644[m
[1m--- a/lmc.py[m
[1m+++ b/lmc.py[m
[36m@@ -1,11 +1,11 @@[m
 import torch[m
 import copy[m
 [m
[31m-from utils.data import get_data_loaders[m
[31m-from utils.utils import *[m
[31m-from utils.training_utils import test[m
[31m-from utils.weight_matching import *[m
[31m-from utils.plot import plot_interp_acc[m
[32m+[m[32mfrom .utils.data import get_data_loaders[m
[32m+[m[32mfrom .utils.utils import *[m
[32m+[m[32mfrom .utils.training_utils import test[m
[32m+[m[32mfrom .utils.weight_matching import *[m
[32m+[m[32mfrom .utils.plot import plot_interp_acc[m
 [m
 [m
 def model_interpolation(model_a, model_b, train_loader, test_loader, device, n_points=25):[m
[36m@@ -91,5 +91,3 @@[m [mdef linear_mode_connect([m
 [m
     # interpolate between model_a and permuted model_b[m
     train_acc_perm, test_acc_perm = model_interpolation(model_a, model_b, train_loader, test_loader, de