In [3]:
%pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.7.4-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0.0->torchmetrics)
  D

In [1]:
# models/resnet18.py
import torchvision.models as models
import torch.nn as nn

def get_model(num_classes):
    model = models.resnet18(pretrained=True)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model


In [4]:
# scripts/train_model.py
from sklearn.metrics import f1_score, confusion_matrix
from torchmetrics.classification import F1Score

def evaluate_model(model, dataloader, labels):
    y_true, y_pred = [], []
    for images, targets in dataloader:
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        y_true.extend(targets.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

    print("Custom F1:", f1_score(y_true, y_pred, average='macro'))
    tm_f1 = F1Score(num_classes=len(set(labels)), average='macro')
    print("Torchmetrics F1:", tm_f1(torch.tensor(y_pred), torch.tensor(y_true)))

    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(set(labels)))))
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=labels, yticklabels=labels)
    plt.title("Confusion Matrix")
    plt.savefig("outputs/confusion_matrix.png")


In [5]:
def plot_predictions(correct, incorrect, save_dir="outputs/predictions/"):
    os.makedirs(save_dir, exist_ok=True)
    for i, (img, true, pred) in enumerate(correct[:5]):
        plt.imshow(img.permute(1, 2, 0).numpy())
        plt.title(f"Correct: {true}")
        plt.savefig(f"{save_dir}/correct_{i}.png")

    for i, (img, true, pred) in enumerate(incorrect[:5]):
        plt.imshow(img.permute(1, 2, 0).numpy())
        plt.title(f"Incorrect: True={true}, Pred={pred}")
        plt.savefig(f"{save_dir}/incorrect_{i}.png")
