In [2]:
import gc

import numpy
import torch
from mapie.conformity_scores import LACConformityScore, APSConformityScore, RPSConformityScore

from datasets.retina_mnist import RetinaMNISTDataset

DATA_ROOT = '.'

device = 'mps' if torch.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(1)

def free_garbage():
    torch.cuda.empty_cache()
    gc.collect()

In [3]:
from datasets.fgnet import FGNetDataset

dataset = RetinaMNISTDataset(0.15, DATA_ROOT)

X_train, y_train = dataset.get_train_data()
X_hold_out, y_hold_out = dataset.get_hold_out_data()
X_test, y_test = dataset.get_test_data()

num_classes = dataset.get_num_classes()

100%|██████████| 3.29M/3.29M [00:00<00:00, 11.7MB/s]


In [20]:
from typing import Any
from numpy import ndarray, dtype
from mapie.classification import SplitConformalClassifier
from torch.optim.lr_scheduler import ReduceLROnPlateau
from skorch import NeuralNetClassifier
from skorch.dataset import ValidSplit
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
from torch.optim import AdamW, Adam
from torch import nn
from dlordinal.output_layers import COPOC
from torchvision import models
from dlordinal.losses import TriangularLoss, WKLoss, EMDLoss
from torch.nn import CrossEntropyLoss

losses = [
    'COPOC',
    # CrossEntropyLoss(),
    # TriangularLoss(base_loss=CrossEntropyLoss(), num_classes=num_classes),
    # WKLoss(num_classes=num_classes, use_logits=False),
    # EMDLoss(num_classes=num_classes),
]

scores = [
    LACConformityScore(),
    APSConformityScore(),
    # RPSConformityScore(),
]

preds: dict[str, tuple[ndarray[Any, dtype[Any]], ndarray[Any, dtype[Any]]]] = {}

for loss in losses:
    model = models.resnet18(weights="IMAGENET1K_V1")
    if loss == 'COPOC':
        loss_name = loss
        model.fc = nn.Sequential(nn.Linear(model.fc.in_features, num_classes), COPOC())
        loss_function = CrossEntropyLoss().to(device)
    elif (type(loss).__name__ == 'TriangularLoss'
        or type(loss).__name__ == 'WKLoss'
        or type(loss).__name__ == 'EMDLoss'):
        loss_name = type(loss).__name__
        model.fc = nn.Sequential(nn.Linear(model.fc.in_features, num_classes), nn.Softmax(dim=1))
        loss_function = loss
    else:
        loss_name = type(loss).__name__
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        loss_function = loss

    classifier = NeuralNetClassifier(
        module=model.to(device),
        criterion=loss_function.to(device),
        optimizer=AdamW,
        lr=0.001,
        batch_size=128,
        train_split=None,
        callbacks=[
            EarlyStopping(patience=40, monitor="train_loss"),  # long patience
            # LRScheduler(policy=ReduceLROnPlateau, patience=10, factor=0.5, min_lr=1e-6),
            # Checkpoint(monitor="valid_loss_best", load_best=True)
        ],
        max_epochs=25,
        device=device,
    )

    for score in scores:
        cp = SplitConformalClassifier(
            estimator=classifier,
            conformity_score=score,
            confidence_level=[0.98, 0.97, 0.95, 0.92, 0.9, 0.8, 0.7],
            prefit=False,
            random_state=1,
        )

        cp.fit(X_train, y_train)
        cp.conformalize(X_hold_out, y_hold_out)

        y_pred, y_pred_set = cp.predict_set(X_test)
        preds[f'{loss_name}_{type(score).__name__}'] = (y_pred, y_pred_set)

    free_garbage()


  epoch    train_loss     dur
-------  ------------  ------
      1        [36m1.5377[0m  0.4387
      2        [36m1.3491[0m  0.2077
      3        [36m1.2238[0m  0.2030
      4        [36m1.2057[0m  0.2048
      5        [36m1.1162[0m  0.2030
      6        [36m1.1042[0m  0.2104
      7        [36m1.0990[0m  0.2054
      8        [36m1.0859[0m  0.2027
      9        [36m1.0784[0m  0.2041
     10        [36m1.0149[0m  0.2045
     11        [36m0.9724[0m  0.2022
     12        [36m0.9673[0m  0.2032
     13        [36m0.9405[0m  0.2044
     14        [36m0.9032[0m  0.2044
     15        0.9223  0.2037
     16        0.9269  0.2032
     17        [36m0.8728[0m  0.2033
     18        [36m0.8358[0m  0.2035
     19        [36m0.8318[0m  0.2048
     20        [36m0.8135[0m  0.2037
     21        [36m0.7741[0m  0.2042
     22        [36m0.7206[0m  0.2035
     23        [36m0.7176[0m  0.2038
     24        [36m0.7125[0m  0.2060
     25        [36m0.

In [22]:
from metrics import calc_accuracy, calc_mae, calc_qwk

metrics = {}
for pred, (y_pred, y_pred_set) in preds.items():
    metrics[pred] = (calc_accuracy(y_test, y_pred), calc_mae(y_test, y_pred), calc_qwk(y_test, y_pred))
