In [1]:
import torch
import torchvision
import cv2
from PIL import Image
from torchvision import transforms
from typing import List
import PIL
from categories import _IMAGENET_CATEGORIES

In [2]:
from settings import settings

from config import TrainingConfig, MODELS_ROOT

torch.hub.set_dir(MODELS_ROOT)

In [3]:
# Define transforms to preprocess the image
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the ResNet-18 pre-trained model
model = torchvision.models.resnet18(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /data/models/checkpoints/resnet18-f37072fd.pth
100.0%


In [4]:
def get_image_tensor(filename):
    image = cv2.imread(f"images/{filename}", cv2.COLOR_BGR2RGB)
    image = PIL.Image.fromarray(image)
    return transform(image)

In [5]:
# Load an image and preprocess it
def predict(filename):
    
    input_batch = get_image_tensor(filename).unsqueeze(0) # create a mini-batch as expected by the model
    input_batch = input_batch.to('cuda')
    model.to('cuda')
    model.eval()
    
    with torch.no_grad():
        output = model(input_batch)
        
    _, predicted = torch.topk(output, 2)
    
    for num in predicted[0]:
        print('Predicted class: ', _IMAGENET_CATEGORIES[num])
        
    return predicted
    

In [6]:
predicted = predict("image.jpg")

Predicted class:  brambling
Predicted class:  bulbul


In [7]:
predicted = predict("car.jpg")

Predicted class:  sports car
Predicted class:  car wheel


In [8]:
predicted = predict("horse.jpg")

Predicted class:  Saluki
Predicted class:  whippet


In [9]:
predicted = predict("cow.jpg")

Predicted class:  ox
Predicted class:  oxcart
