In [None]:
import os

# download the dataset if it doesn't exist:
if not os.path.exists("msl-labeled-data-set-v2.1"):
    if not os.path.isfile("msl_v2.1.zip.zip"):
        !wget https://zenodo.org/record/4033453/files/msl-labeled-data-set-v2.1.zip?download=1 -O msl_v2.1.zip
    !unzip -q msl_v2.1.zip -d .  # quiet unzip

In [None]:
import torch
import torchvision
from torchvision.datasets.folder import IMG_EXTENSIONS, default_loader


class MSLDataset(torchvision.datasets.DatasetFolder):
    def __init__(self, root, transform=None, split="train"):
        if split not in ["train", "val", "test"]:
            raise ValueError(f"dataset.split must be train, val, or test. Got {split}")
        self.split = split
        if transform is None:
            transform = torchvision.transforms.ToTensor()
        super().__init__(root, default_loader, IMG_EXTENSIONS, transform=transform)
        self.imgs = self.samples

    def find_classes(self, root):
        with open(os.path.join(root, "class_map.csv"), "r") as f:
            class_to_idx = {
                l.split(",")[1].strip(): int(l.split(",")[0]) for l in f.readlines()
            }
        return sorted(list(class_to_idx.keys())), class_to_idx

    def make_dataset(self, root, class_to_idx, extensions, is_valid_file):
        helper = lambda file, idx: (os.path.join(root, "images", file), int(idx))
        with open(os.path.join(root, f"{self.split}-set-v2.1.txt"), "r") as f:
            return [helper(*l.split(" ")) for l in f.readlines()]


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.hub.load(
    "jpl-clover/weights:devel",
    "resnet18_distilled_from_r101_1x_sk0_finetuned_on_100pctMSL",
    force_reload=True,
    num_classes=19,
)
model = model.to(device)

In [14]:
from tqdm.auto import tqdm

ds = MSLDataset("msl-labeled-data-set-v2.1", split="test")
accuracy = 0
model.eval()
with torch.no_grad():
    progress_bar = tqdm(ds)
    for img, label in progress_bar:
        img = img.unsqueeze(0).to(device)
        output = model(img)
        accuracy += (output.argmax(1) == label).sum().item()
        # add current accuracy to tqdm
        progress_bar.set_postfix(accuracy=f"{accuracy / len(ds):.3f}")

accuracy /= len(ds)
print(f"Final accuracy: {accuracy:.4f}")