In [63]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [64]:
class encoderBlocks(nn.Module):
    def __init__(self,levelwise_filter_config):
        super().__init__()
        self.network_block=nn.Sequential(
            nn.Conv2d(in_channels=levelwise_filter_config[0],out_channels=levelwise_filter_config[1],kernel_size=3,stride=1,padding=0),
            nn.ReLU(),
            nn.Conv2d(in_channels=levelwise_filter_config[1],out_channels=levelwise_filter_config[2],kernel_size=3,stride=1,padding=0),
            nn.ReLU()
        )
        self.pool=nn.MaxPool2d(kernel_size=2,stride=2,padding=0)
    def forward(self,x):
        feature_maps=self.network_block(x)
        max_pooled_output=self.pool(feature_maps)
        return max_pooled_output,feature_maps

In [65]:
class encoder(nn.Module):
    def __init__(self,no_of_blocks,block_wise_filter_config):
        super().__init__()
        assert no_of_blocks==len(block_wise_filter_config)
        self.encoder=nn.ModuleList(
            [encoderBlocks(block_wise_filter_config[i]) for i in range(no_of_blocks)]
        )
    def forward(self,x):
        skip_inputs=[]
        for block in self.encoder:
            x,feature_maps=block(x)
            skip_inputs.append(feature_maps)
        return x,skip_inputs

In [66]:
class bottleNeckLayer(nn.Module):
    def __init__(self,bottleneck_filter_config):
        super().__init__()
        self.network=nn.Sequential(
            nn.Conv2d(in_channels=bottleneck_filter_config[0],out_channels=bottleneck_filter_config[1],kernel_size=3,stride=1,padding=0),
            nn.ReLU(),
            nn.Conv2d(in_channels=bottleneck_filter_config[1],out_channels=bottleneck_filter_config[2],kernel_size=3,stride=1,padding=0),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=bottleneck_filter_config[2],out_channels=bottleneck_filter_config[2]//2,kernel_size=2,stride=2,padding=0)
        )
        def forward(self,x):
            bottleneck_output=self.network(x)
            return bottleneck_output

In [67]:
class decoderBlock(nn.Module):
    def __init__(self,levelwise_filter_config):
        super().__init__()
        self.network=nn.Sequential(
            nn.Conv2d(in_channels=levelwise_filter_config[0],out_channels=levelwise_filter_config[1],kernel_size=3,stride=1,padding=0),
            nn.ReLU(),
            nn.Conv2d(in_channels=levelwise_filter_config[1],out_channels=levelwise_filter_config[2],kernel_size=3,stride=1,padding=0),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=levelwise_filter_config[2],out_channels=levelwise_filter_config[2]//2,kernel_size=2,stride=2,padding=0)
        )
    def forward(self,x,y):
        ### assiming x is networks output and y is skip connections input.
        #### batch_size,chanel,height,width
        outputs_dim=x.shape[2]
        all_side_cropping_pad=(y.shape[2]-outputs_dim)//2
        y=y[:,:,all_side_cropping_pad:y.shape[2]-all_side_cropping_pad,all_side_cropping_pad:y.shape[3]-all_side_cropping_pad]
        assert x.shape==y.shape
        combined_input=torch.cat([x,y],dim=1)
        output=self.network(combined_input)
        return output


In [68]:
class decoder(nn.Module):
    ## will need 3 decoder block
    def __init__(self,no_of_blocks,block_wise_filter_config):
        super().__init__()
        self.decoder=nn.ModuleList(
            [decoderBlock(block_wise_filter_config[i]) for i in range(no_of_blocks)]
        )
        self.final_block=nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=64,kernel_size=3,stride=1,padding=0),
            nn.ReLU(),
            nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=0),
            nn.ReLU(),
            nn.Conv2d(in_channels=64,out_channels=2,kernel_size=1,stride=1,padding=0)
        )
    def forward(self,bottlenecks_output,skip_inputs):
        for i,block in enumerate(self.decoder):
            bottlenecks_output=block(bottlenecks_output,skip_inputs[-(i+1)])

        final_output=self.final_block(bottlenecks_output,skip_inputs[0])
        return final_output


In [69]:
class Unet(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.encoder=encoder(config['encoder']['no_of_blocks'],config['encoder']['block_wise_filter_config'])
        self.bottle_neck=bottleNeckLayer(config['bottleneck_filter_config'])
        self.decoder=decoder(config['decoder']['no_of_blocks'],config['decoder']['block_wise_filter_config'])

    def forward(self,x):
        encoder_output,skip_inputs=self.encoder(x)
        bottleneck_output=self.bottle_neck(encoder_output)
        decoder_output=self.decoder(bottleneck_output,skip_inputs)
        return decoder_output