In [None]:
cd /content/drive/MyDrive/workspace/lotte/

In [None]:
!unzip ./dataset/LPD_competition.zip -d /content

## 라이브러리 import

In [None]:
!pip install pytorch_pretrained_vit

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from pytorch_pretrained_vit import ViT
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from tqdm.auto import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import models
from glob import glob
import cv2
from PIL import Image
import torch.nn.functional as F
from sklearn.model_selection import KFold, StratifiedKFold
from torch.optim.lr_scheduler import ReduceLROnPlateau 

## Config Setting

In [None]:
#CONFIG
torch.manual_seed(128)
BATCH_SIZE=50
EPOCHS=40
LEARNING_RATE=1e-6
#DEVICE
print(f'PyTorch Version : [{torch.__version__}]')
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'Device : [{device}]')

## Custom Datasets

In [None]:
class LotteDataset(Dataset):
  def __init__(self, data_root, train_mode):
    super(LotteDataset, self).__init__()
    self.train_mode=train_mode

    if self.train_mode==False:
      self.img_list = glob(os.path.join(data_root, '*.jpg'))
      self.img_list.sort(key=lambda x:int(x.split('/')[3][:-4]))
    else:
      self.img_list = glob(os.path.join(data_root, '*/*.jpg'))
      self.train_y=[]
      for img_path in self.img_list:
        self.train_y.append(int(img_path.split('/')[3]))
        
    self.len = len(self.img_list)

  def __getitem__(self, index):
    img_path = self.img_list[index]
    if self.train_mode:
      label=int(img_path.split('/')[3])
    # Image Loading
    img = Image.open(img_path)

    if self.train_mode:
      return img,label
    else:
      return img

  def __len__(self):
    return self.len

In [None]:
class MapTransform(Dataset):
    def __init__(self, dataset, transform, train_mode):
        self.dataset = dataset
        self.transform=transform
        self.train_mode=train_mode

    def __getitem__(self, index):
        if self.train_mode:
          return self.transform(self.dataset[index][0]), self.dataset[index][1]
        else:
          return self.transform(self.dataset[index])

    def __len__(self):
        return len(self.dataset)

In [None]:
train_transforms=transforms.Compose([
    transforms.RandomChoice([
        transforms.ColorJitter(brightness=(1,1.1)),
        transforms.ColorJitter(contrast=0.1), 
        transforms.ColorJitter(saturation=0.1),
    ]),
    transforms.RandomChoice([
        transforms.RandomAffine(degrees=15, translate=(0.2, 0.2), scale=(0.8, 1.2), shear=10, resample=Image.BILINEAR,fill=255),
        transforms.RandomCrop((224,224)),
    ]),
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])
test_transforms=transforms.Compose([transforms.ToTensor(),
                                    transforms.Resize((224,224)),
                                    transforms.Normalize([0.485, 0.456, 0.406],
                                                          [0.229, 0.224, 0.225])])

train_data=LotteDataset('/content/train',train_mode=True)
test_data=LotteDataset('/content/test',train_mode=False)

trans_train_data=MapTransform(train_data,train_transforms,train_mode=True)
trans_test_data=MapTransform(test_data,test_transforms,train_mode=False)

train_iter=DataLoader(trans_train_data,batch_size=BATCH_SIZE,shuffle=True,num_workers=2)
test_iter=DataLoader(trans_test_data,batch_size=BATCH_SIZE,shuffle=False,num_workers=2)

  "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"


### CutMix

In [None]:
def rand_bbox(W, H, lam):
    cut_rat = torch.sqrt(1.0 - lam)
    cut_w = (W * cut_rat).type(torch.long)
    cut_h = (H * cut_rat).type(torch.long)
    # uniform
    cx = torch.randint(W, (1,)).to(device)
    cy = torch.randint(H, (1,)).to(device)
    x1 = torch.clamp(cx - cut_w // 2, 0, W)
    y1 = torch.clamp(cy - cut_h // 2, 0, H)
    x2 = torch.clamp(cx + cut_w // 2, 0, W)
    y2 = torch.clamp(cy + cut_h // 2, 0, H)
    return x1, y1, x2, y2


def cutmix_data(x, y, alpha=1.0, p=0.5):
    if np.random.random() > p:
        return x, y, torch.zeros_like(y), 1.0
    W, H = x.size(2), x.size(3)
    shuffle = torch.randperm(x.size(0)).to(device)
    cutmix_x = x

    lam = torch.distributions.beta.Beta(alpha, alpha).sample().to(device)

    x1, y1, x2, y2 = rand_bbox(W, H, lam)
    cutmix_x[:, :, x1:x2, y1:y2] = x[shuffle, :, x1:x2, y1:y2]
    # Adjust lambda to match pixel ratio
    lam = 1 - ((x2 - x1) * (y2 - y1) / float(W * H)).item()
    y_a, y_b = y, y[shuffle]
    return cutmix_x, y_a, y_b, lam

### Label Smooth

In [None]:
def loss_fn(outputs, targets):
    if len(targets.shape) == 1:
        return F.cross_entropy(outputs, targets)
    else:
        return torch.mean(torch.sum(-targets * F.log_softmax(outputs, dim=1), dim=1))

def label_smooth_loss_fn(outputs, targets, epsilon=0.1):
    onehot = F.one_hot(targets, 1000).float().to(device)
    targets = (1 - epsilon) * onehot + torch.ones(onehot.shape).to(device) * epsilon / 1000
    return loss_fn(outputs, targets)

## Model Train

In [None]:
def get_submission(Model,data_iter,epoch):
  with torch.no_grad():
    Model.eval()
    pred_label=[]
    print("Final Testing....\n")
    for imgs in tqdm(iter(test_iter)):
      model_pred=Model(imgs.to(device))

      _, y_pred=torch.max(model_pred.data,1)
      pred_label.extend(y_pred.tolist())

  Model.train()

  submission = pd.read_csv('./dataset/sample.csv', encoding = 'utf-8')
  submission['prediction'] = pred_label
  submission.to_csv('./checkpoint_vit/submission2_'+str(epoch)+'_single.csv', index = False)

In [None]:
def freeze(Model,idx):
  # idx : -1 -> FC Layer 제외하고 동결
  # idx : -3 -> 추출기 동결
  for m in list(Model.children()):
    for param in m.parameters():
      param.requires_grad=True

  for m in list(Model.children())[:idx]:
    for param in m.parameters():
      param.requires_grad=False

In [None]:
#/content/drive/MyDrive/workspace/lotte/checkpoint_vit/ViT_epoch_27.tar
model_path='./checkpoint_vit/ViT_epoch_27.tar'

Model = ViT('B_16_imagenet1k', pretrained=False,image_size=224)
checkpoint=torch.load(model_path)
Model.load_state_dict(checkpoint['model_state_dict'])
freeze(Model,-3)
Model.eval()

In [None]:
f = open("./ViT_Single_trainlog.txt", 'w')

scaler = torch.cuda.amp.GradScaler()
optimizer = optim.Adam(Model.parameters(), lr=LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2,threshold_mode='abs',min_lr=1e-8, verbose=True)
loss=label_smooth_loss_fn

Model.train()

Model.to(device)

for epoch in range(EPOCHS) :
  loss_val_sum=0
  visual_loss_sum=0 # 확인용 Train Loss
  for imgs, labels in tqdm(iter(train_iter)):
    # Cut mix P=0.5
    imgs, labels = imgs.to(device), labels.to(device)
    imgs, labels_a, labels_b, lam = cutmix_data(imgs, labels)

    # optimizer.zero_grad()
    for param in Model.parameters():
      param.grad = None
    model_pred=Model(imgs)

    # Label Smoothing + Cutmix
    loss_out = lam * loss(model_pred, labels_a) + (1 - lam) * loss(model_pred, labels_b)
    # 확인용 Train Loss
    normal_loss_out=lam * loss_fn(model_pred.clone().detach(), labels_a.clone().detach()) + (1 - lam) * loss_fn(model_pred.clone().detach(), labels_b.clone().detach())

    scaler.scale(loss_out).backward()
    scaler.step(optimizer)
    scaler.update()

    loss_val_sum+=loss_out
    visual_loss_sum+=normal_loss_out

  loss_val_avg=loss_val_sum/len(train_iter)
  visual_loss_avg=visual_loss_sum/len(train_iter)
  get_submission(Model,test_iter,epoch)

  print("epoch:[%d] train loss smooth:[%.5f] train normal loss:[%.5f] class:[%.5f]\n"%(epoch,loss_val_avg,visual_loss_avg,score))
  f.write("epoch:[%d] train loss smooth:[%.5f] train normal loss:[%.5f]\n"%(epoch,loss_val_avg,visual_loss_avg))
  print("Model Save....\n")
  torch.save({'model_state_dict': Model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict()}, './checkpoint_vit/ViT_epoch_'+str(epoch)+'.tar')
  scheduler.step(loss_val_avg) # LR Scheduler
f.close()