In [6]:
import os
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision.transforms import Resize
import torch.nn.functional as F
from matplotlib.colors import ListedColormap
import torch.optim as optim
import wandb
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
import lightning as pl
import torchmetrics

wandb_logger = WandbLogger(log_model="all", project="VOCSegmentation", name='exp1')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ALL_CLASSES = [
    'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
    'chair', 'cow', 'dining table', 'dog', 'horse', 'motorbike', 'person', 'potted plant',
    'sheep', 'sofa', 'train', 'tv/monitor'
]

LABEL_COLORS_LIST = [
    [0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128],
    [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
    [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0],
    [0, 192, 0], [128, 192, 0], [0, 64, 128]
]

jaccard = torchmetrics.JaccardIndex(task="multiclass", num_classes=len(ALL_CLASSES)).to(device)

normalized_colors = [[r / 255, g / 255, b / 255] for r, g, b in LABEL_COLORS_LIST]

cmap = ListedColormap(normalized_colors)

class VOCDataSet(Dataset):
    def __init__(self, root_dir, dataset_type='train', transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_folder = os.path.join(root_dir, dataset_type + '_images')
        self.label_folder = os.path.join(root_dir, dataset_type + '_labels')
        self.image_list = os.listdir(self.image_folder)
        self.label_list = os.listdir(self.label_folder)

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_folder, self.image_list[idx])
        label_name = os.path.join(self.label_folder, self.label_list[idx])
        image = Image.open(img_name).convert('RGB')
        label = Image.open(label_name).convert('RGB')
        if self.transform:
            image = self.transform(image)
            label = self.transform(label)
        label = np.array(label)
        label_integer = self.rgb_to_integer(label)
        return image, label_integer

    def rgb_to_integer(self, label_rgb):
        label_integer = np.zeros(label_rgb.shape[:2], dtype=np.uint8)
        for i, color in enumerate(LABEL_COLORS_LIST):
            mask = np.all(label_rgb == color, axis=-1)
            label_integer[mask] = i
        return label_integer

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = VOCDataSet(root_dir='voc_2012_segmentation_data', dataset_type='train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

val_dataset = VOCDataSet(root_dir='voc_2012_segmentation_data', dataset_type='val', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

class VOCUNet(pl.LightningModule):
    def __init__(self, in_channels=3, n_classes=21, features=64, learning_rate=1e-3):
        super(VOCUNet, self).__init__()
        self.learning_rate = learning_rate
        self.criterion = nn.CrossEntropyLoss()

        self.encoder1 = nn.Conv2d(in_channels, features, kernel_size=3, padding=1)
        self.encoder2 = nn.Conv2d(features, features * 2, kernel_size=3, padding=1)

        self.bottleneck = nn.Conv2d(features * 2, features * 4, kernel_size=3, padding=1)

        # Upsampling layer should match the feature maps it will concatenate with
        self.upsample1 = nn.ConvTranspose2d(in_channels=features * 4, out_channels=features * 2, kernel_size=4, stride=2, padding=1)
        self.upsample2 = nn.ConvTranspose2d(in_channels=features * 3, out_channels=features, kernel_size=4, stride=2, padding=1)

        # Adjust decoder layers to match the concatenated channel counts
        self.decoder2 = nn.Conv2d(features * 4, features * 2, kernel_size=3, padding=1)  # features*2 (from upsample) + features*2 (from encoder2)
        self.decoder1 = nn.Conv2d(features * 2, features, kernel_size=3, padding=1)  # features (from upsample) + features (from encoder1)

        self.outconv = nn.Conv2d(features, n_classes, kernel_size=1)

        self.bn1 = nn.BatchNorm2d(features)
        self.bn2 = nn.BatchNorm2d(features * 2)
        self.bn3 = nn.BatchNorm2d(features * 4)

    def forward(self, x):
        x1 = self.encoder1(x)
        x1 = self.bn1(x1)
        x1 = nn.ReLU()(x1)

        x = nn.MaxPool2d(2, stride=2)(x1)
        x2 = self.encoder2(x)
        x2 = self.bn2(x2)
        x2 = nn.ReLU()(x2)

        x = nn.MaxPool2d(2, stride=2)(x2)
        x = self.bottleneck(x)
        x = self.bn3(x)
        x = nn.ReLU()(x)

        x = self.upsample1(x)
        x = torch.cat([x, x2], dim=1)
        x = self.decoder2(x)
        x = self.bn2(x)
        x = nn.ReLU()(x)

        x = self.upsample2(x)
        x = torch.cat([x, x1], dim=1)
        x = self.decoder1(x)
        x = self.bn1(x)
        x = nn.ReLU()(x)

        y = self.outconv(x)
        return y


    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)

        softmaxed_logits = torch.softmax(logits, dim=1)
        predicted_mask = torch.argmax(softmaxed_logits, dim=1)

        iou = jaccard(predicted_mask, y)
        self.log('train/iou', iou, on_epoch=True, on_step=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)

        softmaxed_logits = torch.softmax(logits, dim=1)
        predicted_mask = torch.argmax(softmaxed_logits, dim=1)

        iou = jaccard(predicted_mask, y)
        self.log('val/iou', iou, on_epoch=True, on_step=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)

    def configure_callbacks(self):
        checkpoint_callback = ModelCheckpoint(
            monitor='val/iou',
            dirpath='./checkpoints',
            filename='best_model',
            save_top_k=1,
            mode='max',
            verbose=True
        )
        return [checkpoint_callback]

model = VOCUNet()
callbacks = model.configure_callbacks()

trainer = pl.Trainer(logger=wandb_logger, max_epochs=50, devices=1, accelerator="auto", callbacks=callbacks)

trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint

   | Name       | Type             | Params
-------------------------------------------------
0  | criterion  | CrossEntropyLoss | 0     
1  | encoder1   | Conv2d           | 1.8 K 
2  | encoder2   | Conv2d           | 73.9 K
3  | bottleneck | Conv2d           | 295 K 
4  | upsample1  | ConvTranspose2d  | 524 K 
5  | upsample2  | ConvTranspose2d  | 196 K 
6  | decoder2   | Conv2d           | 295 K 
7  | decoder1   | Conv2d           | 73.8 K
8  | outconv    | Conv2d           | 1.4 K 
9  | bn1        | BatchNorm2d      | 128   
10 | bn2        | BatchNorm2d      | 256   
11 | bn3        | BatchNorm2d      | 512   
-------------------------------------------------
1.5 M     Trainable params
0      

Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

  mask = np.all(label_rgb == color, axis=-1)


RuntimeError: Given transposed=1, weight of size [192, 64, 4, 4], expected input[32, 128, 128, 128] to have 192 channels, but got 128 channels instead