In [None]:
import torch
import torch.nn as nn
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.datasets import ImageFolder

In [None]:
device = "cpu"

In [None]:
tensor_norm = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

ds = ImageFolder(
    "/home/jayh/Pictures/Plant",
    transforms.Compose(
        [
            transforms.Resize(256),
            transforms.FiveCrop(224),
            transforms.Lambda(
                lambda crops: torch.stack([tensor_norm(crop) for crop in crops])
            ),
        ]
    ),
)

dl = DataLoader(
    ds,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    pin_memory=torch.cuda.is_available(),
)

In [None]:
def evaluate(model):
    with torch.no_grad():
        running_loss = 0.0
        running_corrects = 0
        total_samples = 0
        total_preds = []
        total_targets = []

        for x, y in dl:
            x = x.to(device)
            y = y.to(device)

            bs, ncrops, c, h, w = x.size()

            outputs = model(x.view(-1, c, h, w))
            outputs = outputs.view(bs, ncrops, -1).mean(1)

            _, preds = torch.max(outputs, 1)

            loss = nn.CrossEntropyLoss()(outputs, y)

            total_samples += x.size(0)
            running_loss += loss.item() * x.size(0)
            running_corrects += torch.sum(preds == y.data).item()

            total_preds.extend(preds.cpu().tolist())
            total_targets.extend(y.cpu().tolist())

        loss = running_loss / total_samples
        acc = running_corrects / total_samples
        f1 = f1_score(total_targets, total_preds, average="macro")

        return loss, acc, f1

In [None]:
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 29)

model_paths = [
    "../frontend/noise_gwb_p1.pt",
    "../frontend/noise_gwb_p0.5.pt",
    "../frontend/noise.pt",
    "../frontend/ft.pt",
    "../frontend/cv.pt",
    "../frontend/segmented.pt",
]

for path in model_paths:
    state_dict = torch.load(path, map_location=torch.device("cpu"))
    model.load_state_dict(state_dict)

    model.eval()

    loss, acc, f1 = evaluate(model)

    print(path)
    print("loss:", loss, "acc:", acc, "f1:", f1)
    print()