In [None]:
from google.colab import drive

drive.mount('/content/drive/')

%cd "/content/drive/MyDrive/CONSEGNA_ML/"


In [None]:
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as T
from torchvision import transforms
import torch
from torch import nn, optim
from models import DeepLabV3, CNN_7_Layers, DeepLabV3Lite
from tqdm import tqdm
from utils import ImageSegmentationDatasetOneHotEncoding, ImageSegmentationDatasetLogit, color_map
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
from PIL import Image
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt

In [None]:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

## Load dataset

In [None]:
# Full dataset
# train_dir = '../../datasets/esame_ml/train'
train_dir = './train' # Colab
dataset = ImageSegmentationDatasetLogit(root_dir=train_dir)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

torchvision transforms are not designed to work on both image and label synchronously : we use a boolean to do it in the loading

In [None]:
# Data split
val_percent = 0.2
val_size = int(len(dataset) * val_percent)
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

## Functions for training

In [None]:
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0

    loop = tqdm(dataloader, desc="Training", leave=False)

    for rgb, _, labels in loop:
        rgb, labels = rgb.to(device), labels.to(device)

        #targets = labels.argmax(dim=1)  # (B, C, H, W) → (B, H, W) - to keep if we use one hot encoding
        targets=labels.long()

        optimizer.zero_grad()
        outputs = model(rgb)

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    return running_loss / len(dataloader)

In [None]:
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for rgb, _, labels in dataloader:
            rgb, labels = rgb.to(device), labels.to(device)
            #targets = labels.argmax(dim=1)
            targets=labels.long()
            outputs = model(rgb)
            loss = criterion(outputs, targets)

            total_loss += loss.item()

    return total_loss / len(dataloader)

shape [16, 544, 1024] vs [16, 9, 544, 1024] aka input [16, 9, 544, 1024] vs target [16, 544, 1024]

normal for CE : "Input (C), (C,N), (C,N,d_1,...,d_k)" "Target (), (N), (N,d_1,...,d_k)" in the documentation

CE needs class indices

## Loop

In [None]:
x = torch.randn(1, 3, 224, 224).to('cuda')
print("Tensor on GPU:", x.device)

Tensor on GPU: cuda:0


In [None]:
#!nvidia-smi

In [None]:
for _,_, labels in train_loader:
    print(labels.shape, labels.min(), labels.max(), labels.dtype)
    break

torch.Size([4, 544, 1024]) tensor(0, dtype=torch.uint8) tensor(8, dtype=torch.uint8) torch.uint8


In [None]:
def load_model_from_checkpoint():
    checkpoint = torch.load('checkpoint_epoch_10.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    # start_epoch = checkpoint['epoch'] + 1
    return model, start_epoch

In [None]:
# Hyperparameters
EPOCHS = 15
LR = 1e-3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

start_epoch = 0
model = DeepLabV3Lite().to(device)
# model, start_epoch = load_model_from_checkpoint()

optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

# Weights for the loss - weighted loss
class_weights = torch.tensor([
  class_weights[0] = 1.5      # undefined/background - white
  class_weights[1] = 1.0      # smooth trail - grey
  class_weights[2] = 2.0      # traversable grass - light green
  class_weights[3] = 1.0      # rough trail - brown
  class_weights[4] = 1.0      # puddle - pink
  class_weights[5] = 1.0      # obstacle - red
  class_weights[6] = 2.0      # non-traversable low vegetation - medium green
  class_weights[7] = 1.0      # high vegetation - dark green
  class_weights[8] = 1.0      # sky - blue
], dtype=torch.float).to(device)

# Loss
criterion = nn.CrossEntropyLoss(weight=class_weights)




In [None]:
# Loop
for epoch in range(start_epoch, EPOCHS):
    train_loss = train(model, train_loader, optimizer, criterion, device)
    val_loss = evaluate(model, val_loader, criterion, device)
    scheduler.step(val_loss)
    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    # Save every 5 epochs
    if (epoch + 1) % 5 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, f'checkpoint_epoch_{epoch+1}.pth')

                                                                                                                                                                                                             

Epoch 2/50 | Train Loss: 0.9500 | Val Loss: 1.3592


                                                                                                                                                                                                             

Epoch 3/50 | Train Loss: 0.7451 | Val Loss: 0.7809


                                                                                                                                                                                                             

Epoch 4/50 | Train Loss: 0.6777 | Val Loss: 0.7496


                                                                                                                                                                                                             

Epoch 5/50 | Train Loss: 0.6334 | Val Loss: 0.6599


                                                                                                                                                                                                             

Epoch 6/50 | Train Loss: 0.6176 | Val Loss: 0.6695


                                                                                                                                                                                                             

Epoch 7/50 | Train Loss: 0.5792 | Val Loss: 0.6824


                                                                                                                                                                                                             

Epoch 8/50 | Train Loss: 0.5642 | Val Loss: 0.6404


                                                                                                                                                                                                             

Epoch 9/50 | Train Loss: 0.5498 | Val Loss: 0.6826


                                                                                                                                                                                                             

Epoch 10/50 | Train Loss: 0.5471 | Val Loss: 0.6596


                                                                                                                                                                                                             

Epoch 11/50 | Train Loss: 0.5086 | Val Loss: 0.6675


                                                                                                                                                                                                             

Epoch 12/50 | Train Loss: 0.4783 | Val Loss: 0.6801


                                                                                                                                                                                                             

Epoch 13/50 | Train Loss: 0.4740 | Val Loss: 0.6858


                                                                                                                                                                                                             

Epoch 14/50 | Train Loss: 0.4530 | Val Loss: 0.7837


                                                                                                                                                                                                             

Epoch 15/50 | Train Loss: 0.3807 | Val Loss: 0.6858


                                                                                                                                                                                                             

Epoch 16/50 | Train Loss: 0.3549 | Val Loss: 0.6969


                                                                                                                                                                                                             

Epoch 17/50 | Train Loss: 0.3192 | Val Loss: 0.6852


                                                                                                                                                                                                             

Epoch 18/50 | Train Loss: 0.3008 | Val Loss: 0.7295


                                                                                                                                                                                                             

Epoch 19/50 | Train Loss: 0.2823 | Val Loss: 0.7808


                                                                                                                                                                                                             

Epoch 20/50 | Train Loss: 0.2741 | Val Loss: 0.8263


                                                                                                                                                                                                             

Epoch 21/50 | Train Loss: 0.2432 | Val Loss: 0.7498


                                                                                                                                                                                                             

Epoch 22/50 | Train Loss: 0.2124 | Val Loss: 0.7871


                                                                                                                                                                                                             

Epoch 23/50 | Train Loss: 0.2089 | Val Loss: 0.7844


                                                                                                                                                                                                             

Epoch 24/50 | Train Loss: 0.1992 | Val Loss: 0.8026


                                                                                                                                                                                                             

Epoch 25/50 | Train Loss: 0.1955 | Val Loss: 0.7785


                                                                                                                                                                                                             

Epoch 26/50 | Train Loss: 0.1933 | Val Loss: 0.7964


                                                                                                                                                                                                             

Epoch 27/50 | Train Loss: 0.1691 | Val Loss: 0.7811


                                                                                                                                                                                                             

Epoch 28/50 | Train Loss: 0.1640 | Val Loss: 0.7857


                                                                                                                                                                                                             

Epoch 29/50 | Train Loss: 0.1594 | Val Loss: 0.8216


                                                                                                                                                                                                             

Epoch 30/50 | Train Loss: 0.1643 | Val Loss: 0.8961


                                                                                                                                                                                                             

Epoch 31/50 | Train Loss: 0.1648 | Val Loss: 0.8041


                                                                                                                                                                                                             

Epoch 32/50 | Train Loss: 0.1606 | Val Loss: 0.8347


                                                                                                                                                                                                             

Epoch 33/50 | Train Loss: 0.1514 | Val Loss: 0.8158


                                                                                                                                                                                                             

Epoch 34/50 | Train Loss: 0.1444 | Val Loss: 0.8467


                                                                                                                                                                                                             

Epoch 35/50 | Train Loss: 0.1409 | Val Loss: 0.8477


                                                                                                                                                                                                             

Epoch 36/50 | Train Loss: 0.1398 | Val Loss: 0.8278


                                                                                                                                                                                                             

Epoch 37/50 | Train Loss: 0.1482 | Val Loss: 0.8346


                                                                                                                                                                                                             

Epoch 38/50 | Train Loss: 0.1424 | Val Loss: 0.8758


                                                                                                                                                                                                             

Epoch 39/50 | Train Loss: 0.1391 | Val Loss: 0.8635


                                                                                                                                                                                                             

Epoch 40/50 | Train Loss: 0.1334 | Val Loss: 0.8437


                                                                                                                                                                                                             

Epoch 41/50 | Train Loss: 0.1343 | Val Loss: 0.8482


                                                                                                                                                                                                             

Epoch 42/50 | Train Loss: 0.1339 | Val Loss: 0.8618


                                                                                                                                                                                                             

Epoch 43/50 | Train Loss: 0.1324 | Val Loss: 0.8607


                                                                                                                                                                                                             

Epoch 44/50 | Train Loss: 0.1336 | Val Loss: 0.8464


                                                                                                                                                                                                             

Epoch 45/50 | Train Loss: 0.1283 | Val Loss: 0.8573


                                                                                                                                                                                                             

Epoch 46/50 | Train Loss: 0.1309 | Val Loss: 0.8601


                                                                                                                                                                                                             

Epoch 47/50 | Train Loss: 0.1282 | Val Loss: 0.8422


                                                                                                                                                                                                             

Epoch 48/50 | Train Loss: 0.1295 | Val Loss: 0.8757


                                                                                                                                                                                                             

Epoch 49/50 | Train Loss: 0.1300 | Val Loss: 0.8611


                                                                                                                                                                                                             

Epoch 50/50 | Train Loss: 0.1271 | Val Loss: 0.8953


## Validation visualization

In [None]:
def convert_prediction_to_rgb(pred, color_map):
    class_map = pred.argmax(dim=0).cpu().numpy()  # (H, W)
    inverse_color_map = {v: k for k, v in color_map.items()}
    h, w = class_map.shape
    rgb = np.zeros((h, w, 3), dtype=np.uint8)
    for class_idx, color in inverse_color_map.items():
        rgb[class_map == class_idx] = color
    return rgb

In [None]:
def visualize_sample_prediction(model, rgb_tensor, label_tensor, color_map, device):
    image_tensor = rgb_tensor.unsqueeze(0).to(device)  # (1, 3, H, W)
    rgb_image = image_tensor.squeeze()

    model = model.to(device)
    model.eval()

    with torch.no_grad():
        output = model(image_tensor)  # (1, 9, H, W)
        output = output.squeeze(0).softmax(dim=0)  # (9, H, W)

    pred_rgb = convert_prediction_to_rgb(output, color_map)

    rgb_image = to_pil_image(rgb_image.cpu())
    label_image = to_pil_image(label_tensor.cpu())

    # Plot all
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.title("Input Image")
    plt.imshow(rgb_image)
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.title("Ground Truth")
    plt.imshow(label_image)
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.title("Model Prediction")
    plt.imshow(pred_rgb)
    plt.axis("off")

    plt.show()

In [None]:
# if we want to check previous checkpoints with the same validation set

checkpoint_vis = torch.load('checkpoint_epoch_15.pth')
model_vis = DeepLabV3Lite().to(device)
model_vis.load_state_dict(checkpoint_vis['model_state_dict'])
model_vis.eval();

In [None]:
for rgb_batch, label_batch, _ in val_loader:
    batch_size = rgb_batch.size(0)
    for i in range(min(5, batch_size)):
        rgb_tensor = rgb_batch[i]
        label_tensor = label_batch[i]
        visualize_sample_prediction(model_vis, rgb_tensor, label_tensor, color_map, device)
    #break