In [1]:
# default_exp Models

# Models

> Classes and helper fuction for creating flexible 2D and 3D unets

In [2]:
#hide
from nbdev.showdoc import *
from fastcore.test import *
from pdb import set_trace

In [3]:
#export
import torch.nn as nn
import torch
import types

In [67]:
#export
def _get_conv(ndim: int):
    "Get Convolution Layer of any dimension"
    assert 1 <= ndim <=3
    return getattr(nn, f'Conv{ndim}d')

def _get_bn(ndim: int):
    "Get BatchNorm Layer of any dimension"
    assert 1 <= ndim <=3
    return getattr(nn, f'BatchNorm{ndim}d')



In [68]:
_get_bn(2)

torch.nn.modules.batchnorm.BatchNorm2d

In [5]:
test_eq(_get_conv(1), torch.nn.modules.conv.Conv1d)
test_eq(_get_conv(2), torch.nn.modules.conv.Conv2d)
test_eq(_get_conv(3), torch.nn.modules.conv.Conv3d)
test_eq(_get_bn(1), torch.nn.modules.batchnorm.BatchNorm1d)
test_eq(_get_bn(2), torch.nn.modules.batchnorm.BatchNorm2d)
test_eq(_get_bn(3), torch.nn.modules.batchnorm.BatchNorm3d)

In [6]:
show_doc(_get_conv)

<h4 id="_get_conv" class="doc_header"><code>_get_conv</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>_get_conv</code>(**`ndim`**:`int`)

Get Convolution Layer of any dimension

In [7]:
show_doc(_get_bn)

<h4 id="_get_bn" class="doc_header"><code>_get_bn</code><a href="__main__.py#L7" class="source_link" style="float:right">[source]</a></h4>

> <code>_get_bn</code>(**`ndim`**:`int`)

Get BatchNorm Layer of any dimension

In [8]:
#export
def init_func(m, func=nn.init.kaiming_normal_):
    "Initialize pytorch model `m` weights with `func`"
    if func and hasattr(m, 'weight'): func(m.weight)
    return m

In [9]:
show_doc(init_func)

<h4 id="init_func" class="doc_header"><code>init_func</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>init_func</code>(**`m`**, **`func`**=*`'kaiming_normal_'`*)

Initialize pytorch model `m` weights with `func`

In [10]:
#export 
def layer_types(m):
    "returns list of pytorch models type"
    if isinstance(m, list): return list(map(type, m))
    return list(map(type, m.children()))

In [11]:
#export
def extract_layer(m, name=torch.nn.modules.Conv3d):
    res = []
    for child in m.children():
        for layer in child.modules():
            if(isinstance(layer,name)):
                res.append(layer)
    return res

In [69]:
#export
class FlexConvLayer(nn.Sequential):
    '''
      Create Flexible Convolution layer.
    
      This module allows to create 1, 2, or 3D convolutional layers containing (optional) activation function,
      batch normalization, upsampling or additional Pytorch Classes 
      
      Parameters:
       \n`ni`: in_channels
       \n`nf`: out_channels
       \n`ks`: kernal_size
       \n`st`: stride
       \n`pd`: padding default is 1 
       \n`ups`: adds Upsampling layer if `True`
       \n`sf`: scale factore if `ups` = True  upsampling layer
       \n`bn`: adds BatchNorm layer if `True`
       \n`ndim`: number of dimensions for layers e.g if 3 will create `nn.Conv3D` or `nn.BatchNorm3d`
       \n`func`: initiation function by default `nn.init.kaiming_normal_`
       \n`xtra`: adds any extra nn.Layers
       \n`act_fn`: activation function
       
       \nReturns:
       Sequential model containing specified Paramaters 
       
       '''  
    def __init__(self, ni: int, nf: int, ks: int=3, st:int=1, ndim: int=3, sf: int=2, pd: int=None, act_fn=None, bn: bool=False, ups: bool=False, xtra=None, func=nn.init.kaiming_normal_, **kwargs):
        layers = []
        if pd is None: pd = 1
        conv_l = _get_conv(ndim)(in_channels=ni, out_channels=nf,kernel_size =ks, stride=st, padding=pd, **kwargs)
        init_func(conv_l, func)
        layers.append(conv_l)
        if ups       : layers.insert(0, nn.Upsample(scale_factor=sf))
        if bn        : layers.append(_get_bn(ndim)(nf))
        if act_fn    : layers.append(act_fn())
        if xtra      : layers.append(xtra)
        super().__init__(*layers)
        
        

In [70]:
show_doc(FlexConvLayer)

<h2 id="FlexConvLayer" class="doc_header"><code>class</code> <code>FlexConvLayer</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>FlexConvLayer</code>(**`ni`**:`int`, **`nf`**:`int`, **`ks`**:`int`=*`3`*, **`st`**:`int`=*`1`*, **`ndim`**:`int`=*`3`*, **`sf`**:`int`=*`2`*, **`pd`**:`int`=*`None`*, **`act_fn`**=*`None`*, **`bn`**:`bool`=*`False`*, **`ups`**:`bool`=*`False`*, **`xtra`**=*`None`*, **`func`**=*`'kaiming_normal_'`*, **\*\*`kwargs`**) :: `Sequential`

      Create Flexible Convolution layer.
    
      This module allows to create 1, 2, or 3D convolutional layers containing (optional) activation function,
      batch normalization, upsampling or additional Pytorch Classes 
      
      Parameters:
       
`ni`: in_channels
       
`nf`: out_channels
       
`ks`: kernal_size
       
`st`: stride
       
`pd`: padding default is 1 
       
`ups`: adds Upsampling layer if `True`
       
`sf`: scale factore if `ups` = True  upsampling layer
       
`bn`: adds BatchNorm layer if `True`
       
`ndim`: number of dimensions for layers e.g if 3 will create `nn.Conv3D` or `nn.BatchNorm3d`
       
`func`: initiation function by default `nn.init.kaiming_normal_`
       
`xtra`: adds any extra nn.Layers
       
`act_fn`: activation function
       
       
Returns:
       Sequential model containing specified Paramaters 
       
       

In [71]:
 #2D convolutional layer with in_channel=1, out_channel=2 and BatchNorm
FlexConvLayer(1, 2, ks=4, bn=True, ndim=2)    

FlexConvLayer(
  (0): Conv2d(1, 2, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [81]:
#3D convolutional layer with in_channel=1, out_channel=2 and ReLU activation function 
FlexConvLayer(1, 2, ndim=3, act_fn = nn.ReLU).children()

<generator object Module.children at 0x7f14b94eab50>

In [83]:
test_eq(layer_types(FlexConvLayer(1, 2)), [torch.nn.modules.conv.Conv3d])
test_eq(layer_types(FlexConvLayer(1, 2, bn=True)), [torch.nn.modules.conv.Conv3d, torch.nn.modules.batchnorm.BatchNorm3d])
test_eq(layer_types(FlexConvLayer(1, 2, bn=True, ndim=2)), [torch.nn.modules.conv.Conv2d, torch.nn.modules.batchnorm.BatchNorm2d])
test_eq(layer_types(FlexConvLayer(1, 2, bn=True, ups=True)), [torch.nn.modules.upsampling.Upsample, torch.nn.modules.conv.Conv3d, torch.nn.modules.batchnorm.BatchNorm3d])
test_eq(layer_types(FlexConvLayer(1, 2, bn=True, ups=True, act_fn=nn.ELU)), [torch.nn.modules.upsampling.Upsample, torch.nn.modules.conv.Conv3d, torch.nn.modules.batchnorm.BatchNorm3d, torch.nn.modules.activation.ELU])

In [17]:
#export
def init_default(m, func=nn.init.kaiming_normal_):
    "Initialize `m` weights with `func` and set `bias` to 0."
    if func and hasattr(m, 'weight'): func(m.weight)
    with torch.no_grad():
        if getattr(m, 'bias', None) is not None: m.bias.fill_(0.)
    return m

In [84]:
#export 
class FlexUnetEncoder(nn.Module):
    '''
    Creates flexible encoder for Unets.
        
    Provided convolution depth will generate unet encoder based on `FlexConvLayer`. During forward pass will also return `features` tensor, contaning stored features which will be used in decoder for  concatinating during  upsampling. Last element of `feature` is `x` which used to enter first layer in `decoder`
       
    Parameters:
    \n`ni`: in_channels
    \n`nf`: out_channels
    \n`ks`: kernal_size
    \n`st`: stride
    \n`pd`: padding default is 1 
    \n`bn`: adds BatchNorm layer if `True`
    \n`act_fn`: activation function
    \n`conv_depth`: number of conv layers
    
    \nReturns:
    Sequential encoder, and feautre list 
    
    '''
    
    def __init__(self, ni: int, nf: int, ks: int, st: int, pd: int, conv_depth: int, ndim: int=3,  act_fn=nn.ELU, **kwargs):
        super().__init__()
        nf = nf
        self.module_dict = nn.ModuleDict()
        self.module_dict['pass_n'] = FlexConvLayer(ni, nf, act_fn=act_fn, ndim=ndim, **kwargs)
        self.module_dict['save_n'] = FlexConvLayer(nf, nf, act_fn=act_fn, ndim=ndim, **kwargs)
        for i in range(conv_depth):
            self.module_dict[f'pass_{i}_k'] = nn.Sequential(FlexConvLayer(nf, nf, ks=ks-1, st=st + 1, act_fn=act_fn, pd=pd, ndim=ndim, **kwargs))
            self.module_dict[f'save_{i}']  = nn.Sequential(FlexConvLayer(nf, nf*2, ks=ks, st=st, act_fn=act_fn, ndim=ndim, **kwargs), )
            self.module_dict[f'pass_{i}'] = nn.Sequential(FlexConvLayer(nf*2, nf*2, ks=ks, st=st, act_fn=act_fn, ndim=ndim, **kwargs))
            nf *=2
    def forward(self, x):
        features = []
        for i in self.module_dict:
            x = self.module_dict[i](x)
            if  i.startswith('save'): features.append(x)
                
        features = features[:-1]; 
        features.append(x)
        return features
    
    
#RENAME everything

In [85]:
show_doc(FlexUnetEncoder)

<h2 id="FlexUnetEncoder" class="doc_header"><code>class</code> <code>FlexUnetEncoder</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>FlexUnetEncoder</code>(**`ni`**:`int`, **`nf`**:`int`, **`ks`**:`int`, **`st`**:`int`, **`pd`**:`int`, **`conv_depth`**:`int`, **`ndim`**:`int`=*`3`*, **`act_fn`**=*`'ELU'`*, **\*\*`kwargs`**) :: `Module`

    Creates flexible encoder for Unets.
        
    Provided convolution depth will generate unet encoder based on [`FlexConvLayer`](/nbdev_template/models#FlexConvLayer). During forward pass will also return `features` tensor, contaning stored features which will be used in decoder for  concatinating during  upsampling. Last element of `feature` is `x` which used to enter first layer in `decoder`
       
    Parameters:
    
`ni`: in_channels
    
`nf`: out_channels
    
`ks`: kernal_size
    
`st`: stride
    
`pd`: padding default is 1 
    
`bn`: adds BatchNorm layer if `True`
    
`act_fn`: activation function
    
`conv_depth`: number of conv layers
    
    
Returns:
    Sequential encoder, and feautre list 
    
    

In [86]:
ndim       = 2
conv_depth = 2

In [93]:
#2d
tst_encoder = FlexUnetEncoder(ni=1, nf =48, ks =3, st =1, pd =0, conv_depth=conv_depth, ndim=ndim)
tst_encoder

FlexUnetEncoder(
  (module_dict): ModuleDict(
    (pass_n): FlexConvLayer(
      (0): Conv3d(1, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): BatchNorm3d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0)
    )
    (save_n): FlexConvLayer(
      (0): Conv3d(48, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): BatchNorm3d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0)
    )
    (pass_0_k): Sequential(
      (0): FlexConvLayer(
        (0): Conv3d(48, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
        (1): BatchNorm3d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ELU(alpha=1.0)
      )
    )
    (save_0): Sequential(
      (0): FlexConvLayer(
        (0): Conv3d(48, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): BatchNorm3d(96, eps=1e-05, momentum=0.1, affine=True, track_running_

In [88]:
#3d
ndim = 3
tst_encoder = FlexUnetEncoder(1, 48, 3, 1, 0, conv_depth=conv_depth, ndim=ndim)
tst_encoder

FlexUnetEncoder(
  (module_dict): ModuleDict(
    (pass_n): FlexConvLayer(
      (0): Conv3d(1, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): ELU(alpha=1.0)
    )
    (save_n): FlexConvLayer(
      (0): Conv3d(48, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): ELU(alpha=1.0)
    )
    (pass_0_k): Sequential(
      (0): FlexConvLayer(
        (0): Conv3d(48, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
        (1): ELU(alpha=1.0)
      )
    )
    (save_0): Sequential(
      (0): FlexConvLayer(
        (0): Conv3d(48, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): ELU(alpha=1.0)
      )
    )
    (pass_0): Sequential(
      (0): FlexConvLayer(
        (0): Conv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): ELU(alpha=1.0)
      )
    )
    (pass_1_k): Sequential(
      (0): FlexConvLayer(
        (0): Conv3d(96, 96, kernel_size=(2, 2, 2), stride=(2, 2, 2))
        (

In [89]:
x_batch = torch.rand(20, 1, 20, 20, 20)
res = tst_encoder(x_batch)

In [90]:
test_eq(len(res), conv_depth + 1)
test_eq(res[-1].shape[1], 192)

In [99]:
test_eq(len(extract_layer(tst_encoder)), 8)
test_eq(layer_types(extract_layer(tst_encoder))[:1], [torch.nn.modules.conv.Conv3d])
test_eq(layer_types(extract_layer(tst_encoder)), [torch.nn.modules.conv.Conv3d]*8)
test_eq(getattr(extract_layer(tst_encoder)[0], 'kernel_size'), (3, 3, 3))
test_eq(getattr(extract_layer(tst_encoder)[0], 'stride'), (1, 1, 1))
test_eq(getattr(extract_layer(tst_encoder)[0], 'padding'), (1, 1, 1))
test_eq(getattr(extract_layer(tst_encoder)[2], 'padding'), (0, 0, 0))

In [26]:
#export 
class FlexUnetDecoder(nn.Module):
    '''
    Creates flexible decoder for Unets.
    
    This class will autmatically create decoder based on encoder paramaters. In forward pass, it will take features generated from FlexUnetEcnder and concatinate them on upsampled layers. 
    
    Parameters:
    \n`nf`: out_channels for the first layer in encoder
    \n`ks`: kernal_size
    \n`st`: stride
    \n`pd`: padding default is 1 
    \n`conv_depth`: number of conv layers
    \n`act_fn`: activation function by default its `nn.ELU`
    
    Returns:
    Decoder Model, and output of unet Model wchich should match input Dimensions
    ''' 
    def __init__(self, nf: int,  ks: int, st: int, pd: int, conv_depth: int, ndim=3,  act_fn=nn.ELU, **kwargs):
        super().__init__()
        nf = self._get_enc_filter(nf, conv_depth)
        self.module_dict = nn.ModuleDict()

        for i in range(conv_depth):
            self.module_dict[f'conc_{i}'] = nn.Sequential(FlexConvLayer(nf, nf//2, ks=ks, st=st, ups = True, act_fn=act_fn, pd=pd, ndim=ndim, **kwargs))
            self.module_dict[f'pass_{i}'] = nn.Sequential(FlexConvLayer(nf, nf//2, ks=ks, st=st, act_fn=act_fn, ndim=ndim, **kwargs),
                                                          FlexConvLayer(nf//2, nf//2, ks=ks, st=st, act_fn=act_fn, ndim=ndim, **kwargs))
            nf //=2
            
    def forward(self, features):
        #x replaced with features 
        #rework this a bit 
        x = features.pop()
        for i in self.module_dict:
            if i.startswith('conc'): 
                  x = self.module_dict[i](x)
                  x = torch.cat([x,features.pop()],1) 
            else: x = self.module_dict[i](x)
        return x
    
    @staticmethod
    def _get_enc_filter(nf, conv_depth):
        '''calculates number of in filters for decoder model given conv_depth and nf
        in the first conv layer in unet encoder'''
        nf = nf
        for i in range(conv_depth): nf *=2 
        return nf

In [27]:
show_doc(FlexUnetDecoder)

<h2 id="FlexUnetDecoder" class="doc_header"><code>class</code> <code>FlexUnetDecoder</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>FlexUnetDecoder</code>(**`nf`**:`int`, **`ks`**:`int`, **`st`**:`int`, **`pd`**:`int`, **`conv_depth`**:`int`, **`ndim`**=*`3`*, **`act_fn`**=*`'ELU'`*, **\*\*`kwargs`**) :: `Module`

    Creates flexible decoder for Unets.
    
    This class will autmatically create decoder based on encoder paramaters. In forward pass, it will take features generated from FlexUnetEcnder and concatinate them on upsampled layers. 
    
    Parameters:
    
`nf`: out_channels for the first layer in encoder
    
`ks`: kernal_size
    
`st`: stride
    
`pd`: padding default is 1 
    
`conv_depth`: number of conv layers
    
`act_fn`: activation function by default its `nn.ELU`
    
    Returns:
    Decoder Model, and output of unet Model wchich should match input Dimensions
    

In [28]:
nf_enc_in = 48
conv_depth = 2
ndim = 3
ks = 3
st = 1
pd = 1
act_fn = nn.ELU()

In [29]:
tst_decoder = FlexUnetDecoder(nf_enc_in, ks, st, pd, conv_depth)
tst_decoder

FlexUnetDecoder(
  (module_dict): ModuleDict(
    (conc_0): Sequential(
      (0): FlexConvLayer(
        (0): Upsample(scale_factor=2.0, mode=nearest)
        (1): Conv3d(192, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (2): ELU(alpha=1.0)
      )
    )
    (pass_0): Sequential(
      (0): FlexConvLayer(
        (0): Conv3d(192, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): ELU(alpha=1.0)
      )
      (1): FlexConvLayer(
        (0): Conv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): ELU(alpha=1.0)
      )
    )
    (conc_1): Sequential(
      (0): FlexConvLayer(
        (0): Upsample(scale_factor=2.0, mode=nearest)
        (1): Conv3d(96, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (2): ELU(alpha=1.0)
      )
    )
    (pass_1): Sequential(
      (0): FlexConvLayer(
        (0): Conv3d(96, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        

In [30]:
tst_encoder = FlexUnetEncoder(1, 48, 3, 1, 0, conv_depth=2, ndim=3)
x_batch     = torch.rand(20, 1, 20, 20, 20)
res         = tst_encoder(x_batch)
res_dec     = tst_decoder(res)

In [31]:
test_eq(res_dec.shape[1], nf_enc_in)

In [32]:
#export
class SUNET(nn.Module):
    '''
    Generates 1D, 2D or 3D Unet.
    
    Autmatically genrates UNET model based on user specified paramaters
    
    Parameters:
    \n`ni`: in_channels
    \n`nf`: out_channels
    \n`ks`: kernal_size
    \n`st`: stride
    \n`pd`: padding default is 1 
    \n`ndim`: (1, 2, 3) 2D or 3D depending on dimensions
    \n`conv_depth`: number of conv layers
    \n **kwargs: see `FlexConvLayer` for generating flexible conv layers
    
    Returns:
    \n Unet Model
    '''
    def __init__(self, ni, nc, ks, st, pd, conv_depth, ndim, **kwargs):
        super().__init__()
        self.encoder = FlexUnetEncoder(ni, nc, ks, st, pd, conv_depth, ndim, **kwargs)
        self.decoder = FlexUnetDecoder(nc, ks, st, pd+1, conv_depth, ndim, **kwargs)
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
        

In [33]:
show_doc(SUNET)

<h2 id="SUNET" class="doc_header"><code>class</code> <code>SUNET</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>SUNET</code>(**`ni`**, **`nc`**, **`ks`**, **`st`**, **`pd`**, **`conv_depth`**, **`ndim`**, **\*\*`kwargs`**) :: `Module`

    Generates 1D, 2D or 3D Unet.
    
    Autmatically genrates UNET model based on user specified paramaters
    
    Parameters:
    
`ni`: in_channels
    
`nf`: out_channels
    
`ks`: kernal_size
    
`st`: stride
    
`pd`: padding default is 1 
    
`ndim`: (1, 2, 3) 2D or 3D depending on dimensions
    
`conv_depth`: number of conv layers
    
 **kwargs: see [`FlexConvLayer`](/nbdev_template/models#FlexConvLayer) for generating flexible conv layers
    
    Returns:
    
 Unet Model
    

In [104]:
#3d
tst_unet = SUNET(1, 48, 3, 1, 0, 2, 3)
tst_unet

SUNET(
  (encoder): FlexUnetEncoder(
    (module_dict): ModuleDict(
      (pass_n): FlexConvLayer(
        (0): Conv3d(1, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): ELU(alpha=1.0)
      )
      (save_n): FlexConvLayer(
        (0): Conv3d(48, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): ELU(alpha=1.0)
      )
      (pass_0_k): Sequential(
        (0): FlexConvLayer(
          (0): Conv3d(48, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
          (1): ELU(alpha=1.0)
        )
      )
      (save_0): Sequential(
        (0): FlexConvLayer(
          (0): Conv3d(48, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (1): ELU(alpha=1.0)
        )
      )
      (pass_0): Sequential(
        (0): FlexConvLayer(
          (0): Conv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (1): ELU(alpha=1.0)
        )
      )
      (pass_1_k): Sequential(
        (0): FlexConvLayer(


In [105]:
x_batch = torch.rand(20, 1, 20, 20, 20)
res     = tst_unet(x_batch)

In [106]:
test_eq(res.shape[1], 48)

In [40]:
#2d unet
tst_unet = SUNET(1, 48, 3, 1, 0, 2, 3)
tst_unet

SUNET(
  (encoder): FlexUnetEncoder(
    (module_dict): ModuleDict(
      (pass_n): FlexConvLayer(
        (0): Conv3d(1, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): ELU(alpha=1.0)
      )
      (save_n): FlexConvLayer(
        (0): Conv3d(48, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): ELU(alpha=1.0)
      )
      (pass_0_k): Sequential(
        (0): FlexConvLayer(
          (0): Conv3d(48, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
          (1): ELU(alpha=1.0)
        )
      )
      (save_0): Sequential(
        (0): FlexConvLayer(
          (0): Conv3d(48, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (1): ELU(alpha=1.0)
        )
      )
      (pass_0): Sequential(
        (0): FlexConvLayer(
          (0): Conv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (1): ELU(alpha=1.0)
        )
      )
      (pass_1_k): Sequential(
        (0): FlexConvLayer(


In [43]:
tst_unet = SUNET(1, 48, 3, 1, 0, 5, 3, bn=True, act_fn=nn.ReLU)
tst_unet

SUNET(
  (encoder): FlexUnetEncoder(
    (module_dict): ModuleDict(
      (pass_n): FlexConvLayer(
        (0): Conv3d(1, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): BatchNorm3d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (save_n): FlexConvLayer(
        (0): Conv3d(48, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): BatchNorm3d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (pass_0_k): Sequential(
        (0): FlexConvLayer(
          (0): Conv3d(48, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
          (1): BatchNorm3d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
      )
      (save_0): Sequential(
        (0): FlexConvLayer(
          (0): Conv3d(48, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (1): BatchNorm3d(96, eps=1e-05, m

In [39]:
from nbdev.export import *
notebook2script()

Converted 00_models.ipynb.
Converted 01_simulation.ipynb.
Converted index.ipynb.
