In [None]:
import os
from dataset import EuroSat
from datasets import load_dataset
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torcheval.metrics.functional.classification import multiclass_recall
from torcheval.metrics.functional import multiclass_precision, multiclass_confusion_matrix, multiclass_f1_score
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
# Device initiation 
device = ("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

In [None]:
# Transformations
transforms = transforms.Compose([
    transforms.Resize(232),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.3445, 0.3803, 0.4077], [0.0915, 0.0652, 0.0553])
])

# Parameters
batch_size = 64

# Labels
labels = [
    "Forest",
    "River",
    "Highway",
    "AnnualCrop",
    "SeaLake",
    "HerbaceousVegetation",
    "Industrial",
    "Residential",
    "PermanentCrop",
    "Pasture"
  ]

In [None]:
# Data (test set)
test_data = EuroSat(load_dataset("cm93/eurosat", split='test'), transform=transforms)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
# Models paths after fine tuning
model_paths = [
    "models/resnet_18_ft_20.pth",
    "models/resnet_50_ft_17.pth",
    "models/vit_base_ft_8.pth"
]
model_r18 = timm.create_model('resnet18', pretrained=False, num_classes=10)
model_r50 = timm.create_model('resnet50', pretrained=False, num_classes=10)
model_vit = timm.create_model("hf_hub:timm/vit_base_patch16_224.augreg2_in21k_ft_in1k", pretrained=False, num_classes=10)
models = [model_r18, model_r50, model_vit]

In [None]:
# Evaliating the models on test dataset
for index in range(len(model_paths)):
    print(f"------{index}------")
    model = models[index]
    model.to(device)
    checkpoint = torch.load(model_paths[index])
    model.load_state_dict(checkpoint['model_state_dict'])

    model.eval()
    criterion = nn.CrossEntropyLoss()
    size = len(test_dataloader.dataset)
    running_loss = 0.0
    correct = 0
    pred_all = torch.tensor([]).to(device)
    targ_all = torch.tensor([]).to(device)

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_dataloader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            output = output.argmax(1)
            pred_all = torch.cat((pred_all, output))
            targ_all = torch.cat((targ_all, target))
            running_loss += loss.item()
            correct += (output == target).sum().item()

    acc = (100*correct)/size
    avg_loss = running_loss/(batch_idx+1)
    print(f"Test: Avg loss: {avg_loss:>8f}, Accuracy: {(acc):>0.2f}%")

    i = pred_all.to(torch.int64)
    t = targ_all.to(torch.int64)

    #----------PerClass---------------
    precision_cls = multiclass_precision(i, t, average=None, num_classes=10)
    recall_cls = multiclass_recall(i, t, average=None, num_classes=10)
    f1_cls = multiclass_recall(i, t, average=None, num_classes=10)

    print(f"precision_cls : {precision_cls}")
    print(f"recall_cls : {recall_cls}")
    print(f"f1_cls : {f1_cls}")

    #----------Global-----------------
    precision_g = multiclass_precision(i, t, average='weighted', num_classes=10)
    recall_g = multiclass_recall(i, t, average='weighted', num_classes=10)
    f1_g = multiclass_f1_score(i, t, average='weighted', num_classes=10)

    print(f"precision_g : {precision_g}")
    print(f"recall_g : {recall_g}")
    print(f"f1_g : {f1_g}")

    cf_matrix = multiclass_confusion_matrix(i, t, 10)
    cf_matrix = cf_matrix.cpu().numpy().astype(int)

    # Create the heatmap
    plt.figure(figsize=(12, 12))  # Adjust the size here
    sns.heatmap(
        cf_matrix,
        annot=True,
        fmt="d",
        linewidth=.5,
        cmap="crest",
        annot_kws={"size": 12},
        xticklabels=labels,
        yticklabels=labels
    )
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.show()

    torch.cuda.empty_cache()