In [1]:
import re
from pathlib import Path
from collections import OrderedDict

import pandas as pd
from tqdm import tqdm_notebook as tqdm

import torch
import tensorflow as tf

from highresnet import HighRes3DNet

In [2]:
output_csv_path = '/tmp/state_dict_tf.csv'
models_dir = Path('~/niftynet/models/').expanduser()
checkpoint_name = 'model.ckpt-33000'
checkpoint_path = models_dir / 'highres3dnet_brain_parcellation'/ 'models' / checkpoint_name

state_dict_tf_path = '/tmp/state_dict_tf.pth'
state_dict_pt_path = '/tmp/state_dict_pt.pth'

pd.set_option('display.max_colwidth', -1)  # do not truncate strings when displaying data frames
pd.set_option('display.max_rows', None)  # show all rows

filter_variables = True

## TensorFlow

In [3]:
tf.reset_default_graph()

rows = []
variables_dict = OrderedDict()
variables_list = tf.train.list_variables(str(checkpoint_path))
for name, shape in variables_list:
    if filter_variables:
        if (
            'Adam' in name
            or 'Exponential' in name
            or 'biased' in name
            or not shape
        ):
            continue
    variables_dict[name] = tf.get_variable(name, shape=shape)
    name = name.replace('HighRes3DNet/', '')
    shape = ', '.join(str(n) for n in shape)
    row = {'name': name, 'shape': shape}
    rows.append(row)
df_tf = pd.DataFrame.from_dict(rows)
df_tf

Unnamed: 0,name,shape
0,conv_0_bn_relu/bn_/beta,16
1,conv_0_bn_relu/bn_/gamma,16
2,conv_0_bn_relu/bn_/moving_mean,16
3,conv_0_bn_relu/bn_/moving_variance,16
4,conv_0_bn_relu/conv_/w,"3, 3, 3, 1, 16"
5,conv_1_bn_relu/bn_/beta,80
6,conv_1_bn_relu/bn_/gamma,80
7,conv_1_bn_relu/bn_/moving_mean,80
8,conv_1_bn_relu/bn_/moving_variance,80
9,conv_1_bn_relu/conv_/w,"1, 1, 1, 64, 80"


In [4]:
saver = tf.train.Saver()
state_dict_tf = {}
with tf.Session() as sess:
    print('Restoring session...')
    saver.restore(sess, str(checkpoint_path))
    for name, shape in tqdm(variables_list):
        if filter_variables:
            if (
                'Adam' in name
                or 'Exponential' in name
                or 'biased' in name
                or not shape
            ):
                continue
        array = variables_dict[name].eval()
        name = name.replace('HighRes3DNet/', '')
        state_dict_tf[name] = torch.tensor(array)
print('Saving state dictionary...')
torch.save(state_dict_tf, state_dict_tf_path)

Restoring session...
INFO:tensorflow:Restoring parameters from /home/fernando/niftynet/models/highres3dnet_brain_parcellation/models/model.ckpt-33000


HBox(children=(IntProgress(value=0, max=380), HTML(value='')))


Saving state dictionary...


## PyTorch

In [5]:
model = HighRes3DNet(1, 160, add_dropout_layer=True)
state_dict_pt = model.state_dict()
rows = []
for name, parameters in state_dict_pt.items():
    shape = ', '.join(str(n) for n in parameters.shape)
    row = {'name': name, 'shape': shape}
    rows.append(row)
df_pt = pd.DataFrame.from_dict(rows)
df_pt.style.set_properties(**{'text-align': 'left'})
df_pt

Unnamed: 0,name,shape
0,block.0.convolutional_block.1.weight,"16, 1, 3, 3, 3"
1,block.0.convolutional_block.2.weight,1
2,block.0.convolutional_block.2.bias,1
3,block.0.convolutional_block.2.running_mean,1
4,block.0.convolutional_block.2.running_var,1
5,block.0.convolutional_block.2.num_batches_tracked,
6,block.1.dilation_block.0.residual_block.0.convolutional_block.0.weight,16
7,block.1.dilation_block.0.residual_block.0.convolutional_block.0.bias,16
8,block.1.dilation_block.0.residual_block.0.convolutional_block.0.running_mean,16
9,block.1.dilation_block.0.residual_block.0.convolutional_block.0.running_var,16


In [9]:
def tf2pt_name(name_tf):
    param_type_dict = {
        'w': 'weight',
        'gamma': 'weight',
        'beta': 'bias',
        'moving_mean': 'running_mean',
        'moving_variance': 'running_var',
    }
    
    if name_tf.startswith('res_'):
        # res_2_0/bn_0/moving_variance
        pattern = (
            'res'
            r'_(\d)'  # 2
            r'_(\d)'  # 0
            r'/(\w+)' # bn
            r'_(\d)'  # 0
            r'/(\w+)'  # moving_variance
        )
        groups = re.match(pattern, name_tf).groups()
        dil_idx, res_idx, layer_type, layer_idx, param_type = groups
        param_idx = 3 if layer_type == 'conv' else 0
            
        name_pt = (
            f'block.{dil_idx}.dilation_block.{res_idx}.residual_block'
            f'.{layer_idx}.convolutional_block.{param_idx}.{param_type_dict[param_type]}'
        )
    elif name_tf.startswith('conv_'):
        conv_layers_dict = {
            'conv_0_bn_relu/conv_/w': 'block.0.convolutional_block.1.weight',  # first conv layer
            'conv_0_bn_relu/bn_/gamma': 'block.0.convolutional_block.2.weight',  
            'conv_0_bn_relu/bn_/beta': 'block.0.convolutional_block.2.bias',
            'conv_0_bn_relu/bn_/moving_mean': 'block.0.convolutional_block.2.running_mean',
            'conv_0_bn_relu/bn_/moving_variance': 'block.0.convolutional_block.2.running_var',

            'conv_1_bn_relu/conv_/w': 'block.4.convolutional_block.1.weight',  # layer with dropout
            'conv_1_bn_relu/bn_/gamma': 'block.4.convolutional_block.2.weight',  
            'conv_1_bn_relu/bn_/beta': 'block.4.convolutional_block.2.bias',
            'conv_1_bn_relu/bn_/moving_mean': 'block.4.convolutional_block.2.running_mean',
            'conv_1_bn_relu/bn_/moving_variance': 'block.4.convolutional_block.2.running_var',

            'conv_2_bn/conv_/w': 'block.6.convolutional_block.1.weight',  # layer with dropout
            'conv_2_bn/bn_/gamma': 'block.6.convolutional_block.2.weight',  
            'conv_2_bn/bn_/beta': 'block.6.convolutional_block.2.bias',
            'conv_2_bn/bn_/moving_mean': 'block.6.convolutional_block.2.running_mean',
            'conv_2_bn/bn_/moving_variance': 'block.6.convolutional_block.2.running_var',
        }
        name_pt = conv_layers_dict[name_tf]
    return name_pt
    
    
    
def tf2pt(name_tf, tensor_tf):
    name_pt = tf2pt_name(name_tf)
    num_dimensions = tensor_tf.dim()
    if num_dimensions == 1:
        tensor_pt = tensor_tf
    elif num_dimensions == 5:
        tensor_pt = tensor_tf.permute(4, 3, 0, 1, 2)
    return name_pt, tensor_pt

In [13]:
for name_tf, tensor_tf in tqdm(list(state_dict_tf.items())):
    name_pt, tensor_pt = tf2pt(name_tf, tensor_tf)
    print(f'{str(tuple(tensor_pt.shape)):18}', name_pt) 
    state_dict_pt[name_pt] = tensor_pt
torch.save(state_dict_pt, state_dict_pt_path)

HBox(children=(IntProgress(value=0, max=105), HTML(value='')))

(16,)              block.0.convolutional_block.2.bias
(16,)              block.0.convolutional_block.2.weight
(16,)              block.0.convolutional_block.2.running_mean
(16,)              block.0.convolutional_block.2.running_var
(16, 1, 3, 3, 3)   block.0.convolutional_block.1.weight
(80,)              block.4.convolutional_block.2.bias
(80,)              block.4.convolutional_block.2.weight
(80,)              block.4.convolutional_block.2.running_mean
(80,)              block.4.convolutional_block.2.running_var
(80, 64, 1, 1, 1)  block.4.convolutional_block.1.weight
(160,)             block.6.convolutional_block.2.bias
(160,)             block.6.convolutional_block.2.weight
(160,)             block.6.convolutional_block.2.running_mean
(160,)             block.6.convolutional_block.2.running_var
(160, 80, 1, 1, 1) block.6.convolutional_block.1.weight
(16,)              block.1.dilation_block.0.residual_block.0.convolutional_block.0.bias
(16,)              block.1.dilation_block.0.r