In [None]:
import os

import hydra
from sklearn.metrics import classification_report, confusion_matrix
import torch
from torch import Tensor, nn
from torch.nn.functional import softmax

from skin_disease_recognition.core.config import MODEL_DIR, PROJECT_ROOT
from skin_disease_recognition.data.factory import make_loaders

In [None]:
from hydra.core.global_hydra import GlobalHydra

GlobalHydra.instance().clear()
hydra.initialize(version_base=None, config_path='../conf')
cfg = hydra.compose(config_name='config.yaml')

In [None]:
device = cfg.device

In [None]:
model: nn.Module = torch.load(
    MODEL_DIR / 'efficientnetb0_weight_decay.pth',
    weights_only=False,
    map_location=torch.device(device),
).to(device)

In [None]:
os.chdir(PROJECT_ROOT)

In [None]:
train_loader, test_loader = make_loaders(cfg)

In [None]:
y_preds = []
y_trues = []

In [None]:
high_miss = []
low_miss = []

miss_probs = []

In [None]:
model.eval()
with torch.no_grad():
    for data in test_loader:
        images: Tensor
        labels: Tensor
        images, labels = data

        images = images.to(device)
        labels = labels.to(device)

        pred: Tensor = softmax(model(images), 1)
        pred_label = torch.argmax(pred, dim=1)

        for j, (p, lab) in enumerate(zip(pred_label, labels, strict=True)):
            if p != lab:
                miss_probs.append(pred[j][p].item())
            if p != lab and pred[j][p] > 0.99:
                entry = (images[j], p, lab, pred[j][p])
                high_miss.append(entry)
            if p != lab and pred[j][p] < 0.4:
                entry = (images[j], p, lab, pred[j][p])
                low_miss.append(entry)

        y_preds.extend(pred_label.cpu())
        y_trues.extend(labels.cpu())

In [None]:
classes = test_loader.dataset.classes

### classification report

In [None]:
print(classification_report(y_true=y_trues, y_pred=y_preds, target_names=classes))

### confusion matrix

In [None]:
conf = confusion_matrix(y_trues, y_preds)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
plt.figure(figsize=(10, 8))
sns.heatmap(
    conf, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes
)
plt.title('Confusion Matrix on Test Set')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

In [None]:
def plot_missclassification(case, cls_names):
    img: Tensor = case[0]
    img = img.permute(1, 2, 0)
    img = img * 0.5 + 0.5
    plt.imshow(img.cpu())
    pred_class = cls_names[case[1]]
    true_class = cls_names[case[2]]
    prob = case[3]
    plt.title(f'Real: {true_class} / Predicted: {pred_class} / Prob: {prob}')

In [None]:
plot_missclassification(low_miss[0], classes)

In [None]:
plt.figure(figsize=(6, 4))
sns.histplot(miss_probs, bins=100)
plt.title('Prediction value of missclassified samples')