In [1]:
import sys
import torch
sys.path.append("/hdd02/zhangyiyang/temporal-shift-module")
from tsm.models import TSN, resnet50, MobileNetV2

In [4]:
resnet_offline_model = TSN(
    num_classes=27,
    num_segments=8,
    modality='RGB',
    base_model='resnet50', 
    new_length=None,
    consensus_type='avg', 
    before_softmax=False,
    dropout=0.8, 
    img_feature_dim=256,
    crop_num=1, 
    partial_bn=False,
    print_spec=True, 
    pretrain=None,
    is_shift=True, 
    shift_div=8,
    shift_place='blockres', 
    fc_lr5=False,
    temporal_pool=False, 
    non_local=False,
    uni_direction=False,
)
resnet_offline_model = torch.nn.DataParallel(resnet_offline_model).cuda()


                    Initializing TSN with base model: resnet50.
                    TSN Configurations:
                        input_modality:     RGB
                        num_segments:       8
                        new_length:         1
                        consensus_module:   avg
                        dropout_ratio:      0.8
                        img_feature_dim:    256
            


In [5]:
resnet_online_model = resnet50()

In [6]:
offline_model_keys = resnet_offline_model.state_dict().keys()
list(offline_model_keys)

['module.base_model.conv1.weight',
 'module.base_model.bn1.weight',
 'module.base_model.bn1.bias',
 'module.base_model.bn1.running_mean',
 'module.base_model.bn1.running_var',
 'module.base_model.bn1.num_batches_tracked',
 'module.base_model.layer1.0.conv1.net.weight',
 'module.base_model.layer1.0.bn1.weight',
 'module.base_model.layer1.0.bn1.bias',
 'module.base_model.layer1.0.bn1.running_mean',
 'module.base_model.layer1.0.bn1.running_var',
 'module.base_model.layer1.0.bn1.num_batches_tracked',
 'module.base_model.layer1.0.conv2.weight',
 'module.base_model.layer1.0.bn2.weight',
 'module.base_model.layer1.0.bn2.bias',
 'module.base_model.layer1.0.bn2.running_mean',
 'module.base_model.layer1.0.bn2.running_var',
 'module.base_model.layer1.0.bn2.num_batches_tracked',
 'module.base_model.layer1.0.conv3.weight',
 'module.base_model.layer1.0.bn3.weight',
 'module.base_model.layer1.0.bn3.bias',
 'module.base_model.layer1.0.bn3.running_mean',
 'module.base_model.layer1.0.bn3.running_var',
 

In [7]:
online_model_keys = resnet_online_model.state_dict().keys()
list(online_model_keys)

['conv1.weight',
 'bn1.weight',
 'bn1.bias',
 'bn1.running_mean',
 'bn1.running_var',
 'bn1.num_batches_tracked',
 'layer1.0.conv1.weight',
 'layer1.0.bn1.weight',
 'layer1.0.bn1.bias',
 'layer1.0.bn1.running_mean',
 'layer1.0.bn1.running_var',
 'layer1.0.bn1.num_batches_tracked',
 'layer1.0.conv2.weight',
 'layer1.0.bn2.weight',
 'layer1.0.bn2.bias',
 'layer1.0.bn2.running_mean',
 'layer1.0.bn2.running_var',
 'layer1.0.bn2.num_batches_tracked',
 'layer1.0.conv3.weight',
 'layer1.0.bn3.weight',
 'layer1.0.bn3.bias',
 'layer1.0.bn3.running_mean',
 'layer1.0.bn3.running_var',
 'layer1.0.bn3.num_batches_tracked',
 'layer1.0.downsample.0.weight',
 'layer1.0.downsample.1.weight',
 'layer1.0.downsample.1.bias',
 'layer1.0.downsample.1.running_mean',
 'layer1.0.downsample.1.running_var',
 'layer1.0.downsample.1.num_batches_tracked',
 'layer1.1.conv1.weight',
 'layer1.1.bn1.weight',
 'layer1.1.bn1.bias',
 'layer1.1.bn1.running_mean',
 'layer1.1.bn1.running_var',
 'layer1.1.bn1.num_batches_trac

In [21]:
online_keys = [k for k in online_model_keys]
offline_keys = [k for k in offline_model_keys]


In [24]:
def key_offline_to_online(key):
    return key.replace("module.", "").replace("base_model.", "").replace("net.", "").replace("new_", "")

In [34]:
def key_online_to_offline(key):
    if not key.startswith("fc"):
        key = "base_model." + key
    key = "module." + key
    if 'conv1' in key and 'layer' in key:
        key = key.replace('conv1', 'conv1.net')
    if 'fc' in key:
        key = key.replace('fc', 'new_fc')
    return key

In [35]:
online_keys = [key_online_to_offline(k) for k in online_model_keys]
offline_keys = [k for k in offline_model_keys]
for k in offline_keys:
    if k not in online_keys:
        print(k)
print()
for k in online_keys:
    if k not in offline_keys:
        print(k)


