# 3.1 Training the fine-tuned models

Code is based on the following Kaggle Notebook:
https://www.kaggle.com/code/sovitrath/caltech-256-amp-pytorch

In [None]:
!pip install albumentations
!pip install pretrainedmodels
!pip install torchmetrics

In [None]:
import torch
print(torch.__version__)
print(torch.cuda.get_device_name(0))

In [None]:
# Watermark detection

import time
import copy
import numpy as np
import matplotlib.pyplot as plt
import os
import glob

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

torch.manual_seed(17)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

transforms = torchvision.transforms.Compose([
                           torchvision.transforms.ToTensor(),
                           torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225])
                       ])


from tqdm import tqdm
from torch.utils.data import DataLoader,Dataset
import cv2

class_names = ['baseline',
               'chinese',
               'latin',
               'hindi',
               'arabic_numerals'
              ]

class ImageDataset(Dataset):
    def __init__(self,root,transform):
        self.root=root
        self.transform=transform

        self.image_names=glob.glob(self.root + '*.JPEG')
        self.image_names.sort()
   
    #The __len__ function returns the number of samples in our dataset.
    def __len__(self):
        return len(self.image_names)
 
    def __getitem__(self,index):
        image=cv2.imread(self.image_names[index])
        image=cv2.cvtColor(image,cv2.COLOR_BGR2RGB)

        image=self.transform(image)

        return image
    
    
with torch.no_grad():
    activation = {}
    model_name = 'densenet161'
    model = torchvision.models.densenet161(pretrained = True).to(device)
    d = 2208
    
    def get_activation(name):
            def hook(model, input, output):
                activation[name] = output.mean(axis = [2,3])
            return hook
    model.features.register_forward_hook(get_activation('features'))
    model.eval()

    logit_scores = torch.zeros([998, 5, d])

    for c, class_name in enumerate(class_names):
        dataset = ImageDataset('../dataset/{class_name}/'.format(class_name = class_name),
                               transforms
                              )
        testloader = torch.utils.data.DataLoader(dataset,
                                      batch_size=512,
                                      shuffle=False,
                                      num_workers=2)

        counter = 0
        with torch.no_grad():
            for i, x in tqdm(enumerate(testloader)):
                x = x.float().data.to(device)

                outputs = model(x)
                logit_scores[counter:counter + x.shape[0],c,:] = activation["features"]

                counter += x.shape[0]

    torch.save(logit_scores, '../activations/{name}_features_wtrmrks.tnsr'.format(name = model_name))

In [None]:
import torchmetrics
d = 2208
target = torch.tensor([0 if x<998 else 1 for x in range(2*998)])

k= 0
with torch.no_grad():
    acitvations = torch.load('../activations/densenet161_features_wtrmrks.tnsr')
    aucs = torch.zeros([4, d])

    for j in range(4):
        for r in range(d):
            aucs[j, r] = torchmetrics.functional.classification.binary_auroc(torch.cat((acitvations[:, 0, r], acitvations[:, j+1, r]), 0), target)

In [None]:
%%writefile preprocess.py

import os
import pandas as pd
import numpy as np

from tqdm import tqdm
from sklearn.preprocessing import LabelBinarizer


# put the path to the dataset
root_dir = '../input/caltech256/256_ObjectCategories'
# get all the folder paths
all_paths = os.listdir(root_dir)

# create a DataFrame
data = pd.DataFrame()

images = []
labels = []
counter = 0
for folder_path in tqdm(all_paths, total=len(all_paths)):
    # get all the image names in the particular folder
    image_paths = os.listdir(f"{root_dir}/{folder_path}")
    # get the folder as label
    label = folder_path.split('.')[-1]
    
    if label == 'clutter':
        continue

    # save image paths in the DataFrame
    for image_path in image_paths:
        if image_path.split('.')[-1] == 'jpg':
            data.loc[counter, 'image_path'] = f"{root_dir}/{folder_path}/{image_path}"
            labels.append(label)
            counter += 1

labels = np.array(labels)
# one-hot encode the labels
lb = LabelBinarizer()
labels = lb.fit_transform(labels)

# add the image labels to the dataframe
for i in range(len(labels)):
    index = np.argmax(labels[i])
    data.loc[i, 'target'] = int(index)
    
# shuffle the dataset
data = data.sample(frac=1).reset_index(drop=True)

print(f"Number of labels or classes: {len(lb.classes_)}")
print(f"The first one hot encoded labels: {labels[0]}")
print(f"Mapping the first one hot encoded label to its category: {lb.classes_[0]}")
print(f"Total instances: {len(data)}")
 
# save as CSV file
data.to_csv('data.csv', index=False)
 
print(data.head(5))

In [None]:
!python preprocess.py

In [None]:
%%writefile dataset.py

import albumentations
import numpy as np
import torch

from PIL import Image
from torch.utils.data import Dataset

# custom dataset
class ImageDataset(Dataset):
    def __init__(self, images, labels=None, tfms=None):
        self.X = images
        self.y = labels

        # apply augmentations
        if tfms == 0: # if validating
            self.aug = albumentations.Compose([
                albumentations.Resize(224, 224, always_apply=True),
            ])
        else: # if training
            self.aug = albumentations.Compose([
                albumentations.Resize(224, 224, always_apply=True),
                albumentations.HorizontalFlip(p=0.5),
                albumentations.ShiftScaleRotate(
                    shift_limit=0.3,
                    scale_limit=0.3,
                    rotate_limit=15,
                    p=0.5
                ),
            ])
         
    def __len__(self):
        return (len(self.X))
    
    def __getitem__(self, i):
        image = Image.open(self.X[i])
        image = image.convert('RGB')
        image = self.aug(image=np.array(image))['image']
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
        label = self.y[i]
        return {
            'image': torch.tensor(image, dtype=torch.float), 
            'target': torch.tensor(label, dtype=torch.long)
        }

In [None]:
%%writefile model.py

import pretrainedmodels
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
    
class DenseNet161(nn.Module):
    def __init__(self, percentile = 0.):
        super(DenseNet161, self).__init__()
        self.model = torchvision.models.densenet161(pretrained = True)
        
        aucs = torch.load("aucs.tnsr")
        
        for param in self.model.parameters():
            param.requires_grad = False
        
        if percentile == 0.:
            self.ommited_indx = None
        else:
            a = torch.quantile((aucs[0]-0.5).abs(), 1-percentile, dim=0, keepdim=True)
            self.ommited_indx = ((aucs[0]-0.5).abs() > a).nonzero()
        
        self.l0 = nn.Linear(2208, 256)

    def forward(self, x):
        batch, _, _, _ = x.shape
        x = self.model.features(x)
        x = F.adaptive_avg_pool2d(x, 1).reshape(batch, -1)
        if self.ommited_indx is not None:
            x[:, self.ommited_indx] = 0
        l0 = self.l0(x)
        return l0

# model = DenseNet161()
# print(model)

In [None]:
%%writefile engine.py

from tqdm import tqdm

import torch

# training function
def fit(model, dataloader, optimizer, criterion, train_data, device, use_amp):
    print('Training')
    if use_amp == 'yes':
        scaler = torch.cuda.amp.GradScaler() 

    model.train()
    train_running_loss = 0.0
    train_running_correct = 0
    for i, data in tqdm(enumerate(dataloader), total=int(len(train_data)/dataloader.batch_size)):
        data, target = data['image'].to(device), data['target'].to(device)
        optimizer.zero_grad()
        
        if use_amp == 'yes':
            with torch.cuda.amp.autocast():
                outputs = model(data)
                loss = criterion(outputs, target)
        
        elif use_amp == 'no':
            outputs = model(data)
            loss = criterion(outputs, target)
            
        train_running_loss += loss.item()
        _, preds = torch.max(outputs.data, 1)
        train_running_correct += (preds == target).sum().item()
        
        if use_amp == 'yes':
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        
        elif use_amp == 'no':
            loss.backward()
            optimizer.step()
        
    train_loss = train_running_loss/len(dataloader.dataset)
    train_accuracy = 100. * train_running_correct/len(dataloader.dataset)    
    return train_loss, train_accuracy

# validation function
def validate(model, dataloader, optimizer, criterion, val_data, device, use_amp):
    print('Validating')
    if use_amp == True:
        scaler = torch.cuda.amp.GradScaler() 
        
    model.eval()
    val_running_loss = 0.0
    val_running_correct = 0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), total=int(len(val_data)/dataloader.batch_size)):
            data, target = data['image'].to(device), data['target'].to(device)
            
            if use_amp == 'yes':
                with torch.cuda.amp.autocast():
                    outputs = model(data)
                    loss = criterion(outputs, target)
        
            elif use_amp == 'no':
                outputs = model(data)
                loss = criterion(outputs, target)
            
            val_running_loss += loss.item()
            _, preds = torch.max(outputs.data, 1)
            val_running_correct += (preds == target).sum().item()
        
        val_loss = val_running_loss/len(dataloader.dataset)
        val_accuracy = 100. * val_running_correct/len(dataloader.dataset)        
        return val_loss, val_accuracy

In [None]:
%%writefile train.py

from sklearn.model_selection import train_test_split
from model import DenseNet161
from dataset import ImageDataset
from torch.utils.data import DataLoader
from engine import fit, validate

import torch.optim as optim
import time
import torch.nn as nn
import argparse
import pandas as pd
import matplotlib
import torch 
import matplotlib.pyplot as plt

# build and parse the argument parser
parser = argparse.ArgumentParser()
parser.add_argument('-b', '--batch-size', dest='batch_size', type=int, 
                    help='batch size for the dataset', default=512)
parser.add_argument('-p', '--percentile', dest='percentile', type=float, 
                    help='percentile for ommiting representations', default=0.)
parser.add_argument('-a', '--use-amp', dest='use_amp', 
                    help='to use Automatic Mixed Precision or not',
                    default='yes', choices=['yes', 'no'])
args = vars(parser.parse_args())

# learning parameters
batch_size = args['batch_size']
print(f"Batch size: {batch_size}")

percentile = args['percentile']
print(f"Percentile: {percentile}")

epochs = 15
lr = 0.0001
use_amp = args['use_amp']
print(f"Use AMP: {use_amp}")

# computation device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# get the dataset ready
df = pd.read_csv('data.csv')
X = df.image_path.values # image paths
y = df.target.values # targets
(xtrain, xtest, ytrain, ytest) = train_test_split(X, y,
	test_size=0.10, random_state=42)
print(f"Training instances: {len(xtrain)}")
print(f"Validation instances: {len(xtest)}")
train_data = ImageDataset(xtrain, ytrain, tfms=1)
test_data = ImageDataset(xtest, ytest, tfms=0)
# dataloaders
train_data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_data_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# model
model = DenseNet161(percentile = percentile)
model.to(device)
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# loss function
criterion = nn.CrossEntropyLoss()
# total parameters and trainable parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} trainable parameters.")

train_loss , train_accuracy = [], []
val_loss , val_accuracy = [], []
if use_amp == 'yes':
    print('Tranining and validating with Automatic Mixed Precision')
elif use_amp == 'no':
    print('Tranining and validating without Automatic Mixed Precision')
    
start = time.time()
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss, train_epoch_accuracy = fit(model, train_data_loader, 
                                                 optimizer, criterion, 
                                                 train_data, device, use_amp)
    val_epoch_loss, val_epoch_accuracy = validate(model, valid_data_loader, 
                                                 optimizer, criterion, 
                                                 test_data, device, use_amp)
    train_loss.append(train_epoch_loss)
    train_accuracy.append(train_epoch_accuracy)
    val_loss.append(val_epoch_loss)
    val_accuracy.append(val_epoch_accuracy)
    print(f"Train Loss: {train_epoch_loss:.4f}, Train Acc: {train_epoch_accuracy:.2f}")
    print(f'Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_accuracy:.2f}')
end = time.time()

print(f"Took {((end-start)/60):.3f} minutes to train for {epochs} epochs")
    
# save model checkpoint
torch.save({
            'epoch': epochs,
            'model_state_dict': model.state_dict(),
            'model': model,
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': criterion,
            }, f"{percentile}_model.pth")

# accuracy plots
plt.figure(figsize=(10, 7))
plt.plot(train_accuracy, color='green', label='train accuracy')
plt.plot(val_accuracy, color='blue', label='validataion accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig(f"amp_{use_amp}_accuracy.png")
plt.show()
 
# loss plots
plt.figure(figsize=(10, 7))
plt.plot(train_loss, color='orange', label='train loss')
plt.plot(val_loss, color='red', label='validataion loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig(f"amp_{use_amp}_loss.png")
plt.show()

In [None]:
torch.cuda.empty_cache()

In [None]:
import os

for percentile in [0, 0.005, 0.01, 0.02, 0.03, 0.05, 0.1, 0.15, 0.25, 0.5]:
    torch.cuda.empty_cache()
    os.system('python train.py --batch-size 256 --percentile {percentile} --use-amp no'.format(percentile = percentile))