# Preparations 

In [None]:
!pip install albumentations
import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import torch.optim as optim
import matplotlib.pyplot as plt
import glob
import cv2
from PIL import Image, ImageChops
from tqdm import tqdm, trange
from torch.utils.data import Dataset, DataLoader
import albumentations as A
import torchvision.models as models
from torchvision.models import VGG16_Weights
from transformers import SegformerForSemanticSegmentation

In [None]:
# choose device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# helper functions
def show_images(list_of_images, titles=None):
    plt.figure(figsize=(16, 10))
    count = len(list_of_images)
    for idx in range(count):
        subplot = plt.subplot(1, count, idx+1)
        if titles is not None:
          subplot.set_title(titles[idx])
          
        img = list_of_images[idx]
        cmap = 'gray' if (len(img.shape) == 2 or img.shape[2] == 1) else None
        subplot.imshow(img, cmap=cmap)
    plt.show()  

def plot_losses(losses_dict):
    plt.figure(figsize=(16, 10))
    for label, losses in losses_dict.items():
        plt.plot(losses, label=label)
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Training and Validation Losses")
    plt.legend()
    plt.grid()
    plt.show()

class IoULoss(nn.Module):
    def __init__(self):
        super(IoULoss, self).__init__()
        self.eps = 1e-6

    def forward(self, y_pred, y_true):
        y_pred = y_pred.view(-1)        
        y_true = y_true.view(-1)

        intersection = (y_pred * y_true).sum()
        union = y_pred.sum() + y_true.sum() - intersection
        
        iou = intersection / (union + self.eps)
        return 1 - iou


In [None]:
# class to store hyperparameters for model training
class Config:
    def __init__(self, model_name):
        if model_name == "UNet":
            self.epochs = 100
            self.lr = 1e-3
            self.loss_fn = IoULoss()
            self.save_path = "checkpoints/unet_best_model.pth"
            self.need_training = False
        elif model_name == "FCN":
            self.epochs = 100
            self.lr = 1e-3
            self.loss_fn = IoULoss()
            self.save_path = "checkpoints/fcn_best_model.pth"
            self.need_training = False
        elif model_name == "SegFormer":
            self.epochs = 10
            self.lr = 5e-5  
            self.loss_fn = IoULoss()
            self.save_path = "checkpoints/seg_former_best_model.pth"
            self.need_training = False

# Dataset

## Data analysis

In [None]:
# dataset_path = '/kaggle/input/'
dataset_path = './content/'

In [None]:
print(f"Dataset Path: {dataset_path}")
for root, dirs, files in os.walk(dataset_path):
    print(f"\nDirectory: {root}")
    print(f"  Subdirectories: {dirs}")
    print(f"  Files: {files}")

In [None]:
def count_files_in_dir(directory):
    return len([f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))])

In [None]:
training_path = os.path.join(dataset_path, 'training')
calib_path = os.path.join(training_path, 'calib')
gt_path = os.path.join(training_path, 'gt_image_2')
img_path = os.path.join(training_path, 'image_2')

print("Training Data Analysis:")
print(f"Calibration Files in {calib_path}: {count_files_in_dir(calib_path)}")
print(f"Ground Truth Files in {gt_path}: {count_files_in_dir(gt_path)}")
print(f"Image Files in {img_path}: {count_files_in_dir(img_path)}")

In [None]:
testing_path = os.path.join(dataset_path, 'testing')
print("Testing Data Analysis:")
video_files = [f for f in os.listdir(testing_path) if f.endswith('.mp4')]
print(f"Number of Videos: {len(video_files)}")

for video in video_files:
    video_path = os.path.join(testing_path, video)
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    print(f"\nVideo: {video}")
    print(f"  Frame Count: {frame_count}")
    print(f"  FPS: {fps}")
    print(f"  Resolution: {width}x{height}")
    cap.release()

The KITTI Road Dataset contains images and their corresponding ground truth segmentation files. However, **the number of images does not match the number of ground truth files**. This discrepancy arises because the dataset includes **additional lane label files**, which are not relevant for our current task of road segmentation.

To resolve this, we will **filter out the lane label files** and ensure a **one-to-one correspondence** between the input images and the ground truth segmentation files.


In [None]:
def filter_and_match_files(image_dir, label_dir):
    img_files = glob.glob(os.path.join(image_dir, "*.png"))
    label_files = glob.glob(os.path.join(label_dir, "*_road_*.png"))

    img_files.sort()
    label_files.sort()

    img_filenames = [os.path.basename(f) for f in img_files]
    label_filenames = [os.path.basename(f).replace("_road_", "_") for f in label_files]

    matching_files = [
        (img, lbl) for img, lbl in zip(img_filenames, label_filenames) if img == lbl
    ]
    
    print(f"Total images: {len(img_files)}, Total road labels: {len(label_files)}")
    print(f"Matching files: {len(matching_files)}")

    return matching_files, img_files, label_files
    
def convert_to_binary_mask(mask, road_label=[255, 0, 255]):
    road_label = np.array(road_label)
    binary_mask = (np.all(mask == road_label, axis=2)).astype(np.uint8)
    return binary_mask
    
def visualize_overlay(image_path, label_path):
    img = Image.open(image_path)
    label = Image.open(label_path)

    img_np = np.array(img)
    label_np = np.array(label)

    mask = np.array(Image.open(label_path))
    binary_mask = convert_to_binary_mask(mask)
    binary_img = Image.fromarray(binary_mask * 255)
    overlay_binary = ImageChops.add(img, binary_img.convert("RGB"), scale=1.7)
    overlay_binary_np = np.array(overlay_binary)
    
    show_images(
        [img_np, label_np, binary_mask, overlay_binary_np],
        titles=["Image", "Label", "Binary Mask", "Overlay Binary"])

def visualize_training_examples(image_path, label_path):
    img = Image.open(image_path)
    label = Image.open(label_path)

    img_np = np.array(img)
    label_np = np.array(label)

    show_images(
        [img_np, label_np],
        titles=["Image", "Label"])

In [None]:
train_img_dir = os.path.join(dataset_path, "training/image_2")
train_label_dir = os.path.join(dataset_path, "training/gt_image_2")

matching_files, img_files, label_files = filter_and_match_files(
        train_img_dir, train_label_dir)

In [None]:
for i in range(5): 
    visualize_training_examples(img_files[i], label_files[i])

In [None]:
for i in range(5,10): 
    visualize_overlay(img_files[i], label_files[i])

The KITTI Road Dataset contains images categorized into three distinct types based on road conditions and markings:

1. **`uu` (Urban Unmarked)**:
   - Roads in urban areas **without lane markings**.
   - Example: Regular streets in cities where lane boundaries are not explicitly marked.

2. **`um` (Urban Marked)**:
   - Urban roads with **clearly marked lanes**.
   - Example: Streets with visible lane lines that define driving paths.

3. **`umm` (Urban Multiple Marked)**:
   - Urban roads with **multiple lanes and lane markings**.
   - Example: Complex intersections or multi-lane roads with clear markings for different lanes.

Each file in the dataset is named with a prefix (`uu`, `um`, `umm`) to indicate its type. Below are visual examples of each type.

In [None]:
def filter_images_by_prefix(image_dir, prefixes):
    images = {prefix: [] for prefix in prefixes}
    for prefix in prefixes:
        images[prefix] = glob.glob(os.path.join(image_dir, f"{prefix}_*.png"))
    return images

def visualize_images_by_type(image_dict, num_samples=3):
    for prefix, files in image_dict.items():
        selected_files = files[:num_samples]
        print(f"Type: {prefix}")
        images = [np.array(Image.open(file)) for file in selected_files]
        titles = [f"{prefix} example {i+1}" for i in range(len(images))]
        show_images(images, titles=titles)
        
prefixes = ["uu", "um", "umm"]
images_by_type = filter_images_by_prefix(train_img_dir, prefixes)
visualize_images_by_type(images_by_type, num_samples=3)

## Creating a dataset 

In [None]:
# Constants
IMG_SIZE = (256, 256)
BATCH_SIZE = 32
AUGMENTATION_COUNT = 10
TRAIN_RATIO = 0.8
VAL_RATIO = 0.1

# Paths
MASK_PATH = dataset_path + '/training/gt_image_2'
IMG_PATH = dataset_path + '/training/image_2'


# Utility functions
def load_images(path, mask=False):
    """Load and resize images from a given directory."""
    images = sorted([f for f in os.listdir(path) if not mask or "road" in f])
    result = []
    for img in tqdm(images, desc=f"Loading {'masks' if mask else 'images'}"):
        result.append(np.asarray(Image.open(os.path.join(path, img)).resize(IMG_SIZE)))
    return np.array(result)

def convert_masks_to_binary(masks, road_label=(255, 0, 255)):
    """Convert masks to binary based on the road label."""
    binary_masks = []
    for mask in tqdm(masks, desc="Converting masks to binary"):
        binary_mask = np.all(mask == road_label, axis=-1).astype(np.float32)
        binary_masks.append(np.expand_dims(binary_mask, axis=-1))
    return np.array(binary_masks)

def normalize_images(images):
    """Normalize images to the range [0, 1]."""
    return images / 255.0

def augment_data(images, masks, pipeline, augment_count):
    """Apply augmentations to images and masks."""
    augmented_images, augmented_masks = [], []
    for img, mask in tqdm(zip(images, masks), desc="Augmenting data", total=len(images)):
        img = img.astype(np.float32)
        mask = mask.astype(np.float32)
        for _ in range(augment_count):
            augmented = pipeline(image=img, mask=mask)
            augmented_images.append(augmented['image'])
            augmented_masks.append(augmented['mask'])
    return np.array(augmented_images), np.array(augmented_masks)

def split_data(data, labels, train_ratio, val_ratio):
    """Split data into training, validation, and test sets."""
    total_samples = len(data)
    train_size = int(total_samples * train_ratio)
    val_size = int(total_samples * val_ratio)
    indices = np.random.permutation(total_samples)

    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size + val_size]
    test_indices = indices[train_size + val_size:]

    return (
        data[train_indices], labels[train_indices],
        data[val_indices], labels[val_indices],
        data[test_indices], labels[test_indices]
    )

# Dataset class
class ImageDataset(Dataset):
    def __init__(self, images, masks):
        self.images = images
        self.masks = masks

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], self.masks[idx]

# Load and preprocess data
images = load_images(IMG_PATH)
masks = load_images(MASK_PATH, mask=True)
masks = convert_masks_to_binary(masks)
images = normalize_images(images)

# Augmentation pipeline
augmentation_pipeline = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=10, p=0.5, border_mode=cv2.BORDER_CONSTANT),
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.5),
    A.RGBShift(r_shift_limit=1, g_shift_limit=1, b_shift_limit=1, p=0.5),
])

# Apply augmentations
augmented_images, augmented_masks = augment_data(images, masks, augmentation_pipeline, AUGMENTATION_COUNT)

# Combine original and augmented data
images = np.concatenate((images, augmented_images), axis=0)
masks = np.concatenate((masks, augmented_masks), axis=0)

# Split data
X_train, y_train, X_val, y_val, X_test, y_test = split_data(images, masks, TRAIN_RATIO, VAL_RATIO)

# Create DataLoaders
train_dataset = ImageDataset(X_train.transpose((0, 3, 1, 2)), y_train.transpose((0, 3, 1, 2)))
val_dataset = ImageDataset(X_val.transpose((0, 3, 1, 2)), y_val.transpose((0, 3, 1, 2)))
test_dataset = ImageDataset(X_test.transpose((0, 3, 1, 2)), y_test.transpose((0, 3, 1, 2)))

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [None]:
# Print dataset sizes
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")

In [None]:
# Function to display images and their corresponding masks
def show_images_from_loader(data_loader, num_images=5):
    images_shown = 0
    for images, masks in data_loader:
        # Detach from torch tensors and convert to numpy
        images = images.numpy().transpose(0, 2, 3, 1)  # Change to HWC for visualization
        masks = masks.numpy()
        
        for i in range(len(images)):
            if images_shown >= num_images:
                return
            plt.figure(figsize=(8, 4))

            # Display image
            plt.subplot(1, 2, 1)
            plt.imshow(images[i])
            plt.title("Image")
            plt.axis("off")

            # Display corresponding mask
            plt.subplot(1, 2, 2)
            plt.imshow(masks[i][0], cmap='gray')  # Display first channel of mask
            plt.title("Mask")
            plt.axis("off")

            plt.show()
            images_shown += 1

# Display a few images and masks from train_loader
show_images_from_loader(train_loader, num_images=20)


# U-net model

In [None]:
def train_model(
    model, train_loader, val_loader, loss_fn, optimizer, num_epochs, device, save_path, patience=5
):
    best_val_loss = float("inf")
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    train_losses = []
    validation_losses = []

    epoch_tqdm = trange(num_epochs, desc="Epochs", position=0, leave=True)

    # Early stopping variables
    no_improvement_count = 0

    for epoch in epoch_tqdm:
        # Training
        model.train()
        epoch_train_losses = []
        train_loader_tqdm = tqdm(
            train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}", position=1, leave=False
        )

        for train_input, train_mask in train_loader_tqdm:
            train_input = train_input.to(device).float()
            train_mask = train_mask.to(device).float()

            optimizer.zero_grad()
            outputs = model(train_input)
            loss = loss_fn(outputs, train_mask)
            loss.backward()
            optimizer.step()

            epoch_train_losses.append(loss.item())
            train_loader_tqdm.set_postfix({"Batch Loss": f"{loss.item():.4f}"})

        avg_train_loss = np.mean(epoch_train_losses)
        train_losses.append(avg_train_loss)

        # Validation
        model.eval()
        epoch_val_losses = []
        val_loader_tqdm = tqdm(
            val_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}", position=2, leave=False
        )

        with torch.no_grad():
            for val_input, val_mask in val_loader_tqdm:
                val_input = val_input.to(device).float()
                val_mask = val_mask.to(device).float()

                outputs = model(val_input)
                loss = loss_fn(outputs, val_mask)

                epoch_val_losses.append(loss.item())
                val_loader_tqdm.set_postfix({"Batch Loss": f"{loss.item():.4f}"})

        avg_val_loss = np.mean(epoch_val_losses)
        validation_losses.append(avg_val_loss)

        epoch_tqdm.set_postfix(
            {"Train Loss": avg_train_loss, "Validation Loss": avg_val_loss}
        )

        # Check for improvement
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), save_path)
            no_improvement_count = 0
        else:
            no_improvement_count += 1

        # Early stopping condition
        if no_improvement_count >= patience:
            print(f"Early stopping triggered after {epoch + 1} epochs.")
            break

    return train_losses, validation_losses


In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv,self).__init__() 
        self.dconv=nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=False),
                                 nn.BatchNorm2d(out_channels),
                                 nn.ReLU(inplace=True),
                                 nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=False),
                                 nn.BatchNorm2d(out_channels),
                                 nn.ReLU(inplace=True),
        )
        
    def forward(self, x):
        return self.dconv(x)

    
    
class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNET,self).__init__()
        
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)) 
            self.ups.append(DoubleConv(feature*2, feature))

            
        self.left_over = DoubleConv(features[-1], features[-1]*2) 
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) 

    def forward(self,x):
        skip_connections = []
        
        for layer in self.downs:
            x=layer(x)
            skip_connections.append(x)
            x=self.pool(x)
        
        x=self.left_over(x)
        skip_connections=skip_connections[::-1]
        for idx in range(0,len(self.ups), 2): 
            
            sc=skip_connections[idx//2]
            x=self.ups[idx](x) 
            
            if x.shape != sc.shape:
                x = TF.resize(x, size=sc.shape[2:])

            concat_skip = torch.cat((sc, x), dim=1) 
            x = self.ups[idx+1](concat_skip) 


        return torch.sigmoid(self.final_conv(x))
        

In [None]:
# define model and config for training
config = Config('UNet')
unet_model = UNET().to(device)
optimizer = optim.Adam(unet_model.parameters(), lr=config.lr)

In [None]:
if config.need_training:
    # train the model and plot losses
    train_losses, val_losses = train_model(
        model=unet_model,
        train_loader=train_loader,
        val_loader=val_loader,
        loss_fn=config.loss_fn,
        optimizer=optimizer,
        num_epochs=config.epochs,
        device=device,      
        save_path=config.save_path
    )
    
    plot_losses({"Train Loss": train_losses, "Validation Loss": val_losses})
else:
    unet_model.load_state_dict(torch.load(config.save_path, map_location=torch.device('cpu')))
    unet_model.eval()

In [None]:
def evaluate_model(model, data_loader, loss_fn, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for images, masks in tqdm(data_loader, desc="Evaluating Model"):
            images = images.to(device).float()
            masks = masks.to(device).float()

            # Predict
            predictions = model(images)

            # Calculate loss
            loss = loss_fn(predictions, masks)
            total_loss += loss.item()

    avg_loss = total_loss / len(data_loader)
    print(f"Average IoU Loss on Test Set: {avg_loss:.4f}")
    return avg_loss


def visualize_inference(model, test_loader, device, num_samples=3):
    """
    Visualize inference results on a few images from the test dataset.
    """
    model.eval()
    samples_shown = 0

    for images, masks in test_loader:
        images = images.to(device).float()
        masks = masks.to(device).float()

        # Perform inference
        with torch.no_grad():
            predictions = model(images)
            predictions = (predictions > 0.5).float()

        for i in range(len(images)):
            if samples_shown >= num_samples:
                return

            # Prepare images for display
            image = images[i].cpu().numpy().transpose(1, 2, 0)  
            original_mask = masks[i].cpu().numpy().squeeze() 
            predicted_mask = predictions[i].cpu().numpy().squeeze()  

            # Display results
            show_images(
                [image, original_mask, predicted_mask],
                titles=["Image", "Original Mask", "Predicted Mask"],
            )
            samples_shown += 1

In [None]:
avg_test_loss_unet = evaluate_model(unet_model, test_loader, config.loss_fn, device)


In [None]:
visualize_inference(unet_model, test_loader, device, num_samples=3)

# FCN model

In [None]:
class FCN(nn.Module):
    def __init__(self, num_of_classes=1, height=512, width=512):
        super(FCN, self).__init__()
        # Encoder (VGG16 features)
        self.vgg16_model = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features
        for parameter in self.vgg16_model.parameters():
            parameter.requires_grad = False

        # Decoder
        self.num_of_classes = num_of_classes
        self.height = height
        self.width = width

        self.skip_layer_4 = nn.Conv2d(in_channels=512, out_channels=256,
                                      kernel_size=(1, 1), stride=1, padding=0)
        self.skip_layer_3 = nn.Conv2d(in_channels=256, out_channels=128,
                                      kernel_size=(1, 1), stride=1, padding=0)

        self.upsampling_1 = nn.ConvTranspose2d(in_channels=512, out_channels=256,
                                               kernel_size=2, stride=2, padding=0)
        self.upsampling_2 = nn.ConvTranspose2d(in_channels=256, out_channels=128,
                                               kernel_size=2, stride=2, padding=0)
        self.upsampling_3 = nn.ConvTranspose2d(in_channels=128, out_channels=64,
                                               kernel_size=2, stride=2, padding=0)
        self.upsampling_4 = nn.ConvTranspose2d(in_channels=64, out_channels=32,
                                               kernel_size=2, stride=2, padding=0)
        self.upsampling_5 = nn.ConvTranspose2d(in_channels=32, out_channels=16,
                                               kernel_size=2, stride=2, padding=0)

        self.max_pooling = nn.AdaptiveMaxPool3d(output_size=(1, self.height, self.width))

        self.batch_norm_1 = nn.BatchNorm2d(256)
        self.batch_norm_2 = nn.BatchNorm2d(128)
        self.batch_norm_3 = nn.BatchNorm2d(64)
        self.batch_norm_4 = nn.BatchNorm2d(32)
        self.batch_norm_5 = nn.BatchNorm2d(16)

        self.final_conv = nn.Conv2d(16, num_of_classes, kernel_size=1)

    def forward(self, image):
        # Encoder part
        features = []
        num_of_layers = len(self.vgg16_model._modules)
        x = image
        for layer_idx in range(num_of_layers):
            key_in_model = str(layer_idx)
            cur_layer = self.vgg16_model._modules[key_in_model]
            x = cur_layer(x)
            if layer_idx == 16 or layer_idx == 23 or layer_idx == 30:
                features.append(x)
        
        features_3, features_4, features_7 = features[0], features[1], features[2]

        # Decoder part
        vgg_layer_4_raw = self.skip_layer_4(features_4)
        vgg_layer_3_raw = self.skip_layer_3(features_3)

        x = F.relu(self.upsampling_1(features_7))
        x = x.add(vgg_layer_4_raw)
        x = self.batch_norm_1(x)

        x = F.relu(self.upsampling_2(x))
        x = x.add(vgg_layer_3_raw)
        x = self.batch_norm_2(x)

        x = F.relu(self.upsampling_3(x))
        x = self.batch_norm_3(x)

        x = F.relu(self.upsampling_4(x))
        x = self.batch_norm_4(x)

        x = F.relu(self.upsampling_5(x))
        x = self.batch_norm_5(x)
        
        return torch.sigmoid(self.final_conv(x))

In [None]:
# define model and config for training
config = Config('FCN')
fcn_model = FCN().to(device)
optimizer = optim.Adam(fcn_model.parameters(), lr=config.lr)

In [None]:
if config.need_training:
    # train the model and plot losses
    train_losses, val_losses = train_model(
        model=fcn_model,
        train_loader=train_loader,
        val_loader=val_loader,
        loss_fn=config.loss_fn,
        optimizer=optimizer,
        num_epochs=config.epochs,
        device=device,
        save_path=config.save_path
    )

    plot_losses({"Train Loss": train_losses, "Validation Loss": val_losses})
else:
    fcn_model.load_state_dict(torch.load(config.save_path, map_location=torch.device('cpu')))
    fcn_model.eval()

In [None]:
avg_test_loss_fcn = evaluate_model(fcn_model, test_loader, config.loss_fn, device)

In [None]:
visualize_inference(fcn_model, test_loader, device, num_samples=3)


# State-Of-The-Art model

In [None]:
class SegFormer(nn.Module):
    def __init__(self, num_classes=1, image_size=(256, 256)):
        super(SegFormer, self).__init__()
        # Load pre-trained SegFormer model with a lightweight decoder head
        self.segformer = SegformerForSemanticSegmentation.from_pretrained(
            "nvidia/segformer-b0-finetuned-ade-512-512",
            num_labels=num_classes,
            ignore_mismatched_sizes=True
        )
        
        # Adjust the final classification head for the desired output size
        self.segformer.config.hidden_size = 256
        self.image_size = image_size
        self.num_classes = num_classes

    def forward(self, x):
        outputs = self.segformer(pixel_values=x)
        logits = outputs.logits  # Segmentation logits
        logits = F.interpolate(logits, size=self.image_size, mode="bilinear", align_corners=False)
        return torch.sigmoid(logits)


In [None]:
# define model and config for training
config = Config('SegFormer')
seg_former_model = SegFormer().to(device)
optimizer = optim.Adam(seg_former_model.parameters(), lr=config.lr)

In [None]:
if config.need_training:
    train_losses, val_losses = train_model(
        model=seg_former_model,
        train_loader=train_loader,
        val_loader=val_loader,
        loss_fn=config.loss_fn,
        optimizer=optimizer,
        num_epochs=config.epochs,
        device=device,
        save_path=config.save_path
    )

    plot_losses({"Train Loss": train_losses, "Validation Loss": val_losses})
else:
    seg_former_model.load_state_dict(torch.load(config.save_path, map_location=torch.device('cpu')))
    seg_former_model.eval()

In [None]:
# Evaluate SegFormer Model
avg_test_loss_segformer = evaluate_model(seg_former_model, test_loader, config.loss_fn, device)

In [None]:
# Visualization for SegFormer
visualize_inference(seg_former_model, test_loader, device, num_samples=3)

# Resources
- Lab 2 from F24 - Computer Vision course
- https://www.kaggle.com/code/hossamemamo/kitti-road-segmentation-pytorch-unet-from-scratch
- https://www.youtube.com/watch?v=cPOtULagNnI
- https://www.kaggle.com/datasets/sakshaymahna/kittiroadsegmentation 
- https://www.kaggle.com/code/sakshaymahna/fully-convolutional-network/input
- https://www.kaggle.com/code/satyaprakashshukl/road-segmentation-using-unet-model
- https://www.kaggle.com/code/hossamemamo/kitti-road-segmentation-pytorch-unet-from-scratch
- https://arxiv.org/abs/2105.15203