In [1]:
import os
import gc

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision.transforms as T

from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm

In [2]:
from datasets import map_labels_to_trainIds, CityscapesDataset


class FeatureDataset(Dataset):
    def __init__(self, feat_dir, mask_dir, split, timestep: float, feat_transform=None):
        assert 0.0 <= timestep <= 1.0
        self.feat_dir = os.path.join(feat_dir, split, f"timestep_{int(timestep*100)}")
        self.mask_dir = os.path.join(mask_dir, "gtFine", split)
        self.split = split
        self.timestep = timestep

        ## Get features as tensors
        feature_paths = []
        for pt_file in os.listdir(self.feat_dir):
            path = os.path.join(self.feat_dir, pt_file)
            feature_paths.append(path)

        features = []
        for path in feature_paths:
            batch = torch.load(path)
            features.append(batch)

        self.features = torch.cat(features, dim=0)
        del features

        ## Get seg_mask paths
        mask_paths = []
        for city in os.listdir(self.mask_dir):
            city_path = os.path.join(self.mask_dir, city)
            for fpath in os.listdir(city_path):
                if "labelIds" in fpath:
                    lbl_path = os.path.join(city_path, fpath)
                    mask_paths.append(lbl_path)

        mask_paths = sorted(mask_paths)

        assert len(mask_paths) == len(self.features), f"Mismatched files! features={len(self.features)} | masks={len(mask_paths)}"

        self.mask_paths = mask_paths

        ## Transforms
        self.feat_transform = feat_transform
        # Modified mask_transform sequence
        self.mask_transform = T.Compose([
            T.PILToTensor(),                                        # Convert PIL Image to tensor (1, H, W)
            T.Lambda(lambda x: x.squeeze(0).long()),                # Remove channel dim and convert to long (H, W)
            T.Lambda(lambda x: map_labels_to_trainIds(x, CityscapesDataset.label2trainId)), # Map labels (0-18, 19 for ignore)
            # Set ignore_index (19) to -1 for CrossEntropyLoss
            T.Lambda(lambda x: torch.where(x == 19, torch.tensor(-1, dtype=torch.long), x)),
            T.Resize((224, 224), interpolation=T.InterpolationMode.NEAREST), # Resize with nearest neighbor
        ])


    def __len__(self):
        return len(self.features)


    def __getitem__(self, idx):
        feat = self.features[idx]
        mask = Image.open(self.mask_paths[idx])

        if self.feat_transform is not None:
            feat = self.feat_transform(feat)

        # Apply the modified mask_transform
        mask = self.mask_transform(mask)

        return feat, mask

In [3]:
# ---------------- Model ----------------
class SimpleSegmentationHead(nn.Module):
    def __init__(self, in_dim=1152, n_class=19, out_size=224):
        super().__init__()
        self.n_class = n_class
        self.out_size = out_size

        # Reshape flat vector to feature map
        self.feature_dim = (128, 3, 3)  # 128×3×3 = 1152

        # Convolutional decoder to upsample to 224×224
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),   # 6×6
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),    # 12×12
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),    # 24×24
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),    # 48×48
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(16, n_class, kernel_size=4, stride=2, padding=1), # 96×96
            nn.Upsample(size=(out_size, out_size), mode='bilinear', align_corners=False)  # final 224×224
        )

    def forward(self, x):
        # x: (B, 1152)
        B = x.size(0)
        x = x.view(B, *self.feature_dim)  # (B, 128, 3, 3)
        logits = self.decoder(x)          # (B, n_class, 224, 224)
        return logits

In [4]:
# ---------------- Metrics ----------------
def compute_iou(preds, labels, num_classes, device):
    ious = []
    # Ensure labels are on the same device as preds before comparison
    preds = preds.view(-1)
    labels = labels.view(-1).to(device) # Move labels to device

    for cls in range(num_classes):
        pred_inds = preds == cls
        target_inds = labels == cls
        intersection = (pred_inds[target_inds]).sum().item()
        union = pred_inds.sum().item() + target_inds.sum().item() - intersection
        if union == 0:
            ious.append(float('nan'))  # or append 0.0
        else:
            ious.append(intersection / union)
    return ious


def pixel_accuracy(preds, labels):
    # Ensure labels are on the same device as preds before comparison
    correct = (preds == labels.to(preds.device)).sum().item()
    total = labels.numel()
    return correct / total


# ---------------- Training Loop ----------------
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for images, masks in tqdm(dataloader, desc="Training"):
        images = images.to(device)
        # Ensure masks are long type for CrossEntropyLoss
        masks = masks.to(device).long()

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(dataloader)


# ---------------- Evaluation Loop ----------------
def evaluate(model, dataloader, criterion, device, num_classes=19):
    model.eval()
    total_loss = 0.0
    total_iou = []
    total_acc = []

    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc="Evaluating"):
            images = images.to(device)
            # Ensure masks are long type for CrossEntropyLoss
            masks = masks.to(device).long()

            outputs = model(images)
            loss = criterion(outputs, masks)
            total_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            # Pass device to compute_iou and pixel_accuracy
            ious = compute_iou(preds, masks, num_classes, device)
            acc = pixel_accuracy(preds, masks)

            total_iou.append(ious)
            total_acc.append(acc)

    mean_iou = np.nanmean(np.array(total_iou), axis=0)
    overall_acc = np.mean(total_acc)
    return total_loss / len(dataloader), mean_iou, overall_acc

In [6]:
gc.collect()
#torch.cuda.empty_cache()

# Define training parameters
DATA_DIR = "cityscapes"
FEAT_DIR = "cityscapes_features"

EPOCHS = 10
L_RATE = 0.001
B_SIZE = 64
device = "cuda" if torch.cuda.is_available() else "cpu"


## Load data
trainset = FeatureDataset(FEAT_DIR, DATA_DIR, split="train", timestep=0.95)
testset = FeatureDataset(FEAT_DIR, DATA_DIR, split="test", timestep=0.95)

model = SimpleSegmentationHead().to(device)
optim = torch.optim.Adam(model.parameters(), L_RATE)
# Set ignore_index to -1 to ignore those values in the loss calculation
loss_fn = nn.CrossEntropyLoss(ignore_index=-1)

train_loader = DataLoader(trainset, batch_size=B_SIZE, shuffle=True)
test_loader = DataLoader(testset, batch_size=1, shuffle=False)


for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    train_loss = train(model, train_loader, optim, loss_fn, device)
    val_loss, mean_iou, val_acc = evaluate(model, test_loader, loss_fn, device, 19)

    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f}")
    print(f"Val Pixel Accuracy: {val_acc:.4f}")
    print(f"Val Mean IoU: {np.nanmean(mean_iou):.4f}")

  batch = torch.load(path)


AssertionError: Mismatched files! features=640 | masks=2975