In [None]:
# Download ImageNet labels
# !wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt

In [None]:
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import os
import random

# Set device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Load pretrained Swin-T model
model = torchvision.models.swin_t(weights='IMAGENET1K_V1').to(device)
model.eval();  # Set to evaluation mode

In [None]:
# Define image transformations (match ImageNet preprocessing)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Swin-T expects 224x224 input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# Load ImageNet class labels (for interpretation)
imagenet_labels = []
with open('imagenet_classes.txt', 'r') as f:
    imagenet_labels = [line.strip() for line in f.readlines()]
# imagenet_labels

In [None]:
# Select a random image from the pascalvoc folder
image_folder = './pascalvoc/VOCdevkit/VOC2012/JPEGImages'
image_files = [f for f in os.listdir(image_folder) if f.endswith(('.jpg', '.png'))]
if not image_files:
    raise FileNotFoundError("No images found in ./pascalvoc folder")
image_path = os.path.join(image_folder, random.choice(image_files))

In [None]:
# Load and preprocess the image
image = Image.open(image_path).convert('RGB')  # Ensure RGB format
image_tensor = transform(image).unsqueeze(0).to(device)  # Add batch dimension and send to gpu/cpu
print("Tensor shape:", image_tensor.shape)
print("Tensor element type:", image_tensor.dtype)

In [None]:
# Perform classification
with torch.no_grad():
    output = model(image_tensor)
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    predicted_idx = torch.argmax(probabilities).item()
    predicted_label = imagenet_labels[predicted_idx]
    confidence = probabilities[predicted_idx].item()

In [None]:
# Display the image and prediction
plt.figure(figsize=(8, 8))
plt.imshow(image)
plt.axis('off')
plt.title(f'Predicted: {predicted_label} ({confidence:.2%})')
plt.show()