In [None]:
import os

import hydra
import torch
from torch import nn, Tensor
from torch.nn.functional import softmax

from skin_disease_recognition.config import MODEL_DIR, PROJECT_ROOT
from skin_disease_recognition.data.loaders import make_loaders

from sklearn.metrics import classification_report, confusion_matrix

In [None]:
from hydra.core.global_hydra import GlobalHydra

GlobalHydra.instance().clear()
hydra.initialize(version_base=None, config_path='../conf')
cfg = hydra.compose(config_name='config.yaml')

In [None]:
device = cfg.device

In [None]:
model: nn.Module = torch.load(
    MODEL_DIR / 'efficientnetb0.pth',
    weights_only=False,
    map_location=torch.device(device),
).to(device)

In [None]:
os.chdir(PROJECT_ROOT)

In [None]:
train_loader, test_loader = make_loaders(cfg)

In [None]:
y_preds = []
y_trues = []

In [None]:
miss = []
miss_probs = []

In [None]:
i = 0
model.eval()
with torch.no_grad():
    for data in test_loader:
        print(f'Batch {i}')
        i += 1
        images: Tensor
        labels: Tensor
        images, labels = data

        images = images.to(device)
        labels = labels.to(device)

        pred: Tensor = softmax(model(images), 1)
        pred_label = torch.argmax(pred, dim=1)

        for j, (p, l) in enumerate(zip(pred_label, labels)):
            if p!=l:
                miss_probs.append(pred[j][p].item())
            if p != l and pred[j][p] > 0.99:
                entry = (images[j], p, l, pred[j][p])
                miss.append(entry)

        y_preds.extend(pred_label.cpu())
        y_trues.extend(labels.cpu())

In [None]:
classes = test_loader.dataset.classes

### classification report

In [None]:
print(classification_report(y_true=y_trues, y_pred=y_preds, target_names=classes))

### confusion matrix

In [None]:
conf = confusion_matrix(y_trues, y_preds)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
plt.figure(figsize=(10, 8))
sns.heatmap(conf, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
plt.title("Confusion Matrix on Test Set")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

In [None]:
def plot_missclassification(case, cls_names):
    img: Tensor = case[0]
    img = img.permute(1,2,0)
    img = img * 0.5 + 0.5
    plt.imshow(img.cpu())
    pred_class = cls_names[case[1]]
    true_class = cls_names[case[2]]
    prob = case[3]
    plt.title(f'Real: {true_class} / Predicted: {pred_class} / Prob: {prob}')

In [None]:
plot_missclassification(miss[1],classes)

In [None]:
sns.histplot(miss_probs, bins=100)