In [1]:
from torchvision.models import resnet18
from pprint import pprint
import torch.nn as nn

class Conv2dLayerNorm(nn.LayerNorm):        
    def forward(self, x):
        transposed = x.transpose(1,3)
        result = super().forward(transposed)
        
        return result.transpose(3,1)

model = resnet18(num_classes=10, norm_layer=Conv2dLayerNorm) 
pprint({n: p.shape for n, p in model.named_parameters()})

  from .autonotebook import tqdm as notebook_tqdm


{'bn1.bias': torch.Size([64]),
 'bn1.weight': torch.Size([64]),
 'conv1.weight': torch.Size([64, 3, 7, 7]),
 'fc.bias': torch.Size([10]),
 'fc.weight': torch.Size([10, 512]),
 'layer1.0.bn1.bias': torch.Size([64]),
 'layer1.0.bn1.weight': torch.Size([64]),
 'layer1.0.bn2.bias': torch.Size([64]),
 'layer1.0.bn2.weight': torch.Size([64]),
 'layer1.0.conv1.weight': torch.Size([64, 64, 3, 3]),
 'layer1.0.conv2.weight': torch.Size([64, 64, 3, 3]),
 'layer1.1.bn1.bias': torch.Size([64]),
 'layer1.1.bn1.weight': torch.Size([64]),
 'layer1.1.bn2.bias': torch.Size([64]),
 'layer1.1.bn2.weight': torch.Size([64]),
 'layer1.1.conv1.weight': torch.Size([64, 64, 3, 3]),
 'layer1.1.conv2.weight': torch.Size([64, 64, 3, 3]),
 'layer2.0.bn1.bias': torch.Size([128]),
 'layer2.0.bn1.weight': torch.Size([128]),
 'layer2.0.bn2.bias': torch.Size([128]),
 'layer2.0.bn2.weight': torch.Size([128]),
 'layer2.0.conv1.weight': torch.Size([128, 64, 3, 3]),
 'layer2.0.conv2.weight': torch.Size([128, 128, 3, 3]),
 '

In [2]:
from typing import NamedTuple
from collections import defaultdict
from numpy import random

class PermutationSpec(NamedTuple):
    perm_to_axes: dict
    axes_to_perm: dict

rngmix = lambda rng, x: random.default_rng([rng._bit_generator._seed_seq.entropy, hash(x)])

def permutation_spec_from_axes_to_perm(axes_to_perm: dict):
    perm_to_axes = defaultdict(list)
    for wk, axis_perms in axes_to_perm.items():
        for axis, perm in enumerate(axis_perms):
            if perm is not None:
                perm_to_axes[perm].append((wk, axis))
    return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm)

def resnet18_permutation_spec():
    conv = lambda name, p_in, p_out: {f"{name}.weight": (p_out, p_in, None, None)}
    norm = lambda name, p: {f"{name}.weight": (p, ), f"{name}.bias": (p, )}
    dense = lambda name, p_in, p_out: {f"{name}.weight": (p_out, p_in), f"{name}.bias": (p_out, )}

    # This is for easy blocks that use a residual connection, without any change in the number of channels.
    easyblock = lambda name, p: {
        **conv(f"{name}.conv1", p, f"P_{name}_inner"),
        **norm(f"{name}.bn1", f"P_{name}_inner"),
        **conv(f"{name}.conv2", f"P_{name}_inner", p),
        **norm(f"{name}.bn2", p)
    }

    # This is for blocks that use a residual connection, but change the number of channels via a Conv.
    shortcutblock = lambda name, p_in, p_out: {
        **conv(f"{name}.conv1", p_in, f"P_{name}_inner"),
        **norm(f"{name}.bn1", f"P_{name}_inner"),
        **conv(f"{name}.conv2", f"P_{name}_inner", p_out),
        **norm(f"{name}.bn2", p_out),
        **conv(f"{name}.downsample.0", p_in, p_out),
        **norm(f"{name}.downsample.1", p_out),
    }

    ps = permutation_spec_from_axes_to_perm({
        **conv("conv1", None, "P_bg0"),
        **norm("bn1", "P_bg0"),
        #
        **easyblock("layer1.0", "P_bg0"),
        **easyblock("layer1.1", "P_bg0"),
        #
        **shortcutblock("layer2.0", "P_bg0", "P_bg1"),
        **easyblock("layer2.1", "P_bg1"),
        
        **shortcutblock("layer3.0", "P_bg1", "P_bg2"),
        **easyblock("layer3.1", "P_bg2"),
        
        **shortcutblock("layer4.0", "P_bg2", "P_bg3"),
        **easyblock("layer4.1", "P_bg3"),
        #
        **dense("fc", "P_bg3", None),
    })

    return ps

In [3]:
import numpy as np
from scipy.optimize import linear_sum_assignment

def get_permuted_param(ps, perm, k: str, params, except_axis=None):
    """Get parameter `k` from `params`, with the permutations applied."""
    w = params[k]
    for axis, p in enumerate(ps.axes_to_perm[k]):
        # Skip the axis we're trying to permute.
        if axis == except_axis:
            continue

        # None indicates that there is no permutation relevant to that axis.
        if p is not None:
            try:
                w = np.take(w, perm[p], axis=axis)
            except:
                print(k, w.shape, perm[p], axis)
                pprint(perm)
                raise

    return w

def apply_permutation(ps, perm, params):
    """Apply a `perm` to `params`."""
    return {k: get_permuted_param(ps, perm, k, params) for k in params.keys()}

def weight_matching(rng,
                    ps,
                    params_a,
                    params_b,
                    max_iter=100,
                    init_perm=None,
                    silent=False):
    """Find a permutation of `params_b` to make them match `params_a`."""
    perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}

    perm = {p: np.arange(n) for p, n in perm_sizes.items()} if init_perm is None else init_perm
    perm_names = list(perm.keys())

    for iteration in range(max_iter):
        progress = False
        for p_ix in rngmix(rng, iteration).permutation(len(perm_names)):
            p = perm_names[p_ix]
            n = perm_sizes[p]
            A = np.zeros((n, n))
            for wk, axis in ps.perm_to_axes[p]:
                w_a = params_a[wk]
                w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
                try:
                    w_a = np.moveaxis(w_a, axis, 0).reshape((n, -1))
                except:
                    print(wk, w_a.shape)
                    raise
                w_b = np.moveaxis(w_b, axis, 0).reshape((n, -1))
                A += w_a @ w_b.T

            ri, ci = linear_sum_assignment(A, maximize=True)
            assert (ri == np.arange(len(ri))).all()

            oldL = np.vdot(A, np.eye(n)[perm[p]])
            newL = np.vdot(A, np.eye(n)[ci, :])
            if not silent: print(f"{iteration}/{p}: {newL - oldL}")
            progress = progress or newL > oldL + 1e-12

            perm[p] = np.array(ci)

        if not progress:
            break

    return perm

In [4]:
import pickle

def load_model(path):
    with open(path, 'rb') as f:
        return pickle.load(f)
model_a = load_model('model_1.pkl')
model_b = load_model('model_2.pkl')
print(model_b)

{'conv1.weight': array([[[[-0.00483706, -0.0118632 , -0.0071967 , ..., -0.01501707,
          -0.00970301,  0.00161217],
         [-0.03555159,  0.03898596, -0.03391916, ..., -0.01041285,
          -0.01973364, -0.02127761],
         [ 0.01037329, -0.04783124, -0.02337327, ..., -0.00029991,
           0.02189129,  0.02342977],
         ...,
         [ 0.05153418,  0.02454337, -0.0116587 , ..., -0.02387299,
           0.03056212,  0.00909587],
         [ 0.00141551,  0.03295042, -0.03573171, ...,  0.02326334,
          -0.03620902,  0.01090094],
         [ 0.02257729, -0.01892342,  0.01763644, ..., -0.04283727,
          -0.02614045, -0.02200495]],

        [[-0.02103114,  0.00119688, -0.00077752, ..., -0.00334863,
          -0.01624731,  0.0183492 ],
         [ 0.01379603,  0.05084515, -0.00350681, ...,  0.02854057,
          -0.00974901, -0.0043844 ],
         [ 0.03084417, -0.00228273, -0.03781604, ...,  0.01971008,
           0.00480129,  0.02760803],
         ...,
         [-0.0046

In [5]:
rng = np.random.default_rng(123)
permutation_spec = resnet18_permutation_spec()
permutation = weight_matching(rng, permutation_spec, model_a, model_b)
model_b_permuted = apply_permutation(permutation_spec, permutation, model_b)

0/P_layer1.1_inner: 19.430933961426923
0/P_layer3.1_inner: 40.84967173825143
0/P_bg0: 18.235891337124144
0/P_bg3: 120.1953341290664
0/P_bg2: 6.766230011968446
0/P_bg1: 30.258572096626494
0/P_layer4.1_inner: 30.691959497736548
0/P_layer1.0_inner: 9.477370107272705
0/P_layer2.1_inner: 15.11738615579037
0/P_layer3.0_inner: 28.712707135561345
0/P_layer2.0_inner: 16.066960370983026
0/P_layer4.0_inner: 36.35536920213042
1/P_layer4.1_inner: 0.0
1/P_bg1: 0.0
1/P_bg3: 0.9599979232493752
1/P_layer3.1_inner: 8.491260409383642
1/P_layer1.0_inner: 0.0
1/P_bg0: 0.15782936009546233
1/P_layer3.0_inner: 0.0
1/P_layer4.0_inner: 1.988126702681484
1/P_layer1.1_inner: 7.270217111834796
1/P_layer2.1_inner: 0.0
1/P_layer2.0_inner: 0.10694106054754116
1/P_bg2: 0.08204371248336884
2/P_layer1.0_inner: 0.36516844419820416
2/P_bg2: 0.0
2/P_layer4.0_inner: 0.34898747184092827
2/P_layer3.1_inner: 0.6755240731116601
2/P_layer3.0_inner: 0.47260172372273246
2/P_layer4.1_inner: 3.1370289247929577
2/P_bg0: 0.0
2/P_layer

In [6]:
{k: v.shape for k,v in model_b_permuted.items()}

{'conv1.weight': (64, 3, 7, 7),
 'bn1.weight': (64,),
 'bn1.bias': (64,),
 'layer1.0.conv1.weight': (64, 64, 3, 3),
 'layer1.0.bn1.weight': (64,),
 'layer1.0.bn1.bias': (64,),
 'layer1.0.conv2.weight': (64, 64, 3, 3),
 'layer1.0.bn2.weight': (64,),
 'layer1.0.bn2.bias': (64,),
 'layer1.1.conv1.weight': (64, 64, 3, 3),
 'layer1.1.bn1.weight': (64,),
 'layer1.1.bn1.bias': (64,),
 'layer1.1.conv2.weight': (64, 64, 3, 3),
 'layer1.1.bn2.weight': (64,),
 'layer1.1.bn2.bias': (64,),
 'layer2.0.conv1.weight': (128, 64, 3, 3),
 'layer2.0.bn1.weight': (128,),
 'layer2.0.bn1.bias': (128,),
 'layer2.0.conv2.weight': (128, 128, 3, 3),
 'layer2.0.bn2.weight': (128,),
 'layer2.0.bn2.bias': (128,),
 'layer2.0.downsample.0.weight': (128, 64, 1, 1),
 'layer2.0.downsample.1.weight': (128,),
 'layer2.0.downsample.1.bias': (128,),
 'layer2.1.conv1.weight': (128, 128, 3, 3),
 'layer2.1.bn1.weight': (128,),
 'layer2.1.bn1.bias': (128,),
 'layer2.1.conv2.weight': (128, 128, 3, 3),
 'layer2.1.bn2.weight': (12

In [7]:
import torch
model_c = dict()
for n in model_a:
    model_c[n] = model_a[n] / 2 + model_b_permuted[n] / 2

In [8]:
model_c_torch = resnet18(num_classes=10,norm_layer=Conv2dLayerNorm)
model_c_torch.load_state_dict({k: torch.from_numpy(v) for k, v in model_c.items()})

<All keys matched successfully>

In [9]:
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

valid_dataset = CIFAR10('.', train=False, transform=ToTensor())
valid_loader = DataLoader(valid_dataset, batch_size=512)

model_c_torch.eval()
total_samples = 0
val_score = 0
for data, target in valid_loader:
    outputs = model_c_torch(data)
    pred = outputs.argmax(dim=1)
    val_score += pred.eq(target).sum().cpu().numpy()
    total_samples += len(target)
print(val_score / total_samples)

0.0899
