In [68]:
from typing import List, Tuple
import torch
import torch.nn as nn
import numpy as np
from scipy.optimize import linear_sum_assignment

import sys
sys.path.append("./")
from rebasin.weight_matching import PermutationSpec, PartialPermutationSpec


def append_zeros(dim, size, param):
    pad_dims = [(0, 0)] * len(param.shape)
    assert size - param.shape[dim] >= 0
    pad_dims[dim] = (0, size - param.shape[dim])
    return np.pad(param, pad_dims)

def mse_loss(a, b):
    a = a.reshape(a.shape[0], a.shape[1], 1)
    b = b.T.reshape(1, b.shape[1], b.shape[0])
    return np.sum(np.square(a - b), axis=1)

def crop_cost_matrix(cost, n_a, n_b):
    cropped_cost = np.zeros_like(cost)
    cropped_cost[:n_a, :n_b] = cost[:n_a, :n_b]
    return cropped_cost

def partial_alignment(perm_to_axes: List[Tuple[str, int]], params_a, params_b, k, alpha=0, eps=1):
    cost = np.zeros([k, k])
    n_a, n_b = [], []
    for layer_name, dim in perm_to_axes:
        n_a.append(params_a[layer_name].shape[dim])
        n_b.append(params_b[layer_name].shape[dim])
        assert n_a[0] == n_a[-1] and n_b[0] == n_b[-1]
        pad_a = append_zeros(dim, k, params_a[layer_name])
        pad_b = append_zeros(dim, k, params_b[layer_name])
        pad_a = np.moveaxis(pad_a, dim, 0).reshape(k, -1)
        pad_b = np.moveaxis(pad_b, dim, 0).reshape(k, -1)
        cost += mse_loss(pad_a, pad_b)
    n_a = n_a[0]
    n_b = n_b[0]

    cropped_cost = crop_cost_matrix(cost, n_a, n_b)
    cost_ramp = np.linspace(np.min(cost) - eps,
        alpha * np.max(cost) + eps / 2, k - max(n_a, n_b))
    cropped_cost[:n_a, max(n_a, n_b):] = cost_ramp.reshape(1, -1)
    cropped_cost[max(n_a, n_b):, :n_b] = cost_ramp.reshape(-1, 1)

    s_row, s_col = linear_sum_assignment(cropped_cost, maximize=False)
    assert np.all(s_row == np.arange(k))
    loss = np.sum(crop_cost_matrix(cost, n_a, n_b)[s_row, s_col])
    return s_col, loss


outputs = 20
n = 10
a = 5
b = 5
perm = np.random.permutation(n)
perm_matrix = np.zeros([n, n])
perm_matrix[np.arange(n), perm] = 1
overlapping_params = np.random.randn(n, outputs)
error = np.random.randn(n, outputs) * 0.1
A = {
    "layer0": np.concatenate([overlapping_params, np.random.randn(a, outputs)], axis=0),
}
B = {
    "layer0": np.concatenate([perm_matrix @ overlapping_params + error, np.random.randn(b, outputs)], axis=0),
}
perm_to_axes = [("layer0", 0)]
for alpha in np.linspace(0., 0.05, 11):
    for k in reversed(range(max(n + a, n + b), 2*n + a + b + 1)):
        s, cost = partial_alignment(perm_to_axes, A, B, k, alpha=alpha)
        n_aligned = 2*n + a + b - k
        print(alpha, n_aligned, int(cost), np.all(s[perm] == np.arange(n)), s[perm], s)

0.0 0 0 False [26 17 22 19 27 16 15 21 24 29] [15 16 17 29 21 24 27 22 19 26 20 28 23 25 18  0  1  2  3  4  5  6  7  8
  9 10 11 12 13 14]
0.0 1 0 False [20  1 23 24 22 16 15 19 25 28] [15 16  1 28 19 25 22 23 24 20 21 17 18 26 27  0  2  3  4  5  6  7  8  9
 10 11 12 13 14]
0.0 2 0 False [24  1 25 21  4 19 15 27 22 18] [15 19  1 18 27 22  4 25 21 24 26 17 16 23 20  0  2  3  5  6  7  8  9 10
 11 12 13 14]
0.0 3 0 False [19  1 18 24  4 16  6 21 20 25] [ 6 16  1 25 21 20  4 18 24 19 17 22 26 23 15  0  2  3  5  7  8  9 10 11
 12 13 14]
0.0 4 0 False [18  1 22 20  4 16  6 24 17  9] [ 6 16  1  9 24 17  4 22 20 18 23 21 19 15 25  0  2  3  5  7  8 10 11 12
 13 14]
0.0 5 0 False [21  1 20 19  4 16  6 18  8  9] [ 6 16  1  9 18  8  4 20 19 21 24 22 15 17 23  0  2  3  5  7 10 11 12 13
 14]
0.0 6 0 False [ 0  1 22 23  4 16  6 19  8  9] [ 6 16  1  9 19  8  4 22 23  0 21 15 18 20 17  2  3  5  7 10 11 12 13 14]
0.0 7 1 False [ 0  1 17 21  4  5  6 19  8  9] [ 6  5  1  9 19  8  4 17 21  0 15 18 22 20 16

In [None]:
def pad(state_dict, perm_spec: PartialPermutationSpec):
    padded_state_dict = {}
    for perm_name, layers in perm_spec.perm_to_axes.items():
        size = perm_spec.perm_to_target_size[perm_name]
        for layer_name, dim in layers:
            padded_state_dict[layer_name] = append_zeros(dim, size, state_dict[layer_name])
    return padded_state_dict

def alignment_cost(perm_layers, padded_a, padded_b, n, loss_fn=mse_loss):
    cost = np.zeros([n, n])
    for layer_name, dim in perm_layers:
        pad_a = np.moveaxis(padded_a[layer_name], dim, 0).reshape(n, -1)
        pad_b = np.moveaxis(padded_b[layer_name], dim, 0).reshape(n, -1)
        cost += loss_fn(pad_a, pad_b)
    return cost