In [None]:
import torch
import torchvision.models as models
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

In [None]:


def load_best_model(model_path, num_classes=4, input_channels=7, dropout_rate=0.1454303215712593):
    """
    Load the best trained model from saved weights.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    model.conv1 = torch.nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
    model.fc = torch.nn.Sequential(
        torch.nn.Dropout(p=dropout_rate),
        torch.nn.Linear(model.fc.in_features, num_classes)
    )

    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()  

    return model, device

def plot_confusion_matrix(model, loader, device, class_names):
    """
    Generates and displays a confusion matrix for the model on the given loader.
    """
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())

    # Compute confusion matrix
    conf_matrix = confusion_matrix(all_labels, all_preds)
    disp = ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=class_names)
    
    plt.figure(figsize=(8, 6))
    disp.plot(cmap=plt.cm.Blues)
    plt.title("Confusion Matrix")
    plt.show()

model_path = "resnet_model_optuna.pth"
model, device = load_best_model(model_path)

test_dataset_path = 'H:/Datasets/int_split/Testing/'
test_dataset = PTDataset(root_dir=test_dataset_path, target_size=(500, 500))

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=32, shuffle=False, num_workers=0, pin_memory=(device.type == 'cuda')
)

class_names = ['Und.', 'Low', 'Med', 'High']  

plot_confusion_matrix(model, test_loader, device, class_names)