In [None]:
##### INSTRUCTIONS

# There are a few files and moving parts as this is used to train all architectures!
# First, we have the camvid.zip as a drive upload so that one only has to upload it once to work with all six features
# Dataset.py and the network architecture (see the folder) each need to be uploaded directly into the Colab
# After unzipping camvid.zip, please rename the file "data"
# Finally, select the applicable import statement.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!unzip '/content/drive/MyDrive/camvid.zip'

!unzip camvid.zip

Mounted at /content/drive
Archive:  /content/drive/MyDrive/camvid.zip
  inflating: CamVid/class_dict.csv   
  inflating: CamVid/test/0001TP_006690.png  
  inflating: CamVid/test/0001TP_006720.png  
  inflating: CamVid/test/0001TP_006750.png  
  inflating: CamVid/test/0001TP_006780.png  
  inflating: CamVid/test/0001TP_006810.png  
  inflating: CamVid/test/0001TP_006840.png  
  inflating: CamVid/test/0001TP_006870.png  
  inflating: CamVid/test/0001TP_006900.png  
  inflating: CamVid/test/0001TP_006930.png  
  inflating: CamVid/test/0001TP_006960.png  
  inflating: CamVid/test/0001TP_006990.png  
  inflating: CamVid/test/0001TP_007020.png  
  inflating: CamVid/test/0001TP_007050.png  
  inflating: CamVid/test/0001TP_007080.png  
  inflating: CamVid/test/0001TP_007110.png  
  inflating: CamVid/test/0001TP_007140.png  
  inflating: CamVid/test/0001TP_007170.png  
  inflating: CamVid/test/0001TP_007200.png  
  inflating: CamVid/test/0001TP_007230.png  
  inflating: CamVid/test/0001TP_00726

In [None]:
### TRAIN
# ****Requires uploading unet.py and dataset.py****

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt
plt.style.use('bmh')

#### TODO: SELECT
# FCN With Skips is for both hybrid FCN128 and 256.
'''from basic_fcn import UNet'''
'''from fcn_with_skips import UNet'''
'''from unet_dilate_opt import UNetV2'''
'''from unet_with_dilate import UNet'''
'''from fcn_no_batch_norm import UNet'''

from dataset import SegmentationDataset, get_train_transform, get_val_transform, CAMVID_CLASSES

class DiceScore:
    def __init__(self, num_classes):
        self.num_classes = num_classes

    def __call__(self, pred, target):
        pred = torch.softmax(pred, dim=1)
        pred = torch.argmax(pred, dim=1)

        dice_scores = []
        for class_idx in range(self.num_classes):
            pred_mask = (pred == class_idx)
            target_mask = (target == class_idx)

            intersection = (pred_mask & target_mask).sum().float()
            union = pred_mask.sum() + target_mask.sum()

            if union > 0:
                dice = (2. * intersection) / (union + 1e-8)
                dice_scores.append(dice.item())

        return np.mean(dice_scores)

def plot_metrics(train_losses, val_losses, train_dices, val_dices, save_dir):
    """Plot training and validation metrics"""
    epochs = range(1, len(train_losses) + 1)

    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Plot losses
    ax1.plot(epochs, train_losses, 'b-', label='Training Loss')
    ax1.plot(epochs, val_losses, 'r-', label='Validation Loss')
    ax1.set_title('Loss vs. Epochs')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)

    # Plot Dice scores
    ax2.plot(epochs, train_dices, 'b-', label='Training Dice')
    ax2.plot(epochs, val_dices, 'r-', label='Validation Dice')
    ax2.set_title('Dice Score vs. Epochs')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Dice Score')
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'training_metrics.png'))
    plt.close()

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, save_dir):
    """Training loop with validation"""
    best_val_loss = float('inf')
    dice_metric = DiceScore(num_classes=len(CAMVID_CLASSES))

    # Lists to store metrics for plotting
    train_losses = []
    val_losses = []
    train_dices = []
    val_dices = []

    # Create plots directory
    plots_dir = os.path.join(save_dir, 'plots')
    os.makedirs(plots_dir, exist_ok=True)

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        train_dice = 0

        for images, masks in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} - Training'):
            images = images.to(device)
            masks = masks.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_dice += dice_metric(outputs, masks)

        avg_train_loss = train_loss / len(train_loader)
        avg_train_dice = train_dice / len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0
        val_dice = 0

        with torch.no_grad():
            for images, masks in tqdm(val_loader, desc='Validation'):
                images = images.to(device)
                masks = masks.to(device)

                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
                val_dice += dice_metric(outputs, masks)

        avg_val_loss = val_loss / len(val_loader)
        avg_val_dice = val_dice / len(val_loader)

        # Store metrics for plotting
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        train_dices.append(avg_train_dice)
        val_dices.append(avg_val_dice)

        # Plot and save metrics
        plot_metrics(train_losses, val_losses, train_dices, val_dices, plots_dir)

        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {avg_train_loss:.4f}, Train Dice: {avg_train_dice:.4f}')
        print(f'Val Loss: {avg_val_loss:.4f}, Val Dice: {avg_val_dice:.4f}')

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'train_dice': avg_train_dice,
                'val_dice': avg_val_dice,
                'train_history': {
                    'losses': train_losses,
                    'dices': train_dices
                },
                'val_history': {
                    'losses': val_losses,
                    'dices': val_dices
                }
            }, os.path.join(save_dir, 'best_model.pth'))

def main():
    # Set device - check for MPS (Apple Silicon GPU) first, then CUDA, then fall back to CPU
    if torch.backends.mps.is_available():
        device = torch.device("mps")
        print("Using MPS (Apple Silicon GPU)")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
        print("Using CUDA GPU")
    else:
        device = torch.device("cpu")
        print("Using CPU")

    # Hyperparameters
    BATCH_SIZE = 4
    NUM_EPOCHS = 30
    LEARNING_RATE = 1e-4
    IMAGE_SIZE = 256
    NUM_WORKERS = 8  # Increased from 4
    PREFETCH_FACTOR = 2  # Load 2 batches per worker in advance

    # Create save directory
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    save_dir = os.path.join('models', f'unet_{timestamp}')
    os.makedirs(save_dir, exist_ok=True)

    # Create datasets
    train_dataset = SegmentationDataset(
        split='train',
        transform=get_train_transform(IMAGE_SIZE)
    )

    val_dataset = SegmentationDataset(
        split='val',
        transform=get_val_transform(IMAGE_SIZE)
    )

    # Create data loaders with optimized settings
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        prefetch_factor=PREFETCH_FACTOR,
        persistent_workers=True  # Keep workers alive between epochs
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        prefetch_factor=PREFETCH_FACTOR,
        persistent_workers=True
    )

    # Initialize model, criterion, and optimizer
    model = UNet(in_channels=3).to(device)  # Output channels automatically set to number of classes
    criterion = nn.CrossEntropyLoss()  # Use CrossEntropyLoss for multi-class segmentation
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    print(f"Training U-Net for {len(CAMVID_CLASSES)} classes:")
    for i, class_name in enumerate(CAMVID_CLASSES):
        print(f"{i}: {class_name}")

    # Train model
    train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        num_epochs=NUM_EPOCHS,
        device=device,
        save_dir=save_dir
    )

if __name__ == '__main__':
    main()

  check_for_updates()


Using CUDA GPU
Initial vals
3
32
FTRS
512
Training U-Net for 32 classes:
0: Animal
1: Archway
2: Bicyclist
3: Bridge
4: Building
5: Car
6: CartLuggagePram
7: Child
8: Column_Pole
9: Fence
10: LaneMkgsDriv
11: LaneMkgsNonDriv
12: Misc_Text
13: MotorcycleScooter
14: OtherMoving
15: ParkingBlock
16: Pedestrian
17: Road
18: RoadShoulder
19: Sidewalk
20: SignSymbol
21: Sky
22: SUVPickupTruck
23: TrafficCone
24: TrafficLight
25: Train
26: Tree
27: Truck_Bus
28: Tunnel
29: VegetationMisc
30: Void
31: Wall


Epoch 1/30 - Training: 100%|██████████| 93/93 [04:46<00:00,  3.08s/it]
Validation: 100%|██████████| 25/25 [01:16<00:00,  3.06s/it]


Epoch 1/30:
Train Loss: 2.9354, Train Dice: 0.0211
Val Loss: 2.3236, Val Dice: 0.0299


Epoch 2/30 - Training: 100%|██████████| 93/93 [05:02<00:00,  3.25s/it]
Validation: 100%|██████████| 25/25 [00:10<00:00,  2.31it/s]


Epoch 2/30:
Train Loss: 2.4578, Train Dice: 0.0293
Val Loss: 2.1039, Val Dice: 0.0707


Epoch 3/30 - Training: 100%|██████████| 93/93 [05:02<00:00,  3.25s/it]
Validation: 100%|██████████| 25/25 [00:11<00:00,  2.12it/s]


Epoch 3/30:
Train Loss: 2.0206, Train Dice: 0.0546
Val Loss: 1.8403, Val Dice: 0.0784


Epoch 4/30 - Training: 100%|██████████| 93/93 [04:50<00:00,  3.12s/it]
Validation: 100%|██████████| 25/25 [00:11<00:00,  2.12it/s]


Epoch 4/30:
Train Loss: 1.7569, Train Dice: 0.0797
Val Loss: 1.4943, Val Dice: 0.1164


Epoch 5/30 - Training: 100%|██████████| 93/93 [05:04<00:00,  3.27s/it]
Validation: 100%|██████████| 25/25 [00:11<00:00,  2.13it/s]


Epoch 5/30:
Train Loss: 1.5935, Train Dice: 0.0948
Val Loss: 1.4206, Val Dice: 0.1224


Epoch 6/30 - Training: 100%|██████████| 93/93 [05:05<00:00,  3.28s/it]
Validation: 100%|██████████| 25/25 [00:11<00:00,  2.12it/s]


Epoch 6/30:
Train Loss: 1.5836, Train Dice: 0.0975
Val Loss: 1.4285, Val Dice: 0.1199


Epoch 7/30 - Training: 100%|██████████| 93/93 [05:07<00:00,  3.30s/it]
Validation: 100%|██████████| 25/25 [00:15<00:00,  1.65it/s]


Epoch 7/30:
Train Loss: 1.5248, Train Dice: 0.1002
Val Loss: 1.3496, Val Dice: 0.1216


Epoch 8/30 - Training: 100%|██████████| 93/93 [05:06<00:00,  3.29s/it]
Validation: 100%|██████████| 25/25 [00:17<00:00,  1.44it/s]


Epoch 8/30:
Train Loss: 1.4858, Train Dice: 0.1013
Val Loss: 1.3358, Val Dice: 0.1247


Epoch 9/30 - Training: 100%|██████████| 93/93 [05:09<00:00,  3.33s/it]
Validation: 100%|██████████| 25/25 [00:13<00:00,  1.89it/s]


Epoch 9/30:
Train Loss: 1.4221, Train Dice: 0.1037
Val Loss: 1.3284, Val Dice: 0.1245


Epoch 10/30 - Training: 100%|██████████| 93/93 [05:09<00:00,  3.33s/it]
Validation: 100%|██████████| 25/25 [00:14<00:00,  1.72it/s]


Epoch 10/30:
Train Loss: 1.4281, Train Dice: 0.1048
Val Loss: 1.2894, Val Dice: 0.1249


Epoch 11/30 - Training: 100%|██████████| 93/93 [05:07<00:00,  3.31s/it]
Validation: 100%|██████████| 25/25 [00:12<00:00,  1.93it/s]


Epoch 11/30:
Train Loss: 1.4101, Train Dice: 0.1038
Val Loss: 1.3189, Val Dice: 0.1246


Epoch 12/30 - Training: 100%|██████████| 93/93 [05:09<00:00,  3.32s/it]
Validation: 100%|██████████| 25/25 [00:13<00:00,  1.84it/s]


Epoch 12/30:
Train Loss: 1.3932, Train Dice: 0.1059
Val Loss: 1.3127, Val Dice: 0.1230


Epoch 13/30 - Training: 100%|██████████| 93/93 [05:02<00:00,  3.25s/it]
Validation: 100%|██████████| 25/25 [00:14<00:00,  1.70it/s]


Epoch 13/30:
Train Loss: 1.3629, Train Dice: 0.1068
Val Loss: 1.2428, Val Dice: 0.1265


Epoch 14/30 - Training: 100%|██████████| 93/93 [04:59<00:00,  3.22s/it]
Validation: 100%|██████████| 25/25 [00:14<00:00,  1.68it/s]


Epoch 14/30:
Train Loss: 1.3718, Train Dice: 0.1065
Val Loss: 1.3158, Val Dice: 0.1238


Epoch 15/30 - Training: 100%|██████████| 93/93 [05:10<00:00,  3.33s/it]
Validation: 100%|██████████| 25/25 [00:14<00:00,  1.75it/s]


Epoch 15/30:
Train Loss: 1.3495, Train Dice: 0.1089
Val Loss: 1.2358, Val Dice: 0.1265


Epoch 16/30 - Training: 100%|██████████| 93/93 [04:54<00:00,  3.17s/it]
Validation: 100%|██████████| 25/25 [00:15<00:00,  1.62it/s]


Epoch 16/30:
Train Loss: 1.3012, Train Dice: 0.1109
Val Loss: 1.2116, Val Dice: 0.1266


Epoch 17/30 - Training: 100%|██████████| 93/93 [05:04<00:00,  3.28s/it]
Validation: 100%|██████████| 25/25 [00:13<00:00,  1.79it/s]


Epoch 17/30:
Train Loss: 1.3384, Train Dice: 0.1104
Val Loss: 1.3166, Val Dice: 0.1220


Epoch 18/30 - Training:   9%|▊         | 8/93 [00:40<03:18,  2.33s/it]

In [None]:
model = torch.load("/content/models/unet_20241217_223314/best_model.pth")
str_model = str(model)
with open("test.txt", "w") as test_text:
  test_text.write(str_model)

  model = torch.load("/content/models/unet_20241217_223314/best_model.pth")
