# Finetuning ViT

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Setting

In [4]:
!pip install timm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm
  Downloading timm-0.6.13-py3-none-any.whl (549 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m549.1/549.1 KB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting huggingface-hub
  Downloading huggingface_hub-0.13.3-py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.8/199.8 KB[0m [31m27.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: huggingface-hub, timm
Successfully installed huggingface-hub-0.13.3 timm-0.6.13


In [5]:
!pip install wandb -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m27.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.8/194.8 KB[0m [31m22.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.3/184.3 KB[0m [31m21.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 KB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [6]:
import wandb
# option
api_key =  
wandb.login(key=api_key)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import os

import torchvision

import timm

## Dataset : CIFAR10

In [8]:
import torchvision
import torchvision.transforms as transforms

device = torch.device("cuda")


transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    # transforms.RandomCrop(32, padding=4),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=2)
     

Files already downloaded and verified
Files already downloaded and verified


## Model ViT
* vit_tiny_patch16_224
* vit_tiny_patch16_384
* vit_small_patch16_224
* vit_small_patch16_384
* vit_small_patch32_224
* vit_small_patch32_384
* vit_base_patch16_224
* vit_base_patch16_384
* vit_base_patch32_224
* vit_base_patch32_384
* vit_large_patch16_224
* vit_large_patch16_384
* vit_large_patch32_224
* vit_large_patch32_384


In [9]:
def requires_grad(model):
  model_name_list = [name for name, _ in model.named_parameters()][:]
  model_name_list = [model_name_list[-1], model_name_list[-2]]
  for name, param in model.named_parameters():
    if name in model_name_list:
      continue
    param.requires_grad = False

  return model

In [None]:
# model load
model_name = 'vit_base_patch16_224'
model = timm.create_model(model_name, pretrained=True, num_classes = 10)

model = requires_grad(model)

for name, param in model.named_parameters():
  print(name, param.requires_grad)

## Train

In [11]:
def train(epoch, net, optimizer, criterion, device):
    print('\n[ Train epoch: %d ]' % epoch)
    net.to(device)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()

        optimizer.step()
        train_loss += loss.item()
        _, predicted = outputs.max(1)

        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        if batch_idx % 100 == 0:
            print('\nCurrent batch:', str(batch_idx))
            print('Current benign train accuracy:', str(predicted.eq(targets).sum().item() / targets.size(0)))
            print('Current benign train loss:', loss.item())

    train_acc = 100. * correct / total
    train_loss = loss / total
    print('\nTotal benign train accuarcy:', 100. * correct / total)
    print('Total benign train loss:', train_loss)

    try:
      wandb.log({'Train Accuracy': train_acc, 'Tran Loss': train_loss, 'epoch':epoch})
    except:
      print('No wandb')
      pass


best_test_acc = 0
def test(epoch, net, criterion, file_name):
    print('\n[ Test epoch: %d ]' % epoch)
    net.eval()
    loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(test_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        total += targets.size(0)

        outputs = net(inputs)
        loss += criterion(outputs, targets).item()

        _, predicted = outputs.max(1)
        correct += predicted.eq(targets).sum().item()

    test_acc = 100. * correct / total
    test_loss = loss / total
    print('\nTest accuarcy:', 100. * correct / total)
    print('Test average loss:', loss / total)

    # epoch 5마다 저장 & best model 저장
    if epoch % 5 == 0:
      state = {
          'net': net.state_dict()
      }
      if not os.path.isdir('checkpoint'):
          os.mkdir('checkpoint')
      torch.save(state, f'./checkpoint/epoch{epoch}_{file_name}')
      print('Model Saved!')

    # best model 저장
    if best_test_acc < test_acc:
      state = {
          'net': net.state_dict()
      }
      if not os.path.isdir('checkpoint'):
          os.mkdir('checkpoint')
      torch.save(state, f'./checkpoint/best_mode.pt')
      print('Best Model Saved!')

    try:
      wandb.log({'Test Accuracy': test_acc, 'Test Loss': test_loss, 'epoch':epoch})
    except:
      print('No wandb')
      pass


In [None]:
# fintuning vit
CONFIG = {
    'model':model_name, 
    'batch_size':128,
    'epoch': 50,
    'pretrained': 'imagenet',
    'optimizer':'sgd',
    'lr':1e-4,
    'momentum':0.9,
    'weight_decay':0.0002,
    'criterion':'crossentropy'
}

criterion = nn.CrossEntropyLoss()
# teacher_optimizer = optim.Adam(teacher_net.parameters(), lr=1e-4)
optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=0.0002)


wandb.init(project='knowledge_distillation', config=CONFIG, name=model_name)

for epoch in range(51):
    # adjust_learning_rate(optimizer, epoch)
    train(epoch, model, optimizer, criterion, device)
    test(epoch, model, criterion, f'{model_name}.pt')

wandb.finish()