In [22]:
import torch
import nbimporter
import numpy as np
import h5py
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch.utils.data import DataLoader

from two_sats import SatelliteDataset, ConvNet
from sentinal_1 import Sentinel1Dataset, Sentinel1ConvNet
from sentinal_2 import Sentinel2Dataset, Sentinel2ConvNet

In [23]:
# Load the .h5 file into memory once

h5_file_path_test = r"C:\Users\nadav.k\Documents\DS\DL_classification\classification_data\testing_10perc_of_20_subset.h5"

# Open the H5 files
h5_test = h5py.File(h5_file_path_test, 'r')
test_sen1_data = h5_test['sen1']
test_sen2_data = h5_test['sen2']
test_labels = h5_test['label']


In [25]:
from torch.utils.data import DataLoader

def get_dataloader(dataset_type, batch_size=32, shuffle=False):
    """
    Returns a DataLoader for the specified dataset type.

    Args:
        dataset_type (str): Type of the dataset ('SatelliteDataset', 'Sentinel1Dataset', 'Sentinel2Dataset').
        batch_size (int): Batch size for the DataLoader.
        shuffle (bool): Whether to shuffle the dataset.

    Returns:
        DataLoader: DataLoader for the specified dataset.
    """
    if dataset_type == "SatelliteDataset":
        dataset = SatelliteDataset(sen1_data=test_sen1_data, sen2_data=test_sen2_data, labels=test_labels)
    elif dataset_type == "Sentinel1Dataset":
        dataset = Sentinel1Dataset(sen1_data=test_sen1_data, labels=test_labels)
    elif dataset_type == "Sentinel2Dataset":
        dataset = Sentinel2Dataset(sen2_data=test_sen2_data, labels=test_labels)
    else:
        raise ValueError(f"Unsupported dataset type: {dataset_type}")

    # Create and return the DataLoader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    print(f"{dataset_type} loaded with {len(dataset)} samples.")
    return dataloader


In [28]:
test_loader = get_dataloader(dataset_type="Sentinel2Dataset", batch_size=32, shuffle=False)

Sentinel2Dataset loaded with 1597 samples.


In [None]:
# General function to load a model
def load_model(model_class, path, num_classes=17, device="cuda"):
    """
    Load a saved model from the specified path.

    Args:
        model_class: The class of the model to instantiate.
        path: Full path to the saved model (e.g., './models/sentinel2_classification_model.pth').
        num_classes: Number of classes for the model.
        device: Device to load the model ('cuda' or 'cpu').

    Returns:
        The loaded model.
    """
    model = model_class(num_classes=num_classes)
    model.load_state_dict(torch.load(path, map_location=device))
    model.to(device)
    model.eval()
    print(f"Model loaded from {path}")
    return model


In [None]:
def analyze_model_performance_general(model, test_loader, num_classes=17, satellite_type="both", device="cuda"):
    """
    General evaluation function for Sentinel-1, Sentinel-2, or both.

    Args:
        model: The trained model to evaluate.
        test_loader: DataLoader for the test dataset.
        num_classes: Number of classes.
        satellite_type: 'sentinel1', 'sentinel2', or 'both'.
        device: Device for computation ('cuda' or 'cpu').
    """
    model.to(device)
    model.eval()

    true_labels = []
    predicted_labels = []

    with torch.no_grad():
        for data, labels in test_loader:
            data, labels = data.to(device), labels.to(device)
            outputs = model(data)

            # Get predicted labels
            preds = torch.argmax(outputs, dim=1)

            # Append to lists
            true_labels.extend(labels.cpu().numpy())
            predicted_labels.extend(preds.cpu().numpy())

    # Generate confusion matrix
    true_labels = np.array(true_labels)
    predicted_labels = np.array(predicted_labels)
    cm = confusion_matrix(true_labels, predicted_labels, labels=range(num_classes))

    # Display confusion matrix
    print(f"Confusion Matrix for {satellite_type.capitalize()}:\n", cm)
    ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=range(num_classes)).plot(cmap=plt.cm.Blues)
    plt.title(f"Confusion Matrix - {satellite_type.capitalize()}")
    plt.show()

    # Correct vs incorrect predictions
    correct_per_label = np.diag(cm)
    total_per_label = np.sum(cm, axis=1)
    incorrect_per_label = total_per_label - correct_per_label

    plt.figure(figsize=(12, 6))
    x = np.arange(len(correct_per_label))
    plt.bar(x - 0.2, correct_per_label, width=0.4, label="Correct", color="g")
    plt.bar(x + 0.2, incorrect_per_label, width=0.4, label="Incorrect", color="r")
    plt.xticks(ticks=x, labels=range(num_classes))
    plt.title(f"Correct vs Incorrect Predictions - {satellite_type.capitalize()}")
    plt.xlabel("Labels")
    plt.ylabel("Count")
    plt.legend()
    plt.grid(axis="y")
    plt.show()
