In [3]:
# default_exp Models

# Models

> Classes and helper fuction for creating flexible 2D and 3D unets

In [90]:
#hide
from nbdev.showdoc import *
from fastcore.test import *

In [91]:
#export
import torch.nn as nn
import torch

In [92]:
#export
def _get_conv(ndim: int):
    "Get Convolution Layer of any dimension"
    assert 1 <= ndim <=3
    return getattr(nn, f'Conv{ndim}d')

def _get_bn(ndim: int):
    "Get BatchNorm Layer of any dimension"
    assert 1 <= ndim <=3
    return getattr(nn, f'BatchNorm{ndim}d')

In [93]:
test_eq(_get_conv(1), torch.nn.modules.conv.Conv1d)
test_eq(_get_conv(2), torch.nn.modules.conv.Conv2d)
test_eq(_get_conv(3), torch.nn.modules.conv.Conv3d)
test_eq(_get_bn(1), torch.nn.modules.batchnorm.BatchNorm1d)
test_eq(_get_bn(2), torch.nn.modules.batchnorm.BatchNorm2d)
test_eq(_get_bn(3), torch.nn.modules.batchnorm.BatchNorm3d)

In [94]:
show_doc(_get_conv)

<h4 id="_get_conv" class="doc_header"><code>_get_conv</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>_get_conv</code>(**`ndim`**:`int`)

Get Convolution Layer of any dimension

In [95]:
show_doc(_get_bn)

<h4 id="_get_bn" class="doc_header"><code>_get_bn</code><a href="__main__.py#L7" class="source_link" style="float:right">[source]</a></h4>

> <code>_get_bn</code>(**`ndim`**:`int`)

Get BatchNorm Layer of any dimension

In [176]:
#export 
def layer_types(m):
    "returns list of pytorch models type"
    if isinstance(m, list): return list(map(type, m))
    return list(map(type, m.children()))

In [213]:
#export
def extract_layer(m, name=torch.nn.modules.Conv3d):
    res = []
    for child in m.children():
        for layer in child.modules():
            if(isinstance(layer,name)):
                res.append(layer)
    return res

In [177]:
#export
class FlexConvLayer(nn.Sequential):
    '''Create Flexible Conv Layers
       \n`ni`: in_channels
       \n`nf`: out_channels
       \n`ks`: kernal_size
       \n`st`: stride
       \n`pd`: padding default is 1 
       \n`sf`: scale factore if for upsampling layer
       \n`bn`: adds BatchNorm layer if `True`
       \n`ups`: adds Upsampling layer if `True`
       \n`ndim`: number of dimensions for layers e.g if 3 will create `nn.Conv3D` or `nn.BatchNorm3d`
       \n`xtra`: adds any extra nn.Layers
       \n`act_fn`: activation function'''
    def __init__(self, ni, nf,ks=3, st=1, ndim=3, sf=2, pd=None, act_fn=None, bn=False, xtra=None, ups=False, **kwargs):
        layers = []
        if pd is None: pd = 1
        layers.append(_get_conv(ndim)(in_channels=ni, out_channels=nf,kernel_size =ks, stride=st, padding=pd, **kwargs))
        if ups       : layers.insert(0, nn.Upsample(scale_factor=sf))
        if bn        : layers.append(_get_bn(ndim)(nf))
        if act_fn    : layers.append(act_fn)
        if xtra      : layers.append(xtra)
        super().__init__(*layers)

In [178]:
show_doc(FlexConvLayer)

<h2 id="FlexConvLayer" class="doc_header"><code>class</code> <code>FlexConvLayer</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>FlexConvLayer</code>(**`ni`**, **`nf`**, **`ks`**=*`3`*, **`st`**=*`1`*, **`ndim`**=*`3`*, **`sf`**=*`2`*, **`pd`**=*`None`*, **`act_fn`**=*`None`*, **`bn`**=*`False`*, **`xtra`**=*`None`*, **`ups`**=*`False`*, **\*\*`kwargs`**) :: `Sequential`

Create Flexible Conv Layers
       
`ni`: in_channels
       
`nf`: out_channels
       
`ks`: kernal_size
       
`st`: stride
       
`pd`: padding default is 1 
       
`sf`: scale factore if for upsampling layer
       
`bn`: adds BatchNorm layer if `True`
       
`ups`: adds Upsampling layer if `True`
       
`ndim`: number of dimensions for layers e.g if 3 will create `nn.Conv3D` or `nn.BatchNorm3d`
       
`xtra`: adds any extra nn.Layers
       
`act_fn`: activation function

In [179]:
test_eq(layer_types(FlexConvLayer(1, 2)), [torch.nn.modules.conv.Conv3d])
test_eq(layer_types(FlexConvLayer(1, 2, bn=True)), [torch.nn.modules.conv.Conv3d, torch.nn.modules.batchnorm.BatchNorm3d])
test_eq(layer_types(FlexConvLayer(1, 2, bn=True)), [torch.nn.modules.conv.Conv3d, torch.nn.modules.batchnorm.BatchNorm3d])
test_eq(layer_types(FlexConvLayer(1, 2, bn=True, ups=True)), [torch.nn.modules.upsampling.Upsample, torch.nn.modules.conv.Conv3d, torch.nn.modules.batchnorm.BatchNorm3d])
test_eq(layer_types(FlexConvLayer(1, 2, bn=True, ups=True, act_fn=nn.ELU())), [torch.nn.modules.upsampling.Upsample, torch.nn.modules.conv.Conv3d, torch.nn.modules.batchnorm.BatchNorm3d,  torch.nn.modules.activation.ELU])
test_eq(layer_types(FlexConvLayer(1, 2, bn=True, ups=True, act_fn=nn.ELU(), xtra=nn.ReLU())), [torch.nn.modules.upsampling.Upsample, torch.nn.modules.conv.Conv3d, torch.nn.modules.batchnorm.BatchNorm3d,  torch.nn.modules.activation.ELU, torch.nn.modules.activation.ReLU])

In [180]:
#export 
class FlexUnetEncoder(nn.Module):
    '''Creates flexible encoder for Unets
       \n`ni`: in_channels
       \n`nf`: out_channels
       \n`ks`: kernal_size
       \n`st`: stride
       \n`pd`: padding default is 1 
       \n`bn`: adds BatchNorm layer if `True`
       \n`act_fn`: activation function
       \n`conv_depth`: number of conv layers
       '''
    
    def __init__(self, ni, nf, ks, st, pd, conv_depth, ndim=3,  act_fn=nn.ELU()):
        super().__init__()
        nf = nf
        self.module_dict = nn.ModuleDict()
        self.module_dict[f'pass_n'] = FlexConvLayer(ni, nf, act_fn=act_fn, ndim=ndim)
        self.module_dict[f'save_n'] = FlexConvLayer(nf, nf, act_fn=act_fn, ndim=ndim)
        for i in range(conv_depth):
            self.module_dict[f'pass_{i}_k'] = nn.Sequential(FlexConvLayer(nf, nf, ks=ks-1, st=st + 1, act_fn=act_fn, pd=pd, ndim=ndim))
            self.module_dict[f'save_{i}'] = nn.Sequential(FlexConvLayer(nf, nf*2, ks=ks, st=st, act_fn=act_fn, ndim=ndim), )
            self.module_dict[f'pass_{i}'] = nn.Sequential(FlexConvLayer(nf*2, nf*2, ks=ks, st=st, act_fn=act_fn, ndim=ndim))
            nf *=2
    def forward(self, x):
        features = []
        for i in self.module_dict:
            if     i.startswith('pass'): x = self.module_dict[i](x)
            else:  x = self.module_dict[i](x) ;features.append(x)
        return (x, features[:-1])

In [181]:
show_doc(FlexUnetEncoder)

<h2 id="FlexUnetEncoder" class="doc_header"><code>class</code> <code>FlexUnetEncoder</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>FlexUnetEncoder</code>(**`ni`**, **`nf`**, **`ks`**, **`st`**, **`pd`**, **`conv_depth`**, **`ndim`**=*`3`*, **`act_fn`**=*`ELU(alpha=1.0)`*) :: `Module`

Creates flexible encoder for Unets
       
`ni`: in_channels
       
`nf`: out_channels
       
`ks`: kernal_size
       
`st`: stride
       
`pd`: padding default is 1 
       
`bn`: adds BatchNorm layer if `True`
       
`act_fn`: activation function
       
`conv_depth`: number of conv layers
       

In [218]:
tst_encoder = FlexUnetEncoder(1, 48, 3, 1, 0, conv_depth=2, ndim=3)
tst_encoder

FlexUnetEncoder(
  (module_dict): ModuleDict(
    (pass_n): FlexConvLayer(
      (0): Conv3d(1, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): ELU(alpha=1.0)
    )
    (save_n): FlexConvLayer(
      (0): Conv3d(48, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): ELU(alpha=1.0)
    )
    (pass_0_k): Sequential(
      (0): FlexConvLayer(
        (0): Conv3d(48, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
        (1): ELU(alpha=1.0)
      )
    )
    (save_0): Sequential(
      (0): FlexConvLayer(
        (0): Conv3d(48, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): ELU(alpha=1.0)
      )
    )
    (pass_0): Sequential(
      (0): FlexConvLayer(
        (0): Conv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): ELU(alpha=1.0)
      )
    )
    (pass_1_k): Sequential(
      (0): FlexConvLayer(
        (0): Conv3d(96, 96, kernel_size=(2, 2, 2), stride=(2, 2, 2))
        (

In [237]:
test_eq(len(extract_layer(tst_encder)), 8)
test_eq(layer_types(extract_layer(tst_encder))[:1], [torch.nn.modules.conv.Conv3d])
test_eq(layer_types(extract_layer(tst_encder)), [torch.nn.modules.conv.Conv3d]*8)
test_eq(getattr(extract_layer(tst_encder)[0], 'kernel_size'), (3, 3, 3))
test_eq(getattr(extract_layer(tst_encder)[0], 'stride'), (1, 1, 1))
test_eq(getattr(extract_layer(tst_encder)[0], 'padding'), (1, 1, 1))
test_eq(getattr(extract_layer(tst_encder)[2], 'padding'), (0, 0, 0))

In [239]:
from nbdev.export import *
notebook2script()

Converted 00_core.ipynb.
Converted index.ipynb.
