In [None]:
from torchvision.io import read_image
from matplotlib import pyplot as plt

import torch
from torch import nn
from torch import optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor

# Models and pre-trained weights
[Additional Reading](https://pytorch.org/vision/stable/models.html)

The torchvision.models subpackage contains definitions of models for addressing different tasks, including: image classification, pixelwise semantic segmentation, object detection, instance segmentation, person keypoint detection, video classification, and optical flow.

In [None]:
from torchvision import models
dir(models)

### Initializing pre-trained models

In [None]:
#if you get ssl error run following
#import ssl
#ssl._create_default_https_context = ssl._create_stdlib_context

resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

resnet.eval() # we will be using the model for evaluation not training. This step is important as some models have different behaviour at traing and eval time

In [None]:
weights=models.ResNet50_Weights.DEFAULT
weights.meta["categories"]

### Using the pre-trained models
Before using the pre-trained models, one must preprocess the image (resize with right resolution/interpolation, apply inference transforms, rescale the values etc). There is no standard way to do this as it depends on how a given model was trained. It can vary across model families, variants or even weight versions. Using the correct preprocessing method is critical and failing to do so may lead to decreased accuracy or incorrect outputs.

All the necessary information for the inference transforms of each pre-trained model is provided on its weights documentation. To simplify inference, TorchVision bundles the necessary preprocessing transforms into each model weight. These are accessible via the weight.transforms attribute:

[Reference](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)

In [None]:
# Initialize the Transforms (preprocessing)
preprocess = models.ResNet50_Weights.DEFAULT.transforms()

# Read image using Pytorch read_image
img = read_image('./data/images/motorbike.jpeg')

### or get the image from internet using the following:
#from PIL import Image
#import requests
#img = Image.open(requests.get('http://farm8.staticflickr.com/7090/7399887950_8845d3e6e4_z.jpg', stream=True).raw)


# Apply the preprocessing to the input image. Preprocess accepts Pytorch tensor or PIL image
batch = preprocess(img).unsqueeze(0) # unsqueeze will reshape the tensor to the correct shape

# Apply the model to the image
prediction = resnet(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}%")

# Display the image
plt.imshow(transforms.ToPILImage()(img))
#plt.imshow(img) #if the img is already a PIL image

In [None]:
resnet

## Fine tuning
Rather than training from scratch, the preferred technique is transfer learning, achieved by fine-tuning pre-trained models on custom datasets. By following this approach we use their existing knowledge and tailor them to our specific tasks, thereby conserving significant time and computational resources

In [None]:
num_classes = 10 # let say we want to finetune resnet on a dataset containing 10 classes

# Load pre-trained model 
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
preprocess = models.ResNet50_Weights.DEFAULT.transforms()

# Freeze layer weights
for param in model.parameters():
    param.requires_grad = False
    
# Modify the model head for fine-tuning
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, num_classes),
                        

In [None]:
torch.manual_seed(0)
transform=transforms.Compose([transforms.ToTensor(),
                              preprocess])
trainset = datasets.CIFAR10(
    root='~/Downloads/',
    train=True,
    download=True,
    transform=transform
)

testset = datasets.CIFAR10(
    root='~/Downloads/',
    train=False,
    download=True,
    transform=transform
)

batchsize = 64
trainloader = DataLoader(trainset, batch_size=batchsize, shuffle=True)
testloader = DataLoader(testset, batch_size=batchsize, shuffle=False)

In [None]:
device = torch.device('mps')
criteria = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.001)
train_history = []
val_history = []

In [None]:
# Training loop
model.to(device)
model.train() # tell the model that your are trainin the model

for epoch in range(10):
    train_loss = 0.0
    for i, data in enumerate(trainloader):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()  
        outputs = model(inputs)
        
        loss = criteria(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        train_loss += loss.item()

    # validation
    with torch.no_grad():
        val_loss = 0
        for data in testloader:
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = criteria(outputs, labels)
            val_loss += loss.item()
            
    print(f'Epoch [{epoch}], train loss: {train_loss/len(trainset)}, val loss: {val_loss/len(testset)}')        
    train_history += [train_loss/len(trainset)]
    val_history += [val_loss/len(testset)]
print("Finished Training")

In [None]:
from matplotlib import pyplot as plt
plt.plot(train_history, 'b')
plt.plot(val_history, 'r')
plt.title('Convergence plot of gradient descent')
plt.xlabel('No of Epochs')
plt.ylabel('J')
plt.legend('train loss', 'val loss')
plt.show()

In [None]:
transform_back = transforms.Compose([transforms.Normalize((-1.,-1.,-1.),(2.,2.,2.)), 
                            transforms.ToPILImage()])

idx_to_class = {value: key for key, value in trainset.class_to_idx.items()}

images, labels = next(iter(testloader))
images = images.to(device)
labels = labels.to(device)
    
outputs = model(images)
_, predicted = torch.max(outputs, dim=1)

plt.figure(figsize=(20,30))
for i in range(10):
    plt.subplot(1,10,i+1)
    plt.tight_layout()
    plt.imshow(transform_back(images[i]))
    plt.axis('off')
    plt.title(idx_to_class[predicted[i].item()])
    print(f"Actual {idx_to_class[labels[i].item()]}\tPredicted: {idx_to_class[predicted[i].item()]}")
plt.show()

In [None]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %%')

In [None]:
model