In [1]:
import numpy as np
from dlordinal.datasets import FGNet
from dlordinal.losses import TriangularLoss
from dlordinal.metrics import amae, mmae
from skorch import NeuralNetClassifier
from torch import nn
from torch.optim import Adam
from torchvision import models
from torchvision.transforms import Compose, ToTensor

# Download the FGNet dataset
fgnet_train = FGNet(
    root="./datasets",
    train=True,
    #强制转换为 int64）
    target_transform=lambda x: np.array(x, dtype=np.int64),
    transform=Compose([ToTensor()]),
)
fgnet_test = FGNet(
    root="./datasets",
    train=False,
    #强制转换为 int64）
    target_transform=lambda x: np.array(x, dtype=np.int64),
    transform=Compose([ToTensor()]),
)

num_classes_fgnet = len(fgnet_train.classes)

# Model
model = models.resnet18(weights="IMAGENET1K_V1")
model.fc = nn.Linear(model.fc.in_features, num_classes_fgnet)

# Loss function
loss_fn = TriangularLoss(base_loss=nn.CrossEntropyLoss(), num_classes=num_classes_fgnet)

# Skorch estimator
estimator = NeuralNetClassifier(
    module=model,
    criterion=loss_fn,
    optimizer=Adam,
    lr=1e-3,
    max_epochs=25,
)

estimator.fit(X=fgnet_train, y=fgnet_train.targets)
train_probs = estimator.predict_proba(fgnet_train)
test_probs = estimator.predict_proba(fgnet_test)

# Metrics
amae_metric = amae(np.array(fgnet_test.targets), test_probs)
mmae_metric = mmae(np.array(fgnet_test.targets), test_probs)
print(f"Test AMAE: {amae_metric}, Test MMAE: {mmae_metric}")

Files already downloaded and verified
Files already processed and verified
Files already split and verified
Files already downloaded and verified
Files already processed and verified
Files already split and verified
  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        [36m1.7190[0m       [32m0.3168[0m        [35m1.7735[0m  27.6544
      2        [36m0.9176[0m       [32m0.3602[0m        1.7816  22.2978
      3        [36m0.5885[0m       [32m0.5093[0m        [35m1.4999[0m  22.1870
      4        [36m0.4544[0m       0.4658        1.9440  23.4645
      5        [36m0.3989[0m       [32m0.5776[0m        [35m1.1674[0m  22.7371
      6        [36m0.3826[0m       0.5776        1.1716  22.6531
      7        [36m0.3633[0m       [32m0.5963[0m        [35m1.1281[0m  22.1734
      8        [36m0.3565[0m       [32m0.6211[0m        [35m1.0612[0m  22.7198
      9        [36m0.3518[0m   