In [1]:
import numpy as np
import torch
import torchvision

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Data preparation
This prepares and transform the dataset into appropriate dataloaders.

In [2]:
from torch.utils.data import DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

In [3]:
train_transform = A.Compose([
        # resize the images to 256, 512, better for CNN
        A.Resize(384, 512),
        A.HorizontalFlip(),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0),
        ToTensorV2(),
    ])

val_transform = A.Compose([
    # resize the images to 256, 512, better for CNN
    A.Resize(384, 512),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0),
    ToTensorV2(),
])

In [4]:
from Datasets import Cityscapes

In [5]:
root_dir = 'data'

train_dataset = Cityscapes(root=root_dir, split='train', mode='fine', target_type='labelTrainIds',
                           transforms=train_transform)
val_dataset = Cityscapes(root=root_dir, split='val', mode='fine', target_type='labelTrainIds',
                         transforms=val_transform)

In [6]:
train_dataloader = DataLoader(train_dataset, batch_size=3,
                              shuffle=True, num_workers=2, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=3,
                            shuffle=True, num_workers=2, pin_memory=True)

## 2. Training

In [7]:
import torch.nn as nn
import torch.optim as optim
from torchmetrics import JaccardIndex
from tqdm import tqdm

In [8]:
# Model
weights = torchvision.models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT
model = torchvision.models.segmentation.deeplabv3_resnet50(weights=weights)

model = model.to(device)

# Training preparation
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Loss metric
miou = JaccardIndex(task="multiclass", num_classes=20).to(device)

In [9]:
## Preparing the model
# DeepLabV3 ResNet
model.classifier[4] = nn.Conv2d(256, 20, kernel_size=1)
if model.aux_classifier:
    model.aux_classifier[4] = nn.Conv2d(256, 20, kernel_size=1)

In [None]:
# Training
num_epochs = 20

model.to(device)
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    train_progress = tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{num_epochs} - Training', unit='batch') # Monitor training progress
    for inputs, labels in train_progress:
        inputs = inputs.to(device)
        labels = labels.to(device).long().squeeze(1)
        
        optimizer.zero_grad()
        outputs = model(inputs)['out']
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        train_progress.set_postfix(loss=running_loss/len(train_progress)) # Monitor training progress

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_dataloader)}")

    model.eval()
    correct = 0
    total = 0
    miou.reset()

    val_progress = tqdm(val_dataloader, desc=f'Epoch {epoch+1}/{num_epochs} - Validation', unit='batch') # Monitor validation progress
    with torch.no_grad():
        for inputs, labels in val_progress:
            inputs = inputs.to(device)
            labels = labels.to(device).long().squeeze(1)
            
            outputs = model(inputs)['out']
            predicted = torch.argmax(outputs, dim=1)
            
            total += labels.numel()           
            correct += (predicted == labels).sum().item()

            miou.update(predicted, labels)

            val_progress.set_postfix(accuracy=100 * correct / total) # Monitor validation progress

    val_accuracy = 100 * correct / total
    miou_accuracy = miou.compute().item()
    print(f"Validation Accuracy: {val_accuracy}%, MIoU: {miou_accuracy}")

In [11]:
# Save model
torch.save(model.state_dict(), 'checkpoints/deeplabv3resnet101_finetuned_512512_0001_1.pth')

## 3. Inference and evaluation

In [None]:
import matplotlib.pyplot as plt

In [16]:
CITYSCAPES_COLOR_MAP = np.array([
    [128, 64, 128], [244, 35, 232], [70, 70, 70],
    [102, 102, 156], [190, 153, 153], [153, 153, 153],
    [250, 170, 30], [220, 220, 0], [107, 142, 35],
    [152, 251, 152], [0, 130, 180], [220, 20, 60],
    [255, 0, 0], [0, 0, 142], [0, 0, 70],
    [0, 60, 100], [0, 80, 100], [0, 0, 230],
    [119, 11, 32], [255, 255, 255], [0, 0, 0],
    [255, 255, 0], [0, 255, 0], [0, 255, 255],
    [255, 0, 255], [192, 192, 192], [128, 0, 0],
    [128, 128, 0], [0, 128, 0], [128, 0, 128],
    [0, 128, 128], [0, 0, 128], [128, 128, 128],
    [192, 0, 0], [192, 192, 0], [0, 192, 0]
], dtype=np.uint8)

In [17]:
def visualize_segmentation(model, dataset, idx):
    model.eval()
    image, target = dataset[idx]
    
    with torch.no_grad():
        output = model(image.unsqueeze(0).to(device))['out']
        prediction = output.argmax(1).squeeze(0).cpu().numpy()

    def decode_segmap(segmentation):
        return CITYSCAPES_COLOR_MAP[segmentation]

    segmentation_map = decode_segmap(prediction)

    # Convert the tensor image to a numpy array
    image = image.permute(1, 2, 0).cpu().numpy()
    image = (image * 255).astype(np.uint8)

    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title('Original Image')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(segmentation_map)
    plt.title('Segmentation')
    plt.axis('off')

    plt.show()

In [18]:
visualize_segmentation(model, val_dataset, idx=400)