In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!cp /content/drive/MyDrive/RemoteSensingLandTypeClassification/model/dataset.py .

In [None]:
!mkdir data
!mkdir dataset
!unzip /content/drive/MyDrive/RemoteSensingLandTypeClassification/model/dataset/dataset.zip -d ./dataset/

In [4]:
import torch, torchvision
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import os, time

In [5]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

print(device)
if torch.cuda.is_available():
    print(torch.cuda.get_device_name())

cuda:0
NVIDIA A100-SXM4-40GB


In [6]:
from dataset import LandCoverNetDataset, CLASSES

NUM_WORKERS = 4
BATCH_SIZE = 32
PATH = "/content/drive/MyDrive/RemoteSensingLandTypeClassification/model/data"

np.set_printoptions(suppress=True, precision=6)

In [8]:
def compute_iou(pred, target, num_classes):
    ious = []
    for cls in range(num_classes):
        pred_cls = pred == cls
        target_cls = target == cls
        intersection = np.logical_and(pred_cls, target_cls).sum().astype(np.float32)
        union = np.logical_or(pred_cls, target_cls).sum().astype(np.float32)
        if union == 0:
            ious.append(0.0)
        else:
            ious.append(float(intersection / union))
    return ious

In [None]:
def calculate_dice_coefficient(pred_mask, true_mask, num_classes):
    dice_scores = []
    for i in range(num_classes):
        pred_class = (pred_mask == i).astype(np.float32)
        true_class = (true_mask == i).astype(np.float32)
        intersection = np.sum(pred_class * true_class)
        union = np.sum(pred_class) + np.sum(true_class)
        if union == 0.0:
            dice_scores.append(0.0)
        else:
            dice_scores.append(2.0 * intersection / union)
    return dice_scores

In [9]:
def seconds_to_time(seconds):
    s = int(seconds) % 60
    m = int(seconds) // 60
    if m < 1:
        return f'{s}s'
    h = m // 60
    m = m % 60
    if h < 1:
        return f'{m}m{s}s'
    return f'{h}h{m}m{s}s'

In [11]:
def train(model, dataloader, epochs):
    # Define loss function and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Metrics storage
    iou_history = [[] for _ in range(8)]
    loss_history = []

    start_time = time.time()

    # Training loop
    model.to(device)
    for epoch in range(epochs):
        epoch_start_time = time.time()

        model.train()
        running_loss = 0.0
        total_ious = []
        for i, (images, labels) in enumerate(dataloader):
            if i % 100 == 0:
                print(f"Processing batch {i}/{len(dataloader)}")

            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(images)['out']
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Calculate IoU for each batch
            _, preds = torch.max(outputs, dim=1)
            preds = preds.cpu().numpy()
            targets = labels.cpu().numpy()
            for pred, target in zip(preds, targets):
                iou = compute_iou(pred, target, len(CLASSES))
                total_ious.append(iou)

        epoch_loss = running_loss / len(dataloader)
        loss_history.append(epoch_loss)
        mean_iou = np.nanmean(total_ious, axis=0)
        for i, iou in enumerate(mean_iou):
          iou_history[i].append(iou)
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.4f}, Mean IoU: {mean_iou}")
        print(f"Epoch time {seconds_to_time(time.time() - epoch_start_time)}, total time {seconds_to_time(time.time() - start_time)}")

    torch.cuda.empty_cache()
    return (iou_history, loss_history)

In [12]:
def plot_graph(iou_history, loss_history, name, path):
    # List of classes
    classes = CLASSES[1:] # Class 0 doesn't actually appear in the dataset

    # Custom colors for each class
    colors = ['#0000ff', '#888888', '#d1a46d', '#e5e5ef', '#d64c2b', '#186818', '#00ff00', 'purple']

    # Plotting
    plt.figure(figsize=(12, 6))

    # Creating an array of indices for each class
    x = range(1, len(iou_history[0])+1)

    # Plotting lines for each class
    for i in range(len(iou_history)):
        plt.plot(x, iou_history[i], color=colors[i], label=classes[i])

    plt.plot(x, loss_history, color=colors[-1], label="Loss")

    plt.xlabel('Epochs')
    plt.title(name)
    plt.xticks(x)
    plt.ylim(0, 1)  # Setting y-axis limit from 0 to 1
    plt.legend(loc='upper left')
    plt.grid(axis='y', linestyle='--', alpha=0.7)

    plt.savefig(path, bbox_inches='tight')
    #plt.show()

In [None]:
def plot_dice(model, dataloader, name, path):
    model.to(device)
    model.eval()

    class_colors = ("#000000", "#0000ff", "#888888", "#d1a46d", "#f5f5ff", "#d64c2b", "#186818", "#00ff00")
    dice_accum = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
    batches = 0

    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.numpy()

            pred = model(images)["out"].cpu().detach()
            pred = pred.argmax(dim=1)
            pred = pred.long().squeeze().numpy()

            batches += 1
            dice_accum = [sum(x) for x in zip(dice_accum, calculate_dice_coefficient(pred, masks, len(CLASSES)))]

    for i in range(len(dice_accum)):
        dice_accum[i] = dice_accum[i] / batches

    print("Dice scores: ")
    for i, dice_score in enumerate(dice_accum, 1):
        print(f"  {CLASSES[i]}: {dice_score}")

    plt.figure(figsize=(10, 6))
    plt.bar(CLASSES[1:], dice_accum[1:], color=class_colors[1:])

    plt.title(name)
    plt.xlabel("Classes")
    plt.ylabel("Accuracy")

    plt.ylim(0, 1)

    plt.xticks(rotation=45)
    plt.savefig(path, bbox_inches='tight')
    #plt.show()

In [None]:
dataset = LandCoverNetDataset("dataset", "TrueColor", True)
loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

model = torchvision.models.segmentation.deeplabv3_resnet50(num_classes=len(CLASSES))

iou_history, loss_history = train(model, loader, 50)
torch.save(model.state_dict(), os.path.join(PATH, "model_true_color.pt"))

plot_graph(iou_history[1:], loss_history, "True Color training history", os.path.join(PATH, "graph_true_color.png"))
plot_dice(model, loader, "True Color DICE Scores", os.path.join(PATH, "graph_true_color_dice.png"))


In [None]:
dataset = LandCoverNetDataset("dataset", "FalseColor", True)
loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

model = torchvision.models.segmentation.deeplabv3_resnet50(num_classes=len(CLASSES))

iou_history, loss_history = train(model, loader, 50)
torch.save(model.state_dict(), os.path.join(PATH, "model_false_color.pt"))

plot_graph(iou_history[1:], loss_history, "False Color training history", os.path.join(PATH, "graph_false_color.png"))
plot_dice(model, loader, "False Color DICE Scores", os.path.join(PATH, "graph_false_color_dice.png"))


In [None]:
dataset = LandCoverNetDataset("dataset", "SWIR", True)
loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

model = torchvision.models.segmentation.deeplabv3_resnet50(num_classes=len(CLASSES))

iou_history, loss_history = train(model, loader, 50)
torch.save(model.state_dict(), os.path.join(PATH, "model_swir.pt"))

plot_graph(iou_history[1:], loss_history, "SWIR training history", os.path.join(PATH, "graph_swir.png"))
plot_dice(model, loader, "SWIR DICE Scores", os.path.join(PATH, "graph_swir_dice.png"))


In [None]:
dataset = LandCoverNetDataset("dataset", "NDVI", True)
loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

model = torchvision.models.segmentation.deeplabv3_resnet50(num_classes=len(CLASSES))

iou_history, loss_history = train(model, loader, 50)
torch.save(model.state_dict(), os.path.join(PATH, "model_ndvi.pt"))

plot_graph(iou_history[1:], loss_history, "NDVI training history", os.path.join(PATH, "graph_ndvi.png"))
plot_dice(model, loader, "NDVI DICE Scores", os.path.join(PATH, "graph_ndvi_dice.png"))
