In [1]:
from pathlib import Path
import torch
import numpy as np
import matplotlib.pyplot as plt

from nnperm.perm import Permutations, PermutationSpec, perm_compose, perm_inverse
from nnperm.utils import get_open_lth_model

model_name_to_dir = {
    "cifar_vgg_16_8": "lottery_3d9c91d3d4133cfcdcb2006da1507cbb",
    "cifar_vgg_16_16": "lottery_2915b34d8b29a209ffee2288466cf5f6",
    "cifar_vgg_16_32": "lottery_c855d7c25ffef997a89799dc08931e82",
    "cifar_vgg_16_128": "lottery_8d561a7b273e4d6b2705ba6d627a69bd",
    "cifar_vgg_16_256": "lottery_a309ac4ab15380928661e70ca8b054a1",
}
# target_sizes = get_model_perm_size("lottery_2915b34d8b29a209ffee2288466cf5f6")
model = get_open_lth_model(Path(f"../refactor-outputs/ckpts/lottery_3d9c91d3d4133cfcdcb2006da1507cbb/dummy_ckpt.pth"), "cpu")
print(list(model.state_dict().keys()))

def get_model_perm_size(model_name):
    model = get_open_lth_model(Path(f"../refactor-outputs/ckpts/{model_name}/dummy_ckpt.pth"), "cpu")
    perm_spec = PermutationSpec.from_sequential_model(model.state_dict())
    return perm_spec.get_sizes(model.state_dict())

def classify_perm_fractions(perms_a, perms_b, align_sizes):
    for k in perms_a.keys():
        max_size = max(len(perms_a[k]), len(perms_b[k]))
        n_align = align_sizes[k]
        perm_a = perms_a[k][:n_align]
        perm_b = perms_b[k][:n_align]
        set_a = np.zeros(max_size, dtype=bool)
        set_a[perm_a] = True
        set_b = np.zeros(max_size, dtype=bool)
        set_b[perm_b] = True
        identical = (perm_a == perm_b)
        repermuted = set_a * set_b
        repermuted_a = np.logical_and(repermuted[perm_a], perm_a != perm_b)
        repermuted_b = np.logical_and(repermuted[perm_b], perm_a != perm_b)
        assert np.all(np.sum(repermuted_a) == np.sum(repermuted_b))
        unaligned = np.logical_xor(set_a, set_b)
        unaligned_a = unaligned[perm_a]
        unaligned_b = unaligned[perm_b]
        assert np.all(np.sum(unaligned_a) == np.sum(unaligned_b))
        assert np.all(identical + repermuted_a + unaligned_a == 1)
        assert np.all(identical + repermuted_b + unaligned_b == 1)
        yield k, n_align, identical, repermuted_a, repermuted_b, unaligned_a, unaligned_b

stats_dir = Path("../refactor-outputs/fix-embed-lottery_c855d7c25ffef997a89799dc08931e82/")
for file in stats_dir.glob("*linear.pt"):
    name = file.stem
    for k, v in model_name_to_dir.items():
        if v in file.stem:
            name = k
            key = v
    direct_perm_stats_dict = torch.load(f"../refactor-outputs/kernel-test/scratch/open_lth_data/{key}_1_2_pretrain_ep160_linear.pt")
    print(name, file.stem)
    sizes = get_model_perm_size(key)
    stats_dict = torch.load(file)
    perm_a = Permutations(stats_dict["perm_a"])
    perm_b = Permutations(stats_dict["perm_b"])
    direct_perm = Permutations(direct_perm_stats_dict['perm'])
    target_sizes = stats_dict["target_sizes"]
    for k, n, identical, r_a, r_b, u_a, u_b in classify_perm_fractions(perm_a, perm_b, sizes):
        print(k, n, np.sum(identical) / n, np.sum(r_a) / n, np.sum(u_a) / n)
    for k, n, identical, r_a, r_b, u_a, u_b in classify_perm_fractions(direct_perm, perm_a.inverse().compose(perm_b), sizes):
        print(k, n, np.sum(identical) / n, np.sum(r_a) / n, np.sum(u_a) / n)


['layers.0.conv.weight', 'layers.0.conv.bias', 'layers.0.bn.layernorm.weight', 'layers.0.bn.layernorm.bias', 'layers.1.conv.weight', 'layers.1.conv.bias', 'layers.1.bn.layernorm.weight', 'layers.1.bn.layernorm.bias', 'layers.3.conv.weight', 'layers.3.conv.bias', 'layers.3.bn.layernorm.weight', 'layers.3.bn.layernorm.bias', 'layers.4.conv.weight', 'layers.4.conv.bias', 'layers.4.bn.layernorm.weight', 'layers.4.bn.layernorm.bias', 'layers.6.conv.weight', 'layers.6.conv.bias', 'layers.6.bn.layernorm.weight', 'layers.6.bn.layernorm.bias', 'layers.7.conv.weight', 'layers.7.conv.bias', 'layers.7.bn.layernorm.weight', 'layers.7.bn.layernorm.bias', 'layers.8.conv.weight', 'layers.8.conv.bias', 'layers.8.bn.layernorm.weight', 'layers.8.bn.layernorm.bias', 'layers.10.conv.weight', 'layers.10.conv.bias', 'layers.10.bn.layernorm.weight', 'layers.10.bn.layernorm.bias', 'layers.11.conv.weight', 'layers.11.conv.bias', 'layers.11.bn.layernorm.weight', 'layers.11.bn.layernorm.bias', 'layers.12.conv.wei