# Ewaluacja modelu na zbiorze testowym

Podział zbioru jest deterministyczny więc po prostu wywołujemy `get_dataset_split` z takimi samymi parametrami
Weźmiemy najlepszy checkpoint z treningu i przeprowadzimy ewaluację na nim.

In [1]:
import torch
from dataclasses import dataclass
from pathlib import Path

from supernova.modeling.model import SupernovaClassifierV1

checkpoint_path = Path("/home/mgarbowski/Desktop/supernova-epoch=21-val_acc=0.70.ckpt")


def load_model_from_checkpoint(checkpoint_path: Path):
    checkpoint = torch.load(
        checkpoint_path, weights_only=False, map_location=torch.device("cpu")
    )
    cfg = checkpoint.get("hyper_parameters")["model_config"]
    return SupernovaClassifierV1(cfg)


model = load_model_from_checkpoint(checkpoint_path)
model

SupernovaClassifierV1(
  (metadata_mlp): MLP(
    (network): Sequential(
      (0): Linear(in_features=20, out_features=256, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.1943098578634063, inplace=False)
      (3): Sequential(
        (0): Linear(in_features=256, out_features=256, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1943098578634063, inplace=False)
      )
      (4): Sequential(
        (0): Linear(in_features=256, out_features=256, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1943098578634063, inplace=False)
      )
      (5): Sequential(
        (0): Linear(in_features=256, out_features=256, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1943098578634063, inplace=False)
      )
      (6): Sequential(
        (0): Linear(in_features=256, out_features=256, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1943098578634063, inplace=False)
      )
      (7): Linear(in_features=256, out_features=32, bias=True)
    )
  )
  (lightcurve_lst

In [2]:
from supernova.sweep import DATASET_PATH, VAL_SPLIT, TEST_SPLIT
from supernova.dataset import get_data_loaders, get_dataset_split

datasets = get_dataset_split(DATASET_PATH, VAL_SPLIT, TEST_SPLIT)
loaders = get_data_loaders(datasets, batch_size=32)
test_loader = loaders["test"]

In [3]:
len(datasets["test"])

1177

In [4]:
import torch
from sklearn.metrics import (
    confusion_matrix,
    classification_report,
    accuracy_score,
    precision_recall_fscore_support,
)
import numpy as np


@dataclass(frozen=True)
class Metrics:
    accuracy: float
    precision: float
    recall: float
    f1_score: float


@dataclass(frozen=True)
class EvaluationResults:
    predicted_labels: np.ndarray
    real_labels: np.ndarray

    @property
    def confusion_matrix(self):
        return confusion_matrix(self.real_labels, self.predicted_labels)

    @property
    def report(self):
        # TODO class labels
        return classification_report(self.real_labels, self.predicted_labels)

    @property
    def micro_averaged(self):
        precision, recall, f1, _ = precision_recall_fscore_support(
            self.real_labels, self.predicted_labels, average="micro"
        )
        accuracy = accuracy_score(self.real_labels, self.predicted_labels)
        return Metrics(
            accuracy=accuracy, precision=precision, recall=recall, f1_score=f1
        )

    @property
    def macro_averaged(self):
        precision, recall, f1, _ = precision_recall_fscore_support(
            self.real_labels, self.predicted_labels, average="macro"
        )
        cm = self.confusion_matrix
        accuracy = (cm.diagonal() / cm.sum(axis=1)).mean()
        return Metrics(
            accuracy=accuracy, precision=precision, recall=recall, f1_score=f1
        )


def evaluate_model(
    model, test_loader, device="cuda" if torch.cuda.is_available() else "cpu"
):
    model.to(device)
    model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in test_loader:
            metadata = batch["metadata"].to(device)
            sequences = {k: v.to(device) for k, v in batch["sequences"].items()}
            lengths = batch["lengths"]
            labels = batch["labels"]

            logits = model(metadata, sequences, lengths)
            preds = logits.argmax(dim=1).cpu().numpy()

            all_preds.extend(preds)
            all_labels.extend(labels.numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    return EvaluationResults(predicted_labels=all_preds, real_labels=all_labels)

In [5]:
results = evaluate_model(model, test_loader)

In [6]:
print(results.report)

              precision    recall  f1-score   support

           0       0.00      0.00      0.00        22
           1       0.00      0.00      0.00        68
           2       0.11      0.89      0.20       139
           3       0.00      0.00      0.00       183
           4       0.00      0.00      0.00        29
           5       0.00      0.00      0.00         6
           6       0.00      0.00      0.00        85
           7       0.00      0.00      0.00        15
           8       0.00      0.00      0.00       135
           9       0.00      0.00      0.00        31
          10       0.08      0.09      0.09        56
          11       0.00      0.00      0.00       338
          12       0.00      0.00      0.00        38
          13       0.00      0.00      0.00        32

    accuracy                           0.11      1177
   macro avg       0.01      0.07      0.02      1177
weighted avg       0.02      0.11      0.03      1177



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


In [7]:
print(results.confusion_matrix)

[[  0   0  15   0   0   0   0   0   0   0   7   0   0   0]
 [  0   0  68   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0 124   0   0   0   0   0   0   0  15   0   0   0]
 [  0   0 183   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0  28   0   0   0   0   0   0   0   1   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   6   0   0   0]
 [  0   0  84   0   0   0   0   0   0   0   1   0   0   0]
 [  0   0  15   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0 133   0   0   0   0   0   0   0   2   0   0   0]
 [  0   0  31   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0  51   0   0   0   0   0   0   0   5   0   0   0]
 [  0   0 337   0   0   0   0   0   0   0   1   0   0   0]
 [  0   0  17   0   0   0   0   0   0   0  21   0   0   0]
 [  0   0  31   0   0   0   0   0   0   0   1   0   0   0]]
