## Imports

In [None]:
import copy
import logging
from pathlib import Path
from typing import Dict

import hydra
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import seaborn as sns
import torch  # noqa
import wandb
from hydra.utils import instantiate
from matplotlib import tri
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from scipy.stats import qmc
from torch.utils.data import DataLoader
from tqdm import tqdm
from ccmm.matching.utils import perm_indices_to_perm_matrix

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import seed_index_everything
from nn_core.model_logging import NNLogger

from torch.utils.data import DataLoader, Subset, SubsetRandomSampler

import ccmm  # noqa
from ccmm.matching.utils import (
    apply_permutation_to_statedict,
    get_all_symbols_combinations,
    plot_permutation_history_animation,
    restore_original_weights,
)
from ccmm.utils.utils import (
    linear_interpolation,
    load_model_from_info,
    load_permutations,
    map_model_seed_to_symbol,
    save_factored_permutations,
)

from ccmm.utils.utils import vector_to_state_dict
import pytorch_lightning

In [None]:
matplotlib.rcParams["font.family"] = "serif"
sns.set_context("talk")
matplotlib.rcParams["text.usetex"] = True

logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
pylogger = logging.getLogger(__name__)

## Configuration

In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="matching_n_models")

In [None]:
cfg = compose(config_name="matching_n_models", overrides=[])

In [None]:
core_cfg = cfg  # NOQA
cfg = cfg.matching

seed_index_everything(cfg)

## Hyperparameters

In [None]:
num_sampled_points = 100  # 2048
num_test_samples = 500

## Load dataset

In [None]:
transform = instantiate(core_cfg.dataset.test.transform)

train_dataset = instantiate(core_cfg.dataset.train, transform=transform)
test_dataset = instantiate(core_cfg.dataset.test, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers)

test_subset = Subset(test_dataset, list(range(num_test_samples)))

test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)

## Load models

In [None]:
# {a: 1, b: 2, c: 3, ..}
symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}

models: Dict[str, LightningModule] = {
    map_model_seed_to_symbol(seed): load_model_from_info(cfg.model_info_path, seed) for seed in cfg.model_seeds
}

pylogger.info(f"Using model {core_cfg.model.name}")

In [None]:
# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...
symbols = set(symbols_to_seed.keys())
sorted_symbols = sorted(symbols, reverse=False)

# (a, b), (a, c), (b, c), ...
all_combinations = get_all_symbols_combinations(symbols)
# combinations of the form (a, b), (a, c), (b, c), .. and not (b, a), (c, a) etc
canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]

## Matching

In [None]:
pylogger.info(f"Matching the following model pairs: {canonical_combinations}")

### Load permutation specification

In [None]:
permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)
permutation_spec = permutation_spec_builder.create_permutation()

ref_model = list(models.values())[0]
assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())

In [None]:
matcher = instantiate(cfg.matcher, permutation_spec=permutation_spec)
pylogger.info(f"Matcher: {matcher.name}")

In [None]:
permutations, perm_history = matcher(models, symbols=sorted_symbols, combinations=canonical_combinations)

In [None]:
models = {symb: model.to("cpu") for symb, model in models.items()}

### Permute models to universe

In [None]:
models_permuted_to_universe = {symbol: copy.deepcopy(model) for symbol, model in models.items()}

for symbol, model in models_permuted_to_universe.items():
    permuted_params = apply_permutation_to_statedict(permutation_spec, permutations[symbol], model.model.state_dict())
    models_permuted_to_universe[symbol].model.load_state_dict(permuted_params)

In [None]:
for symbol, model in models_permuted_to_universe.items():
    trainer.test(models_permuted_to_universe[symbol], test_loader)
    trainer.test(models[symbol], test_loader)

## Plot stuff

### Sample points in the param space

In [None]:
boundaries = [[-0.5], [0.5]]
lower_bounds = np.array([boundaries[0][0], boundaries[0][0]])
upper_bounds = np.array([boundaries[1][0], boundaries[1][0]])

pylogger.info(f"Lower bounds: {lower_bounds}")
pylogger.info(f"Upper bounds: {upper_bounds}")

In [None]:
random_points_plane = qmc.scale(
    qmc.Sobol(d=2, scramble=True, seed=cfg.seed_index).random(num_sampled_points),
    [-0.5, -0.5],
    [0.5, 0.5],
)

pylogger.info(random_points_plane[:10])

### Flatten models

In [None]:
flat_models = {symbol: torch.nn.utils.parameters_to_vector(model.parameters()) for symbol, model in models.items()}
flat_models_permuted_to_universe = {
    symbol: torch.nn.utils.parameters_to_vector(model.parameters())
    for symbol, model in models_permuted_to_universe.items()
}

In [None]:
for symbol, model in models.items():
    norm = torch.norm(flat_models[symbol])
    pylogger.info(f"Norm of {symbol}: {norm}")
    pylogger.info(f"Norm of {symbol} permuted: {torch.norm(flat_models_permuted_to_universe[symbol])}")

    norm_diff = torch.norm(flat_models[symbol] - flat_models_permuted_to_universe[symbol])
    pylogger.info(f"Norm diff of {symbol}: {norm_diff}")

In [None]:
# matrix of the cosine products
cosine_matrix = np.zeros((len(models), len(models)))

for i, (symbol_i, model_i) in enumerate(models.items()):
    for j, (symbol_j, model_j) in enumerate(models.items()):
        cosine_matrix[i, j] = flat_models_permuted_to_universe[symbol_i].dot(flat_models[symbol_j]) / (
            torch.norm(flat_models_permuted_to_universe[symbol_i]) * torch.norm(flat_models[symbol_j])
        )

In [None]:
cosine_matrix

## Barycentric coordinates

In [None]:
import math


def get_pentagon_vertices(center_x, center_y, radius):
    """
    Get the vertices of a pentagon centered at (center_x, center_y) with the given radius.
    """
    pentagon_vertices = []

    for i in range(5):

        angle_deg = 72 * i  # 72 degrees between each point
        angle_rad = math.radians(angle_deg)  # Convert to radians

        x = radius * math.cos(angle_rad) + center_x
        y = radius * math.sin(angle_rad) + center_y

        pentagon_vertices.append((x, y))

    return np.array(pentagon_vertices)

In [None]:
def represent_barycentric_coordinates(x):
    """
    x: point in the plane (2, )
    """
    origins = get_pentagon_vertices(0.0, 0.0, 0.45)

    # (2, num_models)
    A = origins.transpose(1, 0)

    # (3, num_models)
    A = np.vstack([A, np.ones(5)])

    # (3, )
    x = np.append(x, 1)

    z, residuals, rank, s = np.linalg.lstsq(A, x, rcond=None)

    z -= z.min()
    z /= z.sum()

    A = torch.from_numpy(A)
    z = torch.from_numpy(z)
    x = torch.from_numpy(x)
    # assert torch.allclose(A @ z, x)

    return z.float()

In [None]:
origins = get_pentagon_vertices(0.0, 0.0, 0.45)

# (2, num_models)
A = origins.transpose(1, 0)

# (3, num_models)
A = np.vstack([A, np.ones(5)])

x = origins[0]
x = np.append(x, 1)

z, residuals, rank, s = np.linalg.lstsq(A, x, rcond=None)

z

In [None]:
# origins = get_pentagon_vertices(0.0, 0.0, 0.9)
# A = origins.transpose(1, 0)

# A.shape
# # # plot the origins
# # plt.scatter(A[0], A[1], c='black', s=100)
# # plt.axis('equal')
# # # plot the sphere with radius = 0.9
# # circle = plt.Circle((0, 0), 0.9, color='black', fill=False)

# # plt.gca().add_patch(circle)

### Represent permuted models as barycentric coordinates wrt the n models


In [None]:
def represent_wrt_models(model_to_repr, flat_models):

    # (num_params_per_model, num_models)
    A = torch.stack(list(flat_models.values()), dim=1)

    scaling = 1
    # Augment A with an additional row for the sum-to-one constraint
    ones_row = torch.ones(1, A.shape[1]) * scaling

    # (num_params_per_model + 1, num_models)

    A_augmented = torch.cat([A, ones_row], dim=0)

    # Augment the target model with an additional element for the sum-to-one constraint
    # (num_params_per_model + 1,)
    target_augmented = torch.cat([model_to_repr, torch.tensor([scaling])])

    # Solve the linear system (least squares)
    # want z such that Az = x
    # x is the target model
    barycentric_coords = torch.linalg.lstsq(A_augmented, target_augmented.unsqueeze(1)).solution

    return barycentric_coords.cpu().detach().numpy()

In [None]:
origins = get_pentagon_vertices(0.0, 0.0, 0.45)

model_2D_repr = {symbol: None for symbol in symbols_to_seed.keys()}
universe_model_2D_repr = {symbol: None for symbol in symbols_to_seed.keys()}
for model_num, (symbol, perm_model) in enumerate(flat_models_permuted_to_universe.items()):

    model_baryc_coordinates = represent_wrt_models(perm_model, flat_models)

    model_baryc_coordinates = model_baryc_coordinates / model_baryc_coordinates.sum(axis=0)

    model_2D_repr[symbol] = origins[model_num]
    universe_model_2D_repr[symbol] = (model_baryc_coordinates * origins).sum(axis=0)

In [None]:
# also solve for the barycentric coordinates of the models in the high-dimensional space
# scale the entire equation constraining the coefficients to sum up to 1 by a large scalar

In [None]:
import numpy as np


def evaluate_model_interp_on_point(point, flat_models, model, trainer, test_loader):

    # (num_models, )
    baryc_coords = represent_barycentric_coordinates(point).unsqueeze(1)

    # (num_models, num_params_per_model)
    flat_models = torch.stack(list(flat_models.values()))

    new_flat_params = (flat_models * baryc_coords).sum(dim=0)

    new_params = vector_to_state_dict(new_flat_params, model.model)

    model.model.load_state_dict(new_params)

    eval_results = trainer.test(model, test_loader, verbose=False)

    return eval_results


# random_points_plane = np.concatenate((origins, random_points_plane))

model = copy.deepcopy(models["a"])
eval_results = np.array(
    [
        evaluate_model_interp_on_point(point, flat_models, model, trainer, test_loader)
        for point in tqdm(random_points_plane)
    ]
)

In [None]:
test_losses = np.array([res[0]["loss/test"] for res in eval_results])
test_losses[:100]

## Reference models as basis

In [None]:
# proj = lambda a, b: torch.dot(a, b) / torch.dot(b, b) * b
# norm = lambda a: torch.sqrt(torch.dot(a, a))
# normalize = lambda a: a / norm(a)

In [None]:
# def get_basis_vectors(origin_model, basis_model_1, basis_model_2):
#     basis1 = basis_model_1 - origin_model
#     scale1 = norm(basis1)
#     basis1_normed = normalize(basis1)

#     basis2 = basis_model_2 - origin_model
#     scale2 = norm(basis2)
#     basis2 = basis2 - proj(basis2, basis1_normed)
#     basis2_normed = normalize(basis2)

#     return basis1_normed, basis2_normed, scale1, scale2

# basis_model_1, basis_model_2, scale_1, scale_2 = get_basis_vectors(origin_model=flat_models['a'], basis_model_1=flat_models['b'], basis_model_2=flat_models['c'])

In [None]:
# import numpy as np
# from multiprocessing import Pool

# def evaluate_model_interp_on_point(point, basis_model_1, basis_model_2, origin_model, ref_model, trainer, test_loader, scale_1, scale_2):

#     # (num_models, )
#     new_flat_params = origin_model + (scale_1 * basis_model_1 * point[0] + scale_2 * basis_model_2 * point[1])

#     new_params = vector_to_state_dict(new_flat_params, ref_model.model)

#     ref_model.model.load_state_dict(new_params)

#     eval_results = trainer.test(ref_model, test_loader, verbose=False)

#     return eval_results

# ref_model = copy.deepcopy(models['a'])
# origin_model = flat_models['a']

# eval_results = np.array([evaluate_model_interp_on_point(point, scale_1=scale_1, scale_2=scale_2, basis_model_1=basis_model_1, basis_model_2=basis_model_2, origin_model=origin_model, ref_model=ref_model, trainer=trainer, test_loader=test_loader) for point in tqdm(random_points_plane)])

In [None]:
# from functools import partial

# pool = Pool() #defaults to number of available CPU's


# eval_func = partial(evaluate_model_interp_on_point, basis_model_1=basis_model_1, basis_model_2=basis_model_2, origin_model=origin_model, ref_model=ref_model, trainer=trainer, test_loader=test_loader)

# results = np.zeros(len(random_points_plane))
# for ind, res in enumerate(tqdm(pool.imap(eval_func, iter(random_points_plane)), total=len(random_points_plane))):
#     results[ind] = res

# # eval_results = np.array([evaluate_model_interp_on_point(point, basis_model_1=basis_model_1, basis_model_2=basis_model_2, origin_model=origin_model, ref_model=ref_model, trainer=trainer, test_loader=test_loader) for point in tqdm(random_points_plane)])

In [None]:
# test_losses = np.array([res[0]['loss/test'] for res in eval_results])
# test_losses[:100]

## Represent models 2D

In [None]:
# def represent_wrt_models(model_to_repr, origin_model, basis1, basis2, scale_1, scale_2):

#     x_coord = torch.dot(model_to_repr - origin_model, basis1) / scale_1
#     y_coord = torch.dot(model_to_repr - origin_model, basis2) / scale_2

#     return torch.stack([x_coord, y_coord]).detach().cpu().numpy()

In [None]:
# model_2D_repr = {symbol: None for symbol in symbols_to_seed.keys()}
# universe_model_2D_repr = {symbol: None for symbol in symbols_to_seed.keys()}
# for model_num, (symbol, perm_model) in enumerate(flat_models_permuted_to_universe.items()):

#     model_2D = represent_wrt_models(flat_models[symbol], origin_model=flat_models['a'], basis1= basis_model_1, basis2=basis_model_2, scale_1=scale_1, scale_2=scale_2)
#     model_2D_perm = represent_wrt_models(perm_model, origin_model=flat_models['a'], basis1= basis_model_1, basis2=basis_model_2, scale_1=scale_1, scale_2=scale_2)

#     model_2D_repr[symbol] = model_2D
#     universe_model_2D_repr[symbol] = model_2D_perm

In [None]:
pylogger.info(model_2D_repr)
pylogger.info(universe_model_2D_repr)

## Plot

In [None]:
test_losses

In [None]:
# Create grid values first.
xi = np.linspace(boundaries[0][0], boundaries[1][0])
yi = np.linspace(boundaries[0][0], boundaries[1][0])

# Linearly interpolate the data (x, y) on a grid defined by (xi, yi).
triang = tri.Triangulation(random_points_plane[:, 0], random_points_plane[:, 1])

# We need to cap the maximum loss value so that the contouring is not completely saturated by wildly large losses
interpolator = tri.LinearTriInterpolator(triang, np.clip(test_losses, None, 5))

# interpolator = tri.LinearTriInterpolator(triang, jnp.log(jnp.minimum(1.5, eval_results[:, 0])))
zi = interpolator(*np.meshgrid(xi, yi))

In [None]:
cmap_name = "coolwarm_r"

In [None]:
def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    return colors.LinearSegmentedColormap.from_list(
        "trunc({n},{a:.2f},{b:.2f})".format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)),
    )

In [None]:
plt.figure()
num_levels = 13

plt.contour(xi, yi, zi, levels=num_levels, linewidths=0.25, colors="grey", alpha=0.5)

# cmap = truncate_colormap(plt.get_cmap(cmap_name), 0.0, 1)

plt.contourf(xi, yi, zi, levels=num_levels, cmap=plt.get_cmap(cmap_name), extend="both")
plt.colorbar()

label_bboxes = dict(facecolor="tab:grey", boxstyle="round", edgecolor="none", alpha=0.5)

for symbol, point in model_2D_repr.items():
    plt.scatter(point[0], point[1], marker="x", color="black", zorder=10)
    plt.text(
        point[0] - 0.075,
        point[1] + 0.1,
        r"${\bf \Theta_" + symbol + r"}$",
        color="white",
        fontsize=24,
        bbox=label_bboxes,
        horizontalalignment="right",
        verticalalignment="top",
    )

    universe_point = universe_model_2D_repr[symbol]
    plt.scatter(universe_point[0], universe_point[1], marker="o", color="black", zorder=10)

    connectionstyle = "arc3,rad=-0.3"
    plt.annotate(
        "",
        xy=(point[0], point[1]),
        xytext=(universe_point[0], universe_point[1]),
        arrowprops=dict(
            arrowstyle="-",
            edgecolor="black",
            facecolor="none",
            linewidth=5,
            linestyle=(0, (5, 3)),
            shrinkA=20,
            shrinkB=15,
            connectionstyle=connectionstyle,
        ),
    )

    # Draw arrow head only
    plt.annotate(
        "",
        xy=(point[0], point[1]),
        xytext=(universe_point[0], universe_point[1]),
        arrowprops=dict(
            arrowstyle="<|-",
            edgecolor="none",
            facecolor="black",
            mutation_scale=40,
            linewidth=0,
            shrinkA=12.5,
            shrinkB=15,
            connectionstyle=connectionstyle,
        ),
    )


# mean_ae = flat_models['a'] / 2 + flat_models['e'] / 2
# mean_ae_2D = represent_wrt_models(mean_ae, origin_model=flat_models['a'], basis1= basis_model_1, basis2=basis_model_2, scale=scale)
# plt.scatter(mean_ae_2D[0], mean_ae_2D[1], marker="o", color="yellow", zorder=10)


# mean_ab = flat_models['a'] / 2 + flat_models['c'] / 2
# mean_ab_2D = represent_wrt_models(mean_ab, origin_model=flat_models['a'], basis1= basis_model_1, basis2=basis_model_2, scale=scale)
# plt.scatter(mean_ab_2D[0], mean_ab_2D[1], marker="o", color="yellow", zorder=10)

# mean_bc = flat_models['b'] / 2 + flat_models['c'] / 2
# mean_bc_2D = represent_wrt_models(mean_bc, origin_model=flat_models['a'], basis1= basis_model_1, basis2=basis_model_2, scale=scale)
# plt.scatter(mean_bc_2D[0], mean_bc_2D[1], marker="o", color="yellow", zorder=10)

# mean_abc = flat_models['a'] / 3 + flat_models['b'] / 3 + flat_models['c'] / 3
# for interp_par in interp_params:
#     interp_par_2D = represent_wrt_models(interp_par, origin_model=flat_models['a'], basis1= basis_model_1, basis2=basis_model_2, scale=scale)
#     plt.scatter(interp_par_2D[0], interp_par_2D[1], marker="o", color="yellow", zorder=10)
# mean_abc_2D = represent_wrt_models(mean_abc, origin_model=flat_models['a'], basis1= basis_model_1, basis2=basis_model_2, scale=scale)
# plt.scatter(mean_abc_2D[0], mean_abc_2D[1], marker="o", color="yellow", zorder=10)

box_x = 0.5
box_y = 1.5
title_text = r"$C^2M^2$"

# Draw box only
plt.text(
    box_x,
    box_y,
    title_text,
    color=(0.0, 0.0, 0.0, 0.0),
    fontsize=24,
    horizontalalignment="center",
    verticalalignment="center",
    bbox=dict(boxstyle="round", fc=(1, 1, 1, 1), ec="black", pad=0.4),
)
# Draw text only
plt.text(
    box_x,
    box_y - 0.0115,
    title_text,
    color=(0.0, 0.0, 0.0, 1.0),
    fontsize=24,
    horizontalalignment="center",
    verticalalignment="center",
)


# plt.colorbar()
plt.xlim(-0.4, 1.4)
plt.ylim(-0.45, 1.3)
#   plt.xlim(-0.9, 1.9)
#   plt.ylim(-0.9, 1.9)
# plt.xticks([])
# plt.yticks([])
plt.axis("equal")
# plt.tight_layout()
plt.savefig("resnet_cifar_loss_contour.png", dpi=300)

In [None]:
lambdas = np.linspace(0, 1, 25)

interp_results = []
interp_params = []

model_a_2D = represent_wrt_models(
    flat_models["a"],
    origin_model=flat_models["a"],
    basis1=basis_model_1,
    basis2=basis_model_2,
    scale_1=scale_1,
    scale_2=scale_2,
)
model_e_2D = represent_wrt_models(
    flat_models["e"],
    origin_model=flat_models["a"],
    basis1=basis_model_1,
    basis2=basis_model_2,
    scale_1=scale_1,
    scale_2=scale_2,
)

norms = []
results = {"2D_interp": [], "N_interp": []}
for lamb in lambdas:
    interp_model = flat_models["a"] * lamb + flat_models["e"] * (1 - lamb)
    interp_params.append(interp_model)

    new_params = vector_to_state_dict(interp_model, ref_model.model)
    ref_model.model.load_state_dict(new_params)
    res = trainer.test(ref_model, test_loader, verbose=False)[0]["loss/test"]

    results["2D_interp"].append(res)
    interp_point = model_a_2D * lamb + model_e_2D * (1 - lamb)
    # new_params_reconstructed = origin_model + (scale_1 * basis_model_1 * interp_point[0] + scale_2 * basis_model_2 * interp_point[1])
    new_params_reconstructed = origin_model + (
        scale_1 * basis_model_1 * interp_point[0] + scale_2 * basis_model_2 * interp_point[1]
    )

    ref_model.model.load_state_dict(vector_to_state_dict(new_params_reconstructed, ref_model.model))
    res = trainer.test(ref_model, test_loader, verbose=False)[0]["loss/test"]
    results["N_interp"].append(res)

    norms.append(torch.norm(new_params_reconstructed - interp_model).detach().cpu().numpy())

    # interp_results.append(trainer.test(ref_model, test_loader, verbose=False))

plt.figure()
plt.plot(lambdas, norms, marker="o")

In [None]:
results["2D_interp"] = np.array(results["2D_interp"])
results["N_interp"] = np.array(results["N_interp"])

plt.figure()
plt.plot(lambdas, results["2D_interp"], marker="o")
plt.plot(lambdas, results["N_interp"], marker="x")

In [None]:
(flat_models["c"] / norm(flat_models["c"])) @ (flat_models["b"] / norm(flat_models["b"]))

In [None]:
pylogger.info(model_a_2D)
pylogger.info(model_c_2D)

interps_on_plane = []
for lamb in lambdas:
    interps_on_plane.append(model_a_2D * lamb + model_c_2D * (1 - lamb))

    model_interp = flat_models["a"] * (1 - lamb) + flat_models["c"] * lamb


interps_on_plane

In [None]:
for ind, (par, lambd) in zip(interp_params, lambdas):

    interp_par_2D = represent_wrt_models(
        par, origin_model=flat_models["a"], basis1=basis_model_1, basis2=basis_model_2, scale=scale
    )

In [None]:
basis_model_1 @ basis_model_2

In [None]:
test_losses = np.array([res[0]["loss/test"] for res in interp_results])
test_losses

plt.figure()
plt.plot(lambdas, test_losses)

In [None]:
# for symb, model_repr in model_2D_repr.items():
#     print(f'{symb}: {model_repr}')

for symb, model_repr in universe_model_2D_repr.items():
    print(f"{symb}: {model_repr}")

In [None]:
new_params = vector_to_state_dict(mean_abc, ref_model.model)

ref_model.model.load_state_dict(new_params)

eval_results = trainer.test(ref_model, test_loader, verbose=False)

In [None]:
eval_results[0]["loss/test"]