In [1]:
from google.colab import drive
drive.mount('/content/drive',force_remount=True)
%cd /content/drive/MyDrive/The-Role-of-Spatial-Context-in-Deep-Learning-based-Semantic-Segmentation-of-Remote-Sensing-Imagery/dfc20/
!pip install rasterio

Mounted at /content/drive
/content/drive/MyDrive/The-Role-of-Spatial-Context-in-Deep-Learning-based-Semantic-Segmentation-of-Remote-Sensing-Imagery/dfc20
Collecting rasterio
  Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Collecting cligj>=0.5 (from rasterio)
  Downloading cligj-0.7.2-py3-none-any.whl.metadata (5.0 kB)
Collecting click-plugins (from rasterio)
  Downloading click_plugins-1.1.1-py2.py3-none-any.whl.metadata (6.4 kB)
Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.2/22.2 MB[0m [31m101.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading cligj-0.7.2-py3-none-any.whl (7.1 kB)
Downloading affine-2.4.0-py3-none-any.whl (15 kB)
Downloading click_plugins-1.1.1-py2.py3-none-any.whl (7.5 kB)
Installing collected packag

In [34]:
!cp -r /content/drive/MyDrive/The-Role-of-Spatial-Context-in-Deep-Learning-based-Semantic-Segmentation-of-Remote-Sensing-Imagery/dfc20/data /content/data

^C


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import os
import sys
from torchvision import transforms
import time


project_root = os.getcwd()
sys.path.append(os.path.join(project_root, "models"))
from dataset import *
from models.unet import UNet

In [38]:
############ (TUNED) PARAMETERS ############

# data loader
batch_size = 16
num_workers = 1
#prefetch_factor = 2
#persistent_workers = True

# training
num_epochs = 5
learning_rate = 1e-5
scheduler_factor = 0.5
scheduler_patience = 3

# logging
log = False

In [39]:
############ LOAD DATA ############

path = "./data"
path = "/content/data"

# load datasets
train_set = DFC20(path,
                  subset="train",
                  use_s1=False,
                  use_s2_RGB=False,
                  use_s2_hr=False,
                  use_s2_all=True,
                  as_tensor=True)

val_set = DFC20(path,
                subset="val",
                use_s1=False,
                use_s2_RGB=False,
                use_s2_hr=False,
                use_s2_all=True,
                as_tensor=True)

n_inputs = train_set.n_inputs

# set up dataloaders
train_loader = DataLoader(train_set,
                            batch_size=batch_size,
                            shuffle=True,
                            pin_memory=True,
                            drop_last=False,
                            num_workers=num_workers)

val_loader = DataLoader(val_set,
                            batch_size=batch_size,
                            shuffle=True,
                            pin_memory=True,
                            drop_last=False,
                            num_workers=num_workers)



[Load]: 100%|██████████| 4270/4270 [00:00<00:00, 744251.91it/s]


loaded 4270 samples from the DFC20 subset train


[Load]: 100%|██████████| 684/684 [00:00<00:00, 397740.74it/s]

loaded 684 samples from the DFC20 subset val





In [40]:
############ MODEL CHOICE ############

model = UNet(n_channels=n_inputs)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = model.to(device)

Using device: cuda


In [41]:
############ TRAIN ############

# Loss function
criterion = nn.CrossEntropyLoss()  # maybe dice, maybe weights?

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=scheduler_factor, patience=scheduler_patience, verbose=True
)

# mIoU helper
def compute_miou(pred, label, num_classes):
    ious = []
    pred = pred.view(-1)
    label = label.view(-1)
    for cls in range(num_classes):
        pred_inds = (pred == cls)
        label_inds = (label == cls)
        intersection = (pred_inds & label_inds).sum().item()
        union = (pred_inds | label_inds).sum().item()
        if union == 0:
            ious.append(float('nan'))  # or 0.0
        else:
            ious.append(intersection / union)
    return sum([iou for iou in ious if not torch.isnan(torch.tensor(iou))]) / num_classes

# innit logger
if log:
  writer = SummaryWriter(log_dir='logs/param_tuning') # param_tuning, baseline, ...

# Training loop
best_val_miou = 0.0

for epoch in range(num_epochs):
    epoch_start = time.time()
    model.train()
    running_loss = 0.0
    total_accuracy = 0.0
    total_miou = 0.0
    num_batches = 0

    #total_start = time.time()
    #load_start = time.time()
    time_start = time.time()

    for batch in train_loader:
        #load_end = time.time()
        #data_loading_time = load_end - load_start
        #batch_start = time.time()

        # unpack sample
        inputs = batch['image'].to(device)
        labels = batch['label'].to(device)
        # reset gradients
        optimizer.zero_grad()
        # forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # backward pass
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # prediction
        predicted_labels = torch.argmax(outputs, dim=1)
        # Accuracy
        accuracy = (predicted_labels == labels).float().mean().item() * 100
        total_accuracy += accuracy
        # mIoU
        miou = compute_miou(predicted_labels, labels, num_classes=outputs.shape[1])
        total_miou += miou

        num_batches += 1

        #batch_time = time.time() - batch_start
        #total_time = time.time() - total_start
        time = time.time() - time_start

        print(f"[Batch {num_batches}] Time: {time:.2f}s")
        #print(f"[Batch {num_batches}] Total: {total_time:.2f}s | Load: {data_loading_time:.2f}s | Batch: {batch_time:.2f}s")

        time_start = time.time()
        #load_start = time.time()

    avg_loss = running_loss / len(train_loader)
    avg_accuracy = total_accuracy / num_batches
    avg_miou = total_miou / num_batches

    # Validation
    model.eval()
    val_loss = 0.0
    val_accuracy = 0.0
    val_miou = 0.0
    with torch.no_grad():
        for batch in val_loader:
            inputs = batch['image'].to(device)
            labels = batch['label'].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            predicted = torch.argmax(outputs, dim=1)
            acc = (predicted == labels).float().mean().item() * 100
            miou = compute_miou(predicted, labels, num_classes=outputs.shape[1])
            val_accuracy += acc
            val_miou += miou

    val_loss /= len(val_loader)
    val_accuracy /= len(val_loader)
    val_miou /= len(val_loader)

    scheduler.step(val_loss)

    # Print epoch summary
    epoch_time = time.time() - epoch_start
    print(f"\n=== Epoch {epoch+1}/{num_epochs} — {epoch_time:.2f}s ===")
    print(f"Train     — Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%, mIoU: {avg_miou:.4f}")
    print(f"Validate  — Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.2f}%, mIoU: {val_miou:.4f}\n")

    # TensorBoard logging
    if log:
      writer.add_scalar('Loss/train', avg_loss, epoch)
      writer.add_scalar('Accuracy/train', avg_accuracy, epoch)
      writer.add_scalar('mIoU/train', avg_miou, epoch)
      writer.add_scalar('Loss/val', val_loss, epoch)
      writer.add_scalar('Accuracy/val', val_accuracy, epoch)
      writer.add_scalar('mIoU/val', val_miou, epoch)
      writer.add_scalar('LearningRate', optimizer.param_groups[0]['lr'], epoch)

    if val_miou > best_val_miou:
        best_val_miou = val_miou
        # torch.save(model.state_dict(), 'trained_models/unet_baseline.pth')


if log:
  writer.close()

[Batch 1] Total: 1.41s | Load: 0.77s | Batch: 0.65s
[Batch 2] Total: 2.26s | Load: 0.24s | Batch: 0.61s
[Batch 3] Total: 2.96s | Load: 0.10s | Batch: 0.60s
[Batch 4] Total: 3.57s | Load: 0.00s | Batch: 0.60s
[Batch 5] Total: 4.17s | Load: 0.00s | Batch: 0.60s
[Batch 6] Total: 4.77s | Load: 0.00s | Batch: 0.61s
[Batch 7] Total: 5.37s | Load: 0.00s | Batch: 0.60s
[Batch 8] Total: 5.98s | Load: 0.00s | Batch: 0.60s
[Batch 9] Total: 6.58s | Load: 0.00s | Batch: 0.60s
[Batch 10] Total: 7.19s | Load: 0.00s | Batch: 0.60s
[Batch 11] Total: 7.79s | Load: 0.00s | Batch: 0.61s
[Batch 12] Total: 8.40s | Load: 0.00s | Batch: 0.61s
[Batch 13] Total: 9.01s | Load: 0.00s | Batch: 0.61s
[Batch 14] Total: 9.61s | Load: 0.00s | Batch: 0.61s
[Batch 15] Total: 10.22s | Load: 0.00s | Batch: 0.61s
[Batch 16] Total: 10.83s | Load: 0.00s | Batch: 0.61s
[Batch 17] Total: 11.45s | Load: 0.01s | Batch: 0.61s
[Batch 18] Total: 12.06s | Load: 0.00s | Batch: 0.61s
[Batch 19] Total: 12.66s | Load: 0.00s | Batch: 0.6

In [None]:
############ TRAIN LOGS ############

%load_ext tensorboard
%tensorboard --logdir=runs