## Load the models and compute the posteriors

In [None]:
import torch
from torch import nn

from permute import permute
from scale import scale
from utils import (
    to_vector,
    to_state_dict,
    col_wise_multiplication,
    row_wise_multiplication,
)

from tqdm.auto import tqdm

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

In [None]:
# Load the models

# you may replace by your models if you trained them
models = torch.load("models.pt")
# models = torch.load("your_models.pt")

model_shape = to_vector(models[0])[1]
model_weights = torch.stack([m["2.weight"].squeeze() for m in models])

In [None]:
# Scale the weights

scaled_chkpts = []
for checkpoint in tqdm(models):
    new_weights = {}
    for key in checkpoint.keys():
        new_weights[key.replace("model.", "")] = checkpoint[key]

    model = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1)).float()
    model.load_state_dict(new_weights)
    scaled_weights = scale(model)

    # If scale is successful
    if scaled_weights is not None:
        scaled_chkpts.append(torch.concat([w.reshape(-1) for w in scaled_weights.values()]))

In [None]:
print(f"Number of scaled checkpoints: {len(scaled_chkpts)} - Failed to scale: {len(models) - len(scaled_chkpts)}")

In [None]:
# Permute the weights

permuted_chkpts = []
for checkpoint in tqdm(scaled_chkpts):
    new_weights = {}
    checkpoint = to_state_dict(checkpoint, model_shape)
    for key in checkpoint.keys():
        new_weights[key.replace("model.", "")] = checkpoint[key]
    model = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1)).float()
    model.load_state_dict(new_weights)
    # Here we choose to start sorting from the last layer to improve visuals
    # permuted_weights = permute(model)[0]
    order = new_weights["2.weight"].sort().indices[0, :]
    permuted_weights = permute(model, custom_pis=[order, None])[0]
    permuted_chkpts.append(torch.concat([w.reshape(-1) for w in permuted_weights.values()]))

In [None]:
# Prepare the Figure

plt.rc("axes", axisbelow=True)
plt.rc("font", **{"family": "serif", "serif": ["Cambria Math"]})
plt.rc("text", usetex=True)
cmap = sns.light_palette("#c96004", as_cmap=True)

In [None]:
f, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=True, figsize=(24, 8))
f.subplots_adjust(hspace=0)
ax1.grid(True, linestyle="--", alpha=0.7)

ax1.set_aspect("equal")
ax2.set_aspect("equal")
ax3.set_aspect("equal")

lin0 = model_weights.numpy()[:, 0]
lin1 = model_weights.numpy()[:, 1]

lin_O_ind = np.intersect1d(np.where(lin0 < 4), np.where(lin0 > -4))
lin_1_ind = np.intersect1d(np.where(lin1 < 4), np.where(lin1 > -4))
inds = np.intersect1d(lin_O_ind, lin_1_ind)

ax1.hexbin(
    x=lin0[inds],
    y=lin1[inds],
    cmap=cmap,
    mincnt=3,
    gridsize=85,
    alpha=1,
    norm=matplotlib.colors.Normalize(vmin=0, vmax=30, clip=True),
    antialiased=True,
)
sns.scatterplot(
    x=lin0[inds][:1000],
    y=lin1[inds][:1000],
    s=2,
    alpha=0.1,
    color="red",
    ax=ax1,
    antialiased=True,
)

ax1.grid(True, linestyle="--", alpha=0.7)
ax1.set_xlim(-4.01, 4.01)
ax1.set_ylim(-4.01, 4.01)
ax1.tick_params(axis="both", labelsize=24)
ax1.set_xticks([-4, -2, 0, 2, 4])
ax1.set_yticks([-4, -2, 0, 2, 4])
ax1.set_title("Original Posterior", fontsize=38, y=-0.12)

lin0 = torch.stack(scaled_chkpts).numpy()[:, -3]
lin1 = torch.stack(scaled_chkpts).numpy()[:, -2]

lin_O_ind = np.intersect1d(np.where(lin0 < 4), np.where(lin0 > -4))
lin_1_ind = np.intersect1d(np.where(lin1 < 4), np.where(lin1 > -4))
inds = np.intersect1d(lin_O_ind, lin_1_ind)

ax2.hexbin(
    x=lin0[inds],
    y=lin1[inds],
    cmap=cmap,
    mincnt=3,
    gridsize=85,
    alpha=1,
    norm=matplotlib.colors.Normalize(vmin=0, vmax=30, clip=True),
    antialiased=True,
)
sns.scatterplot(
    x=lin0[inds][:1000],
    y=lin1[inds][:1000],
    s=2,
    alpha=0.1,
    color="red",
    ax=ax2,
    antialiased=True,
)
ax2.grid(True, linestyle="--", alpha=0.7)

ax2.set_xlim(-4.01, 4.01)
ax2.set_ylim(-4.01, 4.01)
ax2.set_xticks([-4, -2, 0, 2, 4])
ax2.set_yticks([-4, -2, 0, 2, 4])
ax2.tick_params(axis="both", labelsize=24)
ax2.set_title("After Scaling Removal", fontsize=38, y=-0.12)

lin0 = torch.stack(permuted_chkpts).detach().numpy()[:, -3]
lin1 = torch.stack(permuted_chkpts).detach().numpy()[:, -2]

lin_O_ind = np.intersect1d(np.where(lin0 < 4), np.where(lin0 > -4))
lin_1_ind = np.intersect1d(np.where(lin1 < 4), np.where(lin1 > -4))
inds = np.intersect1d(lin_O_ind, lin_1_ind)

ax3.hexbin(
    x=lin0[inds],
    y=lin1[inds],
    cmap=cmap,
    mincnt=3,
    gridsize=85,
    alpha=1,
    norm=matplotlib.colors.Normalize(vmin=0, vmax=30, clip=True),
    antialiased=True,
)
sns.scatterplot(
    x=lin0[inds][:1000],
    y=lin1[inds][:1000],
    s=2,
    alpha=0.1,
    color="red",
    ax=ax3,
    antialiased=True,
)
ax3.plot([-4, 4], [-4, 4], color="#c96004", linestyle="--", alpha=0.5)
ax3.grid(True, linestyle="--", alpha=0.7)
ax3.set_xlim(-4.01, 4.01)
ax3.set_ylim(-4.01, 4.01)
ax3.set_xticks([-4, -2, 0, 2, 4])
ax3.set_yticks([-4, -2, 0, 2, 4])
ax3.tick_params(axis="both", labelsize=24)
ax3.set_title("After Scaling \& Permutation Removal", fontsize=38, y=-0.12)

plt.tight_layout()
sns.despine(f, top=True, right=True, left=True, bottom=True)