In [1]:
import kagglehub, os

# Download latest version
path = kagglehub.dataset_download("apollo2506/eurosat-dataset")

print("Path to dataset files:", path)

Path to dataset files: /root/.cache/kagglehub/datasets/apollo2506/eurosat-dataset/versions/6


In [2]:
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import torch
from PIL import Image
import os
import numpy as np
import torchvision.transforms as transforms
from torch import nn

from src.datasets import EuroSATDataset

In [3]:
from src.models.backbones.resnet import ResNet
from src.models.heads import FFN

class EuroSATModel(nn.Module):
    def __init__(self, backbone, head):
        super(EuroSATModel, self).__init__()
        self.backbone = backbone
        self.head = head
        self.pool = nn.AdaptiveAvgPool2d((1, 1))  # Global average pooling
        self.flatten = nn.Flatten()
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.backbone(x)
        x = self.pool(x)
        x = self.flatten(x)
        x = self.head(x)
        return x
    
    def loss(self, input, target):
        logits = self.forward(input)
        loss = self.criterion(logits, target)
        return loss
    
    def forward_train(self, x, target):
        img = x['image']
        logits = self.forward(img)
        loss = self.criterion(logits, target)
        return logits, loss
    
    def forward_test(self, x):
        img = x['image']
        logits = self.forward(img)
        return logits
    
    def predict(self, x):
        logits = self.forward_test(x)
        return torch.argmax(logits, dim=1)
    
model = EuroSATModel(
    backbone=ResNet(idims=3, odims=64, arch=(2, 2, 2, 2), base_dims=32), 
    head=FFN(idims=64, odims=10, hidden_dims=64, dropout=0.5, nlayers=6))

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, average_precision_score, confusion_matrix
from sklearn.preprocessing import label_binarize
from src.runner.utils import progress_bar

class Evaluator:
    def __init__(self, model, device, num_classes, class_names=None, save_dir = None):
        """
        Args:
            model (torch.nn.Module): Trained model for inference.
            device (str or torch.device): 'cpu' or 'cuda'.
            num_classes (int): Number of target classes.
            class_names (list[str], optional): Names of classes for plots.
        """
        self.model = model.to(device)
        self.device = torch.device(device)
        self.num_classes = num_classes
        self.class_names = class_names or [str(i) for i in range(num_classes)]
        self.save_dir = save_dir

    def predict(self, dataloader):
        """
        Runs model inference on a DataLoader.
        Returns:
            y_true (np.ndarray): True labels shape (N,).
            y_pred (np.ndarray): Predicted labels shape (N,).
            y_scores (np.ndarray): Predicted probabilities shape (N, num_classes).
        """
        self.model.eval()
        y_true, y_pred, y_scores = [], [], []
        with torch.no_grad():
            for i, batch in enumerate(dataloader):
                imgs = batch['image'].to(self.device)
                labels = batch['label'].to(self.device)
                logits = self.model(imgs)
                probs = torch.softmax(logits, dim=1)
                preds = probs.argmax(dim=1)

                progress_bar(
                    iteration=i + 1,
                    total_iterations=len(dataloader),
                    prefix="Evaluating",
                    postfix=f" {(i+1/len(dataloader)):.2f}%",
                    style="arrow"
                )

                y_true.extend(labels.cpu().numpy())
                y_pred.extend(preds.cpu().numpy())
                y_scores.extend(probs.cpu().numpy())

        return np.array(y_true), np.array(y_pred), np.array(y_scores)

    def compute_f1(self, y_true, y_pred, average='macro'):
        """Compute F1 score."""
        return f1_score(y_true, y_pred, average=average)

    def compute_map(self, y_true, y_scores):
        """
        Compute mean Average Precision (mAP) for multiclass.
        Returns:
            mean_ap (float): Mean of per-class AP.
            ap_per_class (np.ndarray): AP for each class.
        """
        # Binarize labels for one-vs-rest
        y_true_bin = label_binarize(y_true, classes=list(range(self.num_classes)))
        ap_per_class = average_precision_score(y_true_bin, y_scores, average=None)
        mean_ap = np.mean(ap_per_class)
        return mean_ap, ap_per_class

    def plot_confusion_matrix(self, y_true, y_pred, normalize=True):
        """
        Plots the confusion matrix.
        Args:
            y_true (array): True labels.
            y_pred (array): Predicted labels.
            normalize (bool): Whether to normalize by row sums.
        Returns:
            fig: Matplotlib figure object.
        """
        cm = confusion_matrix(y_true, y_pred)
        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

        fig, ax = plt.subplots()
        im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
        ax.figure.colorbar(im, ax=ax)

        # Setup labels
        ax.set(
            xticks=np.arange(self.num_classes),
            yticks=np.arange(self.num_classes),
            xticklabels=self.class_names,
            yticklabels=self.class_names,
            ylabel='True label',
            xlabel='Predicted label',
            title='Confusion Matrix'
        )
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right', rotation_mode='anchor')

        # Annotate cells
        fmt = '.2f' if normalize else 'd'
        thresh = cm.max() / 2.
        for i in range(self.num_classes):
            for j in range(self.num_classes):
                ax.text(j, i, format(cm[i, j], fmt),
                        ha='center', va='center',
                        color='white' if cm[i, j] > thresh else 'black')

        fig.tight_layout()
        return fig


In [5]:
from src.runner.utils import progress_bar

class Runner:
    def __init__(self, model: nn.Module,
                 loading_cfg: dict,
                 data_cfg: dict,
                 optim_cfg: dict,
                 save_dir: str = None,
                 device: str = 'cpu'):
        self.device = torch.device(device)
        self.model  = model.to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), **optim_cfg)
        self.criterion = nn.CrossEntropyLoss()

        self.train_data = EuroSATDataset(**data_cfg, split='train')
        self.val_data   = EuroSATDataset(**data_cfg, split='validation')
        self.test_data  = EuroSATDataset(**data_cfg, split='test')

        self.batch_size = loading_cfg['batch_size']

        self.history = {'train_loss': [], 'val_loss': []}

    def run(self,
            mode: str = 'train',
            val_interval: int= 10,
            log_interval: int= 10,
            epochs: int = 100,
            start_epoch: int = 1):
        if mode == 'train':
            return self._train_loop(start_epoch, epochs, val_interval, log_interval)

        elif mode == 'validation':
            return self.evaluate(self.val_data, batch_size=1)

        elif mode == 'test':
            return self.evaluate(self.test_data, batch_size=self.batch_size)

        else:
            raise ValueError("Mode must be 'train', 'validation', or 'test'.")
        
        

    def _train_loop(self, start, epochs, val_interval, log_interval):
        train_loader = DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)


        for epoch in range(start, epochs):
            train_loss = self._train_epoch(train_loader, epoch=epoch, total_epochs=epochs, log_interval=log_interval)
            self.history['train_loss'].append(train_loss)

            if epoch % val_interval == 0:
                val_loss = self.evaluate(self.val_data, batch_size=self.batch_size, loss=True)
                self.history['val_loss'].append(val_loss)
                

        return self.history

    def _train_epoch(self, loader, epoch, total_epochs, log_interval=10):
        self.model.train()
        total_loss = 0.
        total_batches = len(loader)

        for i, batch in enumerate(loader):
            imgs = batch['image'].to(self.device)
            labels = batch['label'].to(self.device)

            self.optimizer.zero_grad()
            logits = self.model(imgs)
            loss = self.criterion(logits, labels)
            loss.backward()
            self.optimizer.step()


            if i % log_interval == 0:
                progress_bar(
                    epoch=epoch,
                    total_epochs=total_epochs,
                    iteration=i+1,
                    total_iterations=total_batches,
                    vars={
                        'loss': loss.item(),
                        'lr': self.optimizer.param_groups[0]['lr'],
                    },)                

            total_loss += loss.item()

        return total_loss / len(loader)

    def evaluate(self, dataset, batch_size, loss=False):
        evaluator = Evaluator(self.model, self.device, num_classes=dataset.num_classes, class_names=dataset.class_names)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        y_true, y_pred, y_scores = evaluator.predict(dataloader)
        if loss:
            val_loss = self.criterion(torch.tensor(y_scores), torch.tensor(y_true)).item()
        else:
            val_loss = None
        f1 = evaluator.compute_f1(y_true, y_pred)
        mean_ap, ap_per_class = evaluator.compute_map(y_true, y_scores)

        if evaluator.save_dir:
            fig = evaluator.plot_confusion_matrix(y_true, y_pred)
            fig.savefig(os.path.join(evaluator.save_dir, 'confusion_matrix.png'))
            plt.close(fig)
        
        print(f"F1 Score: {f1:.4f}, Mean Average Precision (mAP): {mean_ap:.4f}") 

        return val_loss

In [None]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

loading_cfg = {
    'batch_size': 1024,
    'num_workers': 4,
}

data_cfg = {
    'root_dir': os.path.join(path, 'EuroSAT'),
    'transform': transform,
}

optim_cfg = {
    'lr': 0.001,
    'weight_decay': 1e-4,
}

runner = Runner(model=model, loading_cfg=loading_cfg, data_cfg=data_cfg, optim_cfg=optim_cfg, device='cuda:2', save_dir='results/eurosat')

In [None]:
runner.run(mode='train', val_interval=1, log_interval=1, epochs=10, start_epoch=1)

 | Epoch 01/10 | Iter 19/19 | [██████████] |  | loss: 1.8676 | lr: 1.0000e-03
Epoch [2/10], Train Loss: 2.1220

Evaluating | Iter 6/6 | [>>>>>>>>>>] |  5.166666666666667:.2f%%%
F1 Score: 0.0426, Mean Average Precision (mAP): 0.1892
 | Epoch 02/10 | Iter 1/19 | [----------] |  | loss: 1.8404 | lr: 1.0000e-03