## Imports

In [None]:
import copy
import logging
from pathlib import Path
from typing import Dict
import math

import hydra
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import seaborn as sns
import torch  # noqa
import wandb
from hydra.utils import instantiate
from matplotlib import tri
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from scipy.stats import qmc
from torch.utils.data import DataLoader
from tqdm import tqdm
from ccmm.matching.utils import perm_indices_to_perm_matrix
from ccmm.utils.utils import normalize_unit_norm, project_onto

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import seed_index_everything
from nn_core.model_logging import NNLogger
from ccmm.utils.utils import fuse_batch_norm_into_conv
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler

import ccmm  # noqa
from ccmm.matching.utils import (
    apply_permutation_to_statedict,
    get_all_symbols_combinations,
    plot_permutation_history_animation,
    restore_original_weights,
)
from ccmm.utils.utils import (
    linear_interpolate_state_dicts,
    load_model_from_info,
    map_model_seed_to_symbol,
    save_factored_permutations,
)

from ccmm.matching.utils import load_permutations

from ccmm.utils.utils import vector_to_state_dict
import pytorch_lightning

In [None]:
matplotlib.rcParams["font.family"] = "serif"
sns.set_context("talk")
matplotlib.rcParams["text.usetex"] = True
cmap_name = "coolwarm_r"

logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
pylogger = logging.getLogger(__name__)

## Configuration

In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="matching_n_models")

In [None]:
cfg = compose(config_name="matching_n_tasks", overrides=[])

In [None]:
core_cfg = cfg  # NOQA
cfg = cfg.matching

seed_index_everything(cfg)

## Hyperparameters

In [None]:
num_test_samples = 5000
num_train_samples = 5000

## Load dataset

In [None]:
import pytorch_lightning as pl

datamodule: pl.LightningDataModule = hydra.utils.instantiate(core_cfg.nn.data, _recursive_=False)

test_dataloaders = []
train_dataloaders = []

for task_ind in range(datamodule.num_tasks + 1):
    datamodule.task_ind = task_ind
    datamodule.transform_func = hydra.utils.instantiate(core_cfg.dataset.transform_func, _recursive_=True)
    datamodule.setup()
    test_dataloaders.append(datamodule.test_dataloader()[0])
    train_dataloaders.append(datamodule.train_dataloader())

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)

## Load models

In [None]:
from ccmm.utils.utils import load_model_from_artifact

run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type="matching")

# {a: 1, b: 2, c: 3, ..}
symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}

artifact_path = (
    lambda task: f"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.model.model_identifier}_T{task}_{cfg.seed_index}:v0"
)

# {a: model_a, b: model_b, c: model_c, ..}
models: Dict[str, LightningModule] = {
    map_model_seed_to_symbol(task): load_model_from_artifact(run, artifact_path(task))
    for task in range(cfg.num_tasks + 1)
}

num_models = len(models)

pylogger.info(f"Using {num_models} models with architecture {core_cfg.model.model_identifier}")

In [None]:
# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...
symbols = set(symbols_to_seed.keys())
symbols = symbols.difference({"o"})  # "o" is the model trained over all the dataset

sorted_symbols = sorted(symbols, reverse=False)

# (a, b), (a, c), (b, c), ...
all_combinations = get_all_symbols_combinations(symbols)
# combinations of the form (a, b), (a, c), (b, c), .. and not (b, a), (c, a) etc
canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]

## Matching

In [None]:
pylogger.info(f"Matching the following model pairs: {canonical_combinations}")

### Load permutation specification

In [None]:
permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)
permutation_spec = permutation_spec_builder.create_permutation()

ref_model = list(models.values())[0]
assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())

In [None]:
matcher = instantiate(cfg.matcher, permutation_spec=permutation_spec)
pylogger.info(f"Matcher: {matcher.name}")

In [None]:
permutations, perm_history = matcher(models, symbols=sorted_symbols, combinations=canonical_combinations)

In [None]:
models = {symb: model.to("cpu") for symb, model in models.items()}

### Permute models to universe

In [None]:
from ccmm.matching.utils import perm_matrix_to_perm_indices

models_permuted_to_universe = {symbol: copy.deepcopy(models[symbol]) for symbol in symbols}

for symbol, model in models_permuted_to_universe.items():
    perms_to_universe = {}

    for perm_name, perm in permutations[symbol].items():
        perm = perm_indices_to_perm_matrix(perm)
        perm_to_universe = perm.T
        perm_to_universe = perm_matrix_to_perm_indices(perm_to_universe)
        perms_to_universe[perm_name] = perm_to_universe

    permuted_params = apply_permutation_to_statedict(permutation_spec, perms_to_universe, model.model.state_dict())
    models_permuted_to_universe[symbol].model.load_state_dict(permuted_params)

### Permute models pairwise

In [None]:
from ccmm.matching.utils import unfactor_permutations

models_permuted_pairwise = {
    symbol: {other_symb: None for other_symb in set(symbols).difference(symbol)} for symbol in symbols
}
pairwise_permutations = unfactor_permutations(permutations)

for fixed, permutee in all_combinations:
    ref_model = copy.deepcopy(models["a"])

    permuted_params = apply_permutation_to_statedict(
        permutation_spec, pairwise_permutations[fixed][permutee], models[permutee].model.state_dict()
    )
    ref_model.model.load_state_dict(permuted_params)
    models_permuted_pairwise[fixed][permutee] = ref_model

### Check performance of models before and after permutation

In [None]:
for loader in test_dataloaders:
    for symbol, model in models_permuted_to_universe.items():
        trainer.test(models_permuted_to_universe[symbol], loader)
        trainer.test(models[symbol], loader)

## Analyze models as vectors

### Flatten models

In [None]:
flat_models = {symbol: torch.nn.utils.parameters_to_vector(model.parameters()) for symbol, model in models.items()}
flat_models_permuted_to_universe = {
    symbol: torch.nn.utils.parameters_to_vector(model.parameters())
    for symbol, model in models_permuted_to_universe.items()
}

flat_models_permuted_pairwise = {
    symbol: {
        other_symb: torch.nn.utils.parameters_to_vector(model.parameters()) for other_symb, model in models.items()
    }
    for symbol, models in models_permuted_pairwise.items()
}

## Interpolation curves

In [None]:
def linear_interpolation(model_a, model_b, lamb):
    return (1 - lamb) * model_a + lamb * model_b


def get_interp_curve(lambdas, model_a, model_b, ref_model, test_loader):

    interp_losses = []
    interp_accs = []

    for lamb in lambdas:
        interp_params = linear_interpolation(model_a=model_a, model_b=model_b, lamb=lamb)

        interp_params = vector_to_state_dict(interp_params, ref_model.model)

        ref_model.model.load_state_dict(interp_params)
        results = trainer.test(ref_model, test_loader, verbose=False)

        interp_losses.append(results[0][f"loss/test"])
        interp_accs.append(results[0][f"acc/test"])

    return interp_losses, interp_accs

## Average in the universe

In [None]:
merged_model = copy.deepcopy(models["a"])

mean_model_params = torch.stack([model for model in flat_models_permuted_to_universe.values()]).mean(dim=0)

merged_model.load_state_dict(vector_to_state_dict(mean_model_params, merged_model))

In [None]:
merged_model.eval()

for loader in test_dataloaders:
    trainer.test(merged_model, loader)

## Repair

In [None]:
import torch.nn as nn


class ResetConv(nn.Module):
    def __init__(self, conv):
        super().__init__()
        self.out_channels = conv.out_channels
        self.conv = conv
        self.bn = nn.BatchNorm2d(self.out_channels)
        self.rescale = False

    def set_stats(self, goal_mean, goal_var, eps=1e-5):
        self.bn.bias.data = goal_mean
        goal_std = (goal_var + eps).sqrt()
        self.bn.weight.data = goal_std

    def forward(self, x):
        x = self.conv(x)
        if self.rescale:
            x = self.bn(x)
        else:
            self.bn(x)
        return x

In [None]:
def reset_bn_stats(model, epochs, loader):
    """
    Reset batchnorm stats. We use the train loader with data augmentation as this gives better results.
    """
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) == nn.BatchNorm2d:
            m.momentum = None  # use simple average
            m.reset_running_stats()

    # run a single train epoch with augmentations to recalc stats
    model.train()
    for _ in range(epochs):
        with torch.no_grad():
            for batch in loader:
                x, y = batch["x"], batch["y"]
                output = model(x.cuda())

In [None]:
def replace_conv_layers(module):
    for name, child in module.named_children():
        if isinstance(child, nn.Conv2d):
            setattr(module, name, ResetConv(child))
        else:
            replace_conv_layers(child)


def make_tracked_net(model):
    """
    Wraps each convolutional layer in a ResetConv module.
    """
    tracked_model = copy.deepcopy(model)
    replace_conv_layers(tracked_model)
    return tracked_model.eval()

In [None]:
def fuse_batch_norm_into_conv_recursive(module):

    for name, child in module.named_children():
        if isinstance(child, ResetConv):
            conv = fuse_batch_norm_into_conv(child.conv, child.bn)

            setattr(module, name, conv)
        else:
            fuse_batch_norm_into_conv_recursive(child)


def fuse_tracked_net(tracked_model):

    fuse_batch_norm_into_conv_recursive(tracked_model.model)

    return tracked_model

In [None]:
def repair_model_generalized(model_to_repair, endpoint_models):
    """
    Set the goal mean/std in added bns of interpolated network, and turn batch renormalization on
    """

    for m_interp, *endpoint_modules in zip(model_to_repair.modules(), *[model.modules() for model in endpoint_models]):

        if not isinstance(m_interp, ResetConv):
            continue

        mu_endpoints = torch.stack([m.bn.running_mean for m in endpoint_modules])

        goal_mean = mu_endpoints.mean(dim=0)

        var_endpoints = torch.stack([m.bn.running_var for m in endpoint_modules])

        goal_var = var_endpoints.mean(dim=0)

        # set these in the interpolated bn controller
        m_interp.set_stats(goal_mean, goal_var)

        # turn rescaling on
        m_interp.rescale = True

In [None]:
merged_model_wrapped = make_tracked_net(merged_model).cuda()

In [None]:
wrapped_models = [make_tracked_net(models_permuted_to_universe[symbol]).cuda() for symbol in symbols]

for model_ind, model in enumerate(wrapped_models):
    reset_bn_stats(model.cuda(), epochs=1, loader=train_dataloaders[model_ind + 1])

repair_model_generalized(merged_model_wrapped, wrapped_models)

In [None]:
reset_bn_stats(merged_model_wrapped.cuda(), epochs=1, loader=train_dataloaders[0])
merged_model_wrapped.eval()

for loader in test_dataloaders:
    trainer.test(merged_model_wrapped, loader)