In [53]:
import torch

if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        raise Exception("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        raise Exception("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

In [54]:
import numpy as np
from dlordinal.datasets import FGNet
from dlordinal.losses import TriangularCrossEntropyLoss
from dlordinal.metrics import amae, mmae
from skorch import NeuralNetClassifier
from torch import nn
from torch.optim import Adam
from torchvision import models
from torchvision.transforms import Compose, ToTensor

# Download the FGNet dataset
fgnet_train = FGNet(
    root="./datasets",
    train=True,
    target_transform=np.array,
    transform=Compose([ToTensor()]),
)
fgnet_test = FGNet(
    root="./datasets",
    train=False,
    target_transform=np.array,
    transform=Compose([ToTensor()]),
)

num_classes_fgnet = len(fgnet_train.classes)

# Model
model = models.resnet18(weights="IMAGENET1K_V1")
model.fc = nn.Linear(model.fc.in_features, num_classes_fgnet)

# Loss function
loss_fn = TriangularCrossEntropyLoss(num_classes=num_classes_fgnet)

# Skorch estimator
estimator = NeuralNetClassifier(
    module=model,
    criterion=loss_fn,
    optimizer=Adam,
    lr=1e-3,
    max_epochs=25,
    device='mps',
)

estimator.fit(X=fgnet_train, y=fgnet_train.targets)
train_probs = estimator.predict_proba(fgnet_train)
test_probs = estimator.predict_proba(fgnet_test)

# Metrics
amae_metric = amae(np.array(fgnet_test.targets), test_probs)
mmae_metric = mmae(np.array(fgnet_test.targets), test_probs)
print(f"Test AMAE: {amae_metric}, Test MMAE: {mmae_metric}")

Files already downloaded and verified
Files already processed and verified
Files already split and verified
Files already downloaded and verified
Files already processed and verified
Files already split and verified
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m1.6924[0m       [32m0.3292[0m        [35m2.0299[0m  8.2437
      2        [36m0.9279[0m       0.3292        2.7551  2.3382
      3        [36m0.6834[0m       [32m0.4596[0m        2.2886  2.3409
      4        [36m0.5570[0m       0.3851        2.2740  2.3239
      5        [36m0.4407[0m       [32m0.5342[0m        [35m1.6565[0m  2.3237
      6        [36m0.4043[0m       [32m0.5528[0m        [35m1.3586[0m  2.3223
      7        [36m0.3782[0m       0.5528        [35m1.3014[0m  2.3161
      8        [36m0.3647[0m       [32m0.5839[0m        [35m1.2299[0m  2.3190
      9        [36m0.3567[0m       0.5839        