In [None]:
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import Dataset
import torch.optim as optim
from torch.optim import Optimizer
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from PIL import Image
from glob import glob

## Define Dataset

In [None]:
RGBLabel2LabelName = {
    (128, 128, 128): "Sky",
    (0, 128, 64): "Building",
    (128, 0, 0): "Building",
    (64, 192, 0): "Building",
    (64, 0, 64): "Building",
    (192, 0, 128): "Building",
    (192, 192, 128): "Pole",
    (0, 0, 64): "Pole",
    (128, 64, 128): "Road",
    (128, 0, 192): "Road",
    (192, 0, 64): "Road",
    (0, 0, 192): "Sidewalk",
    (64, 192, 128): "Sidewalk",
    (128, 128, 192): "Sidewalk",
    (128, 128, 0): "Tree",
    (192, 192, 0): "Tree",
    (192, 128, 128): "SignSymbol",
    (128, 128, 64): "SignSymbol",
    (0, 64, 64): "SignSymbol",
    (64, 64, 128): "Fence",
    (64, 0, 128): "Car",
    (64, 128, 192): "Car",
    (192, 128, 192): "Car",
    (192, 64, 128): "Car",
    (128, 64, 64): "Car",
    (64, 64, 0): "Pedestrian",
    (192, 128, 64): "Pedestrian",
    (64, 0, 192): "Pedestrian",
    (64, 128, 64): "Pedestrian",
    (0, 128, 192): "Bicyclist",
    (192, 0, 192): "Bicyclist",
    (0, 0, 0): "Void",
}

LabelName2LabelIndex = {
    "Sky": 0,
    "Building": 1,
    "Pole": 2,
    "Road": 3,
    "Sidewalk": 4,
    "Tree": 5,
    "SignSymbol": 6,
    "Fence": 7,
    "Car": 8,
    "Pedestrian": 9,
    "Bicyclist": 10,
    "Void": 11,
}


class CustomDataset(Dataset):
    def __init__(self, image_root: str, mask_root):
        super().__init__()

        self.image_paths = sorted(glob(f"{image_root}/*.png"))
        self.gt_paths = sorted(glob(f"{mask_root}/*.png"))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        # load data
        image = np.array(Image.open(self.image_paths[index]).convert("RGB"))
        gt_mask = np.array(Image.open(self.gt_paths[index]))
        gt_mask = gt_mask / (0.039 * 25)

        # transform data
        transforms_img = T.Compose(
            [
                T.ToTensor(),
                T.RandomCrop(224),
                T.RandomHorizontalFlip(p=0.3),
                T.ConvertImageDtype(torch.float),
                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        transforms_gt = T.Compose(
            [
                T.ToTensor(),
                T.RandomCrop(224),
                T.RandomHorizontalFlip(p=0.3),
            ]
        )
        image, gt_mask = transforms_img(image), transforms_gt(gt_mask)
        return image, gt_mask.squeeze().long()


## Define Model

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, features: int = 64):
        super().__init__()

        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 3, padding=1), nn.ReLU()
        )

        self.down1 = Block(features, features * 2, down=True)
        self.down2 = Block(features * 2, features * 4, down=True)
        self.down3 = Block(features * 4, features * 8, down=True)

        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 16, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(features * 16, features * 8, 3, padding=1),
            nn.ReLU(),
        )

        self.up1 = Block(features * 8 * 2, features * 4, down=False)
        self.up2 = Block(features * 4 * 2, features * 2, down=False)
        self.up3 = Block(features * 2 * 2, features, down=False)

        self.final_up = nn.Sequential(nn.ConvTranspose2d(features * 2, out_channels, 1))

    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        bottleneck = self.bottleneck(d4)
        up1 = self.up1(torch.cat([bottleneck, d4], 1))
        up2 = self.up2(torch.cat([up1, d3], 1))
        up3 = self.up3(torch.cat([up2, d2], 1))
        out = self.final_up(torch.cat([up3, d1], 1))
        return out

## Define training functions

In [None]:
def train_val_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    loss_fn: CrossEntropyLoss,
    epoch: int,
    optimizer: Optimizer = None,
    train: bool = True,
):
    running_loss = 0
    pbar = tqdm(dataloader)
    for i, data in enumerate(pbar):
        pbar.set_description(f"[Epoch {str(epoch+1).zfill(3)}]: ")
        device = "cuda" if torch.cuda.is_available() else "cpu"
        images, gts = data
        images = images.to(device).double()
        gts.to(device)

        # zero gradients
        if train:
            optimizer.zero_grad()

        # pass data through model
        preds = model(images)
        loss = loss_fn(preds, gts)
        loss.backward()

        # adjust weights
        if train:
            optimizer.step()

        # update progress bar
        running_loss += loss.item()
        if i % 5 == 4:
            last_loss = running_loss / 5
            pbar.postfix(f"loss: {last_loss}")
            running_loss = 0


## Train Model

In [None]:
# fine hyper params
batch_size = 8
max_epochs = 20
lr = 0.0001
betas = (0.9, 0.999)
num_classes = 12

loss = CrossEntropyLoss()
model = UNet(3, num_classes).double()
optimizer = optim.Adam(model.parameters(), lr=lr, betas=betas)
train_dataset = CustomDataset(
    image_root="/home/jonas/Downloads/CamVid/train",
    mask_root="/home/jonas/Downloads/CamVid/trainannot",
)
val_dataset = CustomDataset(
    image_root="/home/jonas/Documents/data/CamVid/val",
    mask_root="/home/jonas/Documents/data/CamVid/valannot",
)

train_dataloader = DataLoader(
    train_dataset, batch_size, shuffle=True, num_workers=4
)
val_dataloader = DataLoader(val_dataset, batch_size, shuffle=False, num_workers=4)

for epoch in range(max_epochs):
    train_val_epoch(model, train_dataloader, loss, epoch, optimizer, train=True)
    train_val_epoch(model, val_dataloader, loss, epoch, train=False)

torch.save(model.state_dict(), "model.pth")