In [6]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/The-Role-of-Spatial-Context-in-Deep-Learning-based-Semantic-Segmentation-of-Remote-Sensing-Imagery/dfc20/
!pip install rasterio

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/The-Role-of-Spatial-Context-in-Deep-Learning-based-Semantic-Segmentation-of-Remote-Sensing-Imagery/dfc20


In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import os
import sys
from torchvision import transforms
import time


project_root = os.getcwd()
sys.path.append(os.path.join(project_root, "models"))
from dataset import *
from models.unet import UNet

In [30]:
############ LOAD DATA ############

path = "./data"

# load datasets
train_set = DFC20(path,
                  subset="train",
                  use_s1=False,
                  use_s2_RGB=True,
                  use_s2_hr=False,
                  use_s2_all=False,
                  as_tensor=True)

val_set = DFC20(path,
                subset="val",
                use_s1=False,
                use_s2_RGB=True,
                use_s2_hr=False,
                use_s2_all=False,
                as_tensor=True)

n_inputs = train_set.n_inputs

# tuned hyperparams
batch_size = 8
num_workers = 4
prefetch_factor = 2

num_epochs = 3
learning_rate = 1e-4
scheduler_factor = 0.5
scheduler_patience = 3

# set up dataloaders
train_loader = DataLoader(train_set,
                            batch_size=batch_size,
                            shuffle=True,
                            pin_memory=True,
                            drop_last=False,
                            num_workers=num_workers,
                            prefetch_factor=prefetch_factor)

val_loader = DataLoader(train_set,
                            batch_size=batch_size,
                            shuffle=True,
                            pin_memory=True,
                            drop_last=False,
                            num_workers=num_workers,
                            prefetch_factor=prefetch_factor)



[Load]: 100%|██████████| 4270/4270 [00:00<00:00, 375724.88it/s]


loaded 4270 samples from the DFC20 subset train


[Load]: 100%|██████████| 684/684 [00:00<00:00, 353879.85it/s]

loaded 684 samples from the DFC20 subset val





In [31]:
############ MODEL CHOICE ############

model = UNet(n_channels=n_inputs)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = model.to(device)

Using device: cuda


In [None]:
############ TRAIN ############

# Loss function
criterion = nn.CrossEntropyLoss()  # maybe dice, maybe weights?

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=scheduler_factor, patience=scheduler_patience, verbose=True
)

# mIoU helper
def compute_miou(pred, label, num_classes):
    ious = []
    pred = pred.view(-1)
    label = label.view(-1)
    for cls in range(num_classes):
        pred_inds = (pred == cls)
        label_inds = (label == cls)
        intersection = (pred_inds & label_inds).sum().item()
        union = (pred_inds | label_inds).sum().item()
        if union == 0:
            ious.append(float('nan'))  # or 0.0
        else:
            ious.append(intersection / union)
    return sum([iou for iou in ious if not torch.isnan(torch.tensor(iou))]) / num_classes

# Training loop
writer = SummaryWriter(log_dir='logs/param_tuning') # param_tuning, baseline, ...
best_val_accuracy = 0.0

for epoch in range(num_epochs):
    epoch_start = time.time()
    model.train()
    running_loss = 0.0
    total_accuracy = 0.0
    total_miou = 0.0
    num_batches = 0

    load_start = time.time()

    for batch in train_loader:
        load_end = time.time()
        data_loading_time = load_end - load_start
        batch_start = time.time()

        # unpack sample
        inputs = batch['image'].to(device)
        labels = batch['label'].to(device)
        # reset gradients
        optimizer.zero_grad()
        # forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # backward pass
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # prediction
        predicted_labels = torch.argmax(outputs, dim=1)
        # Accuracy
        accuracy = (predicted_labels == labels).float().mean().item() * 100
        total_accuracy += accuracy
        # mIoU
        miou = compute_miou(predicted_labels, labels, num_classes=outputs.shape[1])
        total_miou += miou

        num_batches += 1

        batch_time = time.time() - batch_start

        print(f"[Batch {num_batches}] Load: {data_loading_time:.2f}s | Batch: {batch_time:.2f}s")

        load_start = time.time()

    avg_loss = running_loss / len(train_loader)
    avg_accuracy = total_accuracy / num_batches
    avg_miou = total_miou / num_batches

    # Validation
    model.eval()
    val_loss = 0.0
    val_accuracy = 0.0
    val_miou = 0.0
    with torch.no_grad():
        for batch in val_loader:
            inputs = batch['image'].to(device)
            labels = batch['label'].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            predicted = torch.argmax(outputs, dim=1)
            acc = (predicted == labels).float().mean().item() * 100
            miou = compute_miou(predicted, labels, num_classes=outputs.shape[1])
            val_accuracy += acc
            val_miou += miou

    val_loss /= len(val_loader)
    val_accuracy /= len(val_loader)
    val_miou /= len(val_loader)

    scheduler.step(val_loss)

    # Print epoch summary
    epoch_time = time.time() - epoch_start
    print(f"\n=== Epoch {epoch+1}/{num_epochs} — {epoch_time:.2f}s ===")
    print(f"Train     — Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%, mIoU: {avg_miou:.4f}")
    print(f"Validate  — Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.2f}%, mIoU: {val_miou:.4f}\n")

    # TensorBoard logging
    writer.add_scalar('Loss/train', avg_loss, epoch)
    writer.add_scalar('Accuracy/train', avg_accuracy, epoch)
    writer.add_scalar('mIoU/train', avg_miou, epoch)
    writer.add_scalar('Loss/val', val_loss, epoch)
    writer.add_scalar('Accuracy/val', val_accuracy, epoch)
    writer.add_scalar('mIoU/val', val_miou, epoch)
    writer.add_scalar('LearningRate', optimizer.param_groups[0]['lr'], epoch)

    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        # torch.save(model.state_dict(), 'trained_models/unet_baseline.pth')

writer.close()


In [None]:
############ TRAIN LOGS ############

%load_ext tensorboard
%tensorboard --logdir=runs