# Models

In [None]:
# export
# default_exp models
import torch
from torch import nn
from fastcore.dispatch import patch

In [None]:
# export
import sys
sys.path.append('..')
from attention_unet.modular_unet import ModularUNet
from attention_unet.blocks import BasicResBlock, UnetBlock, ConvLayer, DoubleConv, SqueezeExpand, DeepSupervision
from attention_unet.utils import test_forward

## ResNet-Based Models

In [None]:
# export
class UResNet(ModularUNet):    
    def encoder_layer(self, **kwargs): return BasicResBlock(**kwargs)
    def middle_layer(self, **kwargs): return DoubleConv(**kwargs)
    def skip_layer(self, **kwargs): return nn.Identity()
    def decoder_layer(self, **kwargs): return UnetBlock(**kwargs)
    def extra_after_decoder_layer(self, **kwargs): return nn.Identity()
    def final_layer(self, **kwargs): return BasicResBlock(**kwargs)

### UResNet18-like models

In [None]:
# export 
class UResNet18(UResNet):
    " UNet with ResNet18-like Backbone "
    channels = 32, 64, 128, 256, 512
    kernel_size = 3, 3, 3, 3, 3
    stride = 2, 2, 2, 2, 2
    padding = 'auto', 'auto', 'auto', 'auto', 'auto'
    n_layers = 1, 2, 2, 2, 2
    n_blocks = 5

In [None]:
test_forward(UResNet18(3,3))

In [None]:
# export
class UResNet18WithAttention(UResNet18):
    " UNet with ResNet18-like Backbone and spatial Attention in Upsampling blocks"
    pass

In [None]:
m = UResNet18WithAttention(3,3)
assert not hasattr(m.decoder_block_1, 'sa'), 'No spatial attnetion already in decoder block'

In [None]:
# export
@patch
def decoder_layer(self:UResNet18WithAttention, **kwargs): 
    return UnetBlock(spatial_attention=True, **kwargs)

fastcores dispatch mechanisms can be used to add/change modules in the modular unet. 

In [None]:
m = UResNet18WithAttention(3,3)
assert hasattr(m.decoder_block_1, 'sa'), 'No spatial attnetion layer in decoder block'

In [None]:
test_forward(m)

In [None]:
# export
class UResNet18DeepSupervision(UResNet18):
    " UNet with ResNet18-like Backbone and dee supervision after Upsampling blocks "
    def extra_after_decoder_layer(self, **kwargs): 
        return ConvLayer(**kwargs, act=None, norm=None)

In [None]:
test_forward(UResNet18DeepSupervision(3,2))

In [None]:
# export
class UResNet18WithAttentionAndDeepSupervision(UResNet18):
    " UNet with ResNet18-like Backbone and spatial attention in Upsampling blocks and deep supervision after encoder "
    def decoder_layer(self, **kwargs): 
        return UnetBlock(spatial_attention=True, **kwargs)
    def extra_after_decoder_layer(self, **kwargs): 
        return ConvLayer(**kwargs, act=None, norm=None)

In [None]:
test_forward(UResNet18WithAttentionAndDeepSupervision(3,3))

In [None]:
# export
class UResNet18WithSEAndAttentionAndDeepSupervision(UResNet18):
    " UNet with ResNet18-like Backbone and spatial attention in Upsampling blocks and deep supervision after encoder "
    def encoder_layer(self, in_c, out_c, **kwargs): 
        return nn.Sequential(
            BasicResBlock(in_c, out_c, **kwargs), 
            SqueezeExpand(out_c, se_ratio=0.2)
        )
    def decoder_layer(self, **kwargs): 
        return UnetBlock(spatial_attention=True, **kwargs)
    def extra_after_decoder_layer(self, **kwargs): 
        return ConvLayer(**kwargs, act=None, norm=None)

In [None]:
test_forward(UResNet18WithSEAndAttentionAndDeepSupervision(3,3))

### UResNet34-like models

In [None]:
# export
class UResNet34(UResNet):
    " UNet with ResNet34-like Backbone "
    channels = 32, 64, 128, 256, 512
    kernel_size = 3, 3, 3, 3, 3
    stride = 2, 2, 2, 2, 2
    padding = 'auto', 'auto', 'auto', 'auto', 'auto'
    n_layers = 1, 3, 4, 6, 3
    n_blocks = 5

In [None]:
test_forward(UResNet34(3,3))

In [None]:
# export
class UResNet34WithAttention(UResNet34):
    " UNet with ResNet34-like Backbone and spatial Attention in Upsampling blocks"
    def decoder_layer(self, **kwargs): 
        return UnetBlock(spatial_attention=True, **kwargs)

In [None]:
test_forward(UResNet34WithAttention(3,3))

In [None]:
# export
class UResNet34DeepSupervision(UResNet34):
    " UNet with ResNet34-like Backbone and dee supervision after Upsampling blocks "
    def extra_after_decoder_layer(self, **kwargs): 
        return DeepSupervision(**kwargs)

In [None]:
test_forward(UResNet34DeepSupervision(3,3))

In [None]:
# export
class UResNet34WithAttentionAndDeepSupervision(UResNet34):
    " UNet with ResNet34-like Backbone and spatial attention in Upsampling blocks and deep supervision after encoder "
    def decoder_layer(self, **kwargs): 
        return UnetBlock(spatial_attention=True, **kwargs)
    def extra_after_decoder_layer(self, **kwargs): 
        return DeepSupervision(**kwargs)

In [None]:
test_forward(UResNet34WithAttentionAndDeepSupervision(3,3))

In [None]:
# export
class UResNet34WithSEAndAttentionAndDeepSupervision(UResNet34):
    " UNet with ResNet34-like Backbone and spatial attention in Upsampling blocks and deep supervision after encoder "
    def encoder_layer(self, in_c, out_c, **kwargs): 
        return nn.Sequential(
            BasicResBlock(in_c, out_c, **kwargs), 
            SqueezeExpand(out_c, se_ratio=0.2)
        )
    def decoder_layer(self, **kwargs): 
        return UnetBlock(spatial_attention=True, **kwargs)
    def extra_after_decoder_layer(self, **kwargs): 
        return DeepSupervision(**kwargs)

In [None]:
test_forward(UResNet34WithSEAndAttentionAndDeepSupervision(3,3))

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted blocks.ipynb.
Converted index.ipynb.
Converted models.ipynb.
Converted modular_unet.ipynb.
Converted utils.ipynb.
