In [1]:
import torch
import torchvision.transforms as T
from torch import nn
from torch.nn import functional as F
import os
from PIL import Image
from enum import Enum
import nbimporter
from dataset import VocDataset

In [None]:
def build_conv_block(in_channels, out_channels):
    print(f'in_channels={in_channels}, out_channels={out_channels}')
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(num_features=out_channels),
        nn.ReLU(),
        nn.Dropout(0.3),
        
        nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(num_features=out_channels),
        nn.ReLU(),
        nn.Dropout(0.3)
    )

In [None]:
class ConvBlock(nn.Module):
    skip_connections = []

    def __init__(self, in_channels, out_channels, encode=True):
        super(ConvBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.encode = encode
        self.conv = build_conv_block(in_channels=self.in_channels, out_channels=self.out_channels)
        if self.encode:
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        else:
            self.upconv = nn.ConvTranspose2d(in_channels=self.in_channels, out_channels=self.in_channels//2, kernel_size=2, stride=2)
    
    def forward(self, X):
        # print(f'X.shape input={X.shape}')
        if self.encode:
            X = self.conv(X)
            # print(f'X.shape output={X.shape}')
            self.skip_connections.append(X)
            return self.pool(X)
        else:
            X = self.upconv(X)
            # if X[0].shape[0] <= 256:
            #     print(X.shape)
            #     print(self.skip_connections[-1].shape)
            X = torch.cat((X, self.skip_connections.pop()), dim=1)
            # print(f'X.shape output={X.shape}')
            return self.conv(X)

In [None]:
class YouNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=21):
        super(YouNet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        # The down-sampling layers
        self.contractive_path = nn.ModuleDict({
            'encode0': ConvBlock(in_channels=self.in_channels, out_channels=64, encode=True),
            'encode1': ConvBlock(in_channels=64, out_channels=128, encode=True),
            'encode2': ConvBlock(in_channels=128, out_channels=256, encode=True),
            'encode3': ConvBlock(in_channels=256, out_channels=512, encode=True),
        })

        # The bottleneck
        self.trough = build_conv_block(in_channels=512, out_channels=1024)

        # The up-sampling layers
        # in_channels takes input from previous layer and skip connections
        self.expansive_path = nn.ModuleDict({
            'decode3': ConvBlock(in_channels=512*2, out_channels=512, encode=False),
            'decode2': ConvBlock(in_channels=256*2, out_channels=256, encode=False),
            'decode1': ConvBlock(in_channels=128*2, out_channels=128, encode=False),
            'decode0': ConvBlock(in_channels=64*2, out_channels=64, encode=False)
        })

        # The prediction layer
        self.final = nn.Conv2d(in_channels=64, out_channels=self.out_channels, kernel_size=1)
    
    def forward(self, X):
        skip_connections = []
        # Train the contractive path
        for conv_block in self.contractive_path:
            # print(f'X.shape input={X.shape}')
            X = self.contractive_path[conv_block](X)
            # print(f'X.shape output={X.shape}')
        
        # Train the trough
        X = self.trough(X)
        
        # Train the expansive path
        for conv_block in self.expansive_path:
            X = self.expansive_path[conv_block](X)
        
        return self.final(X)
    
    def print_hook_shape(self, module, input, output):
        '''Prints the input and output tensor of a given layer. Used by YouNet.print_forward_hooks'''
        print(f'{module.__class__.__name__}(input shape: {input[0].shape}, output shape: {output.shape})')

    def print_forward_hooks(self):
        '''Prints the input and output tensors of each layer.'''
        for name, layer in self.contractive_path.items():
            layer.register_forward_hook(self.print_hook_shape)
        
        self.trough.register_forward_hook(self.print_hook_shape)
        
        for name, layer in self.expansive_path.items():
            layer.register_forward_hook(self.print_hook_shape)
        
        self.final.register_forward_hook(self.print_hook_shape)
    
    def print_network(self):
        '''Prints the entire network architecture.'''
        for name, conv_block in self.contractive_path.items():
            print('Layer:', name)
            print(conv_block)

        print('Layer: bottleneck')
        print(self.trough)

        for name, conv_block in self.expansive_path.items():
            print('Layer:', name)
            print(conv_block)
        
        print('Layer: final\n', self.final, sep='')

# YouNet expects inputs with shape (any_batch_size, any_channel_size, any_even_height, any_even_width)
# TODO: Implement padding or shaving of inputs to accept any input height or width (even or odd)
# X = torch.randn((32, 3, 320, 480))
# net = YouNet(3, 3)

# net.print_network()
# print()
# net.print_forward_hooks()

# with torch.no_grad():
#     preds = net(X)

# print()
# print('Input shape:', X.shape)
# print('Output shape:', preds.shape)
# assert preds.shape == X.shape

In [5]:
voc_dir = 'C:/Users/Hayden/Machine Learning/d2l/d2l-en/pytorch/chapter_computer-vision/data/VOCdevkit/VOC2012/'
transform = T.Compose([
    T.ToTensor()
])
crop_size = (256, 256)
dataset = 'train'
train_set = VocDataset(voc_dir, transform, transform, crop_size, dataset)
val_set = VocDataset(voc_dir, transform, transform, crop_size, 'val')
test_set = VocDataset(voc_dir, transform, transform, crop_size, 'test')
print(len(train_set))
print(len(val_set))
print(len(test_set))

1444
1426
1429


In [6]:
test_set[0]

tensor([[[-1.6727, -1.8268, -1.9809,  ...,  0.9303,  1.1187,  1.4098],
         [-1.6898, -1.8268, -1.9980,  ...,  0.9474,  1.1700,  1.4612],
         [-1.7069, -1.8097, -1.9809,  ...,  0.9646,  1.2214,  1.3927],
         ...,
         [ 2.0263,  2.1804,  2.2489,  ..., -2.1179, -2.1179, -1.9980],
         [ 1.5810,  1.9749,  2.2318,  ..., -2.0494, -2.0665, -2.1179],
         [ 0.3823,  0.9474,  1.5810,  ..., -2.0323, -2.1179, -2.1179]],

        [[-1.4230, -1.4930, -1.6681,  ...,  1.3606,  1.5532,  1.8333],
         [-1.4405, -1.4930, -1.6856,  ...,  1.3782,  1.6057,  1.8508],
         [-1.4755, -1.5105, -1.6856,  ...,  1.3957,  1.6583,  1.7808],
         ...,
         [ 2.0784,  2.3060,  2.4286,  ..., -1.8782, -1.9657, -1.9132],
         [ 1.3606,  1.9559,  2.3410,  ..., -1.8431, -1.9657, -2.0357],
         [-0.3200,  0.4328,  1.2031,  ..., -1.8782, -1.9832, -2.0357]],

        [[-1.5604, -1.6999, -1.7870,  ...,  1.0714,  1.2631,  1.5071],
         [-1.5779, -1.6999, -1.8044,  ...,  1

In [7]:
image, mask = train_set[0]
print(image.shape, mask.shape)
image = torch.unsqueeze(image, dim=0)
mask = torch.unsqueeze(mask, dim=0)
print(image.shape, mask.shape)
mynet = YouNet(3, 3).cuda()
with torch.no_grad():
    pred = mynet(image.cuda())

torch.Size([3, 256, 256]) torch.Size([1, 256, 256])
torch.Size([1, 3, 256, 256]) torch.Size([1, 1, 256, 256])
in_channels=3, out_channels=64
in_channels=64, out_channels=128
in_channels=128, out_channels=256
in_channels=256, out_channels=512
in_channels=512, out_channels=1024
in_channels=1024, out_channels=512
in_channels=512, out_channels=256
in_channels=256, out_channels=128
in_channels=128, out_channels=64
X.shape input=torch.Size([1, 3, 256, 256])
X.shape output=torch.Size([1, 64, 256, 256])
X.shape input=torch.Size([1, 64, 128, 128])
X.shape output=torch.Size([1, 128, 128, 128])
X.shape input=torch.Size([1, 128, 64, 64])
X.shape output=torch.Size([1, 256, 64, 64])
X.shape input=torch.Size([1, 256, 32, 32])
X.shape output=torch.Size([1, 512, 32, 32])
X.shape input=torch.Size([1, 1024, 16, 16])
X.shape output=torch.Size([1, 1024, 32, 32])
X.shape input=torch.Size([1, 512, 32, 32])
X.shape output=torch.Size([1, 512, 64, 64])
X.shape input=torch.Size([1, 256, 64, 64])
X.shape output=t

In [8]:
print(pred.shape)
print(pred)

torch.Size([1, 3, 256, 256])
tensor([[[[ 0.0979, -0.8878, -0.1129,  ...,  0.2894, -0.4189,  0.6418],
          [-0.6275,  0.1431,  0.4982,  ...,  0.0651,  0.1573,  1.4438],
          [ 0.0277, -0.0104, -0.0405,  ..., -0.2748, -0.2521,  0.3780],
          ...,
          [ 0.2540,  0.6698,  0.6048,  ...,  0.6254,  0.9863,  0.5160],
          [ 0.1728,  0.4521,  0.1488,  ...,  0.4075,  0.5467,  0.5657],
          [ 0.6099,  0.2704,  0.1651,  ...,  0.2091,  0.4707,  0.2409]],

         [[-0.7779, -0.5001, -1.8573,  ..., -1.0764, -1.1688, -0.0066],
          [-0.2200,  0.1282, -1.1821,  ..., -0.7356, -1.0485, -0.6675],
          [-0.4095, -0.8056, -1.1746,  ...,  0.0949, -1.2966, -0.1276],
          ...,
          [-0.4032, -0.0628, -0.3616,  ..., -0.5273, -0.6424, -0.5477],
          [ 0.0255,  0.1283,  0.1863,  ..., -0.2552, -0.2172, -0.5096],
          [ 0.1114, -0.2135, -0.4068,  ..., -0.3544, -1.0439, -0.8867]],

         [[ 1.3020, -0.2616,  0.2513,  ...,  0.2454,  0.5113, -0.1288],
 