In [1]:
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
import os
from PIL import Image

In [None]:
def build_conv_block(in_channels, encode=True):
    if encode:
        return nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels*2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channels*2),
            nn.ReLU(),
            
            nn.Conv2d(in_channels=in_channels*2, out_channels=in_channels*2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channels*2),
            nn.ReLU(),
        )
    else:
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels, out_channels=in_channels/2, kernel_size=2, stride=2, padding=1),

            nn.Conv2d(in_channels=in_channels/2, out_channels=in_channels/2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channels/2),
            nn.ReLU(),
            
            nn.Conv2d(in_channels=in_channels/2, out_channels=in_channels/2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channels/2),
            nn.ReLU(),
        )

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, encode=True):
        super(ConvBlock, self).__init__()
        self.encode = encode
        self.conv = build_conv_block(in_channels=in_channels, encode=encode)
    
    def forward(self, X):
        if self.encode:
            skip_connection = self.conv(X)
            return skip_connection, nn.MaxPool2d(kernel_size=2, stride=2)
        else:
            return self.conv(X)

In [None]:
class YouNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(YouNet, self).__init__()
        self.contractive_path = nn.ModuleDict({
            'encode1': ConvBlock(in_channels=512, encode=True),
            'encode2': ConvBlock(in_channels=256, encode=True),
            'encode3': ConvBlock(in_channels=128, encode=True),
            'encode4': ConvBlock(in_channels=64, encode=True)
        })
        self.trough = nn.Sequential(
            nn.Conv2d(in_channels=512*2, out_channels=512*2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512*2),
            nn.ReLU(),

            nn.Conv2d(in_channels=521*2, out_channels=512*2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512*2),
            nn.ReLU(),
        )
        self.expansive_path = nn.ModuleDict({
            'decode1': ConvBlock(in_channels=64*2, encode=False),
            'decode2': ConvBlock(in_channels=128*2, encode=False),
            'decode3': ConvBlock(in_channels=256*2, encode=False),
            'decode4': ConvBlock(in_channels=512*2, encode=False)
        })
        self.final_layer = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)
    
    def forward(self, X):
        skip_connections = []

        # Train the contractive path
        for conv_block in self.contractive_path.items():
            skip, X = conv_block(X)
            skip_connections.append(skip)
        
        # Train the trough
        X = self.trough(X)

        # Train the expansive path
        for conv_block in self.expansive_path.items():
            X = torch.cat((X, skip_connections.pop()))
            X = conv_block(X)
    

In [None]:
def build_encoder_block(in_channels):
    '''Builds the contracting path of the U-Net.'''
    skip_connection = build_conv_block(in_channels=in_channels, encode=True)
    encoder = nn.MaxPool2d(kernel_size=2, stride=2)
    
    return skip_connection, encoder

In [None]:
def build_decoder_block(in_channels):
    '''Builds the expansive path of the U-Net.'''
    # concatenation = torch.concat((input, skip_connection), dim=1)
    decoder = build_conv_block(in_channels=in_channels, encode=False)
    return decoder

In [None]:
def build_trough_block(in_channels):
    trough = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=in_channels*2, kernel_size=3, stride=1),
        nn.BatchNorm2d(num_features=in_channels*2),
        nn.ReLU(),

        nn.Conv2d(in_channels=in_channels*2, out_channels=in_channels*2, kernel_size=3, stride=1),
        nn.BatchNorm2d(num_feeatures=in_channels*2),
        nn.ReLU()
    )

    return trough

In [None]:
def train_U_Net():
    skip1, encoder1 = build_encoder_block(256, 256)
    skip2, encoder2 = build_encoder_block(256, 256)
    skip3, encoder3 = build_encoder_block(256, 256)
    skip4, encoder4 = build_encoder_block()
    
    trough = build_trough_block()
    
    decoder1 = build_decoder_block()
    decoder2 = build_decoder_block()
    decoder3 = build_decoder_block()
    decoder4 = build_decoder_block()

    

