In [162]:
from torchvision.transforms import ToTensor, Resize, Compose
from torch.utils.data import DataLoader
import torch
from pytorch_ood.utils import ToRGB
from gtsrb import GTSRB
from pathlib import Path
from PIL import Image
import torchvision.transforms.functional as F


In [163]:
trans = Compose([
            ToRGB(),
            ToTensor(),
            Resize((64, 64), antialias=True)
        ])


batch_size = 5
test_data = GTSRB(root=".", train=False, transforms=trans)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)

resnet18_model = torch.load("64x64/label-net-resnet18-64.pt", map_location="cpu", weights_only=False)
wideresnet40_model = torch.load("64x64/label-net-wrn40-64.pt", map_location="cpu", weights_only=False)

img_folder = Path("imgs/")

resnet18_model.eval()
wideresnet40_model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), strid

In [164]:
def save_images():
    # resize images
    size = (1200,1200)
    sampling = Image.Resampling.LANCZOS

    # get the first 10 images from the test set
    for x, y in test_loader:
        # x: Batch von Bildern, y: Labels
        labels = y[:,0]
        for i, img_tensor in enumerate(x):
            img = F.to_pil_image(img_tensor)
            img = img.resize(size, sampling)
            img.save(f"imgs/img_{i}.png")
        break

    return labels.tolist()

In [165]:
labels = save_images()

In [166]:
def classify_images(folder_path, model):

    imgs = [trans(Image.open(x)) for x in sorted(folder_path.glob("*.png"))]
    batch = torch.stack(imgs)

    with torch.no_grad():
        pred = model(batch)

    ergs = pred.argmax(1).tolist()
    return ergs

In [167]:
ergs = classify_images(img_folder, wideresnet40_model)
b = []
for i in range(len(ergs)):
    b.append(ergs[i]==labels[i])
b

[True, True, True, True, True]