In [171]:
import torch
import torchvision

from torch import nn
#from torchsummary import summary

In [20]:
class EncoderBlock(nn.Module):

    def __init__(self, d_in, d_out):

        super(EncoderBlock, self).__init__()

        self.conv_1 = nn.Sequential(
            nn.Conv2d(d_in, d_out, 3, 1, "same"),
            nn.ReLU()
        )

        self.pool = nn.MaxPool2d(2, 2)

        self.conv_2 = nn.Sequential(
            nn.Conv2d(d_out, d_out, 3, 1, "same"),
            nn.ReLU()
        )
    
    def forward(self, inputs):

        a = self.conv_1(inputs)
        a = self.conv_2(a)
        x = self.pool(a)
        x = self.conv_2(x)

        return x, a

In [21]:
class LastEncoder(nn.Module):

    def __init__(self, d_in, d_out):

        super(LastEncoder, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(d_in, d_out, 3, 1, "same"),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(d_out, d_out, 3, 1, "same"),
            nn.ReLU()
        )

    def forward(self, inputs):

        x = self.conv2(self.conv1(inputs))

        return x

In [24]:
class FullEncoder(nn.Module):

    def __init__(self, d_in, filters):

        super(FullEncoder, self).__init__()

        self.encoder_blocks = []
        for f in filters[:-1]:

            encoder = EncoderBlock(d_in, f)
            self.encoder_blocks.append(encoder)
            d_in = f
        
        self.last_encoder = LastEncoder(f, filters[-1])


    def forward(self, inputs):

        activations = []
        x = inputs
        for eb in self.encoder_blocks:
            x, a = eb(x)
            activations.append(a)
        
        x = self.last_encoder(x)

        return x, activations


In [25]:
unet_encoder = FullEncoder(3, [64, 128, 256, 512, 1024])

## Decoder for the U-Net

In [27]:
class DecoderBlock(nn.Module):

    def __init__(self, d_in, d_out):

        super(DecoderBlock, self).__init__()
        self.upconv = nn.Sequential(
            nn.ConvTranspose2d(d_in, d_out, 2, 2, padding="same"),
            nn.ReLU()
        )

        self.conv = nn.Sequential(
            nn.Conv2d(d_out, d_out, 3, 1, padding="same"),
            nn.ReLU()
        )

    def forward(self, inp, a):

        x = self.upconv(inp)
        if a is not None:
            x = torch.cat([a, x], axis=-1)
        
        x = self.conv(self.conv(x))
    
        return x

In [31]:
class Decoder(nn.Module):

    def __init__(self, d_in, filters, num_classes):

        super(Decoder, self).__init__()

        self.decoder_blocks = []

        for f in filters:

            self.db = DecoderBlock(d_in, f)
            self.decoder_blocks.append(self.db)
            d_in = f
        
        self.output = nn.Conv2d(f, num_classes, 1, 1)
    
    def forward(self, inputs, activations):

        x = inputs
        for db, a in zip(self.decoder_blocks, activations):

            x = db(x, a)
        
        output = self.output(x)

        return output
        


## Full U-Net

In [32]:
class UNet(nn.Module):

    def __init__(self, d_in, num_classes, filters):

        super(UNet, self).__init__()
        self.encoder = FullEncoder(d_in, filters[:-1])

        self.decoder = Decoder(d_in, filters[:-1][::-1], num_classes)
    
    def forward(self,inputs):

        x, activations = self.encoder(inputs)

        o = self.decoder(x, activations[::-1])

In [33]:
unet = UNet(3, 5, [64, 128, 256, 512, 1024])

## Training Loop 

## VGG U-NET

In [152]:
from torchvision.models import vgg19, VGG19_Weights

vgg19 = vgg19(weights=VGG19_Weights.DEFAULT)

In [153]:
vgg19_backbone = nn.Sequential(*(list(vgg19.children())[0][:-1]))

In [154]:
vgg19_backbone

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [155]:
activations = []
def getActivations():
    def hook(model, input, output):
        activations.append(output)
    return hook

In [156]:
h1 = vgg19_backbone[3].register_forward_hook(getActivations())
h2 = vgg19_backbone[8].register_forward_hook(getActivations())
h3 = vgg19_backbone[17].register_forward_hook(getActivations())
h4 = vgg19_backbone[26].register_forward_hook(getActivations())


In [157]:
test_img = torch.zeros((1, 3, 512, 512))

In [158]:
activations

[]

In [159]:
out = vgg19_backbone(test_img)

In [160]:
out.shape

torch.Size([1, 512, 32, 32])

In [161]:
print(activations[0].shape)
print(activations[1].shape)
print(activations[2].shape)
print(activations[3].shape)

torch.Size([1, 64, 512, 512])
torch.Size([1, 128, 256, 256])
torch.Size([1, 256, 128, 128])
torch.Size([1, 512, 64, 64])


## ResUNet

In [130]:
from torchvision.models import resnet50, ResNet50_Weights

resnet_model = resnet50(weights=ResNet50_Weights.DEFAULT)

In [132]:
resnet_backbone = nn.Sequential(*(list(resnet_model.children())[0:7]))

In [170]:
resnet_backbone

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [133]:
hr1 = resnet_backbone[2].register_forward_hook(getActivations())
hr2 = resnet_backbone[4][2].relu.register_forward_hook(getActivations())
hr3 = resnet_backbone[5][-1].relu.register_forward_hook(getActivations())

In [134]:
out = resnet_backbone(test_img)

In [135]:
print(activations[0].shape)
print(activations[1].shape)
print(activations[-1].shape)

torch.Size([1, 64, 256, 256])
torch.Size([1, 64, 128, 128])
torch.Size([1, 512, 64, 64])


In [136]:
out.shape

torch.Size([1, 1024, 32, 32])

## Pretrained U-Net Testing

In [169]:
net = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana', pretrained=True, scale=0.5)

Using cache found in /Users/akhtar/.cache/torch/hub/milesial_Pytorch-UNet_master


TypeError: unet_carvana() got an unexpected keyword argument 'map_location'