In [None]:
import torch
import torch.nn as nn
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm

In [None]:
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = torch.device("mps")


In [None]:
tensor_norm = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

ds_google = ImageFolder(
    "/home/jayh/Pictures/Plant",
    transforms.Compose(
        [
            transforms.Resize(256),
            transforms.FiveCrop(224),
            transforms.Lambda(
                lambda crops: torch.stack([tensor_norm(crop) for crop in crops])
            ),
        ]
    ),
)

dl_google = DataLoader(
    ds_google,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    pin_memory=torch.cuda.is_available(),
)

ds_test = ImageFolder(
    "../data/test",
    transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
)

dl_test = DataLoader(
    ds_test,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    pin_memory=torch.cuda.is_available(),
)

In [None]:
def evaluate(model, test: bool):
    with torch.no_grad():
        running_loss = 0.0
        running_corrects = 0
        total_samples = 0
        total_preds = []
        total_targets = []

        loader = dl_test if test else dl_google

        for x, y in tqdm(loader, "Eval", leave=False):
            x = x.to(device)
            y = y.to(device)

            if test:
                outputs = model(x)
            else:
                bs, ncrops, c, h, w = x.size()

                outputs = model(x.view(-1, c, h, w))
                outputs = outputs.view(bs, ncrops, -1).mean(1)

            _, preds = torch.max(outputs, 1)

            loss = nn.CrossEntropyLoss()(outputs, y)

            total_samples += x.size(0)
            running_loss += loss.item() * x.size(0)
            running_corrects += torch.sum(preds == y.data).item()

            total_preds.extend(preds.cpu().tolist())
            total_targets.extend(y.cpu().tolist())

        loss = running_loss / total_samples
        acc = running_corrects / total_samples
        f1 = f1_score(total_targets, total_preds, average="macro")

        return loss, acc, f1

In [None]:
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 29)

model_paths = [
    "../frontend/noise_gwb_p1.pt",
    "../frontend/noise_gwb_p0.5.pt",
    "../frontend/noise.pt",
    "../frontend/ft.pt",
    "../frontend/cv.pt",
    "../frontend/segmented.pt",
]

for path in model_paths:
    state_dict = torch.load(path, map_location=torch.device("cpu"))
    model.load_state_dict(state_dict)

    model.eval()

    print(path)
    loss, acc, f1 = evaluate(model, True)
    print("test:")
    print("loss:", loss, "acc:", acc, "f1:", f1)
    loss, acc, f1 = evaluate(model, False)
    print("google:")
    print("loss:", loss, "acc:", acc, "f1:", f1)
    print()
