In [1]:
import torch
import torch.nn as nn
from torchvision import models
from torchsummary import summary
import numpy as np

!git clone https://github.com/franciscocms/Deep_Learning
import sys  
sys.path.insert(0, './')    

import Deep_Learning.resnet_module as resnet_module


Cloning into 'Deep_Learning'...


In [2]:
def activation_func(activation):
    return  nn.ModuleDict([
        ['relu', nn.ReLU(inplace=True)],
        ['leaky_relu', nn.LeakyReLU(negative_slope=0.01, inplace=True)],
        ['selu', nn.SELU(inplace=True)],
        ['none', nn.Identity()],
        ['sigmoid', nn.Sigmoid()]
    ])[activation]

In [6]:
class ResidualBlock(nn.Module):
    def __init__(self, conv1, conv2, bn1 = None, bn2 = None, activation = 'relu'):
        super().__init__()
        
        self.conv1, self.conv2 = conv1, conv2
        if bn1 is None:
            self.bn1 = nn.BatchNorm2d(num_features = self.conv1.out_channels)
            self.bn2 = nn.BatchNorm2d(num_features = self.conv2.out_channels)
        else:
            self.bn1, self.bn2 = bn1, bn2
        
        self.activate = activation_func(activation)
        
        if self.conv1.in_channels == self.conv2.in_channels:
            self.shortcut = nn.Identity()  
        else:
            self.shortcut = nn.Conv2d(self.conv1.in_channels, self.conv2.out_channels, kernel_size = 1, stride = 1, bias = False)

        
    def forward(self, x):

        residual = self.shortcut(x)
        x = self.activate(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        x = self.activate(x)

        return x


class middle_conv_block(nn.Module):
    def __init__(self, in_channels, out_channels, residual_block = False, activation = 'relu'):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(num_features = out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(num_features = out_channels)
        self.activate = activation_func(activation)
        
        self.residual_block = residual_block
        
    def forward(self, x):
        
        if self.residual_block:
            resblock = ResidualBlock(self.conv1, self.conv2, self.bn1, self.bn2)
            x = resblock(x)
        else:
            x = self.activate(self.bn1(self.conv1(x)))
            x = self.activate(self.bn2(self.conv2(x)))
                      
        return x
    
class UnetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, residual_block, activation = 'relu'):
        super().__init__() 
        
        self.tconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size = 2, stride = 2)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
        self.activate = activation_func(activation)
        

        self.residual_block = residual_block
    
    def forward(self, x, x_connection):
        
        x = self.activate(self.tconv(x))
        x = torch.cat((x_connection, x), dim = 1)
        
        if self.residual_block:
            resblock = ResidualBlock(self.conv1, self.conv2)
            x = resblock(x)
        else:
            x = self.activate(self.conv1(x))         
            x = self.activate(self.conv2(x))
                      

        return x
    
class UnetDecoder(nn.Module):
    def __init__(self, channels):
        super().__init__()
        
        unet_block = UnetBlock
        
        self.blocks = nn.ModuleList([])
        for i in range(len(channels) - 1):
            self.blocks.append(unet_block(channels[i], channels[i+1], True))
        self.final_layers = nn.ModuleList([])
        self.final_layers.append(nn.Conv2d(64, 1, kernel_size = 1, stride = 1, padding = 0, bias = False))
        self.final_layers.append(nn.Conv2d(2, 1, kernel_size = 1, stride = 1, padding = 0, bias = False))
                
       
    def forward(self, x, x_connections, original):
        
        keys = ['block_2', 'block_1','block_0', 'gate']
        
        for i in range(len(self.blocks)):
            x = self.blocks[i](x, x_connections[keys[i]])
        
        x = self.final_layers[0](x)
        x = torch.cat((original, x), dim = 1) # concatenating original image 
        x = self.final_layers[1](x)
        
        return x

    
class UnetEncoder(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        
        self.encoder = encoder
        
        self.activation = {}
        def get_activation(name):
            def hook(model, input, output):
                self.activation[name] = output.detach()#.cpu().numpy()
            return hook
        
        self.encoder.gate[2].register_forward_hook(get_activation('gate'))
        self.encoder.blocks[0].activate.register_forward_hook(get_activation('block_0'))
        self.encoder.blocks[1].activate.register_forward_hook(get_activation('block_1'))   
        self.encoder.blocks[2].activate.register_forward_hook(get_activation('block_2'))
        
        
    def forward(self, x):
        
        original = x.clone()
        x = self.encoder(x)
        
        #print('\nActivations:\t')
        #for k, v in self.activation.items():
        #    print(k, v.shape)
        #print()
  
        return x, self.activation, original

class DynamicUNet(nn.Module):
    def __init__(self, encoder_network):
        super().__init__()
        
        def get_encoder(network, implemented):
            return nn.Sequential(*list(network.children())[:-2]) if not implemented else network.encoder        

        if encoder_network == 'resnet18':
            resnet = resnet_module.resnet18(in_channels = 1, n_classes = 1, unet_encoder = True)
        elif encoder_network == 'resnet34':
            resnet = resnet_module.resnet34(in_channels = 1, n_classes = 1, unet_encoder = True)
        encoder = get_encoder(resnet, implemented = True)
        
        self.encoder = UnetEncoder(encoder)
        self.middle_conv = middle_conv_block(1024, 1024, residual_block = True)
        self.decoder = UnetDecoder([1024, 512, 256, 128, 64])
        
    def forward(self, x):    
        
        x, self.activation, original = self.encoder(x)
        
        print('After encoder shape: {}\t' .format(x.shape))
        x = self.middle_conv(x)
        print('After middle convolutions shape: {}\t' .format(x.shape))
        x = self.decoder(x, self.activation, original)
        
        return x


# Choose the resnetxx encoder (ex. resnet18, resnet34, ...)
encoder = 'resnet34'

model = DynamicUNet(encoder)

x = torch.randn(1, 1, 256, 256)
print('Input shape: {}\t' .format(x.shape))

output = model(x)
print('Output shape: {}\t' .format(output.shape))




Input shape: torch.Size([1, 1, 256, 256])	
After encoder shape: torch.Size([1, 1024, 16, 16])	
After middle convolutions shape: torch.Size([1, 1024, 16, 16])	
Output shape: torch.Size([1, 1, 256, 256])	
