<a href="https://colab.research.google.com/github/matiasguzmanp/vit-yoga-82/blob/main/vit_train_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Vision Transformer on Yoga82

Author: Matías Guzmán Parra

## Download the repo and install dependencies

In [1]:
!git clone https://github.com/matiasguzmanp/vit-yoga-82

Cloning into 'vit-yoga-82'...
remote: Enumerating objects: 68, done.[K
remote: Counting objects: 100% (68/68), done.[K
remote: Compressing objects: 100% (52/52), done.[K
remote: Total 68 (delta 21), reused 48 (delta 12), pack-reused 0[K
Receiving objects: 100% (68/68), 144.85 KiB | 1.11 MiB/s, done.
Resolving deltas: 100% (21/21), done.


In [2]:
!pip install wandb onnx -Uq

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.7/15.7 MB[0m [31m46.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.6/190.6 kB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m254.1/254.1 kB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
import sys
sys.path.insert(0,'/content/vit-yoga-82')

## Data

Assuming that the dataset is in a folder called `./Images` and the file `./Yoga-82.rar` is in the root folder. If not, we can download it.


In [4]:
!gdown --id 1Jc-Dbg2oOPHuwEzvnaC5tJGUUL2_BP46

Downloading...
From: https://drive.google.com/uc?id=1Jc-Dbg2oOPHuwEzvnaC5tJGUUL2_BP46
To: /content/Images.rar
100% 3.10G/3.10G [00:38<00:00, 81.4MB/s]


In [None]:
!unrar x "./Images.rar"

In [6]:
!gdown --id 1jcRgz_mgFiWw5VtchUbxdS8b1oGm7PWF

Downloading...
From: https://drive.google.com/uc?id=1jcRgz_mgFiWw5VtchUbxdS8b1oGm7PWF
To: /content/Yoga-82.rar
100% 1.03M/1.03M [00:00<00:00, 8.64MB/s]


In [None]:
!unrar x "./Yoga-82.rar"

Then, we can clean the dataset from bad images and create a new `.csv` files that stores the correct images

In [None]:
from data.clean import clean_dataset

clean_dataset(csv_path = "./Yoga-82/yoga_train.txt").to_csv("train_dataframe.csv", index=False)
clean_dataset(csv_path = "./Yoga-82/yoga_test.txt").to_csv("test_dataframe.csv", index=False)

## Train

We need `wandb` to monitor the training

In [10]:
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

We can calculate `mean`, `std` and `weights` of the train data. It could take some time

In [11]:
from data.measure import mean_and_std_calculator, compute_weights
from data.dataset import Yoga82
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
import numpy as np

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

dataset = Yoga82(train_val_test="train", csv_path="./train_dataframe.csv", transform=transform, n_classes=82)

loader = DataLoader(dataset,
                        batch_size=10,
                        num_workers=0,
                        shuffle=False,
                        drop_last=False)

mean, std = mean_and_std_calculator(loader)
weights = compute_weights(loader)



## Training
We can start the training process of the `ViT` with 82 classes, `Adam` optimizer and `CrossEntropyLoss`

In [12]:
from utils.train import train, make
from utils.test import test

from data.dataset import Yoga82
from data.measure import compute_weights, mean_and_std_calculator

from torch.utils.data import DataLoader

import torchvision.transforms as transforms
import torch

In [22]:
def model_pipeline(hyperparameters, train_mean, train_std, class_weights):
    # tell wandb to get started
    with wandb.init(project="vit-yoga82", config=hyperparameters):
      # access all HPs through wandb.config, so logging matches execution!
      config = wandb.config

      # make the model, data, and optimization problem
      model, train_loader, val_loader, test_loader, criterion, optimizer = make(config, train_mean, train_std, class_weights, n_classes=config.n_classes)
      print(model)

      # and use them to train the model
      train_loss, val_loss = train(model, train_loader, val_loader, criterion, optimizer, config)

      # and test its final performance
      conf_mat, acc = test(model, test_loader,device=config.device)

    return model, train_loss, val_loss, conf_mat, acc

In [23]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

config = dict(
    chw = (3,128,128),
    patch_size = 8,
    D = 768,
    n_classes = 82,
    heads = 12,
    layers = 12,
    epochs = 1,
    lr = 1e-5,
    batch_size = 32,
    device = device
    )

In [None]:
model, train_loss, val_loss, conf_mat, acc = model_pipeline(config,
                                                  train_mean = mean,
                                                  train_std = std,
                                                  class_weights = weights)

## Plotting results

In [26]:
import matplotlib.pyplot as plt
import seaborn as sns


def plot_loss(train_loss, val_loss, info):
  plt.figure()
  plt.plot(train_loss, label="Loss de entrenamiento")
  plt.plot(val_loss, label="Loss de validación")
  plt.legend()
  plt.grid("on")
  plt.title(f"Loss de entrenamiento y validación en función de la época: {info}")


def plot_conf_mat(conf_mat, acc, info):
  plt.figure(figsize=(10,8))
  sns.heatmap(conf_mat)
  plt.title(f"Matriz de confusión {info}.\nAccuracy={acc:.4f}")
  plt.xlabel('Predichas')
  plt.ylabel('Reales')
  plt.show()