In [None]:
import os

import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

from shared_methods import all_labels, show_image_by_path
from huggingface_pretrained import get_vision_transformer

from transformers import ViTImageProcessor

# Check if we can use Cuda

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

# device = "cpu" # uncomment if you want to use "cpu", currently cpu is faster than cuda (maybe because the NN is very little)
print(f"Using {device} device")

# Load the model

In [None]:
run_id = "miig6ldy"

model = get_vision_transformer()
model.to(device)
model.load_state_dict(torch.load(f"trained_models/vit_{run_id}_6.pth"))
model.eval()

# Load the feature extractor (preprocessor)
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

# Test the model

## Single image

In [None]:
with torch.no_grad():  # Disable gradient calculation
    path = "data/test/461179.jpg"
    image = Image.open(path).convert("RGB")
    inputs = processor(image, return_tensors="pt")
    inputs = inputs.to(device)

    outputs = model(**inputs)  # Run inference

    logits = outputs.logits
    predicted_label_idx = torch.argmax(logits, dim=1).item()

    print(f"I guess this is {all_labels[predicted_label_idx]}")
    show_image_by_path(path)

## All images

### Inference method 1

In [16]:
root_dir = "data/test"
batch_size = 32

all_files = [f for f in os.listdir(root_dir) if f.endswith(".jpg")]
num_batches = len(all_files) // batch_size

with open(f"solutions/solution_no1_{run_id}.csv", "w") as output_file:
    output_file.write("Id,Category\n")

    # This loop iterates over the list of all test image files in batches. The step size is batch_size.
    for i in range(0, len(all_files), batch_size):
        print(f"Predicting batch {i // batch_size + 1} of {num_batches}")

        # taking files from i to (i + batch_size)
        batch_files = all_files[i:i + batch_size]
        batch_tensors = [processor(Image.open(os.path.join(root_dir, f)).convert("RGB"), return_tensors="pt")['pixel_values'] for f in batch_files]

        # stack along a new dimension
        batch_tensors = torch.stack(batch_tensors).squeeze(1)

        with torch.no_grad():
            outputs = model(batch_tensors.to(device))['logits']
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
            predicted_indices = torch.argmax(probabilities, dim=1)  # shape ([32]) -> label for each element in the batch

        for j, predicted_index in enumerate(predicted_indices):
            output_file.write(f"{batch_files[j]},{all_labels[predicted_index.item()]}\n")


Predicting batch 1 of 29
Predicting batch 2 of 29
Predicting batch 3 of 29
Predicting batch 4 of 29
Predicting batch 5 of 29
Predicting batch 6 of 29
Predicting batch 7 of 29
Predicting batch 8 of 29
Predicting batch 9 of 29
Predicting batch 10 of 29
Predicting batch 11 of 29
Predicting batch 12 of 29
Predicting batch 13 of 29
Predicting batch 14 of 29
Predicting batch 15 of 29
Predicting batch 16 of 29
Predicting batch 17 of 29
Predicting batch 18 of 29
Predicting batch 19 of 29
Predicting batch 20 of 29
Predicting batch 21 of 29
Predicting batch 22 of 29
Predicting batch 23 of 29
Predicting batch 24 of 29
Predicting batch 25 of 29
Predicting batch 26 of 29
Predicting batch 27 of 29
Predicting batch 28 of 29
Predicting batch 29 of 29
Predicting batch 30 of 29


## Inference method 2

In [15]:
root_dir = "data/test"

output_file = open(f"solutions/solution_no2_{run_id}.csv", "w")
output_file.write("Id,Category\n")

idx = 0

for filename in os.listdir(root_dir):
    if filename.endswith(".jpg"):
        if idx % 10 == 0:
            print(f"Predicting file with index {idx}")
        idx += 1

        img_path = os.path.join(root_dir, filename)
        input_tensor = processor(Image.open(img_path).convert("RGB"), return_tensors="pt")['pixel_values'].to(device)
        output = model(input_tensor)['logits']

        probabilities = F.softmax(output, dim=1)
        predicted_label_idx = torch.argmax(probabilities).item()

        output_file.write(f"{filename},{all_labels[predicted_label_idx]}\n")

output_file.close()


Predicting file with index 0
Predicting file with index 10
Predicting file with index 20
Predicting file with index 30
Predicting file with index 40
Predicting file with index 50
Predicting file with index 60
Predicting file with index 70
Predicting file with index 80
Predicting file with index 90
Predicting file with index 100
Predicting file with index 110
Predicting file with index 120
Predicting file with index 130
Predicting file with index 140
Predicting file with index 150
Predicting file with index 160
Predicting file with index 170
Predicting file with index 180
Predicting file with index 190
Predicting file with index 200
Predicting file with index 210
Predicting file with index 220
Predicting file with index 230
Predicting file with index 240
Predicting file with index 250
Predicting file with index 260
Predicting file with index 270
Predicting file with index 280
Predicting file with index 290
Predicting file with index 300
Predicting file with index 310
Predicting file wit