# Transfer learning

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os

import yaml
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv
from torch.cuda.amp import autocast, GradScaler

import cv2

from tqdm.auto import tqdm

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
!pip install kaggle
from google.colab import files
if 'kaggle.json' not in os.listdir():
  files.upload() #API token
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download 'alessiocorrado99/animals10'
!unzip 'animals10.zip'

## Dataset

In [35]:
transform = tv.transforms.Compose([
    tv.transforms.Resize((224, 224)),
    tv.transforms.ToTensor(),
    tv.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
])

dataset_path = 'raw-img'
dataset_train = tv.datasets.ImageFolder(
    root=dataset_path,
    transform=transform
)


## DataLoader

In [36]:
train_loader = torch.utils.data.DataLoader(
    dataset_train, shuffle=True,
    batch_size=16, num_workers=1, drop_last=True)


In [37]:
for i in train_loader:
    print(i[0].shape)
    print(i[1])
    break

torch.Size([16, 3, 224, 224])
tensor([5, 0, 5, 4, 5, 8, 1, 0, 6, 4, 0, 3, 0, 4, 0, 9])


## Architecture

In [38]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
model_vgg = tv.models.vgg19(weights=tv.models.vgg.VGG19_Weights)

In [None]:
model_vgg

In [62]:
print(count_parameters(model_vgg))
print(count_parameters(model_vgg.classifier))
print(count_parameters(model_vgg.features))

143667240
123642856
20024384


In [70]:
classificator = nn.Sequential(
    nn.Linear(25088, 100),
    nn.LeakyReLU(0.2),
    nn.Linear(100, 10)
)

In [71]:
count_parameters(classificator)

2509910

In [72]:
model_vgg.classifier = classificator

In [None]:
model_vgg

## Train

In [73]:
model_vgg = model_vgg.to(device)
scaler = GradScaler()

loss_function = nn.CrossEntropyLoss()
loss_function = loss_function.to(device)

optimizer = torch.optim.Adam(model_vgg.classifier.parameters(), lr=0.001, betas=(0.9, 0.999))

In [74]:
def train_mode(model, train_loader, params):
    epochs = 10
    loss_history, acc_history = [], []
    for epoch in range(params['epochs']):
        model.train()
        loss_val, acc_train, test_acc = 0, 0, 0
        for sample in (pbar := tqdm(train_loader)):
          img, label = sample[0], sample[1]
          img = img.to(params['device'])
          label = label.to(params['device'])
          label = F.one_hot(label, 10 ).float()
          optimizer.zero_grad()
          with autocast(params['use_amp']):
            pred = model(img)
            loss = loss_function(pred, label)

          scaler.scale(loss).backward()
          loss_item = loss.item()
          loss_val += loss_item

          scaler.step(optimizer)
          scaler.update()

          acc_current = accuracy(pred.cpu().float(), label.cpu().float())
          acc_train += acc_current

          pbar.set_description(f'epoch: {epoch}\tloss: {loss_item:.5f}\taccuracy: {acc_current:.3f}')

        model.eval()

        loss_history.append(loss_val/len(train_loader))
        acc_history.append(acc_train/len(train_loader))
        print(f'loss: {loss_val/len(train_loader)}')
        print(f'train: {acc_train/len(train_loader)}')
    return loss_history, acc_history

def accuracy(pred, label):
    answer = F.softmax(pred.detach()).numpy().argmax(1) == label.numpy().argmax(1)
    return answer.mean()

In [75]:
params = {'epochs': 3,
          'device': 'cuda',
          'use_amp': True}

train_mode(model_vgg, train_loader, params)

  0%|          | 0/1636 [00:00<?, ?it/s]

  answer = F.softmax(pred.detach()).numpy().argmax(1) == label.numpy().argmax(1)


loss: 0.2717819530699029
train: 0.9271088019559902


  0%|          | 0/1636 [00:00<?, ?it/s]

loss: 0.10427597209738344
train: 0.9713860024449877


  0%|          | 0/1636 [00:00<?, ?it/s]

loss: 0.11730264858194697
train: 0.9752444987775061


([0.2717819530699029, 0.10427597209738344, 0.11730264858194697],
 [0.9271088019559902, 0.9713860024449877, 0.9752444987775061])