## Imports

In [1]:
import copy
import logging
from pathlib import Path
from typing import Dict
import math

import hydra
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import seaborn as sns
import torch  # noqa
import wandb
from hydra.utils import instantiate
from matplotlib import tri
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from scipy.stats import qmc
from torch.utils.data import DataLoader
from tqdm import tqdm
from ccmm.matching.utils import perm_indices_to_perm_matrix
from ccmm.utils.utils import normalize_unit_norm, project_onto

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import seed_index_everything
from nn_core.model_logging import NNLogger
from ccmm.utils.utils import fuse_batch_norm_into_conv
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler

import ccmm  # noqa
from ccmm.matching.utils import (
    apply_permutation_to_statedict,
    get_all_symbols_combinations,
    plot_permutation_history_animation,
    restore_original_weights,
)
from ccmm.utils.utils import (
    linear_interpolate_state_dicts,
    load_model_from_info,
    map_model_seed_to_symbol,
    save_factored_permutations,
)

from ccmm.matching.utils import load_permutations

from ccmm.utils.utils import vector_to_state_dict
import pytorch_lightning

  from .autonotebook import tqdm as notebook_tqdm
  warn(


In [2]:
matplotlib.rcParams["font.family"] = "serif"
sns.set_context("talk")
matplotlib.rcParams["text.usetex"] = True
cmap_name = "coolwarm_r"

logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
pylogger = logging.getLogger(__name__)

## Configuration

In [3]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="matching_n_models")

hydra.initialize()

In [4]:
cfg = compose(config_name="matching_n_models", overrides=[])

In [5]:
core_cfg = cfg  # NOQA
cfg = cfg.matching

seed_index_everything(cfg)

1608637542

## Hyperparameters

In [6]:
num_test_samples = 5000
num_train_samples = 5000

## Load dataset

In [7]:
transform = instantiate(core_cfg.dataset.test.transform)

train_dataset = instantiate(core_cfg.dataset.train, transform=transform)
test_dataset = instantiate(core_cfg.dataset.test, transform=transform)

train_subset = Subset(train_dataset, list(range(num_train_samples)))
train_loader = DataLoader(train_subset, batch_size=5000, num_workers=cfg.num_workers)

test_subset = Subset(test_dataset, list(range(num_test_samples)))

test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)

  rank_zero_deprecation(


## Load models

In [9]:
from ccmm.utils.utils import load_model_from_artifact

run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type="matching")

# {a: 1, b: 2, c: 3, ..}
symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}

artifact_path = (
    lambda seed: f"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.model.model_identifier}_{seed}:v0"
)

# {a: model_a, b: model_b, c: model_c, ..}
models: Dict[str, LightningModule] = {
    map_model_seed_to_symbol(seed): load_model_from_artifact(run, artifact_path(seed)) for seed in cfg.model_seeds
}

num_models = len(models)

pylogger.info(f"Using {num_models} models with architecture {core_cfg.model.model_identifier}")

[34m[1mwandb[0m: Currently logged in as: [33mdbaieri[0m ([33mgladia[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact VGG16_1:v0, 54.17MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4
[34m[1mwandb[0m: Downloading large artifact VGG16_2:v0, 54.17MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4
[34m[1mwandb[0m: Downloading large artifact VGG16_3:v0, 54.17MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4
[34m[1mwandb[0m: Downloading large artifact VGG16_4:v0, 54.17MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4
[34m[1mwandb[0m: Downloading large artifact VGG16_5:v0, 54.17MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4


In [10]:
# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...
symbols = set(symbols_to_seed.keys())
sorted_symbols = sorted(symbols, reverse=False)

# (a, b), (a, c), (b, c), ...
all_combinations = get_all_symbols_combinations(symbols)
# combinations of the form (a, b), (a, c), (b, c), .. and not (b, a), (c, a) etc
canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]

## Matching

In [11]:
pylogger.info(f"Matching the following model pairs: {canonical_combinations}")

### Load permutation specification

In [12]:
permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)
permutation_spec = permutation_spec_builder.create_permutation()

ref_model = list(models.values())[0]
assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())

In [13]:
matcher = instantiate(cfg.matcher, permutation_spec=permutation_spec)
pylogger.info(f"Matcher: {matcher.name}")

In [14]:
permutations, perm_history = matcher(models, symbols=sorted_symbols, combinations=canonical_combinations)

Weight matching: 100%|██████████| 200/200 [11:38<00:00,  3.49s/it]


In [15]:
models = {symb: model.to("cpu") for symb, model in models.items()}

### Permute models to universe

In [16]:
from ccmm.matching.utils import perm_matrix_to_perm_indices

models_permuted_to_universe = {symbol: copy.deepcopy(model) for symbol, model in models.items()}

for symbol, model in models_permuted_to_universe.items():
    perms_to_universe = {}

    for perm_name, perm in permutations[symbol].items():
        perm = perm_indices_to_perm_matrix(perm)
        perm_to_universe = perm.T
        perm_to_universe = perm_matrix_to_perm_indices(perm_to_universe)
        perms_to_universe[perm_name] = perm_to_universe

    permuted_params = apply_permutation_to_statedict(permutation_spec, perms_to_universe, model.model.state_dict())
    models_permuted_to_universe[symbol].model.load_state_dict(permuted_params)

### Permute models pairwise

In [17]:
from ccmm.matching.utils import unfactor_permutations

models_permuted_pairwise = {
    symbol: {other_symb: None for other_symb in set(symbols).difference(symbol)} for symbol in symbols
}
pairwise_permutations = unfactor_permutations(permutations)

for fixed, permutee in all_combinations:
    ref_model = copy.deepcopy(models["a"])

    permuted_params = apply_permutation_to_statedict(
        permutation_spec, pairwise_permutations[fixed][permutee], models[permutee].model.state_dict()
    )

    ref_model.model.load_state_dict(permuted_params)
    models_permuted_pairwise[fixed][permutee] = ref_model

  permutations[symbol][perm_name] = torch.tensor(perm)


### Check performance of models before and after permutation

In [18]:
for symbol, model in models_permuted_to_universe.items():
    trainer.test(models_permuted_to_universe[symbol], test_loader)
    trainer.test(models[symbol], test_loader)

## Analyze models as vectors

### Flatten models

In [19]:
other_symbs = {symbol: set(symbols).difference(symbol) for symbol in symbols}
print(other_symbs)

{'a': {'c', 'b', 'e', 'd'}, 'c': {'a', 'b', 'e', 'd'}, 'e': {'a', 'c', 'b', 'd'}, 'd': {'a', 'c', 'b', 'e'}, 'b': {'a', 'c', 'e', 'd'}}


In [20]:
flat_models = {symbol: torch.nn.utils.parameters_to_vector(model.parameters()) for symbol, model in models.items()}
flat_models_permuted_to_universe = {
    symbol: torch.nn.utils.parameters_to_vector(model.parameters())
    for symbol, model in models_permuted_to_universe.items()
}

flat_models_permuted_pairwise = {
    symbol: {
        other_symb: torch.nn.utils.parameters_to_vector(models_permuted_pairwise[symbol][other_symb].parameters())
        for other_symb in other_symbs[symbol]
    }
    for symbol in symbols
}

In [21]:
flat_models_permuted_pairwise["e"].keys()

dict_keys(['a', 'c', 'b', 'd'])

In [22]:
for symbol in symbols:
    flat_models_permuted_pairwise[symbol][symbol] = flat_models[symbol]
    models_permuted_pairwise[symbol][symbol] = models[symbol]

## Interpolation curves

In [23]:
def linear_interpolation(model_a, model_b, lamb):
    return (1 - lamb) * model_a + lamb * model_b


def get_interp_curve(lambdas, model_a, model_b, ref_model):

    interp_losses = []
    interp_accs = []

    for lamb in lambdas:
        interp_params = linear_interpolation(model_a=model_a, model_b=model_b, lamb=lamb)

        interp_params = vector_to_state_dict(interp_params, ref_model.model)

        ref_model.model.load_state_dict(interp_params)
        results = trainer.test(ref_model, test_loader, verbose=False)

        interp_losses.append(results[0][f"loss/test"])
        interp_accs.append(results[0][f"acc/test"])

    return interp_losses, interp_accs

In [24]:
lambdas = np.linspace(0, 1, 3)
loss_curve_a_univ_c_univ, acc_curve_a_univ_c_univ = get_interp_curve(
    lambdas=lambdas,
    model_a=flat_models_permuted_to_universe["a"],
    model_b=flat_models_permuted_to_universe["c"],
    ref_model=ref_model,
)
loss_curve_a_univ_d_univ, acc_curve_a_univ_d_univ = get_interp_curve(
    lambdas=lambdas,
    model_a=flat_models_permuted_to_universe["a"],
    model_b=flat_models_permuted_to_universe["d"],
    ref_model=ref_model,
)
loss_curve_a_perm_c, acc_curve_a_perm_c = get_interp_curve(
    lambdas=lambdas, model_a=flat_models["a"], model_b=flat_models_permuted_pairwise["a"]["c"], ref_model=ref_model
)
loss_curve_a_b, acc_curve_a_b = get_interp_curve(
    lambdas=lambdas, model_a=flat_models["a"], model_b=flat_models["c"], ref_model=ref_model
)

In [25]:
# plt.plot(lambdas, loss_curve_a_b, label=r"$A, B$")
# plt.plot(lambdas, loss_curve_a_univ_c_univ, label=r"$P_{A}^\top (A), P_{C}^\top (C)$")
# plt.plot(lambdas, loss_curve_a_univ_d_univ, label=r"$P_{A}^\top (A), P_{D}^\top (D)$")
# plt.plot(lambdas, loss_curve_a_perm_c, label=r"$A, P_{A} P_{C}^\top (C)$")
# plt.legend()

In [26]:
# plt.plot(lambdas, acc_curve_a_b, label=r"$A, B$")
# plt.plot(lambdas, acc_curve_a_univ_c_univ, label=r"$P_{A}^\top (A), P_{C}^\top (C)$")
# plt.plot(lambdas, acc_curve_a_univ_d_univ, label=r"$P_{A}^\top (A), P_{D}^\top (D)$")
# plt.plot(lambdas, acc_curve_a_perm_c, label=r"$A, P_{A} P_{C}^\top (C)$")
# plt.legend()

In [27]:
'''
from ccmm.utils.utils import average_models


for ref_symbol in symbols:

    ref_model = copy.deepcopy(models[ref_symbol])

    all_models_permuted_to_ref = {symb: models_permuted_pairwise[ref_symbol][symb] for symb in symbols}

    model_params = {symbol: model.state_dict() for symbol, model in all_models_permuted_to_ref.items()}

    mean_params = average_models(model_params)

    ref_model.load_state_dict(mean_params)

    results = trainer.test(ref_model, test_loader, verbose=True)
'''

'\nfrom ccmm.utils.utils import average_models\n\n\nfor ref_symbol in symbols:\n\n    ref_model = copy.deepcopy(models[ref_symbol])\n\n    all_models_permuted_to_ref = {symb: models_permuted_pairwise[ref_symbol][symb] for symb in symbols}\n\n    model_params = {symbol: model.state_dict() for symbol, model in all_models_permuted_to_ref.items()}\n\n    mean_params = average_models(model_params)\n\n    ref_model.load_state_dict(mean_params)\n\n    results = trainer.test(ref_model, test_loader, verbose=True)\n'

In [28]:
%%capture

from ccmm.matching.weight_matching import PermutationSpec, weight_matching
from ccmm.matching.utils import get_inverse_permutations


pairwise_perms_gitrebasin = {
 symb: {other_symb: None for other_symb in set(symbols).difference(symb)} for symb in symbols   
}

for fixed, permutee in canonical_combinations:
    permutation = weight_matching(
        permutation_spec,
        fixed=models[fixed].model.state_dict(),
        permutee=models[permutee].model.state_dict(),
    )
    pairwise_perms_gitrebasin[fixed][permutee] = permutation
    pairwise_perms_gitrebasin[permutee][fixed] = get_inverse_permutations(permutation)
    

In [30]:
def cyclic_permute(pairwise_perms, symbols, models):
    ordered_symbs = sorted(list(symbols))
    model_current = models[ordered_symbs[0]].model.state_dict()
    for i, symb in enumerate(ordered_symbs[1:] + [ordered_symbs[0]]):
        # print("next: {} -- prev: {}".format(symb, ordered_symbs[i]))
        permutation = pairwise_perms[symb][ordered_symbs[i]]
        model_current = apply_permutation_to_statedict(
            permutation_spec, permutation, model_current
        )
    return model_current

In [31]:

def evaluate_interpolated_model(lambd, model_a, model_b, ref_model):

    interp_params = linear_interpolate_state_dicts(
        t1=model_a.model.state_dict(), t2=model_b.model.state_dict(), lam=lambd
    )

    ref_model.model.load_state_dict(interp_params)

    test_results = trainer.test(ref_model, test_loader, verbose=False)[0]

    return test_results

def get_interp_curve(lambdas, model_a, model_b, ref_model):

    interp_losses = []
    interp_accs = []

    for lamb in lambdas:
        interp_results = evaluate_interpolated_model(model_a=model_a, model_b=model_b, lambd=lamb, ref_model=ref_model)

        interp_losses.append(interp_results[f"loss/test"])
        interp_accs.append(interp_results[f"acc/test"])

    return interp_losses, interp_accs

def l2_norm_models(state_dict1, state_dict2):
    """Calculate the L2 norm of the difference between two state dictionaries."""
    diff_squared_sum = sum(torch.sum((state_dict1[key] - state_dict2[key]) ** 2) for key in state_dict1)
    return torch.sqrt(diff_squared_sum)


In [32]:
import itertools

n = 4
steps = 15
cycles = list(itertools.combinations(list(symbols)[:-1], n-1))

output = {}

for c in cycles:
    print(f"Cycle: {c}")
    key = ''.join(c)
    model_c = cyclic_permute(pairwise_perms_gitrebasin, list(c), models)
    ordered_cycle = sorted(list(c))

    initial_model = models[ordered_cycle[0]]
    permuted_model = copy.deepcopy(initial_model)
    permuted_model.model.load_state_dict(model_c)

    print(f"Model distance: {l2_norm_models(model_c, initial_model.model.state_dict())}")

    lambdas = torch.linspace(0, 1, steps, device='cpu')
    losses, accs = get_interp_curve(lambdas, initial_model, permuted_model, ref_model)

    output[key] = {'x': lambdas.numpy(), 'loss': np.array(losses), 'acc': np.array(accs)}

np.savez('../results/git_rebasin_subcycles.npz', **output)
    

Cycle: ('a', 'c', 'e')
Model distance: 41.07636642456055


Cycle: ('a', 'c', 'd')
Model distance: 41.22671127319336


Cycle: ('a', 'e', 'd')
Model distance: 41.2719612121582


Cycle: ('c', 'e', 'd')
Model distance: 40.73612976074219


FileNotFoundError: [Errno 2] No such file or directory: './results/git_rebasin_subcycles.npz'

In [35]:
output = {}

for c in cycles:
    print(f"Cycle: {c}")
    key = ''.join(c)
    model_c = cyclic_permute(pairwise_permutations, list(c), models)
    ordered_cycle = sorted(list(c))

    initial_model = models[ordered_cycle[0]]
    permuted_model = copy.deepcopy(initial_model)
    permuted_model.model.load_state_dict(model_c)

    print(f"Model distance: {l2_norm_models(model_c, initial_model.model.state_dict())}")

    lambdas = torch.linspace(0, 1, steps, device='cpu')
    losses, accs = get_interp_curve(lambdas, initial_model, permuted_model, ref_model)

    output[key] = {'x': lambdas.numpy(), 'loss': np.array(losses), 'acc': np.array(accs)}

np.savez('../results/frank_wolfe_subcycles.npz', **output)

Cycle: ('a', 'c', 'e')
Model distance: 0.0


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


TypeError: 'NoneType' object is not subscriptable

In [76]:
'''
ordered_symbs = sorted(list(symbols))
model_current = models[ordered_symbs[0]].model.state_dict()

for i, symb in enumerate(ordered_symbs[1:] + [ordered_symbs[0]]):
    print("next: {} -- prev: {}".format(symb, ordered_symbs[i]))
    permutation = pairwise_perms_gitrebasin[symb][ordered_symbs[i]]
    model_current = apply_permutation_to_statedict(
        permutation_spec, permutation, model_current
    )
'''
model_current = cyclic_permute(pairwise_perms_gitrebasin, symbols, models)

next: b -- prev: a
next: c -- prev: b
next: d -- prev: c
next: e -- prev: d
next: a -- prev: e


In [77]:
permuted_model = copy.deepcopy(models[ordered_symbs[0]])
permuted_model.model.load_state_dict(model_current)
initial_model = models[ordered_symbs[0]]

print(f"Model distance: {l2_norm_models(permuted_model.model.state_dict(), initial_model.model.state_dict())}")

lambdas = torch.linspace(0, 1, 5, device='cpu')
losses, accs = get_interp_curve(lambdas, initial_model, permuted_model, ref_model)

print(lambdas)
print(losses)
print(accs)

Model distance: 41.915000915527344


tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
[0.46679627895355225, 0.4383860230445862, 2.2594735622406006, 0.45566055178642273, 0.46679627895355225]
[0.9043999910354614, 0.8659999966621399, 0.1273999959230423, 0.8679999709129333, 0.9043999910354614]


In [63]:
errs=[]
# errs2=[]
for v in pairwise_permutations['a']['b'].keys():
    P1=perm_indices_to_perm_matrix(pairwise_permutations['a']['b'][v])
    P2=perm_indices_to_perm_matrix(pairwise_permutations['b']['c'][v])
    P3=perm_indices_to_perm_matrix(pairwise_permutations['c']['d'][v])
    P4=perm_indices_to_perm_matrix(pairwise_permutations['d']['e'][v])
    P5=perm_indices_to_perm_matrix(pairwise_permutations['e']['a'][v])

    # P5=perm_indices_to_perm_matrix(pairwise_permutations['c']['a'][v])
    # P6=perm_indices_to_perm_matrix(pairwise_permutations['c']['b'][v])
    #P7=perm_indices_to_perm_matrix(pairwise_permutations['a']['c'][v])
    errs.append(((P1@P2@P3@P4@P5)-torch.eye(P4.shape[0])).abs().sum())
    # errs2.append(((P5@P1)-P6).abs().sum())

print(errs) # , print(errs2)

[tensor(0.), tensor(0.), tensor(0.), tensor(0.), tensor(0.), tensor(0.), tensor(0.), tensor(0.), tensor(0.), tensor(0.), tensor(0.), tensor(0.), tensor(0.), tensor(0.), tensor(0.)]
