In [1]:
import numpy as np

import torch
import torch.nn as nn
from torch.autograd import Variable

In [2]:
def _make_conv_block(convolution, batch_norm):
    class ConvBlock(nn.Module):
        def __init__(self, n_chans_in, n_chans_out, kernel_size, activation: callable, stride=1, padding=0, dilation=1):
            super().__init__()
            self.convolution = convolution(in_channels=n_chans_in, out_channels=n_chans_out, kernel_size=kernel_size,
                                           stride=stride, padding=padding, dilation=dilation, bias=False)
            self.batch_norm = batch_norm(num_features=n_chans_out)
            self.activation = activation

        def forward(self, input):
            return self.activation(self.batch_norm(self.convolution(input)))

    return ConvBlock


ConvBlock1d = _make_conv_block(nn.Conv1d, nn.BatchNorm1d)
ConvBlock2d = _make_conv_block(nn.Conv2d, nn.BatchNorm2d)
ConvBlock3d = _make_conv_block(nn.Conv3d, nn.BatchNorm3d)

ConvTransposeBlock1d = _make_conv_block(nn.ConvTranspose1d, nn.BatchNorm1d)
ConvTransposeBlock2d = _make_conv_block(nn.ConvTranspose2d, nn.BatchNorm2d)
ConvTransposeBlock3d = _make_conv_block(nn.ConvTranspose3d, nn.BatchNorm3d)


In [3]:
def copy_iterable_as_list(iterable):
    return [copy_iterable_as_list(item) if hasattr(item, '__iter__') else item for item in iterable]


def make_tnet(conv_block, conv_transposed_block):
    class TNet(nn.Module):
        def __init__(self, n_chans_in, n_chans_out, activation, structure, stride):
            super().__init__()
            assert all([len(level) == 3 for level in structure[:-1]])
            assert len(structure[-1]) == 1

            down_paths = [level[0] for level in structure[:-1]]
            bridge_paths = [level[1] for level in structure[:-1]] + structure[-1]
            up_paths = [level[2] for level in structure[:-1]]

            down_paths[0] = [n_chans_in, *down_paths[0]]
            bridge_paths = [[down_level[-1], *bridge_level]
                            for down_level, bridge_level in zip(down_paths, bridge_paths)]

            assert all([bridge_level[-1] < up_level[0] for bridge_level, up_level in zip(bridge_paths, up_paths)])

            kernel_size = 3
            padding = kernel_size // 2

            def build_level(level):
                return nn.Sequential(*[conv_block(n_chans_in=n_chans_in, n_chans_out=n_chans_out, padding=padding,
                                                  kernel_size=kernel_size, activation=activation)
                                       for n_chans_in, n_chans_out in zip(level[1:], level[:-1])])

            self.down_levels = [build_level(level) for level in down_paths]
            self.bridge_levels = [build_level(level) for level in bridge_paths]
            self.up_levels = [build_level(level) for level in up_paths]

            self.down_steps = [conv_block(n_chans_in=level[-1], n_chans_out=down_level[0], padding=padding,
                                          kernel_size=kernel_size, stride=stride, activation=activation)
                               for level, down_level in zip(down_paths, [*down_paths[1:], bridge_paths[-1]])]

            self.up_steps = [conv_transposed_block(n_chans_in=level[-1], n_chans_out=up_level[0], padding=padding,
                                                   kernel_size=kernel_size, stride=stride, activation=activation)
                             for up_level, level in zip(up_paths, [*up_paths[1:], bridge_paths[-1]])]

            self.output_layer = conv_block(n_chans_in=up_paths[0][-1], n_chans_out=n_chans_out, padding=padding,
                                           kernel_size=kernel_size, stride=stride, activation=lambda x: x)

        def forward(self, input):
            print('1', flush=True)
            down_outputs = []
            for level, down_step in zip(self.down_levels, self.down_steps):
                input = level(input)
                down_outputs.append(input)
                input = down_step(input)

            print('2', flush=True)
            bridge_outputs = [level(input) for input, level in zip([input, *down_outputs], self.bridge_levels)]
            print('3', flush=True)
            bottom_input = bridge_outputs[-1]
            for bridge_output, up_step, up_level in reversed(list(zip(bridge_outputs[:-1], self.up_steps,
                                                                      self.up_levels))):
                bottom_input = up_level(torch.cat([bridge_output, up_step(bottom_input)], dim=1))
            
            print('4', flush=True)
            return self.output_layer(bottom_input)

    return TNet


TNet2d = make_tnet(ConvBlock2d, ConvTransposeBlock2d)
TNet3d = make_tnet(ConvBlock3d, ConvTransposeBlock3d)

# register('TNet2d')(TNet2d)
# register('TNet3d')(TNet3d)


In [4]:
structure = [
    [[8, 8],   [], [56, 56]],
    [[16, 16], [], [48, 48]],
    [[32, 32]]
]


tnet = TNet3d(n_chans_in=3, n_chans_out=4, activation=nn.functional.relu, structure=structure, stride=2)

In [5]:
input = Variable(torch.from_numpy(np.ones((3, 10, 10, 10))), volatile=True)
#tnet(input)

In [6]:
b = ConvBlock3d(3, 4, kernel_size=3, activation=nn.functional)

In [None]:
b(input)