In [1]:
import torch
import torchvision.transforms as T
from torch import nn
from torch.nn import functional as F
import os
from PIL import Image
from enum import Enum
import nbimporter
from dataset import VocDataset

In [2]:
def build_conv_block(in_channels, out_channels):
    print(f'in_channels={in_channels}, out_channels={out_channels}')
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(num_features=out_channels),
        nn.ReLU(),
        
        nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(num_features=out_channels),
        nn.ReLU(),
    )

In [None]:
class ConvBlock(nn.Module):
    skip_connections = []

    def __init__(self, in_channels, out_channels, encode=True):
        super(ConvBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.encode = encode
        self.conv = build_conv_block(in_channels=self.in_channels, out_channels=self.out_channels)
    
    def forward(self, X):
        print(f'X.shape input={X.shape}')
        if self.encode:
            X = self.conv(X)
            print(f'X.shape output={X.shape}')
            self.skip_connections.append(X)
            return nn.MaxPool2d(kernel_size=2, stride=2)(X)
        else:
            X = nn.ConvTranspose2d(in_channels=self.in_channels, out_channels=self.in_channels//2, kernel_size=2, stride=2)(X)
            # if X[0].shape[0] <= 256:
            #     print(X.shape)
            #     print(self.skip_connections[-1].shape)
            X = torch.cat((X, self.skip_connections.pop()), dim=1)
            print(f'X.shape output={X.shape}')
            return self.conv(X)

In [4]:
class YouNet(nn.Module):
    def __init__(self, in_channels=64, out_channels=64):
        super(YouNet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        # The down-sampling layers
        self.contractive_path = nn.ModuleDict({
            'encode0': ConvBlock(in_channels=self.in_channels, out_channels=64, encode=True),
            'encode1': ConvBlock(in_channels=64, out_channels=128, encode=True),
            'encode2': ConvBlock(in_channels=128, out_channels=256, encode=True),
            'encode3': ConvBlock(in_channels=256, out_channels=512, encode=True),
        })

        # The bottleneck
        self.trough = build_conv_block(in_channels=512, out_channels=1024)

        # The up-sampling layers
        # in_channels takes input from previous layer and skip connections
        self.expansive_path = nn.ModuleDict({
            'decode3': ConvBlock(in_channels=512*2, out_channels=512, encode=False),
            'decode2': ConvBlock(in_channels=256*2, out_channels=256, encode=False),
            'decode1': ConvBlock(in_channels=128*2, out_channels=128, encode=False),
            'decode0': ConvBlock(in_channels=64*2, out_channels=64, encode=False)
        })

        # The prediction layer
        self.final = nn.Conv2d(in_channels=64, out_channels=self.out_channels, kernel_size=1)
    
    def forward(self, X):
        # Train the contractive path
        for conv_block in self.contractive_path:
            # print(f'X.shape input={X.shape}')
            X = self.contractive_path[conv_block](X)
            # print(f'X.shape output={X.shape}')
        
        # Train the trough
        X = self.trough(X)
        
        # Train the expansive path
        for conv_block in self.expansive_path:
            X = self.expansive_path[conv_block](X)
        
        return self.final(X)
    
    def print_hook_shape(self, module, input, output):
        '''Prints the input and output tensor of a given layer. Used by YouNet.print_forward_hooks'''
        print(f'{module.__class__.__name__}(input shape: {input[0].shape}, output shape: {output.shape})')

    def print_forward_hooks(self):
        '''Prints the input and output tensors of each layer.'''
        for name, layer in self.contractive_path.items():
            layer.register_forward_hook(self.print_hook_shape)
        
        self.trough.register_forward_hook(self.print_hook_shape)
        
        for name, layer in self.expansive_path.items():
            layer.register_forward_hook(self.print_hook_shape)
        
        self.final.register_forward_hook(self.print_hook_shape)
    
    def print_network(self):
        '''Prints the entire network architecture.'''
        for name, conv_block in self.contractive_path.items():
            print('Layer:', name)
            print(conv_block)

        print('Layer: bottleneck')
        print(self.trough)

        for name, conv_block in self.expansive_path.items():
            print('Layer:', name)
            print(conv_block)
        
        print('Layer: final\n', self.final, sep='')

# YouNet expects inputs with shape (any_batch_size, any_channel_size, any_even_height, any_even_width)
# TODO: Implement padding or shaving of inputs to accept any input height or width (even or odd)
# X = torch.randn((32, 3, 320, 480))
# net = YouNet(3, 3)

# net.print_network()
# print()
# net.print_forward_hooks()

# with torch.no_grad():
#     preds = net(X)

# print()
# print('Input shape:', X.shape)
# print('Output shape:', preds.shape)
# assert preds.shape == X.shape

In [5]:
voc_dir = 'C:/Users/Hayden/Machine Learning/d2l/d2l-en/pytorch/chapter_computer-vision/data/VOCdevkit/VOC2012/'
transform = T.Compose([
    T.ToTensor()
])
crop_size = (256, 256)
dataset = 'train'
train_set = VocDataset(voc_dir, transform, transform, crop_size, dataset)
val_set = VocDataset(voc_dir, transform, transform, crop_size, 'val')
test_set = VocDataset(voc_dir, transform, transform, crop_size, 'test')
print(len(train_set))
print(len(val_set))
print(len(test_set))

1444
1426
1429


In [6]:
test_set[0]

tensor([[[-0.3883, -0.2856, -0.1999,  ...,  0.8447,  0.3823,  0.6392],
         [-0.3712, -0.3883, -0.3541,  ...,  0.8789,  0.3823,  0.6563],
         [-0.3712, -0.4911, -0.5082,  ...,  0.8447,  0.3652,  0.6392],
         ...,
         [ 0.6906, -0.3712, -0.3541,  ..., -0.9705, -0.8849, -0.7822],
         [-0.1657, -0.1143, -0.2684,  ..., -0.9020, -0.7308, -0.6281],
         [-0.0629, -0.0629, -0.2171,  ..., -1.0390, -0.8678, -0.7479]],

        [[-0.8277, -0.7227, -0.6527,  ...,  0.3277, -0.3200, -0.1450],
         [-0.8102, -0.8277, -0.8102,  ...,  0.3102, -0.3200, -0.1275],
         [-0.8102, -0.9328, -0.9678,  ...,  0.2752, -0.3725, -0.1800],
         ...,
         [-0.3901, -1.6506, -1.5980,  ..., -1.5805, -1.4755, -1.3880],
         [-1.5805, -1.4755, -1.5630,  ..., -1.7031, -1.6155, -1.5105],
         [-1.4580, -1.4230, -1.5805,  ..., -1.7556, -1.6506, -1.5455]],

        [[-1.4907, -1.3861, -1.3164,  ..., -0.8981, -1.3861, -1.1073],
         [-1.4384, -1.4907, -1.4733,  ..., -0

In [9]:
image, mask = train_set[0]
print(image.shape, mask.shape)
image = torch.unsqueeze(image, dim=0)
mask = torch.unsqueeze(mask, dim=0)
print(image.shape, mask.shape)
mynet = YouNet(3, 3).cuda()
with torch.no_grad():
    pred = mynet(image.cuda())

torch.Size([3, 256, 256]) torch.Size([1, 256, 256])
torch.Size([1, 3, 256, 256]) torch.Size([1, 1, 256, 256])
in_channels=3, out_channels=64
in_channels=64, out_channels=128
in_channels=128, out_channels=256
in_channels=256, out_channels=512
in_channels=512, out_channels=1024
in_channels=1024, out_channels=512
in_channels=512, out_channels=256
in_channels=256, out_channels=128
in_channels=128, out_channels=64
X.shape input=torch.Size([1, 3, 256, 256])
X.shape output=torch.Size([1, 64, 256, 256])
X.shape input=torch.Size([1, 64, 128, 128])
X.shape output=torch.Size([1, 128, 128, 128])
X.shape input=torch.Size([1, 128, 64, 64])
X.shape output=torch.Size([1, 256, 64, 64])
X.shape input=torch.Size([1, 256, 32, 32])
X.shape output=torch.Size([1, 512, 32, 32])
X.shape input=torch.Size([1, 1024, 16, 16])


RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

In [8]:
print(pred.shape)
print(pred)

torch.Size([1, 3, 256, 256])
tensor([[[[-0.6650, -0.3076, -0.6735,  ..., -1.1186, -0.8151, -1.6791],
          [-1.0963, -0.7336, -0.3296,  ..., -0.6975,  0.2084, -0.7240],
          [-1.0764, -0.3929, -0.6100,  ..., -1.0809, -0.5923, -0.4833],
          ...,
          [-0.3296, -0.4373, -0.3289,  ..., -0.9929, -0.8269, -0.6770],
          [-0.6967, -0.9322, -0.7660,  ..., -0.8244, -1.0234, -1.1647],
          [-0.2097, -0.1567,  0.1230,  ..., -0.1336,  0.0516, -0.0035]],

         [[-0.7101, -0.7239, -0.4165,  ..., -0.2840, -0.6845, -0.2599],
          [-0.2840, -0.2204, -0.2974,  ...,  0.6209, -0.4773,  0.4224],
          [-0.3414, -0.8442, -0.5291,  ..., -0.6757, -0.2749, -0.2343],
          ...,
          [-0.6379, -0.4814, -0.8504,  ..., -0.4851, -0.4969, -0.0518],
          [-0.4699,  0.2723, -0.2632,  ..., -0.1172, -0.5123, -0.7423],
          [-0.5320, -0.5711, -0.4483,  ..., -0.2985, -0.7793, -0.4902]],

         [[-0.7541, -0.9422, -0.5056,  ..., -0.7995, -0.2589,  0.2776],
 