In [None]:
from torchvision.datasets import Cityscapes
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
import torch

image_size = (64, 128)

# Define input transformations
input_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
])

target_transform = transforms.Compose([
    transforms.Resize(image_size, interpolation=transforms.InterpolationMode.NEAREST),
    # transforms.PILToTensor()
])

# Dataset root
root_dir = "datasets/cityscapes"

# Load full train/val datasets
full_train_dataset = Cityscapes(
    root=root_dir,
    split='train',
    mode='fine',
    target_type='semantic',
    transform=input_transform,
    target_transform=target_transform,
)

full_val_dataset = Cityscapes(
    root=root_dir,
    split='val',
    mode='fine',
    target_type='semantic',
    transform=input_transform,
    target_transform=target_transform,
)

In [None]:
# from torch.utils.data import Dataset


# class CityscapesWrapper(Dataset):
#     def __init__(self, base_dataset, target_transform=None):
#         self.base = base_dataset
#         self.target_transform = target_transform
#         self.to_tensor = transforms.PILToTensor()

#     def __getitem__(self, idx):
#         img, target = self.base[idx]
#         if self.target_transform:
#             target = self.target_transform(target)
#         target = self.to_tensor(target)
#         return img, target

#     def __len__(self):
#         return len(self.base)

In [None]:
# Wrap subsets to apply target_transform to masks
from utils.CityscapesWrapper import CityscapesWrapper

train_samples = 400
val_samples = 100

train_subset = CityscapesWrapper(Subset(full_train_dataset, range(train_samples)), target_transform=target_transform)
val_subset = CityscapesWrapper(Subset(full_val_dataset, range(val_samples)), target_transform=target_transform)

# Wrap in DataLoaders
train_loader = DataLoader(train_subset, batch_size=4, shuffle=True, num_workers=0)
val_loader = DataLoader(val_subset, batch_size=4, shuffle=False, num_workers=0)

In [None]:
import torch.nn as nn


class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=34):  # Cityscapes has 34 classes
        super(UNet, self).__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
            )

        self.encoder1 = conv_block(in_channels, 64)
        self.encoder2 = conv_block(64, 128)
        self.encoder3 = conv_block(128, 256)
        self.pool = nn.MaxPool2d(2, 2)
        self.middle = conv_block(256, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = conv_block(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = conv_block(128, 64)

        self.final = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool(enc1))
        enc3 = self.encoder3(self.pool(enc2))
        mid = self.middle(self.pool(enc3))
        dec3 = self.decoder3(torch.cat([self.up3(mid), enc3], dim=1))
        dec2 = self.decoder2(torch.cat([self.up2(dec3), enc2], dim=1))
        dec1 = self.decoder1(torch.cat([self.up1(dec2), enc1], dim=1))
        return self.final(dec1)

In [None]:
# === 3. Training Setup ===
import torch.optim as optim

device = "mps"
model = UNet(
    in_channels=3,
    out_channels=34
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# === 4. Training Loop ===
from tqdm import tqdm

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for images, targets in tqdm(loader, desc="Training"):
        images = images.to(device)
        targets = targets.to(device).long()
        targets = targets.squeeze(1)

        outputs = model(images)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(loader)

def validate(model, loader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for images, targets in tqdm(loader, desc="Validation"):
            images = images.to(device)
            targets = targets.to(device).long()
            targets = targets.squeeze(1)

            outputs = model(images)
            loss = criterion(outputs, targets)
            total_loss += loss.item()
    return total_loss / len(loader)

In [None]:
# === 5. Run Training ===
num_epochs = 5
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss = validate(model, val_loader, criterion)
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

Training:  22%|██▏       | 22/100 [00:07<00:24,  3.17it/s]