##Imports

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!pip install torchmetrics --quiet
!pip install rasterio --quiet
import torchmetrics
import rasterio

import os
os.chdir('/content/drive/MyDrive/dataset')

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torchvision.transforms.functional as TF
from torch import randint
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.optim as optim

from sklearn.metrics import average_precision_score
from sklearn.preprocessing import label_binarize

import albumentations as A
from albumentations.pytorch import ToTensorV2

#Dataset

In [None]:
class DynamicEarthNetDataset(Dataset):
    def __init__(self, file_name, transform=None):
        self.transform = transform
        self.images, self.masks=get_paths(file_name)
        self.num_classes=7

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = self.images[index]
        mask_path = self.masks[index]

        image= rasterio.open(img_path).read()
        image=image.astype(np.float32)

        label= rasterio.open(mask_path).read()
        mask = np.zeros((label.shape[1], label.shape[2]), dtype=np.int64)

        for i in range(self.num_classes):
          if i == 6:
                mask[label[i, :, :] == 255] = -1
          elif i == 3:
              mask[label[i, :, :] == 255] = -1
          elif i > 3:
              mask[label[i, :, :] == 255] = i - 1
          else:
              mask[label[i, :, :] == 255] = i

        return image, mask

#Loaders

In [None]:
def get_loaders(
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
    train_ds = DynamicEarthNetDataset(
        file_name='train.txt',
        transform=train_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = DynamicEarthNetDataset(
        file_name='val.txt',
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    test_ds=DynamicEarthNetDataset(
        file_name='test.txt',
        transform=val_transform,
    )

    test_loader= DataLoader(
        test_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader, test_loader

#Model

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # down part
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # up part
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape: #se l'input non è divisibile per 16
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

def test():
    x = torch.randn((3, 1, 161, 161))
    model = UNET(in_channels=1, out_channels=1)
    preds = model(x)
    assert preds.shape == x.shape

if __name__ == "__main__":
    test()

#Utils

In [None]:
def get_paths(file_name):
    images = []
    masks = []
    with open(file_name, 'r') as file:
        righe = file.readlines()

        for riga in righe:
            image, label = riga.split()
            images.append(image)
            masks.append(label)

        return images, masks


def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])


def multispectral_to_rgb_visualization(img, lower_percentile=5, upper_percentile=95):
    if img.ndim == 2:
        img = img[:, :, np.newaxis]
    img = img.transpose(1,2,0)
    img = img[:, :, [2, 1, 0]]
    img = np.clip(img, np.percentile(img, lower_percentile), np.percentile(img, upper_percentile))
    img = (img - np.min(img)) / (np.max(img) - np.min(img))
    img = (img * 255).astype(np.uint8)

    return img


def prediction_to_image(predicted_mask, num_classes=5):

    if isinstance(predicted_mask, torch.Tensor):
        predicted_mask = predicted_mask.cpu().numpy()

    if predicted_mask.ndim == 3 and predicted_mask.shape[0] == 1:
        predicted_mask = predicted_mask[0]

    height, width = predicted_mask.shape
    label_format = np.zeros((7, height, width), dtype=np.float32)

    class_mapping = {
        0: 0,
        1: 1,
        2: 2,
        3: 4,
        4: 5
        #3 and 6 ignored
    }

    for pred_class, original_channel in class_mapping.items():
        mask = predicted_mask == pred_class
        label_format[original_channel][mask] = 255

    return label_format

def predict_and_visualize(model, image_index=10):

    image = rasterio.open(get_paths('train.txt')[0][image_index]).read()
    image = image.astype(np.float32)
    label = rasterio.open(get_paths('train.txt')[1][image_index]).read()
    ground_truth_mask = label.astype(np.float32)

    model.eval()
    with torch.no_grad():
        input_image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)
        prediction = model(input_image)
        predicted_mask_raw = torch.argmax(prediction, dim=1).cpu().numpy()

        predicted_mask = prediction_to_image(predicted_mask_raw)

    predicted_mask_rgb = multispectral_to_rgb_visualization(predicted_mask)
    ground_truth_mask_rgb = multispectral_to_rgb_visualization(ground_truth_mask)
    image_rgb = multispectral_to_rgb_visualization(image)


    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(image_rgb)
    plt.title('Original Image')

    plt.subplot(1, 3, 2)
    plt.imshow(predicted_mask_rgb) #_rgb
    plt.title('Predicted Mask')

    plt.subplot(1, 3, 3)
    plt.imshow(ground_truth_mask_rgb)
    plt.title('Ground Truth Mask')

    plt.show()

#Metrics

In [None]:
def mean_average_precision(y_true, y_scores, ap_scores, n_classes=5):

  #one hot encoding
  y_true_bin = label_binarize(y_true, classes=range(n_classes))

  for i in range(n_classes):
        if np.sum(y_true_bin[:, i]) > 0:
          ap = average_precision_score(y_true_bin[:, i], y_scores[:, i])
          ap_scores.append(ap)

  mean_ap = np.mean(ap_scores) if ap_scores else 0.0
  return mean_ap, ap_scores


def calculate_class_weights(loader):
    pixel_counts = torch.zeros(5, dtype=torch.float32)

    print("Calculating weights...")
    for _, (_, targets) in enumerate(tqdm(loader)):
        targets = targets.view(-1)

        # ignore -1 index
        mask = targets >= 0
        valid_targets = targets[mask]

        for class_idx in range(5):
            pixel_counts[class_idx] += (valid_targets == class_idx).sum().item()

    total_pixels = pixel_counts.sum()
    frequencies = pixel_counts / total_pixels
    print(f"Frequencies: {frequencies}")

    #weights = 1.0 / (frequencies + 1e-8)
    weights = 1.0 / torch.sqrt(frequencies + 1e-8)
    #weights = torch.log(1.0 / frequencies + 1 + 1e-8)

    weights = weights / weights.sum() * len(weights)
    print(f"Calculated weights: {weights}")

    return weights.to(DEVICE)

#Calculate metrics

In [None]:
def calc_metrics(loader, model, loss_fn, device="cuda"):
    num_correct = 0
    num_pixels = 0
    num_batches = 0
    total_loss=0
    meanAP=[]
    ap_scores=[]
    accuracy_score=[]
    loss_score=[]

    model.eval()


    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            preds = model(x)

            #loss
            loss=loss_fn(preds, y)
            total_loss +=loss.item()
            num_batches +=1


            #softmax and argmax
            probabilities = torch.softmax(preds, dim=1)
            preds_argmax=torch.argmax(probabilities, dim=1)

            num_correct += (preds_argmax == y).sum()
            num_pixels += torch.numel(preds_argmax)

            #average precision
            probs_cpu = probabilities.cpu().numpy()
            y_cpu = y.cpu().numpy()

            batch_size = probs_cpu.shape[0]
            for b in range(batch_size):
                probs_img = probs_cpu[b]
                y_img = y_cpu[b]

                probs_flat = probs_img.reshape(5, -1).transpose()
                y_flat = y_img.flatten()

                batch_ap, _ = mean_average_precision(y_flat, probs_flat, ap_scores=ap_scores, n_classes=5)
                ap_scores.append(batch_ap)


    avg_loss = total_loss / num_batches
    epoch_mean_ap=np.mean(ap_scores)
    meanAP.append(epoch_mean_ap)
    accuracy=num_correct/num_pixels*100

    print(f"Mean AP: {epoch_mean_ap}")
    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    model.train()
    return avg_loss, epoch_mean_ap, accuracy

#Train

In [None]:
LEARNING_RATE = 1e-5
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8
NUM_EPOCHS = 31
NUM_WORKERS = 4
PIN_MEMORY = False
LOAD_MODEL = False
tot_train_loss=[]
tot_eval_loss=[]
tot_train_mAP=[]
tot_eval_mAP=[]
train_loss=[]
eval_loss=[]
train_mAP=[]
train_accuracy=[]
tot_train_accuracy=[]
tot_eval_accuracy=[]

global model

def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)
    model.train()
    total_loss=0
    num_batches=0

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.long().to(device=DEVICE)

        #forward
        with torch.amp.autocast('cuda'):
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        #backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loop.set_postfix(loss=loss.item())

        total_loss += loss.item()
        num_batches += 1

    avg_train_loss = total_loss / num_batches
    tot_train_loss.append(avg_train_loss)


def main():
    train_transform = A.Compose(
        [
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    val_transforms = A.Compose(
        [
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    global model
    model= UNET(in_channels=4, out_channels=5).to(DEVICE)


    class_weights=calculate_class_weights(get_loaders(
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
    )[0])

    loss_fn = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1)

    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader, test_loader = get_loaders(
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
    )

    if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)


    calc_metrics(val_loader, model, loss_fn, device=DEVICE)
    scaler = torch.amp.GradScaler('cuda')

    for epoch in range(NUM_EPOCHS):
        print(f'epoca {epoch}')
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        # save model
        checkpoint = {
            'epoch': epoch,
            "state_dict": model.state_dict(),
            "optimizer":optimizer.state_dict(),
        }
        torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pth')
        print(f'checkpoint saved at epoch: {epoch+1}')

        #show images
        predict_and_visualize(model, image_index=10)
        predict_and_visualize(model, image_index=30)
        predict_and_visualize(model, image_index=55)
        predict_and_visualize(model, image_index=100)

        #trainloss every 5 epoch
        if epoch % 5 == 0 and epoch != 0:
          print('Metrics on train_loader: ')
          train_loss, train_mAP, train_accuracy= calc_metrics(train_loader, model, loss_fn, device=DEVICE)
        print()
        print('Metrics on val_loader: ')
        eval_loss, eval_mAP, eval_accuracy= calc_metrics(val_loader, model, loss_fn, device=DEVICE)


        tot_eval_loss.append(eval_loss)
        tot_eval_mAP.append(eval_mAP)

        if epoch % 5 == 0 and epoch != 0:
          tot_train_loss.append(train_loss)
          tot_train_mAP.append(train_mAP)
          tot_train_accuracy.append(train_accuracy)

        tot_eval_accuracy.append(eval_accuracy)

        print()

    #test to test_set
    print('test on test_set: ')
    calc_metrics(test_loader, model, loss_fn, device=DEVICE)


if __name__ == "__main__":
    main()

#Visualize results

##Visualize training loss

In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(tot_train_loss)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

##Visualize validation loss

In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(tot_eval_loss)
plt.title('Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

##Visualize train accuracy

In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)

# Crea l'asse x che va di 5 in 5, partendo da 5
x_values = range(5, len(tot_train_mAP) * 5 + 5, 5)  # 5, 10, 15, 20, 25, 30...
y_values = [acc.cpu().numpy() if isinstance(acc, torch.Tensor) else acc for acc in tot_train_accuracy]

plt.plot(x_values, y_values)
plt.title('Training Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.xticks(range(5, 31, 5))  # Mostra tick da 5 a 30, ogni 5

##Visualize val mAP

In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(tot_eval_mAP)
plt.title('Validation mean AP')
plt.xlabel('Epoch')
plt.ylabel('mAP')

##Visualize train mAP

In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)

# Crea l'asse x che va di 5 in 5, partendo da 5
x_values = range(5, len(tot_train_mAP) * 5 + 5, 5)  # 5, 10, 15, 20, 25, 30...
y_values = [acc.cpu().numpy() if isinstance(acc, torch.Tensor) else acc for acc in tot_train_mAP]

plt.plot(x_values, y_values)
plt.title('Training mAP')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.xticks(range(5, 31, 5))  # Mostra tick da 5 a 30, ogni 5

##Visualize Validation Accuracy

In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot([acc.cpu().numpy() if isinstance(acc, torch.Tensor) else acc for acc in tot_eval_accuracy])
plt.title('Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')