In [None]:
import numpy as np
import json
import seaborn as sns
import matplotlib.pyplot as plt

import sys
sys.path.append("../")
from dataset_utils import ImageNet, Places365, ArtPlaces

### Konstanten

In [None]:
RESULTS_PATH = r"C:\Users\mariu\Documents\Studium\Praktikum\Evaluation\20251116_161434_Evaluation_L2\results.json"

# Visualization settings
NORMALIZE_CONFUSION_MATRIX = True
DIFFERENT_COLORS_FOR_DIAG = False
HIGHLIGHT_DIAG = False

# Result settings
# NAME = "MoCo (resnet50, dim=128, MoCo weights)"
# DATASET = "Places365"
# METRIC = "Accuracy"
# K = 5
RESULTS_TO_VISUALIZE = [
    # {
    #     "name": "Perceptual Loss (resnet18, dim=25088, ImageNet weights)",
    #     "dataset": ["ImageNet", "Places365", "ArtPlaces"],
    #     "metric": ["Accuracy"],
    #     "k": [5]
    # },
    # {
    #     "name": "ResNet18SiameseNetwork (resnet50, dim=1000, finetuned on ImageNet)",
    #     "dataset": ["ImageNet", "Places365", "ArtPlaces"],
    #     "metric": ["Accuracy"],
    #     "k": [5]
    # },
    {
        "name": "MoCo (resnet50, dim=128, MoCo weights)",
        "dataset": ["ImageNet", "Places365", "ArtPlaces"],
        "metric": ["Accuracy"],
        "k": [5]
    },
    {
        "name": "MoCo (resnet50, dim=128, trained on Places365)",
        "dataset": ["Places365"], # ["Places365", "ArtPlaces"],
        "metric": ["Accuracy"],
        "k": [5]
    },
    {
        "name": "MoCo (resnet50, dim=128, trained on ArtPlaces)",
        "dataset": ["ArtPlaces"],
        "metric": ["Accuracy"],
        "k": [5]
    },
    {
        "name": "DINO_v2 (vitb14, dim=768, pretrained)",
        "dataset": ["ImageNet", "Places365", "ArtPlaces"],
        "metric": ["Accuracy"],
        "k": [5]
    },
    # {
    #     "name": "CLIP (ViT-B/32, dim=512, pretrained)",
    #     "dataset": ["ImageNet", "Places365", "ArtPlaces"],
    #     "metric": ["Accuracy"],
    #     "k": [5]
    # },
]

# CLASSES = [
#     # ImageNet
#     "tabby",
#     "tiger cat",
#     "Persian cat",
#     "Egyptian cat",
#     "cougar",
#     "lynx",
#     "German shepherd",
#     "French bulldog",
#     "Eskimo dog",
#     "brown bear",
#     "tractor",
#     "warplane",
#     "passenger car"

#     # Places365 & ArtPlaces
#     # "amphitheater",
#     # "castle",
#     # "palace",
#     # "aqueduct",
#     # "alley",
#     # "highway",
#     # "beer_hall",
#     # "chalet",
#     # "botanical_garden",
#     # "beach",
#     # "mountain",
#     # "canyon",
# ]

### Daten laden

In [None]:
with open(RESULTS_PATH, "r") as f:
    data = json.load(f)

### Fehlermatrix visualisieren

In [None]:
def visualize_confusion_matrix(cm, method_name, dataset_name):
    # cm = np.array(data[NAME][DATASET]["per_class"]["matrix"][METRIC + "@" + str(K)])

    if NORMALIZE_CONFUSION_MATRIX:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        fmt = ".2f"
    else:
        fmt = "d"

    plt.figure(method_name + " - " + dataset_name, figsize=(8, 6))
    plt.rcParams.update({'font.size': 25})

    if DIFFERENT_COLORS_FOR_DIAG:
        # Masken
        diagonal_mask = np.eye(cm.shape[0], dtype=bool)
        off_diag_mask = ~diagonal_mask

        # Off-Diagonal Heatmap (Rot)
        sns.heatmap(cm, mask=diagonal_mask, annot=False, fmt=fmt, cmap="Reds", cbar=False)

        # Diagonal Heatmap (Grün) darüber plotten
        sns.heatmap(cm, mask=off_diag_mask, annot=False, fmt=fmt, cmap="Greens", cbar=False)
    else:
        sns.heatmap(cm, annot=False, fmt=fmt, cmap="Blues", cbar=False)

    if HIGHLIGHT_DIAG:
        # Diagonale markieren
        for i in range(cm.shape[0]):
            plt.gca().add_patch(plt.Rectangle((i, i), 1, 1, fill=False, edgecolor="black", lw=1))

    plt.xlabel('Vorhergesagte Klasse')
    plt.ylabel('Richtige Klasse')
    # plt.title(method_name + " - " + dataset_name)
    plt.show()

In [None]:
%matplotlib qt

# for method, d1 in data.items():
#     for dataset, d2 in d1.items():
#         pass

for model in RESULTS_TO_VISUALIZE:
    method = model["name"]
    for dataset in model["dataset"]:
        for metric in model["metric"]:
            for k in model["k"]:
                cm = np.array(data[method][dataset]["per_class"]["matrix"][metric + "@" + str(k)])
                visualize_confusion_matrix(cm, method, dataset)

### Größte Fehler ausgeben

In [None]:
imagenet = ImageNet(root=r"C:\Users\mariu\Documents\Studium\Praktikum\ImageNet_Evaluation_Subset")
places365 = Places365(root=r"C:\Users\mariu\Documents\Studium\Praktikum\Places365_Evaluation_Subset")
artplaces = ArtPlaces(root=r"C:\Users\mariu\Documents\Development\Datasets\ArtPlaces_13371280")

In [None]:
def get_top_confusions(cm, n=5, class_labels=None):
    """
    Findet die n größten Verwechslungen in der Confusion-Matrix.
    
    cm: numpy array (Confusion Matrix)
    n: Anzahl der Top-Verwechslungen
    class_labels: Liste der Klassennamen (optional)
    
    Returns:
        list of tuples: [(true_class, predicted_class, count), ...]
    """
    
    cm_copy = cm.copy()
    # Diagonale ignorieren (korrekte Vorhersagen)
    np.fill_diagonal(cm_copy, -1)
    
    # Flatten, sortieren, die größten n auswählen
    flat_indices = cm_copy.flatten().argsort()[::-1][:n]
    top_confusions = []
    
    n_classes = cm.shape[0]
    if class_labels is None:
        class_labels = list(range(n_classes))
    
    for idx in flat_indices:
        i = idx // n_classes  # Zeile
        j = idx % n_classes   # Spalte
        top_confusions.append((class_labels[i], class_labels[j], cm[i, j]))
    
    return top_confusions


In [None]:
for model in RESULTS_TO_VISUALIZE:
    method = model["name"]
    for dataset in model["dataset"]:
        match dataset:
            case "ImageNet":
                d = imagenet
            case "Places365":
                d = places365
            case "ArtPlaces":
                d = artplaces

        for metric in model["metric"]:
            for k in model["k"]:
                cm = np.array(data[method][dataset]["per_class"]["matrix"][metric + "@" + str(k)])
                top_confusions = get_top_confusions(cm, n=5)
                print(method + " - " + dataset + ":")
                for c in top_confusions:
                    r = d.idx_to_class[c[0]]
                    p = d.idx_to_class[c[1]]
                    print(f"Richtig: {r}, Vorhergesagt: {p}, Fehler: {c[2]}")
                print()