# Imports

In [None]:
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision.transforms import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import numpy as np

import matplotlib.pyplot as plt

from PIL import Image


# Setup

Download dataset from [here](https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz?download=1) and extract it into the `data` folder

In [2]:
config = {
    "lr": 1e-3,
    "batch_size": 16,
    "image_dir": "data/CUB_200_2011/images",
    "segmentation_dir": "data/CUB_200_2011/segmentations",
    "image_paths": "data/CUB_200_2011/images.txt",
    "epochs": 10,
    "checkpoint": "checkpoint/bird_segmentation_v1.pth",
    "optimiser": "checkpoint/bird_segmentation_v1_optim.pth",
    "continue_train": False,
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

In [3]:
import os

def makeSubDir(subdir):
    if not os.path.exists(subdir):
        os.makedirs(subdir)
        print("Directory", subdir, "created successfully!")
    else:
        print("Directory", subdir, "already exists!")

makeSubDir("checkpoint")
makeSubDir("test/pred")
makeSubDir("test/true")

Directory checkpoint already exists!
Directory test/pred already exists!
Directory test/true already exists!


# Load data

In [None]:
class BirdDataset(Dataset):
    def __init__(self, image_paths, image_dir, segmentation_dir, transform_image, transform_mask):
        super(BirdDataset, self).__init__()
        self.image_dir = image_dir
        self.segmentation_dir = segmentation_dir
        self.transform_image = transform_image
        self.transform_mask = transform_mask
        with open(image_paths, 'r') as f:
            self.images_paths = [line.split(" ")[-1] for line in f.readlines()]
    
    def __len__(self):
        return len(self.images_paths)
    
    def __getitem__(self, index):
        image_name = ".".join(self.images_paths[index].split('.')[:-1])

        image = Image.open(os.path.join(self.image_dir, f"{image_name}.jpg")).convert("RGB")
        seg = Image.open(os.path.join(self.segmentation_dir, f"{image_name}.png")).convert("L")

        image = self.transform_image(image)
        seg = self.transform_mask(seg)

        return image, seg

In [None]:
transforms_image = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0., 0., 0.), (1., 1., 1.))
])

transforms_mask = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.,), (1.,))
])

In [None]:
def load_data_set(image_paths, image_dir, segmentation_dir, transforms, batch_size=8, shuffle=True):
    dataset = BirdDataset(image_paths,
                          image_dir,
                          segmentation_dir,
                          transform_image=transforms[0],
                          transform_mask=transforms[1])
    print("Complete Dataset length: ", len(dataset))
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [len(dataset)-16, 16])

    return DataLoader( train_dataset, batch_size=batch_size, shuffle=shuffle), \
           DataLoader( val_dataset, batch_size=batch_size, shuffle=shuffle)

In [7]:
train_dataset, val_dataset = load_data_set(
    config['image_paths'],
    config['image_dir'],
    config['segmentation_dir'],
    transforms=[transforms_image, transforms_mask],
    batch_size=config['batch_size']
)

print("loaded", len(train_dataset), "batches")

Complete Dataset length:  11788
loaded 736 batches


# UNet

In [8]:
class conv_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

In [9]:
class encoder_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(encoder_block, self).__init__()
        self.conv = conv_block(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
    
    def forward(self, x):
        x = self.conv(x)
        return self.pool(x), x

In [10]:
class decoder_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(decoder_block, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
        self.conv = conv_block(in_channels, out_channels)
    
    def forward(self, x, enc_x):
        x = self.up(x)
        x = torch.cat([x, enc_x], dim=1)
        return self.conv(x)