In [1]:
from itertools import chain

In [2]:
import numpy as np

In [3]:
from torch import tensor, no_grad

In [4]:
from thesis_v2.blocks.rcnn_basic_kriegeskorte import nn_modules

In [5]:
np.random.seed(0)
random_img = np.random.random_sample((1, 128, 128, 3))*2 - 1
random_img = random_img.astype(np.float32)
random_img = random_img.transpose((0,3,1,2))

In [6]:
a = nn_modules.BLConvLayerStack(
    n_timesteps=8,
    channel_list=[3,96,128,192,256,512,1024,2048],
    ksize_list=[7,5,3,3,3,3,1],
    # matching tf's bn config.
    bn_eps=0.001,
)

In [7]:
for x, y in a.named_buffers():
    print(x,y.size(),y.dtype)

bn_layer_list.0.running_mean torch.Size([96]) torch.float32
bn_layer_list.0.running_var torch.Size([96]) torch.float32
bn_layer_list.0.num_batches_tracked torch.Size([]) torch.int64
bn_layer_list.1.running_mean torch.Size([128]) torch.float32
bn_layer_list.1.running_var torch.Size([128]) torch.float32
bn_layer_list.1.num_batches_tracked torch.Size([]) torch.int64
bn_layer_list.2.running_mean torch.Size([192]) torch.float32
bn_layer_list.2.running_var torch.Size([192]) torch.float32
bn_layer_list.2.num_batches_tracked torch.Size([]) torch.int64
bn_layer_list.3.running_mean torch.Size([256]) torch.float32
bn_layer_list.3.running_var torch.Size([256]) torch.float32
bn_layer_list.3.num_batches_tracked torch.Size([]) torch.int64
bn_layer_list.4.running_mean torch.Size([512]) torch.float32
bn_layer_list.4.running_var torch.Size([512]) torch.float32
bn_layer_list.4.num_batches_tracked torch.Size([]) torch.int64
bn_layer_list.5.running_mean torch.Size([1024]) torch.float32
bn_layer_list.5.runn

In [8]:
import h5py

def load_weight_for_BLConvLayerStack(mod, filepath):
    # three sets of groups.
    # 1. RCL_#_BConv/kernel:0 mapped to layer_list.#.b_conv.weight
    #     with transpose((3,2,0,1))
    # 2. RCL_#_LConv/kernel:0 mapped to layer_list.#.l_conv.weight
    #     with transpose((3,2,0,1))
    # 3. BatchNorm_Layer_#1_Time_#2/beta:0(gamma:0) mapped to bn_layer_list_list.(#2*num_layer + #1).bias(weight)
    
    num_layer = mod.n_layer
    
    with h5py.File(filepath, 'r') as f:
        for x, y in chain(mod.named_parameters(), mod.named_buffers()):
            if x.startswith('layer_list.'):
                # layer_list.0.b_conv.weight
                parts = x.split('.')
                assert len(parts) == 4
                assert parts[0] == 'layer_list'
                assert parts[3] == 'weight'
                layer_idx = int(parts[1])
                mapped_name = {'b_conv': 'BConv', 'l_conv': 'LConv'}[parts[2]]
                mapped_name = f'RCL_{layer_idx}_{mapped_name}/RCL_{layer_idx}_{mapped_name}/kernel:0'
                weight_tf = f[mapped_name][()].transpose((3,2,0,1))
            elif x.startswith('bn_layer_list.'):
                parts = x.split('.')
                assert len(parts) == 3
                assert parts[0] == 'bn_layer_list'
                
                if parts[2] == 'num_batches_tracked':
                    continue
                
                layer_idx_global = int(parts[1])
                layer_idx = layer_idx_global % num_layer
                time_idx = layer_idx_global // num_layer
                mapped_name = {'weight': 'gamma', 'bias': 'beta', 'running_mean': 'moving_mean', 'running_var': 'moving_variance'}[parts[2]]
                mapped_name = f'BatchNorm_Layer_{layer_idx}_Time_{time_idx}/BatchNorm_Layer_{layer_idx}_Time_{time_idx}/{mapped_name}:0'
                weight_tf = f[mapped_name][()]
            else:
                raise RuntimeError
            print(x, mapped_name, weight_tf.shape, y.size())
            assert weight_tf.shape == y.size()
            with no_grad():
                y[()] = tensor(weight_tf)
            

In [9]:
load_weight_for_BLConvLayerStack(a, 'bl_ecoset.h5')

layer_list.0.b_conv.weight RCL_0_BConv/RCL_0_BConv/kernel:0 (96, 3, 7, 7) torch.Size([96, 3, 7, 7])
layer_list.0.l_conv.weight RCL_0_LConv/RCL_0_LConv/kernel:0 (96, 96, 7, 7) torch.Size([96, 96, 7, 7])
layer_list.1.b_conv.weight RCL_1_BConv/RCL_1_BConv/kernel:0 (128, 96, 5, 5) torch.Size([128, 96, 5, 5])
layer_list.1.l_conv.weight RCL_1_LConv/RCL_1_LConv/kernel:0 (128, 128, 5, 5) torch.Size([128, 128, 5, 5])
layer_list.2.b_conv.weight RCL_2_BConv/RCL_2_BConv/kernel:0 (192, 128, 3, 3) torch.Size([192, 128, 3, 3])
layer_list.2.l_conv.weight RCL_2_LConv/RCL_2_LConv/kernel:0 (192, 192, 3, 3) torch.Size([192, 192, 3, 3])
layer_list.3.b_conv.weight RCL_3_BConv/RCL_3_BConv/kernel:0 (256, 192, 3, 3) torch.Size([256, 192, 3, 3])
layer_list.3.l_conv.weight RCL_3_LConv/RCL_3_LConv/kernel:0 (256, 256, 3, 3) torch.Size([256, 256, 3, 3])
layer_list.4.b_conv.weight RCL_4_BConv/RCL_4_BConv/kernel:0 (512, 256, 3, 3) torch.Size([512, 256, 3, 3])
layer_list.4.l_conv.weight RCL_4_LConv/RCL_4_LConv/kernel:

bn_layer_list.45.weight BatchNorm_Layer_3_Time_6/BatchNorm_Layer_3_Time_6/gamma:0 (256,) torch.Size([256])
bn_layer_list.45.bias BatchNorm_Layer_3_Time_6/BatchNorm_Layer_3_Time_6/beta:0 (256,) torch.Size([256])
bn_layer_list.46.weight BatchNorm_Layer_4_Time_6/BatchNorm_Layer_4_Time_6/gamma:0 (512,) torch.Size([512])
bn_layer_list.46.bias BatchNorm_Layer_4_Time_6/BatchNorm_Layer_4_Time_6/beta:0 (512,) torch.Size([512])
bn_layer_list.47.weight BatchNorm_Layer_5_Time_6/BatchNorm_Layer_5_Time_6/gamma:0 (1024,) torch.Size([1024])
bn_layer_list.47.bias BatchNorm_Layer_5_Time_6/BatchNorm_Layer_5_Time_6/beta:0 (1024,) torch.Size([1024])
bn_layer_list.48.weight BatchNorm_Layer_6_Time_6/BatchNorm_Layer_6_Time_6/gamma:0 (2048,) torch.Size([2048])
bn_layer_list.48.bias BatchNorm_Layer_6_Time_6/BatchNorm_Layer_6_Time_6/beta:0 (2048,) torch.Size([2048])
bn_layer_list.49.weight BatchNorm_Layer_0_Time_7/BatchNorm_Layer_0_Time_7/gamma:0 (96,) torch.Size([96])
bn_layer_list.49.bias BatchNorm_Layer_0_Tim

In [10]:
a.cuda().eval()

BLConvLayerStack(
  (layer_list): ModuleList(
    (0): BLConvLayer(
      (b_conv): Conv2d(3, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
      (l_conv): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    )
    (1): BLConvLayer(
      (b_conv): Conv2d(96, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
      (l_conv): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
    )
    (2): BLConvLayer(
      (b_conv): Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (l_conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (3): BLConvLayer(
      (b_conv): Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (l_conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (4): BLConvLayer(
      (b_conv): Conv2d(256, 512, kernel_size=(3, 3), strid

In [11]:
pytorch_out = a(tensor(random_img).cuda())

In [12]:
# let's compare
from pickle import load
with open('debug_rcnn_basic_kriegeskorte_tf.pkl', 'rb') as f_tf:
    tf_out = load(f_tf)

In [13]:
from numpy.linalg import norm

In [14]:
def compare():
    assert len(pytorch_out) == len(tf_out) == 8
    for pytorch_this, tf_this in zip(pytorch_out, tf_out):
        pytorch_this = pytorch_this.detach().cpu().numpy()
        pytorch_this = pytorch_this.transpose((0,2,3,1))
        assert pytorch_this.shape == tf_this.shape
#         print(pytorch_this.mean(), pytorch_this.std())
#         print(tf_this.mean(), tf_this.std())
        print(norm(pytorch_this-tf_this))
        assert norm(pytorch_this-tf_this) < 1e-3

In [15]:
compare()

0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
