# Image Classification with ResNet50


In [1]:
# @title # Run the following cell to install the necessary libraries for this practical. { display-mode: "form" }
# @markdown Don't worry about what's in this collapsed cell


import urllib.request
import os
%pip install - q Pillow


if not os.path.exists('data'):
    os.mkdir('data')

# Download duck.jpg
if not os.path.exists('data/duck.jpg'):
    print('Downloading duck.jpg...')
    urllib.request.urlretrieve(
        'https://s3-eu-west-1.amazonaws.com/aicore-portal-public-prod-307050600709/practicals_files/f0c57e1d-f903-496d-8561-002c618a1c7d/duck.jpg', 'data/duck.jpg')

# Download imagenet_classes.txt
if not os.path.exists('data/imagenet_classes.txt'):
    print('Downloading imagenet_classes.txt...')
    urllib.request.urlretrieve(
        'https://s3-eu-west-1.amazonaws.com/aicore-portal-public-prod-307050600709/practicals_files/f0c57e1d-f903-496d-8561-002c618a1c7d/imagenet_classes.txt', 'data/imagenet_classes.txt')


Note: you may need to restart the kernel to use updated packages.


ERROR: Invalid requirement: '-'


Using a pre-trained image classifier is a quick and efficient way to classify images by content .It involves loading a pretrianed model that has been trained on a large dataset, and then using it to make predictions on new images. Models such as ResNet have already been trained to classify a large range of objects, and so the pretrained model can often be used without any additional training for basic classification tasks.


## Import Dependencies


In [2]:
# import dependencies
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image

# utils
def open_image(path: str, transform) -> torch.Tensor:
    """Get tensor of a local image and apply a transform

    Args:
        path (str): Path to image
        transform (Any): Transform applied to an image tensor

    Returns:
        torch.Tensor: Transformed image tensor
    """
    
    img = Image.open(path).convert('RGB')
    img_tensor = transform(img).unsqueeze(0)
    
    return img_tensor


## Image Transforms

To use the pre-trained model, it is important to transform the image so as to present it to the model in a format which is compatible with the model's architecture, and also reflects the feature engineering used to train the model. For this we can use the `torchvision.transforms` module.

The following codeblock uses the `transforms.compose` class to compose a sequence of transforms. Important considerations when using transforms for pre-trained images include ensuring that the image is of the correct size to match the input layer, and applying any transforms used on the original training set.


In [3]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


## Load an image from your directory structure


In [8]:
img_tensor = open_image("data/duck.jpg", transform)


To make sense of the model's output, it is necessary to have a decoder - essentially a dictionary where the keys are the integer labels used by the classifier, and the values are the human-readable class nammes. The codeblock below loads in the classes as a list, with the keys implicit in the index position.


In [9]:
# get imagenet classes
with open("data/imagenet_classes.txt", "r") as f:
    classes = [line.strip() for line in f.readlines()]

classes[1:20]


['goldfish',
 'great white shark',
 'tiger shark',
 'hammerhead',
 'electric ray',
 'stingray',
 'cock',
 'hen',
 'ostrich',
 'brambling',
 'goldfinch',
 'house finch',
 'junco',
 'indigo bunting',
 'robin',
 'bulbul',
 'jay',
 'magpie',
 'chickadee']

In [10]:
model = torchvision.models.resnet18(pretrained=True)
model.eval()
with torch.no_grad():
    output = model(img_tensor)




In [11]:
dummy, pred = torch.max(output, 1)
print("Prediction label: ", pred)
class_label = classes[pred]
print("Prediction category: ", class_label)


Prediction label:  tensor([285])
Prediction category:  Egyptian cat
