# Convert Pytorch to Flax 

In [1]:
import torch
import jax.numpy as jnp
from jax import random
from flax.serialization import from_bytes
import sys
sys.path.append('../../')
from resnet20_jax import BLOCKS_PER_GROUP, ResNet
from flax.traverse_util import flatten_dict
import numpy as np
from utils import flatten_params, unflatten_params
from weight_matching import resnet20_permutation_spec
from resnet20_torch import resnet20
import os
from jax import random as jax_random
import re

import argparse
import pickle
from pathlib import Path

import jax.numpy as jnp
import matplotlib.pyplot as plt
import wandb
from flax.serialization import from_bytes
from jax import random
from tqdm import tqdm
from scipy.optimize import linear_sum_assignment
from resnet20 import BLOCKS_PER_GROUP, ResNet
from utils import (ec2_get_instance_type, flatten_params, lerp, timeblock, unflatten_params)
from weight_matching import (apply_permutation, resnet20_permutation_spec, weight_matching)

import os
import pickle
import flax
from collections import defaultdict
from typing import NamedTuple

import jax.numpy as jnp
from jax import random
import pdb

from utils import rngmix

import clip
import torch
import torch.nn as nn
import numpy as np
from torch.cuda.amp import autocast


  from .autonotebook import tqdm as notebook_tqdm


In [27]:
def find_pairs(str_splits):
    pairs = []
    for i, str_split_i in enumerate(str_splits):
        if '_' not in str_split_i: continue
        split_i = set([int(k) for k in str_split_i.split('_')])
        for str_split_j in str_splits[i+1:]:
            if '_' not in str_split_j: continue
            split_j = set([int(k) for k in str_split_j.split('_')])
            if len(split_i.intersection(split_j)) == 0:
                pairs.append((str_split_i, str_split_j))
    return pairs


def split_str_to_ints(split):
    return [int(i) for i in split.split('_')]


def is_valid_pair(model_dir, pair, model_type):
    paths = os.listdir(os.path.join(model_dir, pair[0]))
    for path in paths:
        if model_type in path:
            return True
    return False


def torch_to_linen(torch_params, get_flax_keys):
    """Convert PyTorch parameters to Linen nested dictionaries"""

    def add_to_params(params_dict, nested_keys, param, is_conv=False):
        if len(nested_keys) == 1:
            key, = nested_keys
            try:
                params_dict[key] = np.transpose(param, (2, 3, 1, 0)) if is_conv else np.transpose(param)
            except:
                pdb.set_trace()
        else:
            assert len(nested_keys) > 1
            first_key = nested_keys[0]
            if first_key not in params_dict:
                params_dict[first_key] = {}
            add_to_params(params_dict[first_key], nested_keys[1:], param, ('conv' in first_key and \
                                                                         nested_keys[-1] != 'bias'))

    flax_params = {'params': {}, 'batch_stats': {}}
    for key, tensor in torch_params.items():
        flax_keys = get_flax_keys(key.split('.'))
        if flax_keys[-1] is not None:
            if flax_keys[-1] in ('mean', 'var'):
                add_to_params(flax_params['batch_stats'], flax_keys, tensor.detach().numpy())
            else:
                add_to_params(flax_params['params'], flax_keys, tensor.detach().numpy())

    return flax_params


def fix_keys(old_key):
    new_key = old_key
    substitutions =[
        ("bn", "norm"),
        ("layer", "blockgroups_"),
        ("running_mean", "mean"),
        ("running_var", "var"),
        ("weight", "kernel")
    ]
    for sub in substitutions:
        new_key = new_key.replace(sub[0], sub[1])
    new_key = re.sub(r"blockgroups_(\d)", lambda x: f"blockgroups_{int(x.group(1))-1}", new_key)
    new_key = re.sub(r"blockgroups_(\d)\.(\d)\.", "blockgroups_\g<1>.blocks_\g<2>.", new_key)
    new_key = re.sub(r"shortcut\.", "shortcut.layers_", new_key)
    new_key = re.sub(r"norm(\d).kernel", "norm\g<1>.scale", new_key)
    return new_key


def expand_dict(parent_dict, key, value):
    # expand dict along periods of the key
    keys = key.split(".")
    if 'shortcut' in keys and 'layers_1' in keys and 'kernel' in keys:
        keys[-1] = 'scale'
    if 'linear' in keys:
        keys[keys.index('linear')] = 'dense'
    curr_dict = parent_dict
    for new_key in keys[:-1]:
        if curr_dict.get(new_key, None) == None:
            curr_dict[new_key] = dict()
        curr_dict = curr_dict[new_key]    
    curr_dict[keys[-1]] = value


def fix_vals(old_key, old_val):
    new_val = old_val.detach().cpu().numpy()
    if "conv" in old_key or 'shortcut.0' in old_key:
        new_val = jnp.transpose(new_val, (2, 3, 1, 0))
    elif 'linear.weight' in old_key:
        new_val = jnp.transpose(new_val, (1, 0))
    return new_val


def convert_torch_sd_to_flax_sd(torch_state_dict):
    torch_state_dict = {fix_keys(k): fix_vals(k, v) for k, v in torch_state_dict.items()}

    torch_batch_stats = {k: v for k, v in torch_state_dict.items() if not not_bn_stat(k) and "num_batches" not in k}
    torch_params = {k: v for k, v in torch_state_dict.items() if not_bn_stat(k)}
    torch_batch_stats_expanded = dict()
    torch_params_expanded = dict()
    for k, v in torch_batch_stats.items():
        expand_dict(torch_batch_stats_expanded, k, v)
    for k, v in torch_params.items():
        expand_dict(torch_params_expanded, k, v)
    flax_dict = {
            "batch_stats": torch_batch_stats_expanded, 
            "params": torch_params_expanded
        }
    return flax_dict

def not_bn_stat(x):
    return "num_batches_tracked" not in x and 'mean' not in x and 'var' not in x

class PermutationSpec(NamedTuple):
    perm_to_axes: dict
    axes_to_perm: dict
    


def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec:
    perm_to_axes = defaultdict(list)
    for wk, axis_perms in axes_to_perm.items():
        for axis, perm in enumerate(axis_perms):
            if perm is not None:
                perm_to_axes[perm].append((wk, axis))
    return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm)

def resnet20_permutation_spec() -> PermutationSpec:
    conv = lambda name, p_in, p_out: {f"{name}/kernel": (None, None, p_in, p_out)}
    norm = lambda name, p: {
        f"{name}/scale": (p, ), 
        f"{name}/bias": (p, ), 
        f"{name}/mean": (p, ), 
        f"{name}/var": (p, )
    }
    dense = lambda name, p_in, p_out: {f"{name}/kernel": (p_in, p_out), f"{name}/bias": (p_out, )}

    # This is for easy blocks that use a residual connection, without any change in the number of channels.
    easyblock = lambda name, p: {
      **conv(f"{name}/conv1", p, f"P_{name}_inner"),
      **norm(f"{name}/norm1", f"P_{name}_inner"),
      **conv(f"{name}/conv2", f"P_{name}_inner", p),
      **norm(f"{name}/norm2", p)
    }

    # This is for blocks that use a residual connection, but change the number of channels via a Conv.
    shortcutblock = lambda name, p_in, p_out: {
      **conv(f"{name}/conv1", p_in, f"P_{name}_inner"),
      **norm(f"{name}/norm1", f"P_{name}_inner"),
      **conv(f"{name}/conv2", f"P_{name}_inner", p_out),
      **norm(f"{name}/norm2", p_out),
      **conv(f"{name}/shortcut/layers_0", p_in, p_out),
      **norm(f"{name}/shortcut/layers_1", p_out),
    }

    return permutation_spec_from_axes_to_perm({
      **conv("conv1", None, "P_bg0"),
      **norm("norm1", "P_bg0"),
      #
      **easyblock("blockgroups_0/blocks_0", "P_bg0"),
      **easyblock("blockgroups_0/blocks_1", "P_bg0"),
      **easyblock("blockgroups_0/blocks_2", "P_bg0"),
      #
      **shortcutblock("blockgroups_1/blocks_0", "P_bg0", "P_bg1"),
      **easyblock("blockgroups_1/blocks_1", "P_bg1"),
      **easyblock("blockgroups_1/blocks_2", "P_bg1"),
      #
      **shortcutblock("blockgroups_2/blocks_0", "P_bg1", "P_bg2"),
      **easyblock("blockgroups_2/blocks_1", "P_bg2"),
      **easyblock("blockgroups_2/blocks_2", "P_bg2"),
      #
      **dense("dense", "P_bg2", None),
    })
    
def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None):
    """Get parameter `k` from `params`, with the permutations applied."""
    w = params[k]
    for axis, p in enumerate(ps.axes_to_perm[k]):
    # Skip the axis we're trying to permute.
        if axis == except_axis:
            continue

        # None indicates that there is no permutation relevant to that axis.
        if p is not None:
            w = jnp.take(w, perm[p], axis=axis)

    return w

def apply_permutation(ps: PermutationSpec, perm, params):
  """Apply a `perm` to `params`."""
  return {k: get_permuted_param(ps, perm, k, params) for k in params.keys()}

def weight_matching(rng,
                    ps: PermutationSpec,
                    params_a,
                    params_b,
                    max_iter=100,
                    init_perm=None,
                    silent=False):
  """Find a permutation of `params_b` to make them match `params_a`."""
  perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}

  perm = {p: jnp.arange(n) for p, n in perm_sizes.items()} if init_perm is None else init_perm
  perm_names = list(perm.keys())
  
  for iteration in tqdm(range(max_iter)):
    progress = False
    for p_ix in random.permutation(rngmix(rng, iteration), len(perm_names)):
      p = perm_names[p_ix]
      n = perm_sizes[p]
      A = jnp.zeros((n, n))
      for wk, axis in ps.perm_to_axes[p]:
        # pdb.set_trace()
        try:
          w_a = params_a[wk]
        except:
          pdb.set_trace()
        w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
        w_a = jnp.moveaxis(w_a, axis, 0).reshape((n, -1))
        w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1))
        A += w_a @ w_b.T

      ri, ci = linear_sum_assignment(A, maximize=True)
      assert (ri == jnp.arange(len(ri))).all()

      oldL = jnp.vdot(A, jnp.eye(n)[perm[p]])
      newL = jnp.vdot(A, jnp.eye(n)[ci, :])
      if not silent: print(f"{iteration}/{p}: {newL - oldL}")
      progress = progress or newL > oldL + 1e-12

      perm[p] = jnp.array(ci)

    if not progress:
      break

  return perm

def load_pickle(path):
    return pickle.load(open(path, 'rb'))


In [28]:
model = ResNet(blocks_per_group=BLOCKS_PER_GROUP["resnet20"],
                   num_classes=512,
                   width_multiplier=4)

key1 , key2 = random.split(random.PRNGKey(0))
model_params = model.init(key2, jnp.zeros((1, 32, 32, 3)))
params = flatten_params(model_params)
spec = resnet20_permutation_spec().axes_to_perm

In [29]:
eval_pair = 4

DEVICE = "cpu"

model_dir = '/srv/share2/gstoica3/checkpoints/cifar50_traincliphead/'
model_name = 'resnet20x4'
pair = [pair for pair in find_pairs(os.listdir(model_dir)) if is_valid_pair(model_dir, pair, model_name)][eval_pair]
model_save_paths = [os.path.join(model_dir, split, f'{model_name}_v0.pth.tar') for split in pair]


model1 = resnet20(w=4).to(DEVICE)
sd = torch.load(model_save_paths[0], map_location=torch.device(DEVICE))
sd = {k: v.cpu() for k, v in sd.items()}
model1.load_state_dict(sd)

model2 = resnet20(w=4).to(DEVICE)
sd = torch.load(model_save_paths[1], map_location=torch.device(DEVICE))
sd = {k: v.cpu() for k, v in sd.items()}
model2.load_state_dict(sd)

model1_state_dict = dict(model1.state_dict())
model2_state_dict = dict(model2.state_dict())

model1_flax_sd = convert_torch_sd_to_flax_sd(model1_state_dict)
model2_flax_sd = convert_torch_sd_to_flax_sd(model2_state_dict)

# Apply Git Rebasin

In [30]:
a_sd = model1_flax_sd # load_pickle(os.path.join(model_save_paths[0]))
b_sd = model2_flax_sd # load_pickle(os.path.join(model_save_paths[1]))

a_sd_union = flatten_params(a_sd['params'])
a_sd_union.update(flatten_params(a_sd['batch_stats']))

b_sd_union = flatten_params(b_sd['params'])
b_sd_union.update(flatten_params(b_sd['batch_stats']))

permutation_spec = resnet20_permutation_spec()

final_permutation = weight_matching(
    random.PRNGKey(0), permutation_spec,
    # flatten_params(model_a), 
    # flatten_params(model_b)
    a_sd_union, 
    b_sd_union
)

model_b_params_clever = unflatten_params(apply_permutation(permutation_spec, final_permutation, flatten_params(b_sd['params'])))
model_b_stats_clever = unflatten_params(apply_permutation(permutation_spec, final_permutation, flatten_params(b_sd['batch_stats'])))
clever_params_sd = lerp(.5, a_sd['params'], model_b_params_clever.unfreeze())
clever_stats_sd = lerp(.5, a_sd['batch_stats'], model_b_stats_clever.unfreeze())
clever_p_sd = {'params': clever_params_sd, 'batch_stats': clever_stats_sd}
# train_ds, test_ds = load_cifar100()

# model = ResNet(blocks_per_group=BLOCKS_PER_GROUP["resnet20"], num_classes=512, width_multiplier=4)
# stuff = make_stuff(model)
# test_loss, test_acc1, test_acc5 = stuff["dataset_loss_and_accuracies"](clever_p_sd, test_ds, 1000)
# print('Acc: {}'.format(test_acc1))
# save_path = os.path.join(model_dir, pair[0], 'flax_cifar50_2_permuted_to_1.pkl')
# print(save_path)
# pickle.dump(clever_p_sd, open(save_path, 'wb'))

  0%|          | 0/100 [00:00<?, ?it/s]

0/P_blockgroups_1/blocks_1_inner: 7.041866779327393
0/P_blockgroups_1/blocks_2_inner: 7.580317974090576
0/P_blockgroups_1/blocks_0_inner: 5.360034465789795
0/P_blockgroups_2/blocks_0_inner: 12.64707088470459
0/P_blockgroups_0/blocks_0_inner: 4.85614013671875
0/P_blockgroups_2/blocks_2_inner: 15.293581008911133
0/P_bg0: 60.49738311767578


  1%|          | 1/100 [00:00<00:36,  2.73it/s]

0/P_bg1: 14.953239440917969
0/P_blockgroups_0/blocks_1_inner: 5.483088493347168
0/P_blockgroups_2/blocks_1_inner: 15.382516860961914
0/P_blockgroups_0/blocks_2_inner: 2.7440671920776367
0/P_bg2: 4.778938293457031
1/P_blockgroups_2/blocks_0_inner: 5.143930435180664
1/P_blockgroups_0/blocks_0_inner: 5.997053623199463
1/P_blockgroups_2/blocks_2_inner: 3.9344635009765625
1/P_bg0: 3.6104812622070312
1/P_bg1: 0.32691192626953125
1/P_blockgroups_1/blocks_0_inner: 5.436430931091309


  2%|▏         | 2/100 [00:00<00:35,  2.73it/s]

1/P_bg2: 0.6397781372070312
1/P_blockgroups_1/blocks_2_inner: 1.9989070892333984
1/P_blockgroups_2/blocks_1_inner: 4.427860260009766
1/P_blockgroups_0/blocks_2_inner: 1.3585796356201172
1/P_blockgroups_0/blocks_1_inner: 1.1153879165649414
1/P_blockgroups_1/blocks_1_inner: 4.134324073791504
2/P_blockgroups_0/blocks_2_inner: 0.0
2/P_blockgroups_1/blocks_1_inner: 0.0
2/P_blockgroups_1/blocks_2_inner: 0.0
2/P_bg1: 0.79278564453125
2/P_blockgroups_2/blocks_1_inner: 0.0
2/P_bg2: 0.244598388671875
2/P_blockgroups_0/blocks_0_inner: 0.08503055572509766
2/P_bg0: 1.1651229858398438
2/P_blockgroups_1/blocks_0_inner: 0.17174053192138672
2/P_blockgroups_2/blocks_2_inner: 1.54412841796875


  3%|▎         | 3/100 [00:01<00:35,  2.76it/s]

2/P_blockgroups_2/blocks_0_inner: 2.198131561279297
2/P_blockgroups_0/blocks_1_inner: 0.08898067474365234
3/P_bg0: 0.0
3/P_blockgroups_2/blocks_0_inner: 0.0
3/P_blockgroups_2/blocks_2_inner: 0.0
3/P_blockgroups_0/blocks_1_inner: 0.0
3/P_blockgroups_1/blocks_2_inner: 0.3057107925415039
3/P_bg1: 0.000457763671875
3/P_blockgroups_0/blocks_0_inner: 0.006317138671875


  4%|▍         | 4/100 [00:01<00:34,  2.78it/s]

3/P_blockgroups_1/blocks_0_inner: 0.0
3/P_blockgroups_1/blocks_1_inner: 0.24361228942871094
3/P_blockgroups_0/blocks_2_inner: 0.1865367889404297
3/P_bg2: 0.08533477783203125
3/P_blockgroups_2/blocks_1_inner: 0.9787864685058594
4/P_blockgroups_0/blocks_1_inner: 0.0
4/P_blockgroups_1/blocks_2_inner: 0.00106048583984375
4/P_bg2: 0.0570526123046875
4/P_bg0: 0.01183319091796875
4/P_blockgroups_1/blocks_0_inner: 0.01424407958984375
4/P_blockgroups_0/blocks_2_inner: 0.0047893524169921875
4/P_blockgroups_2/blocks_2_inner: 0.9189910888671875
4/P_blockgroups_2/blocks_1_inner: 0.5577049255371094
4/P_bg1: 0.04776763916015625


  5%|▌         | 5/100 [00:01<00:34,  2.72it/s]

4/P_blockgroups_1/blocks_1_inner: 0.07779121398925781
4/P_blockgroups_2/blocks_0_inner: 0.4142265319824219
4/P_blockgroups_0/blocks_0_inner: 0.0
5/P_bg0: 0.019683837890625
5/P_blockgroups_2/blocks_1_inner: 0.0
5/P_blockgroups_1/blocks_1_inner: 0.0
5/P_blockgroups_1/blocks_2_inner: 0.1697835922241211
5/P_blockgroups_0/blocks_1_inner: 0.0
5/P_blockgroups_1/blocks_0_inner: 0.04564094543457031
5/P_bg2: 0.0482330322265625
5/P_blockgroups_2/blocks_0_inner: 0.2035999298095703
5/P_blockgroups_0/blocks_2_inner: 0.0018587112426757812
5/P_blockgroups_2/blocks_2_inner: 0.45946502685546875


  6%|▌         | 6/100 [00:02<00:35,  2.66it/s]

5/P_bg1: 0.1015777587890625
5/P_blockgroups_0/blocks_0_inner: 0.0
6/P_blockgroups_2/blocks_1_inner: 0.5788841247558594
6/P_blockgroups_2/blocks_2_inner: 0.0
6/P_blockgroups_0/blocks_0_inner: 0.0
6/P_blockgroups_0/blocks_2_inner: 0.0
6/P_blockgroups_1/blocks_0_inner: 0.012342453002929688
6/P_blockgroups_0/blocks_1_inner: 0.0


  7%|▋         | 7/100 [00:02<00:34,  2.67it/s]

6/P_bg1: 0.00377655029296875
6/P_blockgroups_1/blocks_1_inner: 0.09410667419433594
6/P_blockgroups_1/blocks_2_inner: 0.07547855377197266
6/P_bg0: 0.0
6/P_bg2: 0.018280029296875
6/P_blockgroups_2/blocks_0_inner: 0.15760231018066406
7/P_blockgroups_1/blocks_1_inner: 0.0
7/P_blockgroups_0/blocks_2_inner: 0.0
7/P_bg2: 0.0
7/P_blockgroups_0/blocks_0_inner: 0.0
7/P_blockgroups_1/blocks_2_inner: 0.0
7/P_blockgroups_2/blocks_1_inner: 0.10795974731445312
7/P_blockgroups_1/blocks_0_inner: 0.028123855590820312
7/P_blockgroups_2/blocks_2_inner: 0.24777984619140625


  8%|▊         | 8/100 [00:02<00:34,  2.70it/s]

7/P_blockgroups_2/blocks_0_inner: 0.0
7/P_bg1: 6.866455078125e-05
7/P_blockgroups_0/blocks_1_inner: 0.0
7/P_bg0: 0.0
8/P_blockgroups_2/blocks_1_inner: 0.0
8/P_bg2: 0.063385009765625
8/P_bg0: 0.0
8/P_blockgroups_0/blocks_1_inner: 0.0
8/P_blockgroups_0/blocks_2_inner: 0.0
8/P_blockgroups_2/blocks_0_inner: 0.1425914764404297
8/P_blockgroups_1/blocks_1_inner: 0.0
8/P_blockgroups_1/blocks_0_inner: 0.030544281005859375
8/P_blockgroups_0/blocks_0_inner: 0.0
8/P_blockgroups_1/blocks_2_inner: 0.020730972290039062


  9%|▉         | 9/100 [00:03<00:32,  2.76it/s]

8/P_bg1: 0.0
8/P_blockgroups_2/blocks_2_inner: 0.25704193115234375
9/P_blockgroups_2/blocks_1_inner: 0.215667724609375
9/P_bg0: 0.0124664306640625
9/P_blockgroups_0/blocks_1_inner: 0.0
9/P_blockgroups_1/blocks_1_inner: 0.0
9/P_blockgroups_0/blocks_0_inner: 0.0
9/P_blockgroups_2/blocks_2_inner: 0.0


 10%|█         | 10/100 [00:03<00:32,  2.77it/s]

9/P_bg2: 0.030914306640625
9/P_blockgroups_1/blocks_0_inner: 0.018289566040039062
9/P_blockgroups_0/blocks_2_inner: 0.019852638244628906
9/P_blockgroups_1/blocks_2_inner: 0.0
9/P_bg1: 0.0
9/P_blockgroups_2/blocks_0_inner: 0.11023139953613281
10/P_blockgroups_1/blocks_1_inner: 0.0
10/P_blockgroups_2/blocks_2_inner: 0.2968482971191406
10/P_blockgroups_1/blocks_2_inner: 0.0
10/P_blockgroups_0/blocks_0_inner: 0.0
10/P_blockgroups_2/blocks_0_inner: 0.0
10/P_blockgroups_2/blocks_1_inner: 0.21407699584960938
10/P_bg1: 0.0
10/P_blockgroups_1/blocks_0_inner: 0.0
10/P_blockgroups_0/blocks_2_inner: 0.0


 11%|█         | 11/100 [00:04<00:31,  2.79it/s]

10/P_bg2: 0.0
10/P_bg0: 0.0
10/P_blockgroups_0/blocks_1_inner: 0.0
11/P_blockgroups_2/blocks_1_inner: 0.0
11/P_bg0: 0.0
11/P_bg1: 0.0


 11%|█         | 11/100 [00:04<00:35,  2.52it/s]

11/P_bg2: 0.0
11/P_blockgroups_1/blocks_1_inner: 0.0
11/P_blockgroups_0/blocks_2_inner: 0.0
11/P_blockgroups_2/blocks_2_inner: 0.0
11/P_blockgroups_0/blocks_1_inner: 0.0
11/P_blockgroups_0/blocks_0_inner: 0.0
11/P_blockgroups_1/blocks_2_inner: 0.0
11/P_blockgroups_1/blocks_0_inner: 0.0
11/P_blockgroups_2/blocks_0_inner: 0.0





# Convert Rebasin Model From Flax to Pytorch

In [31]:
# convert back to pytorch from jax.
def collapse_dict(jax_params_dict):
    torch_params_dict = dict()
    
    for old_dict in [jax_params_dict]:
        # print(old_dict)
        recursively_build_dict([], old_dict, torch_params_dict)
    # pprint(torch_params_dict)
    torch_params_dict = {fix_keys(k): v for k, v in torch_params_dict.items()}
    torch_params_dict = {k: fix_vals(k, v) for k, v in torch_params_dict.items()}

    return torch_params_dict
def recursively_build_dict(old_keys, old_dict, new_dict):
    if isinstance(old_dict, flax.core.frozen_dict.FrozenDict):
        for old_key, old_val in old_dict.items():
            recursively_build_dict(old_keys + [old_key], old_val, new_dict)
    else:
        # now we have an array to convert
        new_dict[".".join(old_keys)] = old_dict
def fix_keys(old_key):
    new_key = old_key
    new_key = re.sub(r"norm(\d).scale", "norm\g<1>.kernel", new_key)
    new_key = re.sub(r"shortcut\.layers_", "shortcut.", new_key)
    new_key = re.sub(r"shortcut\.1\.scale", "shortcut.1.weight", new_key)
    new_key = re.sub(r"blockgroups_(\d)\.blocks_(\d)\.", "blockgroups_\g<1>.\g<2>.", new_key)
    new_key = re.sub(r"blockgroups_(\d)", lambda x: f"blockgroups_{int(x.group(1))+1}", new_key)
    substitutions =[("bn", "norm"),("layer", "blockgroups_"),("running_mean", "mean"),("running_var", "var"),("weight", "kernel"),("linear","dense")]
    for sub in substitutions[::-1]:
        new_key = new_key.replace(sub[1], sub[0]) # in reverse order of old fix_keys
    return new_key
def fix_vals(old_key, old_val):
    new_val = old_val
    if "conv" in old_key or 'shortcut.0' in old_key:
        # new_val = jnp.transpose(new_val, (2, 3, 1, 0))
        new_val = jnp.transpose(new_val, (3, 2, 0, 1))
    elif 'linear.weight' in old_key:
        new_val = jnp.transpose(new_val, (1, 0))

    new_val = torch.tensor(np.array(new_val))
    
    return new_val

In [32]:
output_dict = collapse_dict(flax.core.frozen_dict.FrozenDict(clever_p_sd))
print(len(output_dict))

save_dir = '/srv/share2/gstoica3/checkpoints/cifar50_traincliphead/gitrebasins'
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, f'{eval_pair}.pkl')
pickle.dump(output_dict, open(save_path, "wb"))

107


In [33]:
save_path

'/srv/share2/gstoica3/checkpoints/cifar50_traincliphead/gitrebasins/4.pkl'

In [36]:
output_dict['params.conv1.weight'].shape

torch.Size([64, 3, 3, 3])