In [1]:
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
import pandas as pd
from typing import *

PyTorch Version:  2.0.1
Torchvision Version:  0.15.2


In [2]:
model_name = "resnet"
means = None
batch_size = 8
num_epochs = 15
#binary case
num_classes_binary = 2
#multi class
num_classes_category = 37

num_of_classes = num_classes_binary

In [3]:
filename = './data/train/oxford-iiit-pet/annotations/list.txt'
cat_ids_set = set()
dog_ids_set = set()

with open(filename, 'r') as file:
    for line in file.readlines()[7:]:
        split_str = line.split(' ')
        id: int = int(split_str[1])
        species: int = int(split_str[2])

        if species == 1:
            cat_ids_set.add(id)
        else:
            dog_ids_set.add(id)
print(dog_ids_set)
print(cat_ids_set)

{2, 3, 4, 5, 9, 11, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23, 25, 26, 29, 30, 31, 32, 35, 36, 37}
{1, 33, 34, 6, 7, 8, 10, 12, 21, 24, 27, 28}


In [4]:
def transform_label(label):
    return torch.flatten(torch.nn.functional.one_hot(torch.LongTensor([0 if label in cat_ids_set else 1],num_of_classes))).double()

In [5]:
if num_of_classes == 2:
    trainval_dataset = torchvision.datasets.OxfordIIITPet('./data/train/', split='trainval',  download=True,
                                                  transform=torchvision.transforms.Compose([
                                                      # Resize step is required as we will use a ResNet model, which accepts at leats 224x224 images
                                                      torchvision.transforms.Resize((224,224)),
                                                      torchvision.transforms.ToTensor(),
                                                  ]), target_transform = torchvision.transforms.Lambda(transform_label))
    test_dataset = torchvision.datasets.OxfordIIITPet('./data/test/', split='test',  download=True,
                                                     transform=torchvision.transforms.Compose([
                                                          # Resize step is required as we will use a ResNet model, which accepts at leats 224x224 images
                                                          torchvision.transforms.Resize((224,224)),
                                                          torchvision.transforms.ToTensor(),
                                                      ]), target_transform = torchvision.transforms.Lambda(transform_label))

else:
    trainval_dataset = torchvision.datasets.OxfordIIITPet('./data/train/', split='trainval',  download=True,
                                                      transform=torchvision.transforms.Compose([
                                                          # Resize step is required as we will use a ResNet model, which accepts at leats 224x224 images
                                                          torchvision.transforms.Resize((224,224)),
                                                          torchvision.transforms.ToTensor(),
                                                      ]), target_transform = torchvision.transforms.Lambda(transform_label))
    test_dataset = torchvision.datasets.OxfordIIITPet('./data/test/', split='test',  download=True,
                                                     transform=torchvision.transforms.Compose([
                                                          # Resize step is required as we will use a ResNet model, which accepts at leats 224x224 images
                                                          torchvision.transforms.Resize((224,224)),
                                                          torchvision.transforms.ToTensor(),
                                                      ]), target_transform = torchvision.transforms.Lambda(transform_label))

In [6]:
if num_of_classes == 2:
    trainval_dataset = torchvision.datasets.OxfordIIITPet('./data/train/', split='trainval',  download=True,
                                                  transform=torchvision.transforms.Compose([
                                                      # Resize step is required as we will use a ResNet model, which accepts at leats 224x224 images
                                                      torchvision.transforms.Resize((224,224)),
                                                      torchvision.transforms.ToTensor(),
                                                  ]))
    test_dataset = torchvision.datasets.OxfordIIITPet('./data/test/', split='test',  download=True,
                                                     transform=torchvision.transforms.Compose([
                                                          # Resize step is required as we will use a ResNet model, which accepts at leats 224x224 images
                                                          torchvision.transforms.Resize((224,224)),
                                                          torchvision.transforms.ToTensor(),
                                                      ]))

else:
    trainval_dataset = torchvision.datasets.OxfordIIITPet('./data/train/', split='trainval',  download=True,
                                                      transform=torchvision.transforms.Compose([
                                                          # Resize step is required as we will use a ResNet model, which accepts at leats 224x224 images
                                                          torchvision.transforms.Resize((224,224)),
                                                          torchvision.transforms.ToTensor(),
                                                      ]))
    test_dataset = torchvision.datasets.OxfordIIITPet('./data/test/', split='test',  download=True,
                                                     transform=torchvision.transforms.Compose([
                                                          # Resize step is required as we will use a ResNet model, which accepts at leats 224x224 images
                                                          torchvision.transforms.Resize((224,224)),
                                                          torchvision.transforms.ToTensor(),
                                                      ]))

In [6]:
train_split_percentage: float = 0.8
train_split: int = int(train_split_percentage * len(trainval_dataset))
val_split: int = len(trainval_dataset) - train_split

In [7]:
train_dataset, val_dataset = torch.utils.data.random_split(trainval_dataset, lengths=[train_split, val_split])

## Extract mean and std

In [8]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:
train_dataloader.__iter__()

In [11]:
if means is None:
    means = []
    stdevs = []
    for X, _ in train_dataloader:
        # Dimensions 0,2,3 are respectively the batch, height and width dimensions
        means.append(X.mean(dim=(0,2,3)))
        stdevs.append(X.std(dim=(0,2,3)))

    mean = torch.stack(means, dim=0).mean(dim=0)
    stdev = torch.stack(stdevs, dim=0).mean(dim=0)

In [None]:
train_transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224,224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean, stdev)
])
val_transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224,224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean, stdev)
])
test_transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224,224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean, stdev)
])

#target_transform = torchvision.transforms.Compose([
#    lambda x:torch.LongTensor([x]), # or just torch.tensor
#    lambda x:torch.nn.functional.one_hot(x,num_of_classes)])
#target_transform = torchvision.transforms.Lambda(lambda y: torch.zeros(
#    num_of_classes, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

In [None]:
train_dataset.transform = train_transforms
val_dataset.transform = val_transforms
test_dataset.transform = test_transforms

#train_dataset.target_transform = target_transform
#val_dataset.target_transform = target_transform
#test_dataset.target_transform = target_transform

In [None]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:
_, lab = next(iter(train_dataloader))

In [None]:
lab

## Setting up model

In [None]:
num_gpus: int = torch.cuda.device_count()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def get_net():
    resnet = torchvision.models.resnet18(pretrained=True)
    
    # Substitute the FC output layer
    resnet.fc = torch.nn.Linear(resnet.fc.in_features, 10)
    torch.nn.init.xavier_uniform_(resnet.fc.weight)
    return resnet

In [None]:
resnet_18 = get_net()

In [None]:
def train(net, train_dataloader, valid_dataloader, criterion, optimizer, scheduler=None, epochs=10, device='cpu', checkpoint_epochs=10):
    start = time.time()
    print(f'Training for {epochs} epochs on {device}')
    
    for epoch in range(1,epochs+1):
        print(f"Epoch {epoch}/{epochs}")
        
        net.train()  # put network in train mode for Dropout and Batch Normalization
        train_loss = torch.tensor(0., device=device)  # loss and accuracy tensors are on the GPU to avoid data transfers
        train_accuracy = torch.tensor(0., device=device)
        for X, y in train_dataloader:
            X = X.to(device)
            y = y.to(device)
            preds = net(X)
            loss = criterion(preds, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            with torch.no_grad():
                train_loss += loss * train_dataloader.batch_size
                train_accuracy += (torch.argmax(preds, dim=1) == y).sum()
        
        if valid_dataloader is not None:
            net.eval()  # put network in train mode for Dropout and Batch Normalization
            valid_loss = torch.tensor(0., device=device)
            valid_accuracy = torch.tensor(0., device=device)
            with torch.no_grad():
                for X, y in valid_dataloader:
                    X = X.to(device)
                    y = y.to(device)
                    preds = net(X)
                    loss = criterion(preds, y)

                    valid_loss += loss * valid_dataloader.batch_size
                    valid_accuracy += (torch.argmax(preds, dim=1) == y).sum()
        
        if scheduler is not None: 
            scheduler.step()
            
        print(f'Training loss: {train_loss/len(train_dataloader.dataset):.2f}')
        print(f'Training accuracy: {100*train_accuracy/len(train_dataloader.dataset):.2f}')
        
        if valid_dataloader is not None:
            print(f'Valid loss: {valid_loss/len(valid_dataloader.dataset):.2f}')
            print(f'Valid accuracy: {100*valid_accuracy/len(valid_dataloader.dataset):.2f}')
        
        if epoch%checkpoint_epochs==0:
            torch.save({
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, './checkpoint.pth.tar')
        
        print()
    
    end = time.time()
    print(f'Total training time: {end-start:.1f} seconds')
    return net

In [None]:
resnet_18.to(device)

# Standard CrossEntropy Loss for multi-class classification problems
criterion = torch.nn.CrossEntropyLoss()

params_1x = [param for name, param in resnet_18.named_parameters() if 'fc' not in str(name)]

optimizer = torch.optim.Adam(params_1x, lr=0.0001)

resnet_18_trained = train(resnet_18,
                        train_dataloader,
                        val_dataloader,
                        criterion,
                        optimizer,
                        device=device)