# unet training

libaries

In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt # metrics visuliztion
from PIL import Image # image visulizition
import cv2
from tqdm import tqdm

import torch # main framework
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torchvision.utils import make_grid

from google.colab import drive #dataset from drive

# Mount Google Drive
drive.mount('/content/drive')

# Set random seeds for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Paths to Cityscapes dataset in Google Drive
DATASET_PATH = '/content/drive/MyDrive/cityscape'
IMG_PATH = os.path.join(DATASET_PATH, 'leftImg8bit') #image set
GT_PATH = os.path.join(DATASET_PATH, 'gtFine') # mask set

# Define lower resolution for images (to save memory)
IMG_HEIGHT = 256
IMG_WIDTH = 512

# Define the number of classes in Cityscapes (
NUM_CLASSES = 19


Mounted at /content/drive


colour map

In [None]:
# Define color map for visualization
cityscapes_colors = [
    (0, 0, 0),         # 0: unlabeled
    (0, 0, 0),         # 1: ego vehicle
    (0, 0, 0),         # 2: rectification border
    (0, 0, 0),         # 3: out of roi
    (0, 0, 0),         # 4: static
    (0, 0, 0),         # 5: dynamic
    (0, 0, 0),         # 6: ground
    (0, 0, 70),        # 7: road - dark blue
    (255, 0, 255),     # 8: sidewalk - magenta
    (0, 0, 0),         # 9: parking
    (0, 0, 0),         # 10: rail track
    (255, 165, 0),     # 11: building - orange
    (190, 153, 153),   # 12: wall - light brown
    (170, 120, 220),   # 13: fence - light purple
    (0, 0, 0),         # 14: guard rail
    (0, 0, 0),         # 15: bridge
    (0, 0, 0),         # 16: tunnel
    (153, 153, 153),   # 17: pole - gray
    (0, 0, 0),         # 18: polegroup
    (250, 170, 30),    # 19: traffic light - amber
    (220, 220, 0),     # 20: traffic sign - yellow
    (35, 142, 35),     # 21: vegetation - forest green
    (152, 251, 152),   # 22: terrain - light green
    (70, 130, 180),    # 23: sky - steel blue
    (255, 0, 0),       # 24: person - bright red
    (255, 127, 0),     # 25: rider - dark orange
    (0, 0, 255),       # 26: car - bright blue
    (0, 150, 255),     # 27: truck - light blue
    (0, 80, 150),      # 28: bus - blue-gray
    (0, 0, 110),       # 29: caravan
    (0, 0, 110),       # 30: trailer
    (0, 80, 100),      # 31: train - dark blue-gray
    (0, 80, 100),      # 32: motorcycle - teal
    (119, 11, 32),     # 33: bicycle - maroon
]
# Mapping from Cityscapes IDs to training IDs (ignore label is 255)
id_to_trainid = {
    0: 255, 1: 255, 2: 255, 3: 255, 4: 255, 5: 255, 6: 255,
    7: 0,    # road
    8: 1,    # sidewalk
    9: 255,  # parking
    10: 255, # rail track
    11: 2,   # building
    12: 3,   # wall
    13: 4,   # fence
    14: 255, # guard rail
    15: 255, # bridge
    16: 255, # tunnel
    17: 5,   # pole
    18: 255, # polegroup
    19: 6,   # traffic light
    20: 7,   # traffic sign
    21: 8,   # vegetation
    22: 9,   # terrain
    23: 10,  # sky
    24: 11,  # person
    25: 12,  # rider
    26: 13,  # car
    27: 14,  # truck
    28: 15,  # bus
    29: 255, # caravan
    30: 255, # trailer
    31: 16,  # train
    32: 17,  # motorcycle
    33: 18,  # bicycle
}

# Define class names for the 19 classes used for training
class_names = [
    'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
    'traffic light', 'traffic sign', 'vegetation', 'terrain',
    'sky', 'person', 'rider', 'car', 'truck', 'bus',
    'train', 'motorcycle', 'bicycle'
]


In [None]:
class CityscapesSubset(Dataset):
    def __init__(self, root, split='train', transforms=None, subset_fraction=0.2):

        self.root = root
        self.split = split
        self.transforms = transforms
        self.subset_fraction = subset_fraction

        # List cities based on the split
        self.cities = os.listdir(os.path.join(IMG_PATH, split))

        self.images = []
        self.masks = []
        for city in self.cities:
            img_dir = os.path.join(IMG_PATH, split, city)
            mask_dir = os.path.join(GT_PATH, split, city)
            for file_name in os.listdir(img_dir):
                if file_name.endswith('_leftImg8bit.png'):
                    image_id = file_name.replace('_leftImg8bit.png', '')
                    mask_name = f"{image_id}_gtFine_labelIds.png"
                    img_path = os.path.join(img_dir, file_name)
                    mask_path = os.path.join(mask_dir, mask_name)
                    if os.path.exists(mask_path):
                        self.images.append(img_path)
                        self.masks.append(mask_path)

        # Create a subset of the dataset if needed
        if subset_fraction < 1.0:
            num_samples = int(len(self.images) * subset_fraction)
            indices = []
            city_samples = {}
            for i, img_path in enumerate(self.images):
                city = img_path.split('/')[-2]
                city_samples.setdefault(city, []).append(i)
            for city, samples in city_samples.items():
                city_ratio = len(samples) / len(self.images)
                num_city_samples = max(1, int(num_samples * city_ratio))
                city_indices = random.sample(samples, min(num_city_samples, len(samples)))
                indices.extend(city_indices)
            if len(indices) > num_samples:
                indices = random.sample(indices, num_samples)
            elif len(indices) < num_samples:
                remaining = num_samples - len(indices)
                all_indices = set(range(len(self.images)))
                used_indices = set(indices)
                unused_indices = list(all_indices - used_indices)
                if unused_indices:
                    indices.extend(random.sample(unused_indices, min(remaining, len(unused_indices))))
            self.images = [self.images[i] for i in indices]
            self.masks = [self.masks[i] for i in indices]

        print(f"Created {split} set with {len(self.images)} images from {len(self.cities)} cities")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        mask_path = self.masks[idx]
        image = Image.open(image_path).convert('RGB')
        mask = Image.open(mask_path)

        # Resize images
        image = image.resize((IMG_WIDTH, IMG_HEIGHT), Image.BILINEAR)
        mask = mask.resize((IMG_WIDTH, IMG_HEIGHT), Image.NEAREST)

        mask_np = np.array(mask)
        mask_out = np.ones_like(mask_np) * 255
        for id, train_id in id_to_trainid.items():
            mask_out[mask_np == id] = train_id
        mask = Image.fromarray(mask_out.astype(np.uint8))

        # Apply transformations
        if self.transforms:
            if self.split == 'train':
                image, mask = self.transforms(image, mask)
            else:
                image = TF.to_tensor(image)
                mask = torch.from_numpy(np.array(mask)).long()
        else:
            image = TF.to_tensor(image)
            mask = torch.from_numpy(np.array(mask)).long()

        return image, mask

class CityscapesTransforms:
    def __init__(self, p_flip=0.5, p_rotate=0.3, p_color=0.5):
        self.p_flip = p_flip
        self.p_rotate = p_rotate
        self.p_color = p_color
        self.color_jitter = transforms.ColorJitter(
            brightness=0.2,
            contrast=0.2,
            saturation=0.2,
            hue=0.1
        )

    def __call__(self, image, mask):
        image = TF.to_tensor(image)
        if random.random() < self.p_flip:
            image = TF.hflip(image)
            mask = TF.hflip(mask)
        if random.random() < self.p_rotate:
            angle = random.uniform(-10, 10)
            image = TF.rotate(image, angle, interpolation=TF.InterpolationMode.BILINEAR)
            mask = TF.rotate(mask, angle, interpolation=TF.InterpolationMode.NEAREST)
        if random.random() < self.p_color:
            image = self.color_jitter(image)
        mask = torch.from_numpy(np.array(mask)).long()
        return image, mask
