In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from PIL import Image


def predict_image(model, image_path, class_names, device):
    # Define the transformation for the input image
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # Load the image
    image = Image.open(image_path).convert('RGB')

    # Preprocess the image
    image = transform(image).unsqueeze(0)  # Add batch dimension

    # Move the image to the device
    image = image.to(device)

    # Set the model to evaluation mode
    model.eval()

    # Disable gradient calculation
    with torch.no_grad():
        # Get the model predictions
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)

    # Get the predicted class label
    predicted_class = class_names[predicted.item()]

    return predicted_class


# Example usage:
if __name__ == '__main__':
    # Define paths
    data_dir = 'dataset'
    model_path = 'model.pth'
    image_path = 'mustard_bottle.png'

    # Load the test dataset to get class names
    test_dataset = datasets.ImageFolder(
        root=data_dir, transform=transforms.ToTensor())
    class_names = test_dataset.classes

    # Check if GPU is available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Load the trained model
    model = models.resnet18(pretrained=False)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, len(class_names))
    model.load_state_dict(torch.load(model_path))
    model = model.to(device)

    # Predict the class of the image
    predicted_class = predict_image(model, image_path, class_names, device)
    print(f'The model predicted: {predicted_class}')


The model predicted: mustard_bottle


<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=4929c60d-9325-4a04-bca1-550c19632d0a' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>