In [1]:
# notebook with basic loading tests for dataset classes

# test new token dataset class with epoch resolution

In [1]:
# dataset_base

import logging
logging.basicConfig(level=logging.WARNING)

from pathlib import Path
import json

from typing import Union, List, Tuple, Dict, Any

from shrp.datasets.dataset_tokens import DatasetTokens
from shrp.datasets.augmentations import WindowCutter

from shrp.git_re_basin.git_re_basin import (
    PermutationSpec,
    resnet18_permutation_spec,
    weight_matching,
    apply_permutation,
    zoo_cnn_permutation_spec,
)


import logging

import json
import torch

from shrp.datasets.dataset_auxiliaries import tokenize_checkpoint, tokens_to_checkpoint

# import logging
# logging.basicConfig(level=logging.DEBUG)

In [2]:
zoo_path = [
    Path(
        "/netscratch2/kschuerholt/code/versai/model_zoos/zoos/CIFAR10/resnet19/kaiming_uniform/tune_zoo_cifar10_resnet18_kaiming_uniform"
    ).absolute()
]
permutation_spec = resnet18_permutation_spec()
tokensize = 576 / 2

# zoo_path = Path("/netscratch2/dtaskiran/zoos/SVHN/tune_zoo_svhn_uniform/")
# permutation_spec=zoo_cnn_permutation_spec()
# tokensize = 64

epoch_list = [1, 5,]
map_to_canonical = False
# standardize = True
standardize = False
ds_split = [0.7, 0.15, 0.15]
max_samples = 20
weight_threshold = 2500
num_threads = 30
shuffle_path = True
windowsize = 1024
supersample = "auto"
precision = "32"
ignore_bn = True



result_key_list = ["test_acc", "training_iteration", "ggap"]
config_key_list = []
property_keys = {
    "result_keys": result_key_list,
    "config_keys": config_key_list,
}

dataset = DatasetTokens(
        root=zoo_path,
        epoch_lst=epoch_list,
        permutation_spec=permutation_spec,
        map_to_canonical=map_to_canonical,
        standardize=standardize,
        train_val_test="train",  # determines which dataset split to use
        ds_split=ds_split,  #
        max_samples=max_samples,
        weight_threshold=weight_threshold,
        precision=precision,
        filter_function=None,  # gets sample path as argument and returns True if model needs to be filtered out
        property_keys=property_keys,
        num_threads=12,
        shuffle_path=True,
        verbosity=3,
        getitem="tokens+props",
        ignore_bn=ignore_bn,
        tokensize=tokensize,
    )

2023-05-10 21:17:01,904	INFO worker.py:1553 -- Started a local Ray instance.
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:08<00:00,  3.28it/s]
14it [00:00, 70.55it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:48<00:00,  3.44s/it]


In [147]:
import copy
import torch
import ray
if ray.is_initialized():
    ray.shutdown()
from shrp.datasets.progress_bar import ProgressBar

# from shrp.datasets.dataset_tokens import tokenize_checkpoint, tokens_to_checkpoint

def precompute_permutations(ref_checkpoint,permutation_number, perm_spec, tokensize, ignore_bn, num_threads=6):
        logging.info("start precomputing permutations")
        model_curr = ref_checkpoint
        # find permutation of model to itself as reference
        reference_permutation = weight_matching(
            ps=perm_spec, params_a=model_curr, params_b=model_curr
        )

        logging.info("get random permutation dicts")
        # compute random permutations
        permutation_dicts = []
        for ndx in range(permutation_number):
            perm = copy.deepcopy(reference_permutation)
            for key in perm.keys():
                # get permuted indecs for current layer
                perm[key] = torch.randperm(perm[key].shape[0]).float()
            # append to list of permutation dicts
            permutation_dicts.append(perm)


        logging.info("get permutation indices")
        """
        1: create reference tokenized checkpoints with two position indices
        - position of token in the sequence
        - position of values within the token (per token)
        
        2: map those back to checkpoints
        
        3: apply permutations on checkpoints
        
        4: tokenize the permuted checkpoints again
        
        (5: at getitem: apply permutation on tokenized checkpoints)
        
        """
        #1: get reference checkpoint
        ref_checkpoint = copy.deepcopy(model_curr)
        ## tokenize reference
        ref_tok_global, positions = tokenize_checkpoint(
            checkpoint=ref_checkpoint, tokensize = tokensize, return_mask=False, ignore_bn=ignore_bn,
            )

        seqlen, tokensize = ref_tok_global.shape[0],ref_tok_global.shape[1]
        
        # global index on flattened vector 
        ref_tok_global = torch.arange(seqlen*tokensize)
        
        # view in original shape
        ref_tok_global = ref_tok_global.view(seqlen,tokensize)
        print(ref_tok_global.shape)

        # 2: map reference positions 
        ref_checkpoint_global = tokens_to_checkpoint(
            tokens=ref_tok_global, 
            pos=positions, 
            reference_checkpoint=ref_checkpoint, ignore_bn=ignore_bn
        )
                                             
        # 3: apply permutations on checkpoints
        ray.init(num_cpus=num_threads)
        pb = ProgressBar(total=permutation_number)
        pb_actor = pb.actor
        # get permutations
        permutations_global = []
        for perm_dict in permutation_dicts:
            perm_curr_global = compute_single_perm.remote(
                reference_checkpoint=ref_checkpoint_global,
                permutation_dict=perm_dict,
                perm_spec=perm_spec,
                tokensize=tokensize, 
                ignore_bn=ignore_bn,
                pba=pb_actor,
            )
                        
            permutations_global.append(perm_curr_global)

        permutations_global = ray.get(permutations_global)
                
        ray.shutdown()

        # cast to torch.int
        permutations_global = [perm_g.to(torch.int) for perm_g in permutations_global]

        
        return permutations_global, permutation_dicts
    
@ray.remote(num_returns=1)
def compute_single_perm(reference_checkpoint, permutation_dict, perm_spec, tokensize, ignore_bn, pba):
    # copy reference checkpoint
    index_check = copy.deepcopy(reference_checkpoint)
    # apply permutation on checkpoint
    index_check_perm = apply_permutation(
        ps=perm_spec, perm=permutation_dict, params=index_check
    )
    # vectorize
    index_perm, _ = tokenize_checkpoint(
            checkpoint=index_check_perm, tokensize = tokensize, return_mask=False, ignore_bn=ignore_bn,
            )
    # update counter
    pba.update.remote(1)
    # return list
    return index_perm


In [148]:
def tokens_to_checkpoint(tokens, pos, reference_checkpoint, ignore_bn=False):
    """
    casts sequence of tokens back to checkpoint
    Args:
        tokens: sequence of tokens
        pos: sequence of positions
        reference_checkpoint: reference checkpoint to be used for shape information
        ignore_bn: bool wether to ignore batchnorm layers
    Returns
        checkpoint: checkpoint with weights and biases
    """
    # make copy to prevent memory management issues
    checkpoint = copy.deepcopy(reference_checkpoint)
    # use only weights and biases
    idx = 0
    for key in checkpoint.keys():
        if "weight" in key:
            # get correct slice of modules out of vec sequence
            if ignore_bn and ("bn" in key or "downsample.1" in key):
                continue

            # get modules shape
            mod_shape = checkpoint[key].shape

            # get slice for current layer
            idx_channel = torch.where(pos[:, 1] == idx)[0]
            w_t = torch.index_select(input=tokens, index=idx_channel, dim=0)

            # infer length of content
            contentlength = int(torch.prod(torch.tensor(mod_shape)) / mod_shape[0])

            # update weights
            try:
                # recast to output channels match, padding at the end
#                 w_t = w_t.view(mod_shape[0],-1)
                checkpoint[key] = w_t.view(mod_shape[0],-1)[:, :contentlength].view(mod_shape)
            except Exception as e:
                print(f'error matching layer {idx} {key}')
                print(e)

            # check for bias
            try:
                if key.replace("weight", "bias") in checkpoint:
                    checkpoint[key.replace("weight", "bias")] = w_t.view(mod_shape[0],-1)[:, contentlength]
            except Exception as e:
                print(f'bias error matching layer {idx} {key}')
                print(e)

            # update counter
            idx += 1

    return checkpoint


In [149]:
def permute_model_vector(wdx,idx_start,window,perm):

    # get slice window on token sequence
    idx_end = idx_start + window
    tokensize = perm.shape[1]
    
    # slice permutation in tokenized shape
    perm = perm[idx_start:idx_end]
    
    # flatten perms
    perm = perm.view(-1)
    
    # slice permutation out of flattened weight tokens
    wdx = torch.index_select(input=wdx.view(-1), index=perm, dim=0)

    # reshape weights
    wdx = wdx.view(window,tokensize)
    
    # return tokens
    return wdx

# Check equivalence

In [150]:
import ray
if ray.is_initialized():
    ray.shutdown()
# permutation_spec = resnet18_permutation_spec(batchnorm=True)
perms, permdicts = precompute_permutations(
    ref_checkpoint=dataset.reference_checkpoint,
    permutation_number=3, 
    tokensize=tokensize,
    ignore_bn=False,
    perm_spec=permutation_spec)

torch.Size([43924, 288])


2023-05-10 15:54:08,930	INFO worker.py:1553 -- Started a local Ray instance.


In [151]:
# check that permuted checkpoint == token_to_checkpoint(perm(tokenize(check)))

check = copy.deepcopy(dataset.reference_checkpoint)

check_p = apply_permutation(
        ps=permutation_spec, perm=permdicts[0], params=copy.deepcopy(check)
    )

In [152]:
# add noise to checkpoint
check_n = copy.deepcopy(check)
for key in check_n.keys():
    print(key)
    if "weight" in key:
        #### get weights ####
        if ignore_bn and ("bn" in key or "downsample.1" in key):
            continue
    check_n[key] = torch.randn(check_n[key].shape)

conv1.weight
bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked
layer1.0.conv1.weight
layer1.0.bn1.weight
layer1.0.bn1.bias
layer1.0.bn1.running_mean
layer1.0.bn1.running_var
layer1.0.bn1.num_batches_tracked
layer1.0.conv2.weight
layer1.0.bn2.weight
layer1.0.bn2.bias
layer1.0.bn2.running_mean
layer1.0.bn2.running_var
layer1.0.bn2.num_batches_tracked
layer1.1.conv1.weight
layer1.1.bn1.weight
layer1.1.bn1.bias
layer1.1.bn1.running_mean
layer1.1.bn1.running_var
layer1.1.bn1.num_batches_tracked
layer1.1.conv2.weight
layer1.1.bn2.weight
layer1.1.bn2.bias
layer1.1.bn2.running_mean
layer1.1.bn2.running_var
layer1.1.bn2.num_batches_tracked
layer2.0.conv1.weight
layer2.0.bn1.weight
layer2.0.bn1.bias
layer2.0.bn1.running_mean
layer2.0.bn1.running_var
layer2.0.bn1.num_batches_tracked
layer2.0.conv2.weight
layer2.0.bn2.weight
layer2.0.bn2.bias
layer2.0.bn2.running_mean
layer2.0.bn2.running_var
layer2.0.bn2.num_batches_tracked
layer2.0.downsample.0.weight
layer2.0.downsamp

In [153]:
wdx, pos = tokenize_checkpoint(
            checkpoint=copy.deepcopy(check), tokensize = tokensize, return_mask=False, ignore_bn=False,
            )
perm1 = perms[0]
print(wdx.shape)
print(perm1[0].shape)
print(perm1[1].shape)

torch.Size([43924, 288])
torch.Size([288])
torch.Size([288])


In [156]:
wdx2 = permute_model_vector(wdx=wdx,idx_start=0,window=wdx.shape[0],perm=perm1)
# wdx2 = torch.stack(wdx2, dim=0)
print(wdx2.shape)

torch.Size([43924, 288])


In [157]:
check2 = tokens_to_checkpoint(tokens=wdx2, pos=pos, reference_checkpoint=check_n, ignore_bn=False)

In [158]:
check2.keys() == check.keys()

True

In [159]:
torch.allclose(wdx,wdx2)

False

In [162]:
allgood=True
for key in check2.keys():
    if ignore_bn and ("bn" in key or "downsample.1" in key):
        continue

    if not torch.allclose(check2[key],check_p[key]):
        print(f'missmatch at {key}')
#         print(f'orig: \t{check[key][:2,]}')
#         print(f'noise: \t{check_n[key][:10,]}')
#         print(f'check: \t{check_p[key][:2]}')
#         print(f'token: \t{check2[key][:2]}')
        allgood=False
if allgood:
    print('all keys match')

all keys match


In [164]:
torch.randperm(n=10, dtype=torch.int32, device='cpu')[:3]

tensor([8, 3, 0], dtype=torch.int32)

In [None]:
# check augmentation

In [3]:
from shrp.datasets.augmentations import PermutationAugmentation

permaug = PermutationAugmentation(
    ref_checkpoint=dataset.reference_checkpoint,
    permutation_number=100,
    perm_spec=permutation_spec,
    tokensize=tokensize,
    ignore_bn=True,
    windowsize=128,
    permutations_per_sample = 1
    )

torch.Size([39124, 288])


2023-05-10 21:18:39,704	INFO worker.py:1553 -- Started a local Ray instance.


In [4]:
dataset.transforms = permaug

In [6]:
ddx, mask, pos, props = dataset.__getitem__(0)
print(ddx.shape)
print(mask.shape)
print(pos.shape)
print(props)

torch.Size([1, 128, 288])
torch.Size([128, 288])
torch.Size([128, 3])
tensor([ 0.4363,  1.0000, -0.1091])


In [7]:
dataset.transforms.perms_per_sample = 3

In [8]:
ddx, mask, pos, props = dataset.__getitem__(0)
print(ddx.shape)
print(mask.shape)
print(pos.shape)
print(props)

torch.Size([3, 128, 288])
torch.Size([128, 288])
torch.Size([128, 3])
tensor([ 0.4363,  1.0000, -0.1091])


In [13]:
%%timeit
idx = torch.randint(len(dataset),size=(1,))

4.9 µs ± 30.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [16]:
%%timeit
idx = torch.randint(len(dataset),size=(1,))
ddx, mask, pos, props = dataset.__getitem__(idx)

40.6 ms ± 6.9 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
# build multi-view augmentation to split 

In [19]:
perm_ids = torch.randperm(
            n=ddx.shape[0], dtype=torch.int8, device=ddx.device
        )[:1]
perm_ids[0]

tensor(2, dtype=torch.int8)

In [1]:
from shrp.datasets.augmentations import PermutationSelector


In [2]:
import torch

wdx = torch.randn(128,64)
mdx = torch.randn(128,64)
pdx = torch.randn(128,3)
props = torch.randn(1,4)

print(wdx.shape)
ps = PermutationSelector(mode="canonical")
wdx, mdx, pdx, props = ps(wdx, mdx, pdx, props)
print(wdx.shape)

torch.Size([128, 64])
torch.Size([64])


In [3]:
wdx = torch.randn(128,64)
mdx = torch.randn(128,64)
pdx = torch.randn(128,3)
props = torch.randn(1,4)

print(wdx.shape)
ps = PermutationSelector(mode="identity")
wdx, mdx, pdx, props = ps(wdx, mdx, pdx, props)
print(wdx.shape)

torch.Size([128, 64])
torch.Size([128, 64])


In [6]:
batch = (wdx, mdx, pdx, props)
wdx, mdx, pdx, props = ps2(batch)

In [7]:
wdx

tensor([-0.0775, -0.4073, -0.2606,  0.1260,  0.1148, -0.8796, -0.5024,  0.6747,
        -0.0297, -1.8589,  1.6551,  0.1600, -2.0293,  0.3282, -0.5994, -0.6216,
         1.4658, -0.9540, -0.6741, -1.2107,  0.2330, -0.3016,  0.0740,  1.1781,
         0.3280, -1.1359,  0.8650, -0.1146,  0.8067,  0.0961,  0.2878, -0.9106,
        -0.2778, -1.3343, -0.0446,  0.4382,  0.3909, -0.3165, -0.8315, -0.7822,
         2.0207,  1.7968,  0.7648,  0.3996, -0.8529, -0.1022,  1.2889, -0.4179,
         0.6985, -0.4624,  1.0634, -0.4380, -0.1787, -0.5775, -1.0185, -1.3324,
         0.7661, -1.6492, -1.7587, -0.4814, -1.7970, -1.2670, -1.0457, -0.6354])

In [6]:
12976128
import math
math.log2(12976128)

23.62935662007961

In [7]:
2**24

16777216

In [8]:
4*1 << 21

8388608

In [9]:
math.log2(8388608)

23.0

In [None]:
12976128