In [1]:
from google.colab import drive
drive.mount('/content/drive',force_remount=True)
%cd /content/drive/MyDrive/The-Role-of-Spatial-Context-in-Deep-Learning-based-Semantic-Segmentation-of-Remote-Sensing-Imagery/dfc20/
!pip install rasterio

Mounted at /content/drive
/content/drive/MyDrive/The-Role-of-Spatial-Context-in-Deep-Learning-based-Semantic-Segmentation-of-Remote-Sensing-Imagery/dfc20
Collecting rasterio
  Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Collecting cligj>=0.5 (from rasterio)
  Downloading cligj-0.7.2-py3-none-any.whl.metadata (5.0 kB)
Collecting click-plugins (from rasterio)
  Downloading click_plugins-1.1.1-py2.py3-none-any.whl.metadata (6.4 kB)
Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.2/22.2 MB[0m [31m54.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading cligj-0.7.2-py3-none-any.whl (7.1 kB)
Downloading affine-2.4.0-py3-none-any.whl (15 kB)
Downloading click_plugins-1.1.1-py2.py3-none-any.whl (7.5 kB)
Installing collected package

In [9]:
!cp -r /content/drive/MyDrive/The-Role-of-Spatial-Context-in-Deep-Learning-based-Semantic-Segmentation-of-Remote-Sensing-Imagery/dfc20/data /content/data

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import os
import sys
from torchvision import transforms
import albumentations as A
import time

sys.path.append(os.path.join(os.getcwd(), "models"))
from dataset import *
from models.unet import UNetSmall, UNetBig
from utilities.loss import ComboLoss, DiceLoss

In [60]:
############ (TUNED) PARAMETERS ############

# data
use_s1 = False
use_s2_hr = True
use_s2_all = False
use_s2_RGB = False

# pre-process
normalize = True
standardize = False
as_tensor = True
#augment = None
#"""
augment = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    #A.RandomBrightnessContrast(p=0.3),
    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
    #A.GaussianBlur(p=0.2),
    #.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
    #A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=10, val_shift_limit=10, p=0.3),
])
#"""

# data loader
batch_size = 32 # 16 32
num_workers = 2
prefetch_factor = 3
persistent_workers = True

# training
num_epochs = 15

learning_rate = 5e-5 # 1e-3 1e-4 5e-5 1e-5
weight_decay = 1e-4 # 0 1e-4 1e-5

scheduler_factor = 0.5
scheduler_patience = 3 # 3 5 7

ce_loss = False
dice_loss = False
combo_loss = True
weighted_loss = True

# model
big = True
small = False

# logging
log = True

In [61]:
############ LOAD DATA ############

path = "/content/data"
#path = "./data"

# load datasets
train_set = DFC20(path,
                  subset="train",
                  use_s1=use_s1,
                  use_s2_RGB=use_s2_RGB,
                  use_s2_hr=use_s2_hr,
                  use_s2_all=use_s2_all,
                  as_tensor=as_tensor,
                  normalize=normalize,
                  standardize=standardize,
                  augment=augment)

val_set = DFC20(path,
                subset="val",
                use_s1=use_s1,
                use_s2_RGB=use_s2_RGB,
                use_s2_hr=use_s2_hr,
                use_s2_all=use_s2_all,
                as_tensor=as_tensor,
                normalize=normalize,
                standardize=standardize)

n_inputs = train_set.n_inputs

# set up dataloaders
train_loader = DataLoader(train_set,
                            batch_size=batch_size,
                            shuffle=True,
                            pin_memory=True,
                            drop_last=False,
                            num_workers=num_workers,
                            persistent_workers=persistent_workers)

val_loader = DataLoader(val_set,
                            batch_size=batch_size,
                            shuffle=False,
                            pin_memory=True,
                            drop_last=False,
                            num_workers=num_workers,
                            persistent_workers=persistent_workers)



[Load]: 100%|██████████| 4270/4270 [00:00<00:00, 497822.94it/s]


loaded 4270 samples from the DFC20 subset train


[Load]: 100%|██████████| 684/684 [00:00<00:00, 348252.48it/s]

loaded 684 samples from the DFC20 subset val





In [62]:
############ MODEL CHOICE ############
if small:
  model = UNetSmall(n_channels=n_inputs)
else:
  model = UNetBig(n_channels=n_inputs)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = model.to(device)

Using device: cuda


In [63]:
############ TRAIN SETUP ############
# Loss function
if weighted_loss:
  weights = torch.tensor(1 / train_set.freq, dtype=torch.float32)
  weights = weights / weights.sum()
  if ce_loss:
    criterion = nn.CrossEntropyLoss(weight=weights.to(device))
  elif combo_loss:
    criterion = ComboLoss(weight=weights.to(device))
  elif dice_loss:
    criterion = DiceLoss(weight=weights.to(device))

else:
  if ce_loss:
    criterion = nn.CrossEntropyLoss()
  elif combo_loss:
    criterion = ComboLoss()
  elif dice_loss:
    criterion = DiceLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=scheduler_factor, patience=scheduler_patience, verbose=True
)

# mIoU helper
def compute_miou(pred, label, num_classes):
    ious = []
    pred = pred.view(-1)
    label = label.view(-1)
    for cls in range(num_classes):
        pred_inds = (pred == cls)
        label_inds = (label == cls)
        intersection = (pred_inds & label_inds).sum().item()
        union = (pred_inds | label_inds).sum().item()
        if union == 0:
            continue  # skip this class entirely
        ious.append(intersection / union)
    return sum(ious) / len(ious) if ious else 0.0



In [64]:
############ TRAIN ############

# innit logger
if log:
  writer = SummaryWriter(log_dir='logs/param_tuning') # param_tuning, baseline, ...

# Training loop
best_val_miou = 0.0

for epoch in range(num_epochs):
    epoch_start = time.time()
    model.train()
    running_loss = 0.0
    total_accuracy = 0.0
    total_miou = 0.0
    num_batches = 0

    #time_start = time.time()

    for batch in train_loader:
        # unpack sample
        inputs = batch['image'].to(device)
        labels = batch['label'].to(device)
        # reset gradients
        optimizer.zero_grad()
        # forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # backward pass
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # prediction
        predicted_labels = torch.argmax(outputs, dim=1)
        # Accuracy
        accuracy = (predicted_labels == labels).float().mean().item() * 100
        total_accuracy += accuracy
        # mIoU
        miou = compute_miou(predicted_labels, labels, num_classes=outputs.shape[1])
        total_miou += miou

        num_batches += 1

        #batch_time = time.time() - time_start
        #print(f"[Batch {num_batches}] Time: {batch_time:.2f}s")
        #time_start = time.time()

    avg_loss = running_loss / len(train_loader)
    avg_accuracy = total_accuracy / len(train_loader)
    avg_miou = total_miou / len(train_loader)

    # Validation
    model.eval()
    val_loss = 0.0
    val_accuracy = 0.0
    val_miou = 0.0
    with torch.no_grad():
        for batch in val_loader:
            inputs = batch['image'].to(device)
            labels = batch['label'].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            predicted = torch.argmax(outputs, dim=1)
            acc = (predicted == labels).float().mean().item() * 100
            miou = compute_miou(predicted, labels, num_classes=outputs.shape[1])
            val_accuracy += acc
            val_miou += miou

    val_loss /= len(val_loader)
    val_accuracy /= len(val_loader)
    val_miou /= len(val_loader)

    scheduler.step(val_miou)

    # Print epoch summary
    epoch_time = time.time() - epoch_start
    print(f"\n=== Epoch {epoch+1}/{num_epochs} — {epoch_time:.2f}s ===")
    print(f"Train     — Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%, mIoU: {avg_miou:.4f}")
    print(f"Validate  — Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.2f}%, mIoU: {val_miou:.4f}")
    print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")

    # TensorBoard logging
    if log:
      writer.add_scalar('Loss/train', avg_loss, epoch)
      writer.add_scalar('Accuracy/train', avg_accuracy, epoch)
      writer.add_scalar('mIoU/train', avg_miou, epoch)
      writer.add_scalar('Loss/val', val_loss, epoch)
      writer.add_scalar('Accuracy/val', val_accuracy, epoch)
      writer.add_scalar('mIoU/val', val_miou, epoch)
      writer.add_scalar('LearningRate', optimizer.param_groups[0]['lr'], epoch)

    if val_miou > best_val_miou:
        best_val_miou = val_miou
        torch.save(model.state_dict(), 'trained_models/unet_baseline.pth')


if log:
  writer.close()


=== Epoch 1/15 — 203.81s ===
Train     — Loss: 1.1924, Accuracy: 54.65%, mIoU: 0.2982
Validate  — Loss: 1.2830, Accuracy: 48.22%, mIoU: 0.2562
Learning Rate: 5e-05

=== Epoch 2/15 — 203.36s ===
Train     — Loss: 0.9901, Accuracy: 65.67%, mIoU: 0.4010
Validate  — Loss: 1.3586, Accuracy: 47.80%, mIoU: 0.2530
Learning Rate: 5e-05

=== Epoch 3/15 — 203.12s ===
Train     — Loss: 0.9133, Accuracy: 67.93%, mIoU: 0.4271
Validate  — Loss: 1.3354, Accuracy: 46.02%, mIoU: 0.2647
Learning Rate: 5e-05

=== Epoch 4/15 — 202.45s ===
Train     — Loss: 0.8704, Accuracy: 68.48%, mIoU: 0.4362
Validate  — Loss: 1.3010, Accuracy: 47.01%, mIoU: 0.2762
Learning Rate: 5e-05

=== Epoch 5/15 — 202.58s ===
Train     — Loss: 0.8370, Accuracy: 69.69%, mIoU: 0.4481
Validate  — Loss: 1.3378, Accuracy: 44.01%, mIoU: 0.2634
Learning Rate: 2.5e-05

=== Epoch 6/15 — 202.13s ===
Train     — Loss: 0.8018, Accuracy: 70.79%, mIoU: 0.4633
Validate  — Loss: 1.2735, Accuracy: 51.23%, mIoU: 0.2861
Learning Rate: 2.5e-05

=== E

KeyboardInterrupt: 

In [None]:
  ############ TRAIN LOGS ############

%load_ext tensorboard
%tensorboard --logdir=runs

In [8]:
############ TESTING ############

test_set = DFC20(path,
                subset="test",
                use_s1=use_s1,
                use_s2_RGB=use_s2_RGB,
                use_s2_hr=use_s2_hr,
                use_s2_all=use_s2_all,
                as_tensor=as_tensor,
                normalize=normalize,
                standardize=standardize)

test_loader = DataLoader(test_set,
                            batch_size=batch_size,
                            shuffle=False,
                            pin_memory=True,
                            drop_last=False,
                            num_workers=num_workers,
                            persistent_workers=persistent_workers)


model_path = 'trained_models/unet_baseline.pth'

model_test = UNetSmall(n_channels=n_inputs)  # or ...
model_test.load_state_dict(torch.load(model_path))
#model = torch.load(model_path)
#print(model)
model_test.eval()
model_test.to(device)

# Testing loop
test_loss = 0.0
test_accuracy = 0.0
test_miou = 0.0


with torch.no_grad():
    for batch in test_loader:
        inputs = batch['image'].to(device)
        labels = batch['label'].to(device)

        outputs = model_test(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()

        predicted_labels = torch.argmax(outputs, dim=1)
        accuracy = (predicted_labels == labels).float().mean().item() * 100
        test_accuracy += accuracy
        miou = compute_miou(predicted_labels, labels, num_classes=outputs.shape[1])
        test_miou += miou


test_loss /= len(test_loader)
test_accuracy /= len(test_loader)
test_miou /= len(test_loader)

print("\n============ TESTING RESULTS ============")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.2f}%")
print(f"Test mIoU: {test_miou:.4f}")
print("========================================")

[Load]: 100%|██████████| 1160/1160 [00:00<00:00, 259723.09it/s]


loaded 1160 samples from the DFC20 subset test

Test Loss: 1.2210
Test Accuracy: 38.64%
Test mIoU: 0.1681
