# setup

In [1]:
import torch

from PIL import Image
from torchvision import transforms

from datasets import load_dataset

from tqdm import tqdm

In [2]:
BATCH_SIZE = 512

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [4]:
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def transform(examples):
    examples['image'] = [preprocess(image.convert("RGB")) for image in examples['image']]
    return examples

def collate_fn(examples):
    images = []
    labels = []
    for example in examples:
        images.append((example['image']))
        labels.append(example['label'])

    images = torch.stack(images)
    labels = torch.tensor(labels)
    return {'image': images, 'label': labels}

# data

In [5]:
val = load_dataset("evanarlian/imagenet_1k_resized_256", split = 'val')
val

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

Dataset({
    features: ['image', 'label'],
    num_rows: 50000
})

In [6]:
val = val.with_transform(transform)
val_loader = torch.utils.data.DataLoader(val, collate_fn = collate_fn, batch_size = BATCH_SIZE)
val.num_rows / BATCH_SIZE

97.65625

# model

In [7]:
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet101', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet152', pretrained=True)

Using cache found in /home/josegfer/.cache/torch/hub/pytorch_vision_v0.10.0


In [8]:
model = model.to(device)

# eval

In [10]:
total = 0
correct = 0

model.eval()
with torch.no_grad():
    for i, sample in tqdm(enumerate(val_loader)):
        x = sample['image'].to(device)
        y = sample['label'].to(device)
        yhat = model.forward(x)

        _, predicted = torch.max(yhat.data, 1)
        total += y.size(0)
        correct += (predicted == y).sum().item()
1 - correct / total

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
98it [01:44,  1.06s/it]


0.28935999999999995