# Transform the ModelGenesis weights from Keras to Pytorch

In [1]:
%load_ext autoreload
%autoreload 2

import keras
import torch
import numpy as np

# Need to clone https://github.com/MrGiovanni/ModelsGenesis in advance
from ModelsGenesis.keras.unet3d import *
from src.model.nets import ModelsGenesisSegNet

Using TensorFlow backend.


In [4]:
keras_weight_path = './ModelsGenesis/pretrained_weights/Genesis_Chest_CT.h5' # ModelGenesis pretrained weights
pytorch_weight_path = './weights/models_genesis.pth' # path to save the transformed Pytorch weights

def forward_hook(self, inputs, outputs):
    features_hook.append(outputs)

## Keras model

In [5]:
keras_model = unet_model_3d((1, 64, 64, 32), batch_normalization=True)
keras_model.load_weights(keras_weight_path)
output = keras_model.get_layer('depth_13_relu').output
keras_model = keras.models.Model(inputs=keras_model.input, outputs=output)
keras_model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            (None, 1, 64, 64, 32 0                                            
__________________________________________________________________________________________________
depth_0_conv (Conv3D)           (None, 32, 64, 64, 3 896         input_2[0][0]                    
__________________________________________________________________________________________________
depth_0_bn (BatchNormalization) (None, 32, 64, 64, 3 128         depth_0_conv[0][0]               
__________________________________________________________________________________________________
depth_0_relu (Activation)       (None, 32, 64, 64, 3 0           depth_0_bn[0][0]                 
__________________________________________________________________________________________________
depth_1_co

In [6]:
keras_layer_names = [weight.name for layer in keras_model.layers for weight in layer.weights]
keras_weights = keras_model.get_weights()
for name, weight in zip(keras_layer_names, keras_weights):
    print(name, weight.shape)

depth_0_conv_1/kernel:0 (3, 3, 3, 1, 32)
depth_0_conv_1/bias:0 (32,)
depth_0_bn_1/gamma:0 (32,)
depth_0_bn_1/beta:0 (32,)
depth_0_bn_1/moving_mean:0 (32,)
depth_0_bn_1/moving_variance:0 (32,)
depth_1_conv_1/kernel:0 (3, 3, 3, 32, 64)
depth_1_conv_1/bias:0 (64,)
depth_1_bn_1/gamma:0 (64,)
depth_1_bn_1/beta:0 (64,)
depth_1_bn_1/moving_mean:0 (64,)
depth_1_bn_1/moving_variance:0 (64,)
depth_2_conv_1/kernel:0 (3, 3, 3, 64, 64)
depth_2_conv_1/bias:0 (64,)
depth_2_bn_1/gamma:0 (64,)
depth_2_bn_1/beta:0 (64,)
depth_2_bn_1/moving_mean:0 (64,)
depth_2_bn_1/moving_variance:0 (64,)
depth_3_conv_1/kernel:0 (3, 3, 3, 64, 128)
depth_3_conv_1/bias:0 (128,)
depth_3_bn_1/gamma:0 (128,)
depth_3_bn_1/beta:0 (128,)
depth_3_bn_1/moving_mean:0 (128,)
depth_3_bn_1/moving_variance:0 (128,)
depth_4_conv_1/kernel:0 (3, 3, 3, 128, 128)
depth_4_conv_1/bias:0 (128,)
depth_4_bn_1/gamma:0 (128,)
depth_4_bn_1/beta:0 (128,)
depth_4_bn_1/moving_mean:0 (128,)
depth_4_bn_1/moving_variance:0 (128,)
depth_5_conv_1/kernel:0

## Pytorch model

In [7]:
pytorch_model = ModelsGenesisSegNet(in_channels=1, out_channels=1)
pytorch_model.eval()
pytorch_model.up_block3.body.relu2.register_forward_hook(forward_hook)
print(pytorch_model)

ModelsGenesisSegNet(
  (in_block): _InBlock(
    (conv1): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (norm1): BatchNorm3d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (conv2): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (norm2): BatchNorm3d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
    (relu2): ReLU(inplace=True)
  )
  (down_block1): _DownBlock(
    (pool): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (norm1): BatchNorm3d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (conv2): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (norm2): BatchNorm3d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
    (relu2

### Trainable parameters

In [8]:
idx = 0
moving_idx = []
for (pytorch_name, param) in pytorch_model.named_parameters():
    try:
        weight = keras_weights[idx]
    except:
        break
    if 'conv' in pytorch_name and 'weight' in pytorch_name:
        param.data = torch.from_numpy(np.transpose(weight, (4, 3, 2, 0, 1)))
    else:
        param.data = torch.from_numpy(weight)
        
    if 'norm' in pytorch_name and 'bias' in pytorch_name:
        moving_idx.append(idx+1)
        moving_idx.append(idx+2)
        idx += 3
    else:
        idx += 1

### Normalization running mean / var 

In [9]:
idx = 0
for name, module in pytorch_model.named_modules():
    if 'norm' in name:
        print(name, keras_weights[moving_idx[idx+1]])
        module.running_mean = torch.from_numpy(keras_weights[moving_idx[idx]])
        module.running_var = torch.from_numpy(keras_weights[moving_idx[idx+1]])
        idx += 2

in_block.norm1 [0.10773456 0.0156241  0.07792253 1.1480795  0.02501738 0.09905564
 0.19020303 0.41333827 0.08917138 0.6359106  0.3284189  2.2120872
 0.15645984 0.14176716 0.14004321 0.12683997 0.06108763 0.32820657
 0.05714194 0.04843716 0.4967732  0.10198742 0.10573964 0.08588026
 0.6854433  0.06085574 0.0786974  0.294242   0.27014118 0.08691396
 0.17866541 0.04669904]
in_block.norm2 [ 3.7916975  3.8988621 19.649761   3.6805334  5.529005   2.2508006
  4.4314017 10.00479    2.4139037  4.296064   2.122373   2.119363
  6.2018237  9.0575905  2.5414429 14.059286   4.6595473  4.4932485
  3.884869   5.5471125  4.308636   1.9267719  2.3988662  4.2054043
  2.7728386  3.1002681  6.69639    3.6184905  4.2444334  2.0157583
  2.1239355  3.3594143  2.9200778  2.024735   6.532754   2.6827655
  3.7127178  2.5198367  3.071708   2.3560147  1.0687444  2.2905595
  8.323762  12.777962   3.3162234  2.023244   2.1988385  3.094626
  2.5355198  2.0601332  2.2036538  9.026318   2.1935406 20.833904
  2.4403563 

## Validate whether the transformation is successful

In [10]:
features_hook = []
input = np.random.randn(1, 1, 64, 64, 32).astype(np.float)

In [11]:
pytorch_input = torch.from_numpy(input.transpose(0, 1, 4, 2, 3)).float().to(torch.device('cuda:0'))
pytorch_model.to(torch.device('cuda:0'))
pytorch_model(pytorch_input)
pytorch_output = features_hook[0].detach().cpu().numpy().transpose(0, 1, 3, 4, 2)

In [12]:
keras_output = keras_model.predict_on_batch(input)

In [13]:
assert keras_output.shape == pytorch_output.shape, f"{keras_output.shape} != {pytorch_output.shape}"
print(np.max(np.abs(keras_output - pytorch_output)))

0.0005493164


In [11]:
state_dict = pytorch_model.state_dict()
state_dict.pop('out_block.weight')
state_dict.pop('out_block.bias')
torch.save(state_dict, pytorch_weight_path)

---------------------

# Validate the segmentation / classification network

In [14]:
%load_ext autoreload
%autoreload 2

import keras
import torch
from ModelsGenesis.keras.unet3d import *
from src.model import ModelsGenesisSegNet, ModelsGenesisClfNet

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
keras_weight_path = './ModelsGenesis/pretrained_weights/Genesis_Chest_CT.h5' # ModelGenesis pretrained weights
pytorch_weight_path = './weights/models_genesis.pth' # path to save the transformed Pytorch weights

def forward_hook(self, inputs, outputs):
    features_hook.append(outputs)

## Segmentation

In [17]:
features_hook = []
input = np.random.randn(1, 1, 64, 64, 32).astype(np.float)

In [18]:
seg_net = ModelsGenesisSegNet(in_channels=1, out_channels=10, weight_path=pytorch_weight_path)
seg_net.eval()
seg_net.up_block3.body.relu2.register_forward_hook(forward_hook)

<torch.utils.hooks.RemovableHandle at 0x7f9a702278d0>

In [19]:
pytorch_input = torch.from_numpy(input.transpose(0, 1, 4, 2, 3)).float().to(torch.device('cuda:0'))
seg_net.to(torch.device('cuda:0'))
seg_net(pytorch_input)
pytorch_output = features_hook[0].detach().cpu().numpy().transpose(0, 1, 3, 4, 2)

In [20]:
keras_model = unet_model_3d((1, 64, 64, 32), batch_normalization=True)
keras_model.load_weights(keras_weight_path)
output = keras_model.get_layer('depth_13_relu').output
keras_model = keras.models.Model(inputs=keras_model.input, outputs=output)
keras_output = keras_model.predict_on_batch(input)

In [21]:
np.abs(pytorch_output - keras_output).max()

0.000579834

## Classification

In [23]:
features_hook = []
input = np.random.randn(1, 1, 64, 64, 32).astype(np.float)

In [22]:
clf_net = ModelsGenesisClfNet(in_channels=1, out_channels=20, weight_path=pytorch_weight_path)
clf_net.eval()
clf_net.down_block3.relu2.register_forward_hook(forward_hook)

<torch.utils.hooks.RemovableHandle at 0x7f9a70267e48>

In [24]:
pytorch_input = torch.from_numpy(input.transpose(0, 1, 4, 2, 3)).float().to(torch.device('cuda:0'))
clf_net.to(torch.device('cuda:0'))
clf_net(pytorch_input)
pytorch_output = features_hook[0].detach().cpu().numpy().transpose(0, 1, 3, 4, 2)

In [26]:
keras_model = unet_model_3d((1, 64, 64, 32), batch_normalization=True)
keras_model.load_weights(keras_weight_path)
output = keras_model.get_layer('depth_7_relu').output
keras_model = keras.models.Model(inputs=keras_model.input, outputs=output)
keras_output = keras_model.predict_on_batch(input)

In [27]:
np.abs(pytorch_output - keras_output).max()

2.670288e-05