In [None]:
import copy
import logging
from pathlib import Path
from typing import Dict
import math
import itertools
from ccmm.utils.utils import l2_norm_models
import hydra
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import seaborn as sns
import torch  # noqa
import wandb
from hydra.utils import instantiate
from matplotlib import tri
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from scipy.stats import qmc
from torch.utils.data import DataLoader
from tqdm import tqdm
from ccmm.matching.utils import perm_indices_to_perm_matrix
from ccmm.utils.utils import normalize_unit_norm, project_onto
from functools import partial

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import seed_index_everything
from nn_core.model_logging import NNLogger
from ccmm.utils.utils import fuse_batch_norm_into_conv
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler

import ccmm  # noqa
from ccmm.matching.utils import (
    apply_permutation_to_statedict,
    get_all_symbols_combinations,
    plot_permutation_history_animation,
    restore_original_weights,
)
from ccmm.utils.utils import (
    linear_interpolate,
    load_model_from_info,
    map_model_seed_to_symbol,
    save_factored_permutations,
)

from ccmm.matching.utils import load_permutations

from ccmm.utils.utils import vector_to_state_dict, get_interpolated_loss_acc_curves
import pytorch_lightning

In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../../conf"), job_name="matching_n_models")

In [None]:
logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
pylogger = logging.getLogger(__name__)

In [None]:
cfg = compose(config_name="matching_n_models", overrides=[])

In [None]:
core_cfg = cfg  # NOQA
cfg = cfg.matching

seed_index_everything(cfg)

## Hyperparameters

In [None]:
num_test_samples = 5000
num_train_samples = 5000

## Load dataset

In [None]:
transform = instantiate(core_cfg.dataset.test.transform)

train_dataset = instantiate(core_cfg.dataset.train, transform=transform)
test_dataset = instantiate(core_cfg.dataset.test, transform=transform)

train_subset = Subset(train_dataset, list(range(num_train_samples)))
train_loader = DataLoader(train_subset, batch_size=5000, num_workers=cfg.num_workers)

test_subset = Subset(test_dataset, list(range(num_test_samples)))

test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)

In [None]:
from ccmm.utils.utils import load_model_from_artifact

run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type="matching")

# {a: 1, b: 2, c: 3, ..}
symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}

artifact_path = (
    lambda seed: f"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.model.model_identifier}_{seed}:v0"
)

# {a: model_a, b: model_b, c: model_c, ..}
models: Dict[str, LightningModule] = {
    map_model_seed_to_symbol(seed): load_model_from_artifact(run, artifact_path(seed)) for seed in cfg.model_seeds
}

num_models = len(models)

pylogger.info(f"Using {num_models} models with architecture {core_cfg.model.model_identifier}")

In [None]:
# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...
symbols = set(symbols_to_seed.keys())
sorted_symbols = sorted(symbols, reverse=False)

# (a, b), (a, c), (b, c), ...
all_combinations = get_all_symbols_combinations(symbols)
# combinations of the form (a, b), (a, c), (b, c), .. and not (b, a), (c, a) etc
canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]

In [None]:
pylogger.info(f"Matching the following model pairs: {canonical_combinations}")

In [None]:
permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)
permutation_spec = permutation_spec_builder.create_permutation()

ref_model = list(models.values())[0]
assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())

## Git Rebasin 

In [None]:
matcher = instantiate(cfg.matcher, permutation_spec=permutation_spec)
pylogger.info(f"Matcher: {matcher.name}")

In [None]:
from ccmm.matching.utils import get_inverse_permutations


universe_model = models["c"]
map_to_universe = {symb: None for symb in symbols}
map_from_universe = {symb: None for symb in symbols}

for symb, model in models.items():
    permutations, perm_history = matcher(fixed=universe_model.model, permutee=model.model)
    map_to_universe[symb] = permutations
    map_from_universe[symb] = get_inverse_permutations(permutations)

In [None]:
# A -> B -> C

perm = partial(apply_permutation_to_statedict, permutation_spec)

A_to_univ = perm(map_to_universe["a"], models["a"].model.state_dict())
A_to_B = perm(map_from_universe["b"], A_to_univ)

A_B_C = perm(map_to_universe["b"], A_to_B)
A_B_C = perm(map_from_universe["c"], A_B_C)

ref_model = copy.deepcopy(models["a"])
ref_model.model.load_state_dict(A_B_C)

In [None]:
trainer.test(ref_model, test_loader)

In [None]:
# from ccmm.matching.utils import unfactor_permutations

# models_permuted_pairwise = {
#     symbol: {other_symb: None for other_symb in set(symbols).difference(symbol)} for symbol in symbols
# }
# pairwise_permutations = unfactor_permutations(permutations)

# for fixed, permutee in all_combinations:
#     ref_model = copy.deepcopy(models["a"])

#     permuted_params = apply_permutation_to_statedict(
#         permutation_spec, pairwise_permutations[fixed][permutee], models[permutee].model.state_dict()
#     )

#     ref_model.model.load_state_dict(permuted_params)
#     models_permuted_pairwise[fixed][permutee] = ref_model

## Synchronized

In [None]:
from ccmm.matching.matcher import FrankWolfeSynchronizedMatcher

fw_matcher = FrankWolfeSynchronizedMatcher(
    name="frank_wolfe", permutation_spec=permutation_spec, initialization_method="identity", max_iter=200
)

In [None]:
permutations, perm_history = fw_matcher(models, sorted_symbols, canonical_combinations)

In [None]:
from ccmm.matching.utils import unfactor_permutations

models_permuted_pairwise = {
    symbol: {other_symb: None for other_symb in set(symbols).difference(symbol)} for symbol in symbols
}
pairwise_permutations = unfactor_permutations(permutations)

for fixed, permutee in all_combinations:
    ref_model = copy.deepcopy(models["c"])

    permuted_params = apply_permutation_to_statedict(
        permutation_spec, pairwise_permutations[fixed][permutee], models[permutee].model.state_dict()
    )

    ref_model.model.load_state_dict(permuted_params)
    models_permuted_pairwise[fixed][permutee] = ref_model

In [None]:
A_to_C = models_permuted_pairwise["c"]["a"]

In [None]:
trainer.test(A_to_C, test_loader)