# Models

> models

In [None]:
#| default_exp models

In [None]:
#| hide
from nbdev.showdoc import *


In [None]:
#| hide
from IPython.display import clear_output, DisplayHandle

def update_patch(self, obj):
    clear_output(wait=True)
    self.display(obj)
DisplayHandle.update = update_patch

In [None]:
#| export 

from fastai.vision.all import ConvLayer, nn
from torch import cat as torch_cat
from Noise2Model.utils import attributesFromDict

In [None]:
#| export

class DnCNN(nn.Module):
    def __init__(self, channels, num_of_layers=9, features=64, kernel_size=3):
        super(DnCNN, self).__init__()
        padding = 1
        layers = []
        layers.append(ConvLayer(channels, features, ks=kernel_size, padding=padding, norm_type=None))
        for _ in range(num_of_layers-2):
            layers.append(ConvLayer(features, features, ks=kernel_size, padding=padding))
        layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
    def forward(self, x):
        residual = self.dncnn(x)
        denoised = x - residual
        return denoised

In [None]:
#| export 
class UNet(nn.Module):
    def __init__(self,
                 depth=4,
                 mult_chan=32,
                 in_channels=1,
                 out_channels=1,
    ):
        super().__init__()
        attributesFromDict(locals( ))
        
        self.net_recurse = _Net_recurse(n_in_channels=self.in_channels, mult_chan=self.mult_chan, depth=self.depth)
        self.conv_out = nn.Conv3d(self.mult_chan, self.out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x_rec = self.net_recurse(x)
        return self.conv_out(x_rec)


class _Net_recurse(nn.Module):
    def __init__(self, n_in_channels, mult_chan=2, depth=0):
        """Class for recursive definition of U-network.p

        Parameters:
        in_channels - (int) number of channels for input.
        mult_chan - (int) factor to determine number of output channels
        depth - (int) if 0, this subnet will only be convolutions that double the channel count.
        """
        super().__init__()
        self.depth = depth
        n_out_channels = n_in_channels*mult_chan
        self.sub_2conv_more = SubNet2Conv(n_in_channels, n_out_channels)
        
        if depth > 0:
            ks = 3
            self.sub_2conv_less = SubNet2Conv(2*n_out_channels, n_out_channels)
            self.conv_down = ConvLayer(n_out_channels, n_out_channels, ks=ks)
            self.convt = ConvLayer(2*n_out_channels, n_out_channels, ks=ks, transpose=True, padding=(ks-1)//2)
            self.sub_u = _Net_recurse(n_out_channels, mult_chan=2, depth=(depth - 1))
            
    def forward(self, x):
        if self.depth == 0:
            return self.sub_2conv_more(x)
        else:  # depth > 0
            x_2conv_more = self.sub_2conv_more(x)
            x_conv_down = self.conv_down(x_2conv_more)
            x_sub_u = self.sub_u(x_conv_down)
            x_convt = self.convt(x_sub_u)
            x_cat = torch_cat((x_2conv_more, x_convt), 1)  # concatenate
            x_2conv_less = self.sub_2conv_less(x_cat)
        return x_2conv_less


class SubNet2Conv(nn.Module):
    def __init__(self, n_in, n_out):
        super().__init__()

        self.C1 = ConvLayer(n_in, n_out, ks=3, padding=1)
        self.C2 = ConvLayer(n_out, n_out, ks=3, padding=1)
        
    def forward(self, x): return self.C2(self.C1(x))

Alternatively, we can use fastai unet builder

In [None]:
learn = unet_learner(dls, models.resnet18, loss_func=F.l1_loss, n_in=1, n_out=1, pretrained=False, cut=None)
learn.summary()

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()