# unet training

libaries

In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt # metrics visuliztion
from PIL import Image # image visulizition
import cv2
from tqdm import tqdm

import torch # main framework
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torchvision.utils import make_grid

from google.colab import drive #dataset from drive

# Mount Google Drive
drive.mount('/content/drive')

# Set random seeds for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Paths to Cityscapes dataset in Google Drive
DATASET_PATH = '/content/drive/MyDrive/cityscape'
IMG_PATH = os.path.join(DATASET_PATH, 'leftImg8bit') #image set
GT_PATH = os.path.join(DATASET_PATH, 'gtFine') # mask set

# Define lower resolution for images (to save memory)
IMG_HEIGHT = 256
IMG_WIDTH = 512

# Define the number of classes in Cityscapes (
NUM_CLASSES = 19


Mounted at /content/drive


colour map

In [None]:
# Define color map for visualization
cityscapes_colors = [
    (0, 0, 0),         # 0: unlabeled
    (0, 0, 0),         # 1: ego vehicle
    (0, 0, 0),         # 2: rectification border
    (0, 0, 0),         # 3: out of roi
    (0, 0, 0),         # 4: static
    (0, 0, 0),         # 5: dynamic
    (0, 0, 0),         # 6: ground
    (0, 0, 70),        # 7: road - dark blue
    (255, 0, 255),     # 8: sidewalk - magenta
    (0, 0, 0),         # 9: parking
    (0, 0, 0),         # 10: rail track
    (255, 165, 0),     # 11: building - orange
    (190, 153, 153),   # 12: wall - light brown
    (170, 120, 220),   # 13: fence - light purple
    (0, 0, 0),         # 14: guard rail
    (0, 0, 0),         # 15: bridge
    (0, 0, 0),         # 16: tunnel
    (153, 153, 153),   # 17: pole - gray
    (0, 0, 0),         # 18: polegroup
    (250, 170, 30),    # 19: traffic light - amber
    (220, 220, 0),     # 20: traffic sign - yellow
    (35, 142, 35),     # 21: vegetation - forest green
    (152, 251, 152),   # 22: terrain - light green
    (70, 130, 180),    # 23: sky - steel blue
    (255, 0, 0),       # 24: person - bright red
    (255, 127, 0),     # 25: rider - dark orange
    (0, 0, 255),       # 26: car - bright blue
    (0, 150, 255),     # 27: truck - light blue
    (0, 80, 150),      # 28: bus - blue-gray
    (0, 0, 110),       # 29: caravan
    (0, 0, 110),       # 30: trailer
    (0, 80, 100),      # 31: train - dark blue-gray
    (0, 80, 100),      # 32: motorcycle - teal
    (119, 11, 32),     # 33: bicycle - maroon
]
# Mapping from Cityscapes IDs to training IDs (ignore label is 255)
id_to_trainid = {
    0: 255, 1: 255, 2: 255, 3: 255, 4: 255, 5: 255, 6: 255,
    7: 0,    # road
    8: 1,    # sidewalk
    9: 255,  # parking
    10: 255, # rail track
    11: 2,   # building
    12: 3,   # wall
    13: 4,   # fence
    14: 255, # guard rail
    15: 255, # bridge
    16: 255, # tunnel
    17: 5,   # pole
    18: 255, # polegroup
    19: 6,   # traffic light
    20: 7,   # traffic sign
    21: 8,   # vegetation
    22: 9,   # terrain
    23: 10,  # sky
    24: 11,  # person
    25: 12,  # rider
    26: 13,  # car
    27: 14,  # truck
    28: 15,  # bus
    29: 255, # caravan
    30: 255, # trailer
    31: 16,  # train
    32: 17,  # motorcycle
    33: 18,  # bicycle
}

# Define class names for the 19 classes used for training
class_names = [
    'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
    'traffic light', 'traffic sign', 'vegetation', 'terrain',
    'sky', 'person', 'rider', 'car', 'truck', 'bus',
    'train', 'motorcycle', 'bicycle'
]


dataset path

In [None]:
class CityscapesSubset(Dataset):
    def __init__(self, root, split='train', transforms=None, subset_fraction=0.2):

        self.root = root
        self.split = split
        self.transforms = transforms
        self.subset_fraction = subset_fraction

        # List cities based on the split
        self.cities = os.listdir(os.path.join(IMG_PATH, split))

        self.images = []
        self.masks = []
        for city in self.cities:
            img_dir = os.path.join(IMG_PATH, split, city)
            mask_dir = os.path.join(GT_PATH, split, city)
            for file_name in os.listdir(img_dir):
                if file_name.endswith('_leftImg8bit.png'):
                    image_id = file_name.replace('_leftImg8bit.png', '')
                    mask_name = f"{image_id}_gtFine_labelIds.png"
                    img_path = os.path.join(img_dir, file_name)
                    mask_path = os.path.join(mask_dir, mask_name)
                    if os.path.exists(mask_path):
                        self.images.append(img_path)
                        self.masks.append(mask_path)

        # Create a subset of the dataset if needed
        if subset_fraction < 1.0:
            num_samples = int(len(self.images) * subset_fraction)
            indices = []
            city_samples = {}
            for i, img_path in enumerate(self.images):
                city = img_path.split('/')[-2]
                city_samples.setdefault(city, []).append(i)
            for city, samples in city_samples.items():
                city_ratio = len(samples) / len(self.images)
                num_city_samples = max(1, int(num_samples * city_ratio))
                city_indices = random.sample(samples, min(num_city_samples, len(samples)))
                indices.extend(city_indices)
            if len(indices) > num_samples:
                indices = random.sample(indices, num_samples)
            elif len(indices) < num_samples:
                remaining = num_samples - len(indices)
                all_indices = set(range(len(self.images)))
                used_indices = set(indices)
                unused_indices = list(all_indices - used_indices)
                if unused_indices:
                    indices.extend(random.sample(unused_indices, min(remaining, len(unused_indices))))
            self.images = [self.images[i] for i in indices]
            self.masks = [self.masks[i] for i in indices]

        print(f"Created {split} set with {len(self.images)} images from {len(self.cities)} cities")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        mask_path = self.masks[idx]
        image = Image.open(image_path).convert('RGB')
        mask = Image.open(mask_path)

        # Resize images
        image = image.resize((IMG_WIDTH, IMG_HEIGHT), Image.BILINEAR)
        mask = mask.resize((IMG_WIDTH, IMG_HEIGHT), Image.NEAREST)

        mask_np = np.array(mask)
        mask_out = np.ones_like(mask_np) * 255
        for id, train_id in id_to_trainid.items():
            mask_out[mask_np == id] = train_id
        mask = Image.fromarray(mask_out.astype(np.uint8))

        # Apply transformations
        if self.transforms:
            if self.split == 'train':
                image, mask = self.transforms(image, mask)
            else:
                image = TF.to_tensor(image)
                mask = torch.from_numpy(np.array(mask)).long()
        else:
            image = TF.to_tensor(image)
            mask = torch.from_numpy(np.array(mask)).long()

        return image, mask

class CityscapesTransforms:
    def __init__(self, p_flip=0.5, p_rotate=0.3, p_color=0.5):
        self.p_flip = p_flip
        self.p_rotate = p_rotate
        self.p_color = p_color
        self.color_jitter = transforms.ColorJitter(
            brightness=0.2,
            contrast=0.2,
            saturation=0.2,
            hue=0.1
        )

    def __call__(self, image, mask):
        image = TF.to_tensor(image)
        if random.random() < self.p_flip:
            image = TF.hflip(image)
            mask = TF.hflip(mask)
        if random.random() < self.p_rotate:
            angle = random.uniform(-10, 10)
            image = TF.rotate(image, angle, interpolation=TF.InterpolationMode.BILINEAR)
            mask = TF.rotate(mask, angle, interpolation=TF.InterpolationMode.NEAREST)
        if random.random() < self.p_color:
            image = self.color_jitter(image)
        mask = torch.from_numpy(np.array(mask)).long()
        return image, mask


u-net architecture

In [None]:
# U-Net model components
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)

        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


metrics

In [None]:
# Calculate Intersection over Union (IoU)
def calculate_iou(pred, target, n_classes):
    ious = []
    pred = pred.view(-1)
    target = target.view(-1)
    mask = (target != 255)
    pred = pred[mask]
    target = target[mask]
    for cls in range(n_classes):
        pred_inds = pred == cls
        target_inds = target == cls
        intersection = (pred_inds & target_inds).sum().item()
        union = (pred_inds | target_inds).sum().item()
        if union == 0:
            ious.append(float('nan'))
        else:
            ious.append(intersection / union)
    return ious

# Calculate accuracy metrics
def calculate_accuracy_metrics(pred, target, n_classes):
    mask = (target != 255)
    pred = pred[mask]
    target = target[mask]
    correct = (pred == target).sum().item()
    total = mask.sum().item()
    pixel_acc = correct / total if total > 0 else 0
    class_accuracies = []
    for cls in range(n_classes):
        target_cls = target == cls
        if target_cls.sum().item() > 0:
            pred_cls = pred == cls
            class_correct = (pred_cls & target_cls).sum().item()
            class_total = target_cls.sum().item()
            class_accuracies.append(class_correct / class_total)
        else:
            class_accuracies.append(float('nan'))
    valid_accs = [acc for acc in class_accuracies if not np.isnan(acc)]
    mean_acc = np.mean(valid_accs) if valid_accs else 0
    return pixel_acc, mean_acc, class_accuracies

# Visualize predictions
def visualize_prediction(image, pred, target, class_colors):
    image = image.cpu().permute(1, 2, 0).numpy()
    pred = pred.cpu().numpy()
    target = target.cpu().numpy()
    image = (image - image.min()) / (image.max() - image.min())
    pred_color = np.zeros((pred.shape[0], pred.shape[1], 3), dtype=np.uint8)
    target_color = np.zeros((target.shape[0], target.shape[1], 3), dtype=np.uint8)
    for i, color in enumerate(class_colors[:NUM_CLASSES]):
        pred_color[pred == i] = color
        target_color[target == i] = color
    target_color[target == 255] = (0, 0, 0)
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    axs[0].imshow(image)
    axs[0].set_title('Input Image')
    axs[0].axis('off')
    axs[1].imshow(pred_color)
    axs[1].set_title('Prediction')
    axs[1].axis('off')
    axs[2].imshow(target_color)
    axs[2].set_title('Ground Truth')
    axs[2].axis('off')
    plt.tight_layout()
    return fig

# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs, scheduler=None):
    best_miou = 0.0
    best_accuracy = 0.0
    train_losses = []
    val_losses = []
    miou_scores = []
    pixel_acc_scores = []
    class_acc_scores = []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)
        model.train()
        running_loss = 0.0
        for images, masks in tqdm(train_loader, desc='Training'):
            images = images.to(device)
            masks = masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)
        epoch_loss = running_loss / len(train_loader.dataset)
        train_losses.append(epoch_loss)

        model.eval()
        val_loss = 0.0
        iou_scores = []
        pixel_accuracies = []
        mean_accuracies = []
        class_accuracies = [[] for _ in range(NUM_CLASSES)]
        with torch.no_grad():
            for images, masks in tqdm(val_loader, desc='Validation'):
                images = images.to(device)
                masks = masks.to(device)
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item() * images.size(0)
                preds = torch.argmax(outputs, dim=1)
                batch_ious = calculate_iou(preds, masks, NUM_CLASSES)
                batch_pixel_acc, batch_mean_acc, batch_class_accs = calculate_accuracy_metrics(preds, masks, NUM_CLASSES)
                valid_ious = [iou for iou in batch_ious if not np.isnan(iou)]
                if valid_ious:
                    iou_scores.append(np.mean(valid_ious))
                pixel_accuracies.append(batch_pixel_acc)
                mean_accuracies.append(batch_mean_acc)
                for cls in range(NUM_CLASSES):
                    if not np.isnan(batch_class_accs[cls]):
                        class_accuracies[cls].append(batch_class_accs[cls])
        epoch_val_loss = val_loss / len(val_loader.dataset)
        val_losses.append(epoch_val_loss)
        mean_iou = np.mean(iou_scores)
        miou_scores.append(mean_iou)
        mean_pixel_acc = np.mean(pixel_accuracies)
        mean_class_acc = np.mean(mean_accuracies)
        pixel_acc_scores.append(mean_pixel_acc)
        class_acc_scores.append(mean_class_acc)
        avg_class_accuracies = []
        for cls in range(NUM_CLASSES):
            if class_accuracies[cls]:
                avg_class_accuracies.append(np.mean(class_accuracies[cls]))
        print(f'Training Loss: {epoch_loss:.4f}')
        print(f'Validation Loss: {epoch_val_loss:.4f}')
        print(f'Mean IoU: {mean_iou:.4f}')
        print(f'Overall Pixel Accuracy: {mean_pixel_acc:.4f}')
        print(f'Mean Class Accuracy: {mean_class_acc:.4f}')
        print("\nPer-class accuracies for important classes:")
        important_classes = [0, 11, 13, 24, 26]  # Example: road, building, car, person, etc.
        for cls in important_classes:
            if cls < len(class_names) and avg_class_accuracies[cls]:
                print(f"{class_names[cls]}: {avg_class_accuracies[cls]:.4f}")
        if scheduler:
            if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(mean_iou)
            else:
                scheduler.step()
        if mean_iou > best_miou:
            best_miou = mean_iou
            torch.save(model.state_dict(), 'best_unet_model_iou.pth')
            print('Best IoU model saved!')
        if mean_pixel_acc > best_accuracy:
            best_accuracy = mean_pixel_acc
            torch.save(model.state_dict(), 'best_unet_model_accuracy.pth')
            print('Best Accuracy model saved!')
        print()
    print("Loading best accuracy model for evaluation...")
    model.load_state_dict(torch.load('best_unet_model_accuracy.pth'))
    return model, train_losses, val_losses, miou_scores, pixel_acc_scores, class_acc_scores


setups and data tranformation

In [None]:
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    # Create datasets with transformations
    train_transforms = CityscapesTransforms(p_flip=0.5, p_rotate=0.3, p_color=0.8)
    train_dataset = CityscapesSubset(
        root=DATASET_PATH,
        split='train',
        transforms=train_transforms,
        subset_fraction=0.2  # Use 20% of training data
    )
    val_dataset = CityscapesSubset(
        root=DATASET_PATH,
        split='val',
        transforms=None,  # No augmentation for validation
        subset_fraction=0.3  # Use 30% of validation data
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=8,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=8,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    # Create model
    model = UNet(n_channels=3, n_classes=NUM_CLASSES, bilinear=True)
    model = model.to(device)

    # Define loss and optimizer (with class weights)
    class_weights = torch.ones(NUM_CLASSES).to(device)
    class_weights[6] = 2.0   # traffic light
    class_weights[11] = 2.0  # person
    class_weights[12] = 2.0  # rider
    class_weights[17] = 2.0  # motorcycle
    class_weights[18] = 2.0  # bicycle
    criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=255)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=0.5,
        patience=3,
        verbose=True
    )

    # Train the model
    model, train_losses, val_losses, miou_scores, pixel_acc_scores, class_acc_scores = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        num_epochs=50,  # Adjust number of epochs as needed
        scheduler=scheduler
    )

    # Plot training progress
    plt.figure(figsize=(15, 10))
    plt.subplot(2, 2, 1)
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss Curves')

    plt.subplot(2, 2, 2)
    plt.plot(miou_scores, label='Mean IoU')
    plt.xlabel('Epoch')
    plt.ylabel('Mean IoU')
    plt.legend()
    plt.title('IoU Metric')

    plt.subplot(2, 2, 3)
    plt.plot(pixel_acc_scores, label='Pixel Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Overall Pixel Accuracy')

    plt.subplot(2, 2, 4)
    plt.plot(class_acc_scores, label='Mean Class Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Mean Class Accuracy')

    plt.tight_layout()
    plt.savefig('training_progress.png')
    plt.show()

    # Save training metrics
    import pandas as pd
    metrics_df = pd.DataFrame({
        'Epoch': range(1, len(train_losses) + 1),
        'Training_Loss': train_losses,
        'Validation_Loss': val_losses,
        'Mean_IoU': miou_scores,
        'Pixel_Accuracy': pixel_acc_scores,
        'Class_Accuracy': class_acc_scores
    })
    metrics_df.to_csv('training_metrics.csv', index=False)

    # Final evaluation on validation set
    model.eval()
    all_pixel_accs = []
    all_class_accs = []
    all_ious = []
    print("\nEvaluating final model on validation set...")
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc='Final Evaluation'):
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            batch_ious = calculate_iou(preds, masks, NUM_CLASSES)
            batch_pixel_acc, batch_mean_acc, _ = calculate_accuracy_metrics(preds, masks, NUM_CLASSES)
            valid_ious = [iou for iou in batch_ious if not np.isnan(iou)]
            if valid_ious:
                all_ious.append(np.mean(valid_ious))
            all_pixel_accs.append(batch_pixel_acc)
            all_class_accs.append(batch_mean_acc)
    final_miou = np.mean(all_ious)
    final_pixel_acc = np.mean(all_pixel_accs)
    final_class_acc = np.mean(all_class_accs)
    print(f"\nFinal Evaluation Results:")
    print(f"Mean IoU: {final_miou:.4f}")
    print(f"Overall Pixel Accuracy: {final_pixel_acc:.4f}")
    print(f"Mean Class Accuracy: {final_class_acc:.4f}")

    # Visualize some validation predictions
    test_samples = min(5, len(val_dataset))
    for i in range(test_samples):
        image, mask = val_dataset[i]
        image = image.unsqueeze(0).to(device)
        with torch.no_grad():
            output = model(image)
            pred = torch.argmax(output, dim=1).squeeze(0).cpu()
        fig = visualize_prediction(image.squeeze(0), pred, mask, cityscapes_colors)
        plt.savefig(f'prediction_{i}.png')
        plt.close(fig)

    # Create prediction video for a sequence (if available)
    try:
        sequential_city = val_dataset.cities[0]
        city_img_dir = os.path.join(IMG_PATH, 'val', sequential_city)
        city_images = sorted([f for f in os.listdir(city_img_dir) if f.endswith('_leftImg8bit.png')])
        if len(city_images) > 10:
            print(f"\nCreating prediction video for {sequential_city}...")
            frames = []
            for idx, img_file in enumerate(city_images[:20]):
                img_path = os.path.join(city_img_dir, img_file)
                image = Image.open(img_path).convert('RGB')
                image = image.resize((IMG_WIDTH, IMG_HEIGHT), Image.BILINEAR)
                image_tensor = TF.to_tensor(image).unsqueeze(0).to(device)
                with torch.no_grad():
                    output = model(image_tensor)
                    pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
                pred_color = np.zeros((pred.shape[0], pred.shape[1], 3), dtype=np.uint8)
                for i, color in enumerate(cityscapes_colors[:NUM_CLASSES]):
                    pred_color[pred == i] = color
                image_np = np.array(image)
                combined = np.hstack([image_np, pred_color])
                frames.append(combined)
            if frames:
                out_path = 'prediction_video.mp4'
                height, width, _ = frames[0].shape
                fourcc = cv2.VideoWriter_fourcc(*'mp4v')
                video = cv2.VideoWriter(out_path, fourcc, 5, (width, height))
                for frame in frames:
                    video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
                video.release()
                print(f"Prediction video saved to {out_path}")
    except Exception as e:
        print(f"Could not create video: {e}")

if __name__ == '__main__':
    main()
