In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import re

import torch
from torchsummary import summary

from pyvrl.models.backbones.r3d import R2Plus1D as CTPR2Plus1D
from torchvision.models.video.resnet import r2plus1d_18 as TorchR2Plus1D

### Load model

In [3]:
bb = TorchR2Plus1D()

In [4]:
bsd = bb.state_dict()

### Load ckpt

In [5]:
ckpt_path = "/scratch-shared/fmthoker/pbagad/CtP-ssl/output/ctp/r2plus1d_18_kinetics/pretraining_snellius/epoch_6.pth"

In [6]:
ckpt = torch.load(ckpt_path, map_location="cpu")

In [7]:
csd = ckpt["state_dict"]

### Map ckpt

In [8]:
csd = {k.replace("backbone.", ""):v for k,v in csd.items()}

In [19]:
mapping = {
    "stem.conv_s": "stem.0",
    "stem.bn_s": "stem.1",
    "stem.conv_t": "stem.3",
    "stem.bn_t": "stem.4",
    "layer\\d{1}.\\d{1}.conv1.conv_s": "layer\\d{1}.\\d{1}.conv1.0.0",
    "layer\\d{1}.\\d{1}.conv1.bn_s": "layer\\d{1}.\\d{1}.conv1.0.1",
    "layer\\d{1}.\\d{1}.conv1.relu_s": "layer\\d{1}.\\d{1}.conv1.0.2",
    "layer\\d{1}.\\d{1}.conv1.conv_t": "layer\\d{1}.\\d{1}.conv1.0.3",
    "layer\\d{1}.\\d{1}.bn1": "layer\\d{1}.\\d{1}.conv1.1",
    "layer\\d{1}.\\d{1}.conv2.conv_s": "layer\\d{1}.\\d{1}.conv2.0.0",
    "layer\\d{1}.\\d{1}.conv2.bn_s": "layer\\d{1}.\\d{1}.conv2.0.1",
    "layer\\d{1}.\\d{1}.conv2.relu_s": "layer\\d{1}.\\d{1}.conv2.0.2",
    "layer\\d{1}.\\d{1}.conv2.conv_t": "layer\\d{1}.\\d{1}.conv2.0.3",
    "layer\\d{1}.\\d{1}.bn2": "layer\\d{1}.\\d{1}.conv2.1",
    "layer\\d{1}.\\d{1}.downsample": "layer\\d{1}.\\d{1}.downsample.0",
    "layer\\d{1}.\\d{1}.downsample_bn": "layer\\d{1}.\\d{1}.downsample.1",
    "layer\\d{1}.\\d{1}.downsample.conv": "layer\\d{1}.\\d{1}.downsample.0",
    "layer\\d{1}.\\d{1}.downsample_bn": "layer\\d{1}.\\d{1}.downsample.1",
}

In [20]:
csd_keys_to_bsd_keys = dict()

for k in csd.keys():
    
    for x in mapping:
        pattern = re.compile(x)
        if pattern.match(k):
            if x.startswith("layer"):
                ori = ".".join((x.split(".")[2:]))
                new = ".".join((mapping[x].split(".")[2:]))
                replaced = k.replace(ori, new)
            else:
                ori = x
                new = mapping[x]
                replaced = k.replace(ori, new)
                
            disp = "\t\t".join([k, ori, new, replaced])
            print(disp)
            csd_keys_to_bsd_keys[k] = replaced


stem.conv_s.weight		stem.conv_s		stem.0		stem.0.weight
stem.bn_s.weight		stem.bn_s		stem.1		stem.1.weight
stem.bn_s.bias		stem.bn_s		stem.1		stem.1.bias
stem.bn_s.running_mean		stem.bn_s		stem.1		stem.1.running_mean
stem.bn_s.running_var		stem.bn_s		stem.1		stem.1.running_var
stem.bn_s.num_batches_tracked		stem.bn_s		stem.1		stem.1.num_batches_tracked
stem.conv_t.weight		stem.conv_t		stem.3		stem.3.weight
stem.bn_t.weight		stem.bn_t		stem.4		stem.4.weight
stem.bn_t.bias		stem.bn_t		stem.4		stem.4.bias
stem.bn_t.running_mean		stem.bn_t		stem.4		stem.4.running_mean
stem.bn_t.running_var		stem.bn_t		stem.4		stem.4.running_var
stem.bn_t.num_batches_tracked		stem.bn_t		stem.4		stem.4.num_batches_tracked
layer1.0.conv1.conv_s.weight		conv1.conv_s		conv1.0.0		layer1.0.conv1.0.0.weight
layer1.0.conv1.bn_s.weight		conv1.bn_s		conv1.0.1		layer1.0.conv1.0.1.weight
layer1.0.conv1.bn_s.bias		conv1.bn_s		conv1.0.1		layer1.0.conv1.0.1.bias
layer1.0.conv1.bn_s.running_mean		conv1.bn_s		conv1.0.1		laye

In [21]:
csd_keys_to_bsd_keys

{'stem.conv_s.weight': 'stem.0.weight',
 'stem.bn_s.weight': 'stem.1.weight',
 'stem.bn_s.bias': 'stem.1.bias',
 'stem.bn_s.running_mean': 'stem.1.running_mean',
 'stem.bn_s.running_var': 'stem.1.running_var',
 'stem.bn_s.num_batches_tracked': 'stem.1.num_batches_tracked',
 'stem.conv_t.weight': 'stem.3.weight',
 'stem.bn_t.weight': 'stem.4.weight',
 'stem.bn_t.bias': 'stem.4.bias',
 'stem.bn_t.running_mean': 'stem.4.running_mean',
 'stem.bn_t.running_var': 'stem.4.running_var',
 'stem.bn_t.num_batches_tracked': 'stem.4.num_batches_tracked',
 'layer1.0.conv1.conv_s.weight': 'layer1.0.conv1.0.0.weight',
 'layer1.0.conv1.bn_s.weight': 'layer1.0.conv1.0.1.weight',
 'layer1.0.conv1.bn_s.bias': 'layer1.0.conv1.0.1.bias',
 'layer1.0.conv1.bn_s.running_mean': 'layer1.0.conv1.0.1.running_mean',
 'layer1.0.conv1.bn_s.running_var': 'layer1.0.conv1.0.1.running_var',
 'layer1.0.conv1.bn_s.num_batches_tracked': 'layer1.0.conv1.0.1.num_batches_tracked',
 'layer1.0.conv1.conv_t.weight': 'layer1.0.con

In [22]:
new_csd = dict()
for k,v in csd.items():
    if k in csd_keys_to_bsd_keys:
        new_csd[csd_keys_to_bsd_keys[k]] = csd[k]
    else:
        new_csd[k] = csd[k]

In [23]:
incompatible, unexpected = bb.load_state_dict(new_csd, strict=False)

In [24]:
incompatible

['fc.weight', 'fc.bias']

In [25]:
unexpected

['head.temporal_conv.weight',
 'head.temporal_conv.bias',
 'head.fc1.weight',
 'head.fc1.bias',
 'head.pred_head.weight',
 'head.pred_head.bias']