<a href="https://colab.research.google.com/github/choki0715/UnLiteFlowNet-PIV/blob/master/baseline3_pit_b_distilled_224_fold7.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!mkdir data
!unzip -q /content/drive/MyDrive/beef/data.zip -d ./data

In [None]:
!pip install timm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm
  Downloading timm-0.5.4-py3-none-any.whl (431 kB)
[K     |████████████████████████████████| 431 kB 7.9 MB/s 
Installing collected packages: timm
Successfully installed timm-0.5.4


In [None]:
import pandas as pd
import numpy as np
import os
from os import path as osp
import cv2
from tqdm.notebook import tqdm

import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import albumentations
import timm

import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold

from warnings import filterwarnings
filterwarnings("ignore")

device = torch.device('cuda')

In [None]:
# timm.list_models(pretrained=True)

In [None]:
print(timm.__version__)

0.5.4


In [None]:
# path
train_path = './data/train'
test_path = './data/test'
train_csv_path = osp.join(train_path, 'grade_labels.csv')
test_csv_path = osp.join(test_path, 'test_images.csv')

# data
image_size = 224
label_dict = {'1++': 0, '1+': 1, '1': 2, '2': 3, '3': 4}
reversed_label_dict = {v:k for k,v in label_dict.items()}

kfold = 8
fold_id = 7
mean_pixel_min_value = 0

num_classes = len(label_dict)
backbone_name = 'pit_b_distilled_224'


batch_size = 32
n_worker = 4

init_lr = 5e-5
n_epochs = 200

random_state=139


In [None]:
def load_data(csv_path, is_train=True):
  df = pd.read_csv(csv_path)
  df['file_path'] = df.imname.apply(lambda x: osp.join(osp.dirname(csv_path), 'images', x))
  if is_train:
    df['label'] = df.grade.apply(lambda x: label_dict[x])
  return df

def set_fold_column(df, n_splits=kfold):
  skf = StratifiedKFold(n_splits=kfold, shuffle=True, random_state=random_state)
  df['fold'] = -1
  for fold, (train_idx, valid_idx) in enumerate(skf.split(df, df.label)):
      df.loc[valid_idx, 'fold'] = fold
  return df

def get_mean_pixel_value(file_path):
  img = cv2.imread(file_path)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  return np.mean(img)

# 평균 밝기를 측정하여 column에 추가하고 min_value보다 큰 행만 남긴다. (어두운 이미지 제거용도)
def set_mean_pixel_column(df, min_value=10):
  df['mean_pixel'] = df.file_path.apply(lambda x: get_mean_pixel_value(x))
  df = df[df.mean_pixel > min_value]
  df = df.reset_index(drop=True)
  return df

def sigmoid(x):
   return 1 / (1 +np.exp(-x))

In [None]:
train_all_df = load_data(train_csv_path, is_train=True)
test_df = load_data(test_csv_path, is_train=False)

train_all_df = set_fold_column(train_all_df, n_splits=kfold)
if mean_pixel_min_value > 0:
  train_all_df = set_mean_pixel_column(train_all_df, mean_pixel_min_value)

print(f'train shape : {train_all_df.shape}, test shape : {test_df.shape}')
display(test_df.head())
train_all_df.head()

train shape : (10000, 5), test shape : (8658, 2)


Unnamed: 0,imname,file_path
0,WuSUZJHN6t.jpg,./data/test/images/WuSUZJHN6t.jpg
1,hrua4NW4Cj.jpg,./data/test/images/hrua4NW4Cj.jpg
2,GDOHhHZJug.jpg,./data/test/images/GDOHhHZJug.jpg
3,Xewfe9T1kN.jpg,./data/test/images/Xewfe9T1kN.jpg
4,y3vLHbbHFs.jpg,./data/test/images/y3vLHbbHFs.jpg


Unnamed: 0,imname,grade,file_path,label,fold
0,cow_1++_4567.jpg,1++,./data/train/images/cow_1++_4567.jpg,0,1
1,cow_2_1390.jpg,2,./data/train/images/cow_2_1390.jpg,3,0
2,cow_1++_2581.jpg,1++,./data/train/images/cow_1++_2581.jpg,0,2
3,cow_2_1689.jpg,2,./data/train/images/cow_2_1689.jpg,3,1
4,cow_3_3287.jpg,3,./data/train/images/cow_3_3287.jpg,4,0


In [None]:
train_all_df.grade.unique()

array(['1++', '2', '3', '1+', '1'], dtype=object)

In [None]:
train_all_df.label.value_counts()

2    2201
0    2134
1    2134
3    2090
4    1441
Name: label, dtype: int64

In [None]:
class BeefDataset(Dataset):
    def __init__(self, df, mode, transform=None):
        self.df = df.reset_index(drop=True)
        self.mode = mode
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        row = self.df.loc[index]
        img = cv2.imread(row.file_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#        r, g, b = cv2.split(img1)
#        r = r/255.
#        g = g/255.
#        b = b/255.
#
#        img2 = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) # ejk
#        h, s, v = cv2.split(img2)
#        h = h/360.
#        s = s/100.
#        v = v/100.
#
#       img = cv2.merge((r,g,b, h, s, v)) # ejk

        if self.transform is not None:
            res = self.transform(image=img)
            img = res['image'].transpose(2,0,1)
        
        if self.mode == 'test':
            return torch.tensor(img).float()
        else:
            return torch.tensor(img).float(), torch.tensor(row.label)


In [None]:

class BeefModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = timm.create_model(backbone_name, pretrained=True, num_classes=num_classes) #, in_chans=6)

        
#         (0): Linear(in_features=224, out_features=5, bias=True)
#         (1): Linear(in_features=448, out_features=5, bias=True)

#    (head): Linear(in_features=256, out_features=5, bias=True)
#    (head_dist): Linear(in_features=256, out_features=5, bias=True)

#  RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x384 and 256x64)


#        self.model.classifier = nn.Sequential(
#            nn.BatchNorm2d(2560),
#            nn.Linear(in_features = 2560, out_features=1024),
#            nn.BatchNorm2d(1024),
#            nn.ReLU(),
#            # nn.Dropout(p=0.3),
#            nn.Linear(in_features=1024, out_features=256),
#            nn.BatchNorm2d(256),
#            nn.ReLU(),
#            # nn.Dropout(p=0.2),
#            nn.Linear(in_features=256, out_features=num_classes)
#        )


    def forward(self, x):
        return self.model(x)
        

In [None]:
model = BeefModel().to(device)
optimizer = optim.Adam(model.parameters(), lr = init_lr)
lsr = 0
criterion = nn.CrossEntropyLoss(label_smoothing=lsr).to(device)
# criterion = LabelSmoothingLoss(classes=5, smoothing=0.1).to(device)

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth" to /root/.cache/torch/hub/checkpoints/pit_b_distill_840.pth


In [None]:
# print(model(torch.randn(3, 224, 224)))

In [None]:
transforms_train = albumentations.Compose([
    albumentations.VerticalFlip(p=0.5),
    albumentations.HorizontalFlip(p=0.5),
    albumentations.RandomBrightness(limit=0.2, p=0.75),
    albumentations.RandomContrast(limit=0.2, p=0.75),
    albumentations.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=90, border_mode=0, p=1),
    albumentations.Resize(image_size, image_size),
    albumentations.Normalize()
])

transforms_valid = albumentations.Compose([                      
    albumentations.Resize(image_size, image_size),
    albumentations.Normalize()
])


In [None]:
train_df = train_all_df[train_all_df['fold'] != fold_id]
valid_df = train_all_df[train_all_df['fold'] == fold_id]

train_dataset = BeefDataset(train_df, 'train', transform = transforms_train)
valid_dataset = BeefDataset(valid_df, 'valid', transform = transforms_valid)
test_dataset= BeefDataset(test_df, 'test', transform = transforms_valid)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers = n_worker)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers = n_worker)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers = n_worker)

In [None]:
def train_epoch(train_loader):
    model.train()
    bar = tqdm(train_loader)
    losses = []
    for batch_idx, (images, targets) in enumerate(bar):
        images, targets = images.to(device), targets.to(device)            
        
        optimizer.zero_grad()
        _, output = model(images)
        loss = criterion(output, targets)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

        bar.set_description(f'loss: {loss.item():.5f}')

    loss_train = np.mean(losses)
    return loss_train

def evaluate(valid_loader):
    loss = 0.0
    correct = 0
    outputs = []
    model.eval()
    with torch.no_grad():
        for images, targets in tqdm(iter(valid_loader)):
            images, targets = images.to(device), targets.to(device)

            output = model(images)
            loss += criterion(output, targets)

            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(targets.view_as(pred)).sum().item()

            outputs.extend(output.tolist())
    acc = 100 * correct / len(valid_loader.dataset)
    print('Valid set: Loss: {:.4f}, Accuracy: {}/{} {:.4f}%'.format(loss / len(valid_loader), correct, len(valid_loader.dataset), acc))
    return loss, acc

In [None]:
# train 
# RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x1024 and 384x64)

best_acc = 0
for epoch in range(1,n_epochs+1):
  print(f'{epoch} Epoch')
  train_epoch(train_loader)
  val_loss, val_acc = evaluate(valid_loader)

  if best_acc < val_acc:
      best_acc = val_acc
      jit_model = torch.jit.script(model)
      torch.jit.save(jit_model, f'./drive/MyDrive/beef/{backbone_name}_{fold_id}_{kfold}_fold_epoch{epoch}_{val_acc:.3f}.pt')
      print(f'Model saved')
  print()

1 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 1.2899, Accuracy: 1498/3333 44.9445%
Model saved

2 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 1.1836, Accuracy: 1660/3333 49.8050%
Model saved

3 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 1.1256, Accuracy: 1740/3333 52.2052%
Model saved

4 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 1.0580, Accuracy: 1823/3333 54.6955%
Model saved

5 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.9219, Accuracy: 2142/3333 64.2664%
Model saved

6 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.8172, Accuracy: 2314/3333 69.4269%
Model saved

7 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.8690, Accuracy: 2214/3333 66.4266%

8 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.6710, Accuracy: 2568/3333 77.0477%
Model saved

9 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.5654, Accuracy: 2693/3333 80.7981%
Model saved

10 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.5332, Accuracy: 2741/3333 82.2382%
Model saved

11 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.5592, Accuracy: 2688/3333 80.6481%

12 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.5159, Accuracy: 2771/3333 83.1383%
Model saved

13 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.5214, Accuracy: 2756/3333 82.6883%

14 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.5057, Accuracy: 2773/3333 83.1983%
Model saved

15 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.5277, Accuracy: 2751/3333 82.5383%

16 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.5011, Accuracy: 2787/3333 83.6184%
Model saved

17 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.5044, Accuracy: 2793/3333 83.7984%
Model saved

18 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4969, Accuracy: 2786/3333 83.5884%

19 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4830, Accuracy: 2798/3333 83.9484%
Model saved

20 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4715, Accuracy: 2834/3333 85.0285%
Model saved

21 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4847, Accuracy: 2793/3333 83.7984%

22 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4849, Accuracy: 2802/3333 84.0684%

23 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4991, Accuracy: 2756/3333 82.6883%

24 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4906, Accuracy: 2804/3333 84.1284%

25 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4526, Accuracy: 2842/3333 85.2685%
Model saved

26 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4399, Accuracy: 2850/3333 85.5086%
Model saved

27 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4678, Accuracy: 2816/3333 84.4884%

28 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4623, Accuracy: 2858/3333 85.7486%
Model saved

29 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4535, Accuracy: 2834/3333 85.0285%

30 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4727, Accuracy: 2819/3333 84.5785%

31 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4642, Accuracy: 2857/3333 85.7186%

32 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4571, Accuracy: 2831/3333 84.9385%

33 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4656, Accuracy: 2813/3333 84.3984%

34 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4206, Accuracy: 2876/3333 86.2886%
Model saved

35 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4769, Accuracy: 2797/3333 83.9184%

36 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4669, Accuracy: 2826/3333 84.7885%

37 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4737, Accuracy: 2814/3333 84.4284%

38 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4873, Accuracy: 2813/3333 84.3984%

39 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4442, Accuracy: 2851/3333 85.5386%

40 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4520, Accuracy: 2832/3333 84.9685%

41 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4354, Accuracy: 2866/3333 85.9886%

42 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4025, Accuracy: 2884/3333 86.5287%
Model saved

43 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4773, Accuracy: 2816/3333 84.4884%

44 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4543, Accuracy: 2847/3333 85.4185%

45 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4658, Accuracy: 2834/3333 85.0285%

46 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4376, Accuracy: 2850/3333 85.5086%

47 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4913, Accuracy: 2816/3333 84.4884%

48 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4660, Accuracy: 2824/3333 84.7285%

49 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4844, Accuracy: 2831/3333 84.9385%

50 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4776, Accuracy: 2817/3333 84.5185%

51 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4362, Accuracy: 2877/3333 86.3186%

52 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4681, Accuracy: 2826/3333 84.7885%

53 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4751, Accuracy: 2843/3333 85.2985%

54 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4292, Accuracy: 2883/3333 86.4986%

55 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4523, Accuracy: 2856/3333 85.6886%

56 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4257, Accuracy: 2897/3333 86.9187%
Model saved

57 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4581, Accuracy: 2855/3333 85.6586%

58 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4188, Accuracy: 2902/3333 87.0687%
Model saved

59 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4759, Accuracy: 2840/3333 85.2085%

60 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4743, Accuracy: 2812/3333 84.3684%

61 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4831, Accuracy: 2820/3333 84.6085%

62 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4441, Accuracy: 2879/3333 86.3786%

63 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4437, Accuracy: 2859/3333 85.7786%

64 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4543, Accuracy: 2858/3333 85.7486%

65 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4516, Accuracy: 2872/3333 86.1686%

66 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4228, Accuracy: 2900/3333 87.0087%

67 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4707, Accuracy: 2833/3333 84.9985%

68 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4426, Accuracy: 2859/3333 85.7786%

69 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4689, Accuracy: 2832/3333 84.9685%

70 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.3952, Accuracy: 2918/3333 87.5488%
Model saved

71 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4553, Accuracy: 2852/3333 85.5686%

72 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4152, Accuracy: 2916/3333 87.4887%

73 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4802, Accuracy: 2853/3333 85.5986%

74 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4592, Accuracy: 2860/3333 85.8086%

75 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4287, Accuracy: 2875/3333 86.2586%

76 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4491, Accuracy: 2868/3333 86.0486%

77 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4396, Accuracy: 2883/3333 86.4986%

78 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4270, Accuracy: 2907/3333 87.2187%

79 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4479, Accuracy: 2871/3333 86.1386%

80 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4321, Accuracy: 2883/3333 86.4986%

81 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4116, Accuracy: 2916/3333 87.4887%

82 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4834, Accuracy: 2823/3333 84.6985%

83 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4535, Accuracy: 2861/3333 85.8386%

84 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4253, Accuracy: 2904/3333 87.1287%

85 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4795, Accuracy: 2817/3333 84.5185%

86 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4729, Accuracy: 2834/3333 85.0285%

87 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4180, Accuracy: 2904/3333 87.1287%

88 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4130, Accuracy: 2880/3333 86.4086%

89 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4821, Accuracy: 2839/3333 85.1785%

90 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4341, Accuracy: 2879/3333 86.3786%

91 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.5633, Accuracy: 2729/3333 81.8782%

92 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4543, Accuracy: 2845/3333 85.3585%

93 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4259, Accuracy: 2902/3333 87.0687%

94 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4190, Accuracy: 2911/3333 87.3387%

95 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4287, Accuracy: 2886/3333 86.5887%

96 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4463, Accuracy: 2871/3333 86.1386%

97 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4172, Accuracy: 2900/3333 87.0087%

98 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4101, Accuracy: 2921/3333 87.6388%
Model saved

99 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4566, Accuracy: 2887/3333 86.6187%

100 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4204, Accuracy: 2905/3333 87.1587%

101 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4195, Accuracy: 2890/3333 86.7087%

102 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4421, Accuracy: 2884/3333 86.5287%

103 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4488, Accuracy: 2869/3333 86.0786%

104 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4360, Accuracy: 2884/3333 86.5287%

105 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4606, Accuracy: 2885/3333 86.5587%

106 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4224, Accuracy: 2912/3333 87.3687%

107 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4361, Accuracy: 2887/3333 86.6187%

108 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4838, Accuracy: 2845/3333 85.3585%

109 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4499, Accuracy: 2874/3333 86.2286%

110 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4188, Accuracy: 2863/3333 85.8986%

111 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4066, Accuracy: 2905/3333 87.1587%

112 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4079, Accuracy: 2910/3333 87.3087%

113 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4107, Accuracy: 2902/3333 87.0687%

114 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4321, Accuracy: 2889/3333 86.6787%

115 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4766, Accuracy: 2822/3333 84.6685%

116 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4253, Accuracy: 2877/3333 86.3186%

117 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4352, Accuracy: 2877/3333 86.3186%

118 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4128, Accuracy: 2902/3333 87.0687%

119 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4306, Accuracy: 2896/3333 86.8887%

120 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Valid set: Loss: 0.4711, Accuracy: 2850/3333 85.5086%

121 Epoch


  0%|          | 0/209 [00:00<?, ?it/s]

In [None]:
# test

preds = []
model.eval()
with torch.no_grad():
    for images in tqdm(iter(test_loader)):
        images = images.to(device)
        logit = model(images)
        pred = logit.argmax(dim=1)
        preds.extend(pred.tolist())

grade_preds = [reversed_label_dict[pred] for pred in preds]

submission = pd.read_csv('./data/sample_submission.csv')
temp = submission.copy()
temp['id'] = test_df.imname
temp['grade'] = grade_preds

submission = pd.merge(submission['id'], temp, on='id', how='left')
submission.to_csv('submit.csv', index=False)