# Models

> models


In [10]:
#| default_exp models

In [11]:
#| hide
from nbdev.showdoc import *


In [12]:
#| hide
from IPython.display import clear_output, DisplayHandle

def update_patch(self, obj):
    clear_output(wait=True)
    self.display(obj)
DisplayHandle.update = update_patch

In [13]:
#| export 

from fastai.vision.all import ConvLayer, Lambda, MaxPool, nn, np
from torch import cat as torch_cat
from Noise2Model.utils import attributesFromDict

In [14]:
#| export

class DnCNN(nn.Module):
    def __init__(self, channels, num_of_layers=9, features=64, kernel_size=3):
        super(DnCNN, self).__init__()
        padding = 1
        layers = []
        layers.append(ConvLayer(channels, features, ks=kernel_size, padding=padding, norm_type=None))
        for _ in range(num_of_layers-2):
            layers.append(ConvLayer(features, features, ks=kernel_size, padding=padding))
        layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
    def forward(self, x):
        residual = self.dncnn(x)
        denoised = x - residual
        return denoised

In [15]:
#| export 
class UNet(nn.Module):
	def __init__(self,
				depth=4,
				mult_chan=32,
				in_channels=1,
				out_channels=1,
				last_activation=None,
				kernel_size=3,
				ndim=2,
				n_conv_per_depth=2,
				activation=nn.ReLU,
				norm_type=1,
				dropout=0.0,
				pool=MaxPool,
				pool_size=2,
				residual=False,
				prob_out=False,
				eps_scale=1e-3,
				):
		super().__init__()
		attributesFromDict(locals())

		parameters = self
		self.net_recurse = _Net_recurse(parameters)
		
		if last_activation is None:
			last_activation = activation if not residual else None
			self.last_activation = activation
		self.conv_out = ConvLayer(mult_chan, out_channels, ndim=ndim, ks=kernel_size, norm_type=None, act_cls=last_activation, padding=1)

	def forward(self, x):
		x_rec = self.net_recurse(x)
		final = self.conv_out(x_rec)

		if self.residual:
			if not (self.out_channels == self.in_channels): raise ValueError("number of input and output channels must be the same for a residual net.")
			final = final + x
		final = self.activation(final)

		if self.prob_out:
			scale = ConvLayer(self.out_channels, self.out_channels, ndim=self.ndim, ks=1, norm_type=None, act_cls=nn.Softplus)(x_rec)
			scale = Lambda(lambda x: x+np.float32(self.eps_scale))(scale)
			final = torch_cat((final,scale), 1)

		return final



In [16]:
#| export
class _Net_recurse(nn.Module):
    def __init__(self, parameters):
        """Class for recursive definition of U-network.p

        Parameters:
        in_channels - (int) number of channels for input.
        mult_chan - (int) factor to determine number of output channels
        depth - (int) if 0, this subnet will only be convolutions that double the channel count.
        """
        super().__init__()
        self.depth = parameters.depth
        n_out_channels = parameters.in_channels*parameters.mult_chan
        ks = parameters.kernel_size
        ndim = parameters.ndim
        n_in, n_out, n_conv = (parameters.in_channels, parameters.out_channels, parameters.n_conv_per_depth)

        Pooling = parameters.pool(ks=parameters.pool_size, ndim=ndim)
        UpSample = nn.Upsample(scale_factor=parameters.pool_size, mode='nearest')
        SubNet_Conv = SubNetConv(ks=ks, stride=1,padding=None, bias=None, ndim=ndim, norm_type=parameters.norm_type, 
                                 bn_1st=True, act_cls=parameters.activation, transpose=False, dropout=parameters.dropout)
        
        self.sub_conv_more = SubNet_Conv(n_in, n_out, n_conv)
        
        if self.depth > 0:
            self.sub_conv_less = SubNet_Conv(2*n_out_channels, n_out_channels, n_conv)
            self.reduce = Pooling #SubNet_Conv(n_out_channels, n_out_channels)
            self.upsample = UpSample #ConvLayer(2*n_out_channels, n_out_channels, ks=ks, transpose=True, padding=(ks-1)//2)
            
            parameters.in_channels = n_out_channels; parameters.mult_chan = 2; parameters.depth=(self.depth - 1)
            self.sub_u = _Net_recurse(parameters)        

    def forward(self, x):
        if self.depth == 0:
            return self.sub_conv_more(x)
        else:  # depth > 0
            x_conv_more = self.sub_conv_more(x)                 # convolutions with increasing number of channels
            x_reduced = self.reduce(x_conv_more)                # layer reducing the image size (usually a pooling layer)
            x_sub_u = self.sub_u(x_reduced)                     # lower unet level
            x_upsampled = self.upsample(x_sub_u)                # layer increasing the image size (usually an upsampling layer)
            x_cat = torch_cat((x_conv_more, x_upsampled), 1)    # concatenate the upsampled outputs of the lower level with the outputs of the next level in size
            x_conv_less = self.sub_conv_less(x_cat)             # convolutions with decreasing number of channels
        return x_conv_less

In [17]:
#| export
def SubNetConv(ks=3, 
            stride=1,
            padding=None,
            bias=None, 
            ndim=2,
            norm_type=1, 
            bn_1st=True, 
            act_cls=nn.ReLU, 
            transpose=False,
            init='auto', 
            xtra=None, 
            bias_std=0.01,
            dropout=0.0,
            ):

    def _conv(n_in ,n_out, n_conv=1):
        s = ConvLayer(n_in,n_out,ks=ks, stride=stride, padding=padding, bias=bias, ndim=ndim, norm_type=norm_type, bn_1st=bn_1st,
                 act_cls=act_cls, transpose=transpose, init=init, xtra=xtra, bias_std=bias_std)
        if dropout is not None and dropout > 0: s = nn.Dropout(dropout)(s)
        for i in range(n_conv-1):
            s = ConvLayer(n_out,n_out,ks=ks, stride=stride, padding=padding, bias=bias, ndim=ndim, norm_type=norm_type, bn_1st=bn_1st,
                 act_cls=act_cls, transpose=transpose, init=init, xtra=xtra, bias_std=bias_std)(s)
            if dropout is not None and dropout > 0: s = nn.Dropout(dropout)(s)
        return s

    return _conv

Alternatively, we can use fastai unet builder


In [18]:
from fastai.vision.all import unet_learner

In [19]:
# learn = unet_learner(dls, models.resnet18, loss_func=F.l1_loss, n_in=1, n_out=1, pretrained=False, cut=None)
# learn.summary()

In [20]:
#| hide
import nbdev; nbdev.nbdev_export()