<a href="https://colab.research.google.com/github/manmeetsingh7781/csci167project/blob/main/image_recognition_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torchvision import models, transforms
from PIL import Image
import json
import os
from torchvision.models import ResNet50_Weights

In [2]:
# Load the pre-trained ResNet-50 model
model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
model.eval()  # Set the model to evaluation mode

# Define the preprocessing pipeline
preprocess = transforms.Compose([
    transforms.Resize(256),               # Resize to 256x256
    transforms.CenterCrop(224),           # Crop the central 224x224 portion
    transforms.ToTensor(),                # Convert to tensor
    transforms.Normalize(                 # Normalize with ImageNet's mean and std
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 94.0MB/s]


In [13]:
# Load an image from a local file
image_path = "/content/images/image2.png"  # Replace with your local image path
if not os.path.exists(image_path):
    raise FileNotFoundError(f"Image file not found: {image_path}")

image = Image.open(image_path).convert("RGB")  # Ensure the image is in RGB format

# Apply preprocessing
input_tensor = preprocess(image).unsqueeze(0)  # Add a batch dimension

# Perform inference
with torch.no_grad():
    output = model(input_tensor)

# Get the predicted class index
_, predicted_class = torch.max(output, 1)
print(f"Predicted class index: {predicted_class.item()}")

# Load ImageNet class names
labels_path = "/content/imagenet-simple-labels.json"  # Ensure this file is downloaded and in the same directory
if not os.path.exists(labels_path):
    raise FileNotFoundError(f"Labels file not found: {labels_path}")

with open(labels_path, "r") as f:
    class_labels = json.load(f)

# Print the predicted label
print(f"Predicted label: {class_labels[predicted_class.item()]}")

Predicted class index: 340
Predicted label: zebra
