In [None]:
Training

Imports, Random Seeding

In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from google.colab import drive
from tqdm import tqdm
import gc  # For garbage collection
import time

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

# Mount Google Drive
drive.mount('/content/drive')


Define Dataset Paths

In [None]:
# Define paths
cityscapes_path = "/content/drive/MyDrive/cityscape"
subset_path = "/content/drive/MyDrive/cityscapes_subset"

# Create directories for subset if they don't exist
if not os.path.exists(subset_path):
    os.makedirs(subset_path)
    os.makedirs(os.path.join(subset_path, "leftImg8bit"))
    os.makedirs(os.path.join(subset_path, "gtFine"))
    for split in ["train", "val"]:
        os.makedirs(os.path.join(subset_path, "leftImg8bit", split))
        os.makedirs(os.path.join(subset_path, "gtFine", split))


Subset Creation Function

In [None]:
def create_subset(source_path, target_path, split, cities_per_split=2, images_per_city=15, target_size=(384, 192)):
    """
    Create a subset of the Cityscapes dataset with reduced resolution.

    Args:
        source_path: Path to the original Cityscapes dataset
        target_path: Path to store the subset
        split: 'train' or 'val'
        cities_per_split: Number of cities to include (REDUCED)
        images_per_city: Number of images per city (REDUCED)
        target_size: Target image size (width, height) (REDUCED)
    """
    # Get list of cities for this split
    img_cities = os.listdir(os.path.join(source_path, "leftImg8bit", split))
    label_cities = os.listdir(os.path.join(source_path, "gtFine", split))

    # Ensure the cities exist in both image and label directories
    cities = [city for city in img_cities if city in label_cities]

    # Select a subset of cities
    if len(cities) > cities_per_split:
        cities = sorted(cities)[:cities_per_split]

    print(f"Selected cities for {split}: {cities}")

    # Process each city
    for city in cities:
        # Create target directories
        os.makedirs(os.path.join(target_path, "leftImg8bit", split, city), exist_ok=True)
        os.makedirs(os.path.join(target_path, "gtFine", split, city), exist_ok=True)

        # Get image and label files
        img_files = sorted(os.listdir(os.path.join(source_path, "leftImg8bit", split, city)))

        # Select a subset of images
        if len(img_files) > images_per_city:
            img_files = sorted(img_files)[:images_per_city]

        print(f"Processing {len(img_files)} images from {city}")

        # Process each image
        for img_file in img_files:
            # Extract image ID
            img_id = img_file.split('_leftImg8bit')[0]

            # Get corresponding label file
            labelIds_file = f"{img_id}_gtFine_labelIds.png"

            # Check if label file exists
            if not os.path.exists(os.path.join(source_path, "gtFine", split, city, labelIds_file)):
                print(f"Warning: {labelIds_file} not found, skipping {img_file}")
                continue

            # Load and resize image
            img_path = os.path.join(source_path, "leftImg8bit", split, city, img_file)
            img = Image.open(img_path)
            img_resized = img.resize(target_size, Image.BILINEAR)

            # Load and resize labelIds
            label_path = os.path.join(source_path, "gtFine", split, city, labelIds_file)
            label = Image.open(label_path)
            label_resized = label.resize(target_size, Image.NEAREST)

            # Save resized image and label
            img_resized.save(os.path.join(target_path, "leftImg8bit", split, city, img_file))
            label_resized.save(os.path.join(target_path, "gtFine", split, city, labelIds_file))

print("Creating dataset subset...")
create_subset(cityscapes_path, subset_path, "train", cities_per_split=2, images_per_city=25)
create_subset(cityscapes_path, subset_path, "val", cities_per_split=1, images_per_city=10)
print("Dataset subset created!")
