In [9]:
import gc
import torch
from mapie.conformity_scores import LACConformityScore, APSConformityScore, RPSConformityScore

DATA_ROOT = '.'

device = 'mps' if torch.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'


def free_garbage():
    torch.cuda.empty_cache()
    gc.collect()

In [10]:
from datasets.fgnet import FGNetDataset

dataset = FGNetDataset(0.15, DATA_ROOT)

X_train, y_train = dataset.get_train_data()
X_hold_out, y_hold_out = dataset.get_hold_out_data()
X_test, y_test = dataset.get_test_data()

num_classes = dataset.get_num_classes()

Files already downloaded and verified
Files already processed and verified
Files already split and verified
Files already downloaded and verified
Files already processed and verified
Files already split and verified


In [11]:
from typing import Any
from numpy import ndarray, dtype
from mapie.classification import SplitConformalClassifier
from torch.optim.lr_scheduler import ReduceLROnPlateau
from skorch import NeuralNetClassifier
from skorch.dataset import ValidSplit
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
from torch.optim import AdamW, Adam
from torch import nn
from dlordinal.output_layers import COPOC
from torchvision import models
from dlordinal.losses import TriangularLoss, WKLoss, EMDLoss
from torch.nn import CrossEntropyLoss

losses = [
    # 'COPOC',
    CrossEntropyLoss(),
    # TriangularLoss(base_loss=CrossEntropyLoss(), num_classes=num_classes),
    # WKLoss(num_classes=num_classes, use_logits=True),
    # EMDLoss(num_classes=num_classes),
]

scores = [
    LACConformityScore(),
    # APSConformityScore(),
    # RPSConformityScore(),
]

preds: dict[str, tuple[ndarray[Any, dtype[Any]], ndarray[Any, dtype[Any]]]] = {}

for loss in losses:
    model = models.resnet18(weights="IMAGENET1K_V1")
    if loss == 'COPOC':
        loss_name = loss
        model.fc = nn.Sequential(nn.Linear(model.fc.in_features, num_classes), COPOC())
        loss_function = CrossEntropyLoss().to(device)
    elif (type(loss).__name__ == 'TriangularLoss'
        or type(loss).__name__ == 'WKLoss'
        or type(loss).__name__ == 'EMDLoss'):
        loss_name = type(loss).__name__
        model.fc = nn.Sequential(nn.Linear(model.fc.in_features, num_classes), nn.Softmax(dim=1))
        loss_function = loss
    else:
        loss_name = type(loss).__name__
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        loss_function = loss

    classifier = NeuralNetClassifier(
        module=model.to(device),
        criterion=loss_function.to(device),
        optimizer=Adam,
        lr=0.001,
        batch_size=128,
        train_split=ValidSplit(
            0.1, random_state=1
        ),
        callbacks=[
            EarlyStopping(patience=40, monitor="valid_loss"),  # long patience
            LRScheduler(policy=ReduceLROnPlateau, patience=10, factor=0.5, min_lr=1e-6),
            Checkpoint(monitor="valid_loss_best", load_best=True)
        ],
        max_epochs=25,
        device=device,
    )

    classifier.fit(X_train, y_train)

    for score in scores:
        cp = SplitConformalClassifier(
            estimator=classifier,
            conformity_score=score,
            confidence_level=[0.98, 0.97, 0.95, 0.92, 0.9, 0.8, 0.7],
            prefit=True,
            random_state=1,
        )

        cp.conformalize(X_hold_out, y_hold_out)

        y_pred, y_pred_set = cp.predict_set(X_test)
        preds[f'{loss_name}_{type(score).__name__}'] = (y_pred, y_pred_set)

    free_garbage()


  epoch    train_loss    valid_acc    valid_loss    cp      lr     dur
-------  ------------  -----------  ------------  ----  ------  ------
      1        [36m1.7109[0m       [32m0.4265[0m        [35m1.7659[0m     +  0.0010  1.1000
      2        [36m0.7853[0m       0.3676        2.4689        0.0010  0.7383
      3        [36m0.3677[0m       [32m0.4559[0m        2.2243        0.0010  0.7215
      4        [36m0.1097[0m       [32m0.5147[0m        1.8335        0.0010  0.7188
      5        [36m0.0331[0m       [32m0.5294[0m        [35m1.7647[0m     +  0.0010  0.7252
      6        [36m0.0074[0m       0.4853        1.8437        0.0010  0.7137
      7        [36m0.0039[0m       [32m0.5588[0m        1.8837        0.0010  0.7210
      8        [36m0.0019[0m       0.5441        1.8439        0.0010  0.7198
      9        [36m0.0011[0m       0.5441        1.7754        0.0010  0.7251
     10        [36m0.0008[0m       [32m0.6029[0m        [35m1.7149[0m