In [2]:
import torch
import torchvision

from torch import nn
#from torchsummary import summary

  from .autonotebook import tqdm as notebook_tqdm


In [20]:
class EncoderBlock(nn.Module):

    def __init__(self, d_in, d_out):

        super(EncoderBlock, self).__init__()

        self.conv_1 = nn.Sequential(
            nn.Conv2d(d_in, d_out, 3, 1, "same"),
            nn.ReLU()
        )

        self.pool = nn.MaxPool2d(2, 2)

        self.conv_2 = nn.Sequential(
            nn.Conv2d(d_out, d_out, 3, 1, "same"),
            nn.ReLU()
        )
    
    def forward(self, inputs):

        a = self.conv_1(inputs)
        a = self.conv_2(a)
        x = self.pool(a)
        x = self.conv_2(x)

        return x, a

In [21]:
class LastEncoder(nn.Module):

    def __init__(self, d_in, d_out):

        super(LastEncoder, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(d_in, d_out, 3, 1, "same"),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(d_out, d_out, 3, 1, "same"),
            nn.ReLU()
        )

    def forward(self, inputs):

        x = self.conv2(self.conv1(inputs))

        return x

In [24]:
class FullEncoder(nn.Module):

    def __init__(self, d_in, filters):

        super(FullEncoder, self).__init__()

        self.encoder_blocks = []
        for f in filters[:-1]:

            encoder = EncoderBlock(d_in, f)
            self.encoder_blocks.append(encoder)
            d_in = f
        
        self.last_encoder = LastEncoder(f, filters[-1])


    def forward(self, inputs):

        activations = []
        x = inputs
        for eb in self.encoder_blocks:
            x, a = eb(x)
            activations.append(a)
        
        x = self.last_encoder(x)

        return x, activations


In [25]:
unet_encoder = FullEncoder(3, [64, 128, 256, 512, 1024])

## Decoder for the U-Net

In [27]:
class DecoderBlock(nn.Module):

    def __init__(self, d_in, d_out):

        super(DecoderBlock, self).__init__()
        self.upconv = nn.Sequential(
            nn.ConvTranspose2d(d_in, d_out, 2, 2, padding="same"),
            nn.ReLU()
        )

        self.conv = nn.Sequential(
            nn.Conv2d(d_out, d_out, 3, 1, padding="same"),
            nn.ReLU()
        )

    def forward(self, inp, a):

        x = self.upconv(inp)
        if a is not None:
            x = torch.cat([a, x], axis=-1)
        
        x = self.conv(self.conv(x))
    
        return x

In [31]:
class Decoder(nn.Module):

    def __init__(self, d_in, filters, num_classes):

        super(Decoder, self).__init__()

        self.decoder_blocks = []

        for f in filters:

            self.db = DecoderBlock(d_in, f)
            self.decoder_blocks.append(self.db)
            d_in = f
        
        self.output = nn.Conv2d(f, num_classes, 1, 1)
    
    def forward(self, inputs, activations):

        x = inputs
        for db, a in zip(self.decoder_blocks, activations):

            x = db(x, a)
        
        output = self.output(x)

        return output
        


## Full U-Net

In [32]:
class UNet(nn.Module):

    def __init__(self, d_in, num_classes, filters):

        super(UNet, self).__init__()
        self.encoder = FullEncoder(d_in, filters[:-1])

        self.decoder = Decoder(d_in, filters[:-1][::-1], num_classes)
    
    def forward(self,inputs):

        x, activations = self.encoder(inputs)

        o = self.decoder(x, activations[::-1])

In [33]:
unet = UNet(3, 5, [64, 128, 256, 512, 1024])

## Training Loop 

In [35]:
loss = torch.nn.BCELoss()

gt = torch.Tensor([[0,0,0],[0,1,0], [0,0,0], [0,0,0], [0,0,0], [0,0,0], [0,0,0], [0,0,0], [0,0,0], [0,0,0]])
pred = torch.Tensor([[0,0,0], [0,0,0], [0,0,0], [0,0,0], [0,0,0], [0,0,0], [0,0,0], [0,0,0], [0,0,0], [0,0,0]])

gt = torch.reshape(gt, (-1,))
pred = torch.reshape(pred, (-1,))

l = loss(gt, pred)
l.item()



3.3333332538604736

In [24]:
def train_net(network, data, epochs, batch_size,loss_function, optimizer, log=True, device="cuda"):

    """
    This function implements for the training of a deep neural network : Per epoch, iterate over all the batches, for each batch compute the outpus, compute the loss with the ground truth, update the model parameters using the gradients.
    network : The network to train, it is expected that the forward function will be implemented
    data : Already prebatched data : image data and the annotations (images, annotations)
    epochs : Number of epochs
    loss : the initialized loss function
    optimizer : the optimizer alogirthm to update the model parameters, should be initialized with model parameters
    """
    images, annotations = data
    losses = []
    for e in range(epochs):
        elapsed_time = 0
        st = time.time()
        loss_value = 0
        
        for img_batch, annotation_batch in zip(images, annotations):

            img_batch = img_batch.to(device)
            annotation_batch = annotation_batch.to(device)

            optimizer.zero_grad()

            #mMake model prediction
            pred = network(img_batch)

            #Reshape predictions and ground truth to linear
            #pred = torch.reshape(pred, (-1,))
            #gt = torch.reshape(annotation_batch, (-1,))

            #Compute Loss
            loss = loss_function(gt, pred)

            #Calculate gradients through backpropagation
            loss.backward()
            
            #Update the model parameters
            optimizer.step()

            loss_value += loss.item()
        if log:
            print(f"Loss at epoch : {e} : {round(loss_value / batch_size, 3)}")
        et = time.time()
        elapsed_time = et - st
        print(f"Epoch : {e} took {elapsed_time} seconds")

    return losses



