## Pretrained MedicalNet to SynapseCLR adaptation

Generate a SynapseCLR checkpoint initialized to MedicalNet pre-trained 3D-ResNet18.

In [1]:
import numpy as np
import torch

import synapse_simclr
import torchinfo

First, we need to ascertain that we get the exact same output with our ResNet-3D implementation (w/ the right args) as the MedicalNet ResNet-3D implementation

In [2]:
state_dict_medical = torch.load(
    '../../ext/MedicalNet/pretrain/resnet_18_23dataset.pth')

# fix the state dict key names
fixed_state_dict_medical = dict()
for key, value in state_dict_medical['state_dict'].items():
    fixed_state_dict_medical[key[7:] if key.find("module.") == 0 else key] = value

In [3]:
for key in fixed_state_dict_medical.keys():
    print(key)

conv1.weight
bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked
layer1.0.conv1.weight
layer1.0.bn1.weight
layer1.0.bn1.bias
layer1.0.bn1.running_mean
layer1.0.bn1.running_var
layer1.0.bn1.num_batches_tracked
layer1.0.conv2.weight
layer1.0.bn2.weight
layer1.0.bn2.bias
layer1.0.bn2.running_mean
layer1.0.bn2.running_var
layer1.0.bn2.num_batches_tracked
layer1.1.conv1.weight
layer1.1.bn1.weight
layer1.1.bn1.bias
layer1.1.bn1.running_mean
layer1.1.bn1.running_var
layer1.1.bn1.num_batches_tracked
layer1.1.conv2.weight
layer1.1.bn2.weight
layer1.1.bn2.bias
layer1.1.bn2.running_mean
layer1.1.bn2.running_var
layer1.1.bn2.num_batches_tracked
layer2.0.conv1.weight
layer2.0.bn1.weight
layer2.0.bn1.bias
layer2.0.bn1.running_mean
layer2.0.bn1.running_var
layer2.0.bn1.num_batches_tracked
layer2.0.conv2.weight
layer2.0.bn2.weight
layer2.0.bn2.bias
layer2.0.bn2.running_mean
layer2.0.bn2.running_var
layer2.0.bn2.num_batches_tracked
layer2.1.conv1.weight
layer2.1.bn1.weight
laye

In [4]:
from synapse_simclr.modules import Identity

resnet18_medicalnet = synapse_simclr.modules.resnet18_medicalnet(
    shortcut_type='A')

# drop the fully connected layer
resnet18_medicalnet.fc = Identity()
resnet18_medicalnet.load_state_dict(fixed_state_dict_medical)
resnet18_medicalnet.eval();

  m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')


In [5]:
# generate our own ResNet
resnet18_syn = synapse_simclr.modules.generate_resnet_3d(
    model_depth=18,
    n_input_channels=1,
    no_max_pool=False,
    conv1_kernel_size=7,
    conv1_stride=2,
    conv1_padding=3,
    shortcut_type='A',
    block_dilations=[1, 1, 2, 4],
    block_strides=[1, 2, 1, 1])

# drop the fully connected layer
resnet18_syn.fc = Identity()
resnet18_syn.load_state_dict(fixed_state_dict_medical)
resnet18_syn.eval();

In [6]:
torchinfo.summary(resnet18_medicalnet)

Layer (type:depth-idx)                   Param #
ResNet                                   --
├─Conv3d: 1-1                            21,952
├─BatchNorm3d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool3d: 1-4                         --
├─Sequential: 1-5                        --
│    └─BasicBlock: 2-1                   --
│    │    └─Conv3d: 3-1                  110,592
│    │    └─BatchNorm3d: 3-2             128
│    │    └─ReLU: 3-3                    --
│    │    └─Conv3d: 3-4                  110,592
│    │    └─BatchNorm3d: 3-5             128
│    └─BasicBlock: 2-2                   --
│    │    └─Conv3d: 3-6                  110,592
│    │    └─BatchNorm3d: 3-7             128
│    │    └─ReLU: 3-8                    --
│    │    └─Conv3d: 3-9                  110,592
│    │    └─BatchNorm3d: 3-10            128
├─Sequential: 1-6                        --
│    └─BasicBlock: 2-3                   --
│    │    └─Conv3d: 3-11                 2

In [7]:
torchinfo.summary(resnet18_syn)

Layer (type:depth-idx)                   Param #
ResNet                                   --
├─Conv3d: 1-1                            21,952
├─BatchNorm3d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool3d: 1-4                         --
├─Sequential: 1-5                        --
│    └─BasicBlock: 2-1                   --
│    │    └─Conv3d: 3-1                  110,592
│    │    └─BatchNorm3d: 3-2             128
│    │    └─ReLU: 3-3                    --
│    │    └─Conv3d: 3-4                  110,592
│    │    └─BatchNorm3d: 3-5             128
│    └─BasicBlock: 2-2                   --
│    │    └─Conv3d: 3-6                  110,592
│    │    └─BatchNorm3d: 3-7             128
│    │    └─ReLU: 3-8                    --
│    │    └─Conv3d: 3-9                  110,592
│    │    └─BatchNorm3d: 3-10            128
├─Sequential: 1-6                        --
│    └─BasicBlock: 2-3                   --
│    │    └─Conv3d: 3-11                 2

In [8]:
x = torch.rand((1, 1, 64, 64, 64))

In [9]:
resnet18_syn_out = resnet18_syn(x)
resnet18_medicalnet_out = resnet18_medicalnet(x)

In [10]:
resnet18_syn_out - resnet18_medicalnet_out

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

If all zeros, we're good!

## Adaptation

In [11]:
# load a "prototype" Synapse-SimCLR checkpoint
state_dict_syn = torch.load(
    '../../output/checkpoint__synapseclr__so3__second_stage/model_checkpoint_99.pt')

for key in state_dict_syn.keys():
    print(key)

module.encoder.conv1.weight
module.encoder.bn1.weight
module.encoder.bn1.bias
module.encoder.bn1.running_mean
module.encoder.bn1.running_var
module.encoder.bn1.num_batches_tracked
module.encoder.layer1.0.conv1.weight
module.encoder.layer1.0.bn1.weight
module.encoder.layer1.0.bn1.bias
module.encoder.layer1.0.bn1.running_mean
module.encoder.layer1.0.bn1.running_var
module.encoder.layer1.0.bn1.num_batches_tracked
module.encoder.layer1.0.conv2.weight
module.encoder.layer1.0.bn2.weight
module.encoder.layer1.0.bn2.bias
module.encoder.layer1.0.bn2.running_mean
module.encoder.layer1.0.bn2.running_var
module.encoder.layer1.0.bn2.num_batches_tracked
module.encoder.layer1.1.conv1.weight
module.encoder.layer1.1.bn1.weight
module.encoder.layer1.1.bn1.bias
module.encoder.layer1.1.bn1.running_mean
module.encoder.layer1.1.bn1.running_var
module.encoder.layer1.1.bn1.num_batches_tracked
module.encoder.layer1.1.conv2.weight
module.encoder.layer1.1.bn2.weight
module.encoder.layer1.1.bn2.bias
module.encode

In [12]:
# Load MedicalNet pretrained weights
state_dict_medical = torch.load(
    '../../ext/MedicalNet/pretrain/resnet_18_23dataset.pth')

# fix the state dict key names
fixed_state_dict_medical = dict()
for key, value in state_dict_medical['state_dict'].items():
    fixed_state_dict_medical[key[7:] if key.find("module.") == 0 else key] = value
    
for key in fixed_state_dict_medical.keys():
    print(key)

conv1.weight
bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked
layer1.0.conv1.weight
layer1.0.bn1.weight
layer1.0.bn1.bias
layer1.0.bn1.running_mean
layer1.0.bn1.running_var
layer1.0.bn1.num_batches_tracked
layer1.0.conv2.weight
layer1.0.bn2.weight
layer1.0.bn2.bias
layer1.0.bn2.running_mean
layer1.0.bn2.running_var
layer1.0.bn2.num_batches_tracked
layer1.1.conv1.weight
layer1.1.bn1.weight
layer1.1.bn1.bias
layer1.1.bn1.running_mean
layer1.1.bn1.running_var
layer1.1.bn1.num_batches_tracked
layer1.1.conv2.weight
layer1.1.bn2.weight
layer1.1.bn2.bias
layer1.1.bn2.running_mean
layer1.1.bn2.running_var
layer1.1.bn2.num_batches_tracked
layer2.0.conv1.weight
layer2.0.bn1.weight
layer2.0.bn1.bias
layer2.0.bn1.running_mean
layer2.0.bn1.running_var
layer2.0.bn1.num_batches_tracked
layer2.0.conv2.weight
layer2.0.bn2.weight
layer2.0.bn2.bias
layer2.0.bn2.running_mean
layer2.0.bn2.running_var
layer2.0.bn2.num_batches_tracked
layer2.1.conv1.weight
layer2.1.bn1.weight
laye

In [13]:
import os
from collections import OrderedDict

output_checkpoint_path = '../../output/checkpoint__medicalnet'
copy_projector_params = False

encoder_prefix = 'module.encoder.'
projector_prefix = 'module.projector.'

adapated_state_dict_medical = OrderedDict()

# copy encoder parameters
encoder_keys = list(resnet18_syn.state_dict().keys())
for key in encoder_keys:
    adapated_state_dict_medical[encoder_prefix + key] = fixed_state_dict_medical[key]

# copy projector parameters
if copy_projector_params:
    for key in state_dict_syn.keys():
        if key.find(projector_prefix) == 0:
            adapated_state_dict_medical[key] = state_dict_syn[key]
        
torch.save(
    adapated_state_dict_medical,
    os.path.join(output_checkpoint_path, 'model_checkpoint_0.pt'))