In [36]:
import pdb
import numpy as np

import torch
import torch.nn as nn
from torch.autograd import Variable

In [2]:
def _make_conv_block(convolution, batch_norm):
    class ConvBlock(nn.Module):
        def __init__(self, n_chans_in, n_chans_out, kernel_size, activation=lambda x: x, stride=1,
                     padding=0, dilation=1):
            super().__init__()
            self.convolution = convolution(in_channels=n_chans_in, out_channels=n_chans_out, kernel_size=kernel_size,
                                           stride=stride, padding=padding, dilation=dilation, bias=False)
            self.batch_norm = batch_norm(num_features=n_chans_out)
            self.activation = activation

        def forward(self, input):
            return self.activation(self.batch_norm(self.convolution(input)))

    return ConvBlock


ConvBlock1d = _make_conv_block(nn.Conv1d, nn.BatchNorm1d)
ConvBlock2d = _make_conv_block(nn.Conv2d, nn.BatchNorm2d)
ConvBlock3d = _make_conv_block(nn.Conv3d, nn.BatchNorm3d)

ConvTransposeBlock1d = _make_conv_block(nn.ConvTranspose1d, nn.BatchNorm1d)
ConvTransposeBlock2d = _make_conv_block(nn.ConvTranspose2d, nn.BatchNorm2d)
ConvTransposeBlock3d = _make_conv_block(nn.ConvTranspose3d, nn.BatchNorm3d)


In [39]:
def make_tnet(conv_block, conv_transposed_block):
    class TNet(nn.Module):
        def __init__(self, n_chans_in, n_chans_out, activation, structure, stride):
            super().__init__()
            assert all([len(level) == 3 for level in structure[:-1]])
            assert len(structure[-1]) == 1
            
            pdb.set_trace()

            down_paths = [level[0] for level in structure[:-1]]
            bridge_paths = [level[1] for level in structure[:-1]] + structure[-1]
            up_paths = [level[2] for level in structure[:-1]]
            
            
            
            down_paths[0] = [n_chans_in, *down_paths[0]]
            bridge_paths = [[down_level[-1], *bridge_level]
                            for down_level, bridge_level in zip(down_paths, bridge_paths)]

            assert all([bridge_level[-1] < up_level[0] for bridge_level, up_level in zip(bridge_paths, up_paths)])

            kernel_size = 3
            padding = kernel_size // 2

            def build_level(level):
                return nn.Sequential(*[conv_block(n_chans_in=n_chans_in, n_chans_out=n_chans_out, padding=padding,
                                                  kernel_size=kernel_size, activation=activation)
                                       for n_chans_in, n_chans_out in zip(level[:-1], level[1:])])

            self.down_levels = nn.ModuleList([build_level(level) for level in down_paths])
            self.bridge_levels = nn.ModuleList([build_level(level) for level in bridge_paths])
            self.up_levels = nn.ModuleList([build_level(level) for level in up_paths])

            print([*down_paths[1:], bridge_paths[-1]])
            #print([])
            
            self.down_steps = nn.ModuleList(
                [conv_block(n_chans_in=level[-1], n_chans_out=down_level[0], padding=padding,
                            kernel_size=kernel_size, stride=stride, activation=activation)
                 for level, down_level in zip(down_paths, [*down_paths[1:], bridge_paths[-1]])]
            )

            self.up_steps = nn.ModuleList(
                [conv_transposed_block(n_chans_in=down_level[-1], n_chans_out=level[0] - bridge_level[-1], padding=padding,
                                       kernel_size=kernel_size, stride=stride, activation=activation)
                 for bridge_level, level, down_level in zip(bridge_paths, up_paths, [*up_paths[1:], bridge_paths[-1]])]
            )

            self.output_layer = (conv_block(n_chans_in=up_paths[0][-1], n_chans_out=n_chans_out, padding=padding,
                                           kernel_size=kernel_size, stride=stride, activation=lambda x: x))

        def forward(self, input):
            down_outputs = []
            for level, down_step in zip(self.down_levels, self.down_steps):
                input = level(input)
                down_outputs.append(input)
                input = down_step(input)

            bridge_outputs = [level(input) for input, level in zip([input, *down_outputs], self.bridge_levels)]
            bottom_input = bridge_outputs[-1]
            for bridge_output, up_step, up_level in reversed(list(zip(bridge_outputs[:-1], self.up_steps,
                                                                      self.up_levels))):
                bottom_input = up_level(torch.cat([bridge_output, up_step(bottom_input)], dim=1))
            
            print('4', flush=True)
            return self.output_layer(bottom_input)

    return TNet


TNet2d = make_tnet(ConvBlock2d, ConvTransposeBlock2d)
TNet3d = make_tnet(ConvBlock3d, ConvTransposeBlock3d)

In [41]:
structure = [
    [[8, 8],   [], [56, 56]],
    [[16, 16], [], [48, 48]],
    [[32, 32]]
]


tnet = TNet3d(n_chans_in=3, n_chans_out=4, activation=nn.functional.relu, structure=structure, stride=2).cuda()

> <ipython-input-39-1dbe15723be1>(10)__init__()
-> down_paths = [level[0] for level in structure[:-1]]
(Pdb) n
> <ipython-input-39-1dbe15723be1>(11)__init__()
-> bridge_paths = [level[1] for level in structure[:-1]] + structure[-1]
(Pdb) n
> <ipython-input-39-1dbe15723be1>(12)__init__()
-> up_paths = [level[2] for level in structure[:-1]]
(Pdb) n
> <ipython-input-39-1dbe15723be1>(16)__init__()
-> down_paths[0] = [n_chans_in, *down_paths[0]]
(Pdb) n
> <ipython-input-39-1dbe15723be1>(17)__init__()
-> bridge_paths = [[down_level[-1], *bridge_level]
(Pdb) n
> <ipython-input-39-1dbe15723be1>(18)__init__()
-> for down_level, bridge_level in zip(down_paths, bridge_paths)]
(Pdb) bridge_paths
[[], [], [32, 32]]
(Pdb) l
 13  	
 14  	
 15  	
 16  	            down_paths[0] = [n_chans_in, *down_paths[0]]
 17  	            bridge_paths = [[down_level[-1], *bridge_level]
 18  ->	                            for down_level, bridge_level in zip(down_paths, bridge_paths)]
 19  	
 20  	            assert

BdbQuit: 

> [0;32m/nmnt/media/home/krivov/anaconda3/lib/python3.6/bdb.py[0m(67)[0;36mdispatch_line[0;34m()[0m
[0;32m     65 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mstop_here[0m[0;34m([0m[0mframe[0m[0;34m)[0m [0;32mor[0m [0mself[0m[0;34m.[0m[0mbreak_here[0m[0;34m([0m[0mframe[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0m
[0m[0;32m     66 [0;31m            [0mself[0m[0;34m.[0m[0muser_line[0m[0;34m([0m[0mframe[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m---> 67 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0mquitting[0m[0;34m:[0m [0;32mraise[0m [0mBdbQuit[0m[0;34m[0m[0m
[0m[0;32m     68 [0;31m        [0;32mreturn[0m [0mself[0m[0;34m.[0m[0mtrace_dispatch[0m[0;34m[0m[0m
[0m[0;32m     69 [0;31m[0;34m[0m[0m
[0m
ipdb> exit


In [34]:
input = Variable(torch.from_numpy(np.array(np.random.randn(1, 3, 100, 100, 100), dtype=np.float32)),
                 volatile=True).cuda()

In [31]:
tnet.down_steps

ModuleList (
  (0): ConvBlock (
    (convolution): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
    (batch_norm): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True)
  )
  (1): ConvBlock (
    (convolution): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
    (batch_norm): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True)
  )
)

In [11]:
tnet(input);

RuntimeError: Need input of dimension 5 and input.size[1] == 48 but got input to be of shape: [1 x 8 x 100 x 100 x 100] at /opt/conda/conda-bld/pytorch_1501971235237/work/pytorch-0.1.12/torch/lib/THCUNN/generic/VolumetricFullConvolution.cu:65