In [37]:
# default_exp Models

# Models

> Classes and helper fuction for creating flexible 2D and 3D unets

In [38]:
#hide

from nbdev.showdoc import *
from fastcore.test import *
from pdb import set_trace

In [39]:
#export
import torch.nn as nn
import torch
import types

In [40]:
#export
def _get_conv(ndim: int):
    "Get Convolution Layer of any dimension"
    assert 1 <= ndim <=3
    return getattr(nn, f'Conv{ndim}d')

def _get_bn(ndim: int):
    "Get BatchNorm Layer of any dimension"
    assert 1 <= ndim <=3
    return getattr(nn, f'BatchNorm{ndim}d')

In [41]:
test_eq(_get_conv(1), torch.nn.modules.conv.Conv1d)
test_eq(_get_conv(2), torch.nn.modules.conv.Conv2d)
test_eq(_get_conv(3), torch.nn.modules.conv.Conv3d)
test_eq(_get_bn(1), torch.nn.modules.batchnorm.BatchNorm1d)
test_eq(_get_bn(2), torch.nn.modules.batchnorm.BatchNorm2d)
test_eq(_get_bn(3), torch.nn.modules.batchnorm.BatchNorm3d)

In [42]:
show_doc(_get_conv)

<h4 id="_get_conv" class="doc_header"><code>_get_conv</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>_get_conv</code>(**`ndim`**:`int`)

Get Convolution Layer of any dimension

In [43]:
show_doc(_get_bn)

<h4 id="_get_bn" class="doc_header"><code>_get_bn</code><a href="__main__.py#L7" class="source_link" style="float:right">[source]</a></h4>

> <code>_get_bn</code>(**`ndim`**:`int`)

Get BatchNorm Layer of any dimension

In [44]:
#export
def init_func(m, func=nn.init.kaiming_normal_):
    "Initialize pytorch model `m` weights with `func`"
    if func and hasattr(m, 'weight'): func(m.weight)
    return m

In [45]:
show_doc(init_func)

<h4 id="init_func" class="doc_header"><code>init_func</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>init_func</code>(**`m`**, **`func`**=*`'kaiming_normal_'`*)

Initialize pytorch model `m` weights with `func`

In [46]:
#export 
def layer_types(m):
    "returns list of pytorch models type"
    if isinstance(m, list): return list(map(type, m))
    return list(map(type, m.children()))

In [47]:
#export
def extract_layer(m, name=torch.nn.modules.Conv3d):
    res = []
    for child in m.children():
        for layer in child.modules():
            if(isinstance(layer,name)):
                res.append(layer)
    return res

In [48]:
#export
class FlexConvLayer(nn.Sequential):
    '''Create Flexible Conv Layers
       \n`ni`: in_channels
       \n`nf`: out_channels
       \n`ks`: kernal_size
       \n`st`: stride
       \n`pd`: padding default is 1 
       \n`sf`: scale factore if for upsampling layer
       \n`bn`: adds BatchNorm layer if `True`
       \n`ups`: adds Upsampling layer if `True`
       \n`ndim`: number of dimensions for layers e.g if 3 will create `nn.Conv3D` or `nn.BatchNorm3d`
       \n`func`: initiation function by default `nn.init.kaiming_normal_`
       \n`xtra`: adds any extra nn.Layers
       \n`act_fn`: activation function'''
    def __init__(self, ni, nf,ks=3, st=1, ndim=3, sf=2, pd=None, act_fn=None, bn=False, xtra=None, ups=False, func=nn.init.kaiming_normal_, **kwargs):
        
        layers = []
        if pd is None: pd = 1
        conv_l = _get_conv(ndim)(in_channels=ni, out_channels=nf,kernel_size =ks, stride=st, padding=pd, **kwargs)
        init_func(conv_l, func)
        layers.append(conv_l)
        if ups       : layers.insert(0, nn.Upsample(scale_factor=sf))
        if bn        : layers.append(_get_bn(ndim)(nf))
        if act_fn    : layers.append(act_fn())
        if xtra      : layers.append(xtra)
        super().__init__(*layers)

In [49]:
show_doc(FlexConvLayer)

<h2 id="FlexConvLayer" class="doc_header"><code>class</code> <code>FlexConvLayer</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>FlexConvLayer</code>(**`ni`**, **`nf`**, **`ks`**=*`3`*, **`st`**=*`1`*, **`ndim`**=*`3`*, **`sf`**=*`2`*, **`pd`**=*`None`*, **`act_fn`**=*`None`*, **`bn`**=*`False`*, **`xtra`**=*`None`*, **`ups`**=*`False`*, **`func`**=*`'kaiming_normal_'`*, **\*\*`kwargs`**) :: `Sequential`

Create Flexible Conv Layers
       
`ni`: in_channels
       
`nf`: out_channels
       
`ks`: kernal_size
       
`st`: stride
       
`pd`: padding default is 1 
       
`sf`: scale factore if for upsampling layer
       
`bn`: adds BatchNorm layer if `True`
       
`ups`: adds Upsampling layer if `True`
       
`ndim`: number of dimensions for layers e.g if 3 will create `nn.Conv3D` or `nn.BatchNorm3d`
       
`func`: initiation function by default `nn.init.kaiming_normal_`
       
`xtra`: adds any extra nn.Layers
       
`act_fn`: activation function

In [50]:
test_eq(layer_types(FlexConvLayer(1, 2)), [torch.nn.modules.conv.Conv3d])
test_eq(layer_types(FlexConvLayer(1, 2, bn=True)), [torch.nn.modules.conv.Conv3d, torch.nn.modules.batchnorm.BatchNorm3d])
test_eq(layer_types(FlexConvLayer(1, 2, bn=True)), [torch.nn.modules.conv.Conv3d, torch.nn.modules.batchnorm.BatchNorm3d])
test_eq(layer_types(FlexConvLayer(1, 2, bn=True, ups=True)), [torch.nn.modules.upsampling.Upsample, torch.nn.modules.conv.Conv3d, torch.nn.modules.batchnorm.BatchNorm3d])
test_eq(layer_types(FlexConvLayer(1, 2, bn=True, ups=True, act_fn=nn.ELU)), [torch.nn.modules.upsampling.Upsample, torch.nn.modules.conv.Conv3d, torch.nn.modules.batchnorm.BatchNorm3d, torch.nn.modules.activation.ELU])

In [51]:
#export
def init_default(m, func=nn.init.kaiming_normal_):
    "Initialize `m` weights with `func` and set `bias` to 0."
    if func and hasattr(m, 'weight'): func(m.weight)
    with torch.no_grad():
        if getattr(m, 'bias', None) is not None: m.bias.fill_(0.)
    return m

In [52]:
#export 
class FlexUnetEncoder(nn.Module):
    '''Creates flexible encoder for Unets
       \n`ni`: in_channels
       \n`nf`: out_channels
       \n`ks`: kernal_size
       \n`st`: stride
       \n`pd`: padding default is 1 
       \n`bn`: adds BatchNorm layer if `True`
       \n`act_fn`: activation function
       \n`conv_depth`: number of conv layers
       '''
    
    def __init__(self, ni, nf, ks, st, pd, conv_depth, ndim=3,  act_fn=nn.ELU, **kwargs):
        super().__init__()
        nf = nf
        self.module_dict = nn.ModuleDict()
        self.module_dict[f'pass_n'] = FlexConvLayer(ni, nf, act_fn=act_fn, ndim=ndim, **kwargs)
        self.module_dict[f'save_n'] = FlexConvLayer(nf, nf, act_fn=act_fn, ndim=ndim, **kwargs)
        for i in range(conv_depth):
            self.module_dict[f'pass_{i}_k'] = nn.Sequential(FlexConvLayer(nf, nf, ks=ks-1, st=st + 1, act_fn=act_fn, pd=pd, ndim=ndim, **kwargs))
            self.module_dict[f'save_{i}'] = nn.Sequential(FlexConvLayer(nf, nf*2, ks=ks, st=st, act_fn=act_fn, ndim=ndim, **kwargs), )
            self.module_dict[f'pass_{i}'] = nn.Sequential(FlexConvLayer(nf*2, nf*2, ks=ks, st=st, act_fn=act_fn, ndim=ndim, **kwargs))
            nf *=2
    def forward(self, x):
        features = []
        for i in self.module_dict:
            if     i.startswith('pass'): x = self.module_dict[i](x)
            else:  x = self.module_dict[i](x) ;features.append(x)
        return (x, features[:-1])

In [53]:
show_doc(FlexUnetEncoder)

<h2 id="FlexUnetEncoder" class="doc_header"><code>class</code> <code>FlexUnetEncoder</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>FlexUnetEncoder</code>(**`ni`**, **`nf`**, **`ks`**, **`st`**, **`pd`**, **`conv_depth`**, **`ndim`**=*`3`*, **`act_fn`**=*`'ELU'`*, **\*\*`kwargs`**) :: `Module`

Creates flexible encoder for Unets
       
`ni`: in_channels
       
`nf`: out_channels
       
`ks`: kernal_size
       
`st`: stride
       
`pd`: padding default is 1 
       
`bn`: adds BatchNorm layer if `True`
       
`act_fn`: activation function
       
`conv_depth`: number of conv layers
       

In [54]:
#2d
tst_encoder = FlexUnetEncoder(1, 48, 3, 1, 0, conv_depth=2, ndim=2)
tst_encoder

FlexUnetEncoder(
  (module_dict): ModuleDict(
    (pass_n): FlexConvLayer(
      (0): Conv2d(1, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ELU(alpha=1.0)
    )
    (save_n): FlexConvLayer(
      (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ELU(alpha=1.0)
    )
    (pass_0_k): Sequential(
      (0): FlexConvLayer(
        (0): Conv2d(48, 48, kernel_size=(2, 2), stride=(2, 2))
        (1): ELU(alpha=1.0)
      )
    )
    (save_0): Sequential(
      (0): FlexConvLayer(
        (0): Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ELU(alpha=1.0)
      )
    )
    (pass_0): Sequential(
      (0): FlexConvLayer(
        (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ELU(alpha=1.0)
      )
    )
    (pass_1_k): Sequential(
      (0): FlexConvLayer(
        (0): Conv2d(96, 96, kernel_size=(2, 2), stride=(2, 2))
        (1): ELU(alpha=1.0)
      )
    )
    (save_1): S

In [55]:
#3d
tst_encoder = FlexUnetEncoder(1, 48, 3, 1, 0, conv_depth=2, ndim=3)
tst_encoder

FlexUnetEncoder(
  (module_dict): ModuleDict(
    (pass_n): FlexConvLayer(
      (0): Conv3d(1, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): ELU(alpha=1.0)
    )
    (save_n): FlexConvLayer(
      (0): Conv3d(48, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): ELU(alpha=1.0)
    )
    (pass_0_k): Sequential(
      (0): FlexConvLayer(
        (0): Conv3d(48, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
        (1): ELU(alpha=1.0)
      )
    )
    (save_0): Sequential(
      (0): FlexConvLayer(
        (0): Conv3d(48, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): ELU(alpha=1.0)
      )
    )
    (pass_0): Sequential(
      (0): FlexConvLayer(
        (0): Conv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): ELU(alpha=1.0)
      )
    )
    (pass_1_k): Sequential(
      (0): FlexConvLayer(
        (0): Conv3d(96, 96, kernel_size=(2, 2, 2), stride=(2, 2, 2))
        (

In [56]:
x_batch = torch.rand(20, 1, 20, 20, 20)
res = tst_encoder(x_batch)

In [57]:
test_eq(len(res), 2)
test_eq(res[0].shape[1], 192)
test_eq(len(res[1]), 2)

In [58]:
res[1][0].shape, res[1][1].shape

(torch.Size([20, 48, 20, 20, 20]), torch.Size([20, 96, 10, 10, 10]))

In [59]:
test_eq(len(extract_layer(tst_encoder)), 8)
test_eq(layer_types(extract_layer(tst_encoder))[:1], [torch.nn.modules.conv.Conv3d])
test_eq(layer_types(extract_layer(tst_encoder)), [torch.nn.modules.conv.Conv3d]*8)
test_eq(getattr(extract_layer(tst_encoder)[0], 'kernel_size'), (3, 3, 3))
test_eq(getattr(extract_layer(tst_encoder)[0], 'stride'), (1, 1, 1))
test_eq(getattr(extract_layer(tst_encoder)[0], 'padding'), (1, 1, 1))
test_eq(getattr(extract_layer(tst_encoder)[2], 'padding'), (0, 0, 0))

In [60]:
#export 
class FlexUnetDecoder(nn.Module):
    '''Creates flexible encoder for Unets
       \n`nf`: out_channels
       \n`ks`: kernal_size
       \n`st`: stride
       \n`pd`: padding default is 1 
       \n`conv_depth`: number of conv layers
       ''' 
    def __init__(self, nf,  ks, st, pd, conv_depth, ndim=3,  act_fn=nn.ELU, **kwargs):
        super().__init__()
        nf = self._get_enc_filter(nf, conv_depth)
        self.module_dict = nn.ModuleDict()

        for i in range(conv_depth):
            self.module_dict[f'conc_{i}'] = nn.Sequential(FlexConvLayer(nf, nf//2, ks=ks, st=st, ups = True, act_fn=act_fn, pd=pd, ndim=ndim, **kwargs))
            self.module_dict[f'pass_{i}'] = nn.Sequential(FlexConvLayer(nf, nf//2, ks=ks, st=st, act_fn=act_fn, ndim=ndim, **kwargs),
                                                          FlexConvLayer(nf//2, nf//2, ks=ks, st=st, act_fn=act_fn, ndim=ndim, **kwargs))
            nf //=2
            
    def forward(self, x, features):
        for i in self.module_dict:
            if i.startswith('conc'): 
                  x = self.module_dict[i](x)
                  x = torch.cat([x,features.pop()],1) 
            else: x = self.module_dict[i](x)
        return x
    
    @staticmethod
    def _get_enc_filter(nf, conv_depth):
        nf = nf
        for i in range(conv_depth): nf *=2 
        return nf

In [61]:
show_doc(FlexUnetDecoder)

<h2 id="FlexUnetDecoder" class="doc_header"><code>class</code> <code>FlexUnetDecoder</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>FlexUnetDecoder</code>(**`nf`**, **`ks`**, **`st`**, **`pd`**, **`conv_depth`**, **`ndim`**=*`3`*, **`act_fn`**=*`'ELU'`*, **\*\*`kwargs`**) :: `Module`

Creates flexible encoder for Unets
       
`nf`: out_channels
       
`ks`: kernal_size
       
`st`: stride
       
`pd`: padding default is 1 
       
`conv_depth`: number of conv layers
       

In [62]:
nf_enc_in = 48
conv_depth = 2
ndim = 3
ks = 3
st = 1
pd = 1
act_fn = nn.ELU()

In [63]:
tst_decoder = FlexUnetDecoder(nf_enc_in, ks, st, pd, conv_depth)
tst_decoder

FlexUnetDecoder(
  (module_dict): ModuleDict(
    (conc_0): Sequential(
      (0): FlexConvLayer(
        (0): Upsample(scale_factor=2.0, mode=nearest)
        (1): Conv3d(192, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (2): ELU(alpha=1.0)
      )
    )
    (pass_0): Sequential(
      (0): FlexConvLayer(
        (0): Conv3d(192, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): ELU(alpha=1.0)
      )
      (1): FlexConvLayer(
        (0): Conv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): ELU(alpha=1.0)
      )
    )
    (conc_1): Sequential(
      (0): FlexConvLayer(
        (0): Upsample(scale_factor=2.0, mode=nearest)
        (1): Conv3d(96, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (2): ELU(alpha=1.0)
      )
    )
    (pass_1): Sequential(
      (0): FlexConvLayer(
        (0): Conv3d(96, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        

In [64]:
tst_encoder = FlexUnetEncoder(1, 48, 3, 1, 0, conv_depth=2, ndim=3)
x_batch     = torch.rand(20, 1, 20, 20, 20)
res         = tst_encoder(x_batch)
res_dec     = tst_decoder(*res)

In [65]:
test_eq(res_dec.shape[1], nf_enc_in)

In [66]:
#export
class SUNET(nn.Module):
    '''General UNET for 2D or 3D
       \n`ni`: in_channels
       \n`nf`: out_channels
       \n`ks`: kernal_size
       \n`st`: stride
       \n`pd`: padding default is 1 
       \n`ndim`: (2, 3) 2D or 3D depending on dimensions
       \n`conv_depth`: number of conv layers
    '''
    def __init__(self, ni, nc, ks, st, pd, conv_depth, ndim, **kwargs):
        super().__init__()
        self.encoder = FlexUnetEncoder(ni, nc, ks, st, pd, conv_depth, ndim, **kwargs)
        self.decoder = FlexUnetDecoder(nc, ks, st, pd+1, conv_depth, ndim, **kwargs)
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(*x)
        return x
        

In [67]:
show_doc(SUNET)

<h2 id="SUNET" class="doc_header"><code>class</code> <code>SUNET</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>SUNET</code>(**`ni`**, **`nc`**, **`ks`**, **`st`**, **`pd`**, **`conv_depth`**, **`ndim`**, **\*\*`kwargs`**) :: `Module`

General UNET for 2D or 3D
       
`ni`: in_channels
       
`nf`: out_channels
       
`ks`: kernal_size
       
`st`: stride
       
`pd`: padding default is 1 
       
`ndim`: (2, 3) 2D or 3D depending on dimensions
       
`conv_depth`: number of conv layers
    

In [68]:
#3d
tst_unet = SUNET(1, 48, 3, 1, 0, 2, 3)
tst_unet

SUNET(
  (encoder): FlexUnetEncoder(
    (module_dict): ModuleDict(
      (pass_n): FlexConvLayer(
        (0): Conv3d(1, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): ELU(alpha=1.0)
      )
      (save_n): FlexConvLayer(
        (0): Conv3d(48, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): ELU(alpha=1.0)
      )
      (pass_0_k): Sequential(
        (0): FlexConvLayer(
          (0): Conv3d(48, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
          (1): ELU(alpha=1.0)
        )
      )
      (save_0): Sequential(
        (0): FlexConvLayer(
          (0): Conv3d(48, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (1): ELU(alpha=1.0)
        )
      )
      (pass_0): Sequential(
        (0): FlexConvLayer(
          (0): Conv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (1): ELU(alpha=1.0)
        )
      )
      (pass_1_k): Sequential(
        (0): FlexConvLayer(


In [69]:
x_batch = torch.rand(20, 1, 20, 20, 20)
res     = tst_unet(x_batch)

In [70]:
test_eq(res.shape[1], 48)

In [71]:
#2d unet
tst_unet = SUNET(1, 48, 3, 1, 0, 2, 2)
tst_unet

SUNET(
  (encoder): FlexUnetEncoder(
    (module_dict): ModuleDict(
      (pass_n): FlexConvLayer(
        (0): Conv2d(1, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ELU(alpha=1.0)
      )
      (save_n): FlexConvLayer(
        (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ELU(alpha=1.0)
      )
      (pass_0_k): Sequential(
        (0): FlexConvLayer(
          (0): Conv2d(48, 48, kernel_size=(2, 2), stride=(2, 2))
          (1): ELU(alpha=1.0)
        )
      )
      (save_0): Sequential(
        (0): FlexConvLayer(
          (0): Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ELU(alpha=1.0)
        )
      )
      (pass_0): Sequential(
        (0): FlexConvLayer(
          (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ELU(alpha=1.0)
        )
      )
      (pass_1_k): Sequential(
        (0): FlexConvLayer(
          (0): Conv2d(96, 96, kernel_size=

In [72]:
from nbdev.export import *
notebook2script()

Converted 00_models.ipynb.
Converted index.ipynb.
