# Torch vision example
Source: https://www.learnopencv.com/pytorch-for-beginners-image-classification-using-pre-trained-models/

In [None]:
from torchvision import models
import torch

# Let us look at the Deep learning architectures implemented in the torch vision library.
dir(models)

Notice that there is one entry called **AlexNet** and one called **alexnet**. The capitalised name refers to the Python class (AlexNet) whereas alexnet is a convenience function that returns the model instantiated from the AlexNet class.

In [None]:
alexnet = models.alexnet(pretrained=True) # This will download the weights for the network first time it is run!
alexnet

In [None]:
from PIL import Image
from skimage import io, transform
import matplotlib.pyplot as plt

# img = io.imread('img/cat.jpg')
img = Image.open('img/cat.jpg')
#img = Image.open('img/hamster.jpg')
#img = Image.open('img/centipede.jpg')
# plt.imshow(img)
img

In [None]:
s = 255
img.resize((256,256))

In [None]:
from torchvision import transforms
transform = transforms.Compose([        # Defining a variable transforms
 transforms.Resize(256),                # Resize the image to 256×256 pixels
 transforms.CenterCrop(224),            # Crop the image to 224×224 pixels about the center
 transforms.ToTensor(),                 # Convert the image to PyTorch Tensor data type
 transforms.Normalize(                  # Normalize the image
 mean=[0.485, 0.456, 0.406],            # Mean and std of image as also used when training the network
 std=[0.229, 0.224, 0.225]      
 )])

In [None]:
img_t = transform(img)
img_t.shape

In [None]:
batch_t = torch.unsqueeze(img_t, 0)

In [None]:
alexnet.eval()

In [None]:
out = alexnet(batch_t)
out.shape

In [None]:
with open('imagenet_classes.txt') as f:
  classes = [line.strip() for line in f.readlines()]
print("Number of classes: {}".format(len(classes)))
classes

In [None]:
_, indices = torch.sort(out, descending=True)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
[(classes[idx], percentage[idx].item()) for idx in indices[0][:10]]