In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np


In [7]:
def weight_init(w):
    if isinstance(w, nn.Conv2d):
        nn.init.xavier_normal_(w.weight)
        nn.init.constant_(w.bias,0)
        

def initialize_weigths(*models):
    for model in models:
        for module in model.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
                    
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()
                

def conv(in_channels, out_channels, kernel_size, stride=1,
            padding=1, bias=True, groups=1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size,
        stride=stride,
        padding=padding,
        bias=bias,
        groups=groups)


def up_conv(in_channels, out_channels, kernel_size, mode='transpose'):
    if mode == 'transpose':
        return nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=2)

    else:
        return nn.Sequential(
            nn.Upsample(mode='bilinear', scale_factor=2),
            conv(in_channels, out_channels, kernel_size=1))

In [20]:
class EncoderBlock(nn.Module):
    """
    A helper class that performs 2 convlutions, 1 MaxPool and 1 batch norm
    """
    def __init__(self,in_ch, out_ch, kernel_size=3, pooling = True):
        super(EncoderBlock, self).__init__()
        
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.kernel_size = kernel_size
        self.pooling = pooling
        
        
        
        self.conv1 = conv(self.in_ch, self.out_ch, self.kernel_size,groups=1)
        print(self.conv1)
        self.conv2 = conv(self.out_ch,self.out_ch,self.kernel_size,groups=1)
        print(self.conv2)
        
        if self.pooling:
            self.max_pool = nn.MaxPool2d(kernel_size =2, stride=2)
            
        self.batch_norm = nn.BatchNorm2d(self.out_ch)
        
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        
        before_pool = x
        if self.pooling:
            x = self.max_pool(x)
        
        x = self.batch_norm(x)
        
        return x, before_pool
    

class DecoderBlock(nn.Module):
    """
    A helper class that performs 2 convlutions, 1 up convolution
    """
    def __init__(self, in_ch, out_ch, kernel_size=3,
                 merge_mode = 'concat', up_mode = 'transpose'):
        super(DecoderBlock,self).__init__()
        
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.kernel_size = kernel_size
        self.merge_mode = merge_mode
        self.up_mode = up_mode
        
        self.up_conv_2x2 = up_conv(self.in_ch, self.out_ch, kernel_size =2, mode = self.up_mode)
        self.up_conv_3x3 = up_conv(self.in_ch, self.out_ch, kernel_size =3 , mode = self.up_mode)
        
        if self.merge_mode =='concat':
            self.conv1 = conv(2*self.out_ch, self.out_ch, kernel_size)
        else:
            self.conv1 = conv(self.out_ch,self.out_ch, kernel_size)
            
        self.conv2 = conv(self.out_ch,self.out_ch, kernel_size)
        
    def forward(self, from_encoder, to_decoder):
        if from_encoder.shape[2] % 2 ==0:
            to_decoder = self.up_conv_2x2(to_decoder)
        else:
            to_decoder = self.up_conv_3x3(to_decoder)
            
        if self.merge_mode =='concat':
            x = torch.cat((to_decoder, from_encoder),1)
        else:
            x = to_decoder + from_encoder
        
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        
        return x
    
    
        
        
        

In [21]:
class UNET(nn.Module):
    '''
    UNET model for OCT segmenation
    '''
    def __init__(self, num_classes, in_ch = 3, depth = 5, 
                start_filter_num = 64, filter_num_factor = 2, up_mode = 'transpose', merge_mode = 'concat'):
        super(UNET, self).__init__()
        
        '''
        Value checker, Error check
        '''
        if up_mode in ('transpose', 'upsample'):
            self.up_mode = up_mode
        else:
            raise ValueError("{} is not a valid mode for upsampling. "
                             "Only \"transpose\" and \"upsample\" are allowed".format(up_mode))
       
        if merge_mode in ('concat', 'add'):
            self.merge_mode = merge_mode
        else:
            raise ValueError("{} is not a valid mode for merging encoder and decoder paths. "
                             "Only \"concat\" and \"add\" are valid".format(merge_mode))
            
        self.in_ch = in_ch
        self.num_classes = num_classes
        self.depth = depth
        self.start_filter_num = start_filter_num
        self.filter_num_factor = filter_num_factor
        
        self.encoder_path = []
        self.decoder_path = []
        
        out_filters = self.in_ch
        
        '''
        encoder paths
            depth : how many pooling in encoder blocks.
            
            
        '''
        for i in range(depth):
            print('[DEBUG] : DEPTH of ENCODER : ',i)
            in_filters = out_filters
            out_filters = self.start_filter_num * (self.filter_num_factor ** i)
            pooling = True if i < depth-1 else False
            encoder_block = EncoderBlock(in_filters, out_filters, pooling=pooling)
            self.encoder_path.append(encoder_block)
            
        
        '''
        decoder paths
        '''
        for i in range(depth-1):
            print('[DEBUG] : DEPTH of DECODER : ',i)
            in_filters = out_filters
            out_filters = in_filters // filter_num_factor
            decoder_block = DecoderBlock(in_filters, out_filters,
                                        merge_mode = self.merge_mode, up_mode = self.up_mode)
            self.decoder_path.append(decoder_block)
        
        
        self.ending_conv = conv(out_filters, self.num_classes, kernel_size=1, padding = 0)
        
        
        #TODO using ModuleDict with Encoder and Decoder as keys
        # 이녀석들은 뭐지
        self.encoder_module = nn.ModuleList(self.encoder_path)
        self.decoder_module = nn.ModuleList(self.decoder_path)
        
        initialize_weigths(self)
        
    def forward(self, x):
        encoder_outs = []
        
        for i, module in enumerate(self.encoder_module):
            x, before_pool = module(x)
            encoder_outs.append(before_pool)
            
        for i, module in enumerate(self.decoder_module):
            before_pool = encoder_outs[-(i+2)]
            x = module(before_pool, x)
            
        x = self.ending_conv(x)
        return x
    
    
            

In [22]:
cuda = torch.cuda.is_available()
print(cuda)

True


In [23]:
model = UNET(num_classes=3, in_ch=1, depth=5, filter_num_factor=2, merge_mode='concat')
print(model)
model.apply(weight_init)
cuda = torch.cuda.is_available()
x = Variable(torch.rand(1, 1, 584, 584))
if cuda:
    model.cuda()
    x = x.cuda()
with torch.no_grad():
    out = model(x)
    print(out.shape)

[DEBUG] : DEPTH of ENCODER :  0
Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
[DEBUG] : DEPTH of ENCODER :  1
Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
[DEBUG] : DEPTH of ENCODER :  2
Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
[DEBUG] : DEPTH of ENCODER :  3
Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
[DEBUG] : DEPTH of ENCODER :  4
Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
[DEBUG] : DEPTH of DECODER :  0
[DEBUG] : DEPTH of DECODER :  1
[DEBUG] : DEPTH of DECODER :  2
[DEBUG] : DEPTH of DECODER :  3
UNET(
  (ending_conv): Conv2d(64, 3

In [None]:
model

In [1]:
from model_helper import *

In [2]:
conv

<function model_helper.conv(in_channels, out_channels, kernel_size, stride=1, padding=1, bias=True, groups=1)>

In [6]:
import sys 
import os
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(lib))))


NameError: name 'lib' is not defined