## Segnet Trial for Inference
<img src='https://ars.els-cdn.com/content/image/1-s2.0-S2452414X20300194-gr3.jpg' heigth=600 width=800>

The following model is ready for inference and only that ! It has not been trained so the outputs are completly random. 

In [2]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

In [59]:
class SegNet(nn.Module):
    def __init__(self, classes=10):
        super(SegNet, self).__init__()

        batchNorm_momentum = 0.1

        self.conv11 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn11 = nn.BatchNorm2d(64, momentum=batchNorm_momentum)
        self.conv12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn12 = nn.BatchNorm2d(64, momentum=batchNorm_momentum)

        self.conv21 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn21 = nn.BatchNorm2d(128, momentum=batchNorm_momentum)
        self.conv22 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn22 = nn.BatchNorm2d(128, momentum=batchNorm_momentum)

        self.conv31 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn31 = nn.BatchNorm2d(256, momentum=batchNorm_momentum)
        self.conv32 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn32 = nn.BatchNorm2d(256, momentum=batchNorm_momentum)
        self.conv33 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn33 = nn.BatchNorm2d(256, momentum=batchNorm_momentum)

        self.conv41 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.bn41 = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
        self.conv42 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn42 = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
        self.conv43 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn43 = nn.BatchNorm2d(512, momentum=batchNorm_momentum)

        self.conv51 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn51 = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
        self.conv52 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn52 = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
        self.conv53 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn53 = nn.BatchNorm2d(512, momentum=batchNorm_momentum)

        self.conv53d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn53d = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
        self.conv52d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn52d = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
        self.conv51d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn51d = nn.BatchNorm2d(512, momentum=batchNorm_momentum)

        self.conv43d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn43d = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
        self.conv42d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn42d = nn.BatchNorm2d(512, momentum=batchNorm_momentum)
        self.conv41d = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.bn41d = nn.BatchNorm2d(256, momentum=batchNorm_momentum)

        self.conv33d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn33d = nn.BatchNorm2d(256, momentum=batchNorm_momentum)
        self.conv32d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn32d = nn.BatchNorm2d(256, momentum=batchNorm_momentum)
        self.conv31d = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.bn31d = nn.BatchNorm2d(128, momentum=batchNorm_momentum)

        self.conv22d = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn22d = nn.BatchNorm2d(128, momentum=batchNorm_momentum)
        self.conv21d = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.bn21d = nn.BatchNorm2d(64, momentum=batchNorm_momentum)

        self.conv12d = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn12d = nn.BatchNorm2d(64, momentum=batchNorm_momentum)
        self.conv11d = nn.Conv2d(64, classes, kernel_size=3, padding=1)

    def forward(self, x):
        # Stage 1
        x11 = F.relu(self.bn11(self.conv11(x)))
        x12 = F.relu(self.bn12(self.conv12(x11)))
        x1_size = x12.size()
        x1p, id1 = F.max_pool2d(x12, kernel_size=2, stride=2, return_indices=True)

        # Stage 2
        x21 = F.relu(self.bn21(self.conv21(x1p)))
        x22 = F.relu(self.bn22(self.conv22(x21)))
        x2_size = x22.size()
        x2p, id2 = F.max_pool2d(x22, kernel_size=2, stride=2, return_indices=True)

        # Stage 3
        x31 = F.relu(self.bn31(self.conv31(x2p)))
        x32 = F.relu(self.bn32(self.conv32(x31)))
        x33 = F.relu(self.bn33(self.conv33(x32)))
        x3_size = x33.size()
        x3p, id3 = F.max_pool2d(x33, kernel_size=2, stride=2, return_indices=True)

        # Stage 4
        x41 = F.relu(self.bn41(self.conv41(x3p)))
        x42 = F.relu(self.bn42(self.conv42(x41)))
        x43 = F.relu(self.bn43(self.conv43(x42)))
        x4_size = x43.size()
        x4p, id4 = F.max_pool2d(x43, kernel_size=2, stride=2, return_indices=True)

        # Stage 5
        x51 = F.relu(self.bn51(self.conv51(x4p)))
        x52 = F.relu(self.bn52(self.conv52(x51)))
        x53 = F.relu(self.bn53(self.conv53(x52)))
        x5_size = x53.size()
        x5p, id5 = F.max_pool2d(x53, kernel_size=2, stride=2, return_indices=True)

        # Stage 5d
        x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=2, output_size=x5_size)
        x53d = F.relu(self.bn53d(self.conv53d(x5d)))
        x52d = F.relu(self.bn52d(self.conv52d(x53d)))
        x51d = F.relu(self.bn51d(self.conv51d(x52d)))

        # Stage 4d
        x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2, output_size=x4_size)
        x43d = F.relu(self.bn43d(self.conv43d(x4d)))
        x42d = F.relu(self.bn42d(self.conv42d(x43d)))
        x41d = F.relu(self.bn41d(self.conv41d(x42d)))

        # Stage 3d
        x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2, output_size=x3_size)
        x33d = F.relu(self.bn33d(self.conv33d(x3d)))
        x32d = F.relu(self.bn32d(self.conv32d(x33d)))
        x31d = F.relu(self.bn31d(self.conv31d(x32d)))

        # Stage 2d
        x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2, output_size=x2_size)
        x22d = F.relu(self.bn22d(self.conv22d(x2d)))
        x21d = F.relu(self.bn21d(self.conv21d(x22d)))

        # Stage 1d
        x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2, output_size=x1_size)
        x12d = F.relu(self.bn12d(self.conv12d(x1d)))
        x11d = self.conv11d(x12d)

        #x11d = torch.nn.functional.softmax(x11d,dim=1)
        #x11d = torch.argmax(x11d, dim=1,keepdim=True)    # functioin deestroys the gradients, so the loss cannot be computed ! 

        """_summary_
        This error is raised if the model output or loss has been detached from the computation graph e.g. via:

        using another library such as numpy
        using non-differentiable operations such as torch.argmax
        explicitly detaching the tensor via tensor = tensor.detach()
        rewrapping the tensor via x = torch.tensor(x)
        or if the gradient calculation was disabled in the current context or globally such that no computation graph was created at all.
        """

        return x11d
    
    # Holy shit and mother mamma mia ! Quanto adesso sto bestemmiando !!! 
    def postprocessing(self, x, classes=10):
        """
        The raw SegNet outputs a final tensor of [batch,classes,height,width] of elements where [k,0,0] is a value to determine the probability that the pixel in [0,0] belongs to class K 
        1st: use the Softmak function to transform all elements along dim=1 so that along [batch,k,:,:] all the elements are probability with sum 1
        2nd: argmax function return the K class to which the softmax function provides a better chance to be, thus obtaining
        
        Output: a [batch,1,height,width] Mask with values from (0,classes)
        """

        x = torch.nn.functional.softmax(x,dim=1)
        x = torch.argmax(x, dim=1,keepdim=True)
        return x.to(torch.float32)
        


"""
    # From https://github.com/vinceecws/SegNet_PyTorch
    @staticmethod 
    def save_checkpoint(state, path):
        torch.save(state, path)
        print("Checkpoint saved at {}".format(path))

    @staticmethod
    def Train(trainloader, path=None): #epochs is target epoch, path is provided to load saved checkpoint

        model = SegNet()
        optimizer = optim.SGD(model.parameters(), lr=hyperparam.lr, momentum=hyperparam.momentum)
        loss_fn = nn.CrossEntropyLoss()
        run_epoch = hyperparam.epochs

        if path == None:
            epoch = 0
            path = os.path.join(os.getcwd(), 'segnet_weights.pth.tar')
            print("Creating new checkpoint '{}'".format(path))
        else:
            if os.path.isfile(path):
                print("Loading checkpoint '{}'".format(path))
                checkpoint = torch.load(path)
                epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("Loaded checkpoint '{}' (epoch {})".format(path, checkpoint['epoch']))
            else:
                print("No checkpoint found at '{}'".format(path))
                

        for i in range(1, run_epoch + 1):
            print('Epoch {}:'.format(i))
            sum_loss = 0.0

            for j, data in enumerate(trainloader, 1):
                images, labels = data
                optimizer.zero_grad()
                output = model(images)
                loss = loss_fn(output, labels)
                loss.backward()
                optimizer.step()

                sum_loss += loss.item()

                print('Loss at {} mini-batch: {}'.format(j, loss.item()/trainloader.batch_size))

            print('Average loss @ epoch: {}'.format((sum_loss/j*trainloader.batch_size)))

        print("Training complete. Saving checkpoint...")
        Train.save_checkpoint({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict()}, path)
    """

'\n    # From https://github.com/vinceecws/SegNet_PyTorch\n    @staticmethod \n    def save_checkpoint(state, path):\n        torch.save(state, path)\n        print("Checkpoint saved at {}".format(path))\n\n    @staticmethod\n    def Train(trainloader, path=None): #epochs is target epoch, path is provided to load saved checkpoint\n\n        model = SegNet()\n        optimizer = optim.SGD(model.parameters(), lr=hyperparam.lr, momentum=hyperparam.momentum)\n        loss_fn = nn.CrossEntropyLoss()\n        run_epoch = hyperparam.epochs\n\n        if path == None:\n            epoch = 0\n            path = os.path.join(os.getcwd(), \'segnet_weights.pth.tar\')\n            print("Creating new checkpoint \'{}\'".format(path))\n        else:\n            if os.path.isfile(path):\n                print("Loading checkpoint \'{}\'".format(path))\n                checkpoint = torch.load(path)\n                epoch = checkpoint[\'epoch\']\n                model.load_state_dict(checkpoint[\'state_

In [60]:
model = SegNet()

In [61]:
model # print Model

SegNet(
  (conv11): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn11): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv12): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn12): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv21): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn21): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv22): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn22): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv31): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn31): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv32): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn32): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, 

In [62]:
def main():
    x = torch.rand(1,3,224,224) # -> it is tensor
    model = SegNet()
    print(model(x).shape)
    print(f"Output: 1 channel for every single class\n{model(x).shape}")
    print(f"POSTPROCESSING: {model.postprocessing(model(x)).shape} \n TYPE:{model.postprocessing(model(x)).dtype}")
    print(f"")

* Softmax Classification: The softmax function is applied to the output of the final layer in the decoder. This transforms the decoder’s output into a set of values between 0 and 1, which can be interpreted as probabilities. The sum of these values for each pixel will be 1.
* K-Channel Image: The output is a K-channel image, where each channel corresponds to one of the K classes. For each pixel, the value in a channel represents the probability that the pixel belongs to the corresponding class.
* Pixel-wise Maximum Probability: The predicted segmentation is obtained by assigning each pixel to the class with the maximum probability at that pixel. In other words, for each pixel, we look at the K probabilities in the K channels, and the class corresponding to the highest probability is chosen as the prediction for that pixel.

In [63]:
if __name__ == '__main__':
    main()

torch.Size([1, 10, 224, 224])
Output: 1 channel for every single class
torch.Size([1, 10, 224, 224])
POSTPROCESSING: torch.Size([1, 1, 224, 224, 10]) 
 TYPE:torch.float32



### SegNet e il modello della Rete

Allora finora i mask che pensavo sono [1,height,width] con fomrati int8 e ogni pixel ha un suo label. Finora tutto bene, ma quando provi a fare la segmentation che ti ritorivi con l’output si di UNet sia d iSegNet che sono [1,k,heihgt,width] pensi che potresti fare una inferenza con una argmaxz, ma l’argmax non tiene i gradienti. 

Per cui non esiste una approssimazione attorno a questa cosa ma l’approccio in generale è quello di creare il dataset da 0, con maschere binarie, una per ogni fottura classe alla fine, va che bello ! Almeno il percorso in avanti l’ho trovato diciamo. XD

So now we are writing the script in order to have n binary masks for the number of classes, and combine that in the freaking output ! Oh my freaking model, damn it ! This is why you need to become one with the data and one with the model before hand ! This is freaking an adventure honestly, and one woerth pursuing. 