In [25]:
# !pip install albumentations==0.5.2

In [1]:
import albumentations
import albumentations.pytorch
import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from tqdm import tqdm

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import Resize, ToTensor, Normalize
from sklearn.metrics import f1_score

tqdm.pandas()

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"{device} is using!")

cuda:0 is using!


In [3]:
NUM_EPOCHS = 15
BATCH_SIZE = 8
OUTPUT_CLASS = 18
LEARNING_RATE = 1e-4

In [4]:
class MyDataset(Dataset):
    def __init__(self, path, label, transform):
        img_list = []
        for p in tqdm(path):
            img = cv2.imread(p)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img_list.append(img)
        
        self.X = img_list
        self.y = label
        self.transform = transform

    def __len__(self):
        len_dataset = len(self.X)
        return len_dataset

    def __getitem__(self, idx):
        X,y = self.X[idx], self.y[idx]
        X = self.transform(image=X)
        return X, y

In [5]:
class TestDataset(Dataset):
    def __init__(self, path, label, transform):
        img_list = []
        for p in tqdm(path):
            img = cv2.imread(p)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img_list.append(img)
        
        self.X = img_list
        self.y = label
        self.transform = transform

    def __len__(self):
        len_dataset = len(self.X)
        return len_dataset

    def __getitem__(self, idx):
        X,y = self.X[idx], self.y[idx]
        X = self.transform(image=X)
        return X

In [6]:
train_transform = albumentations.Compose(
  [
      albumentations.Resize(600,600),
#       albumentations.RandomRotation(15),
      albumentations.HorizontalFlip(p=0.3),
      albumentations.OneOf([albumentations.MotionBlur(p=1),
                            albumentations.OpticalDistortion(p=1),
                            albumentations.GaussNoise(p=1)], p=1),
      albumentations.Normalize((0.548, 0.504, 0.479), (0.237, 0.247, 0.246)),
      albumentations.pytorch.transforms.ToTensorV2(),
      #       이미지 원본 사이즈는 384, 512   
  ]
)

test_transform = albumentations.Compose(
  [
      albumentations.Resize(600,600),
      albumentations.Normalize((0.548, 0.504, 0.479), (0.237, 0.247, 0.246)),
      albumentations.pytorch.transforms.ToTensorV2()
      #       이미지 원본 사이즈는 384, 512   
  ]
)

In [7]:
train_df = pd.read_csv('./train_df.csv')
valid_df = pd.read_csv('./valid_df.csv')

dataset_train = MyDataset(path=train_df['full_path'].values,
                          label=train_df['label'].values,
                          transform=train_transform)

dataset_valid = MyDataset(path=valid_df['full_path'].values,
                          label=valid_df['label'].values,
                          transform=test_transform)

train_dataloader = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
valid_dataloader = DataLoader(dataset_valid, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

100%|██████████| 17010/17010 [00:48<00:00, 347.69it/s]
100%|██████████| 1890/1890 [00:05<00:00, 338.62it/s]


In [9]:
model = timm.create_model('efficientnet_b7', num_classes=18, pretrained=True).to(device)
model.load_state_dict(torch.load('./epoch/e7/epoch_9_e7_0.9216931216931217.pth'))

No pretrained weights exist for this model. Using random initialization.


<All keys matched successfully>

In [10]:
def check_f1_score(valid_dataloader, model, device):
    gt_list = []
   pred_list = []
    
    model.eval()
    with torch.no_grad():
        for x, y in tqdm(valid_dataloader):
            x = x['image'].to(device)
            y = y.to(device)
            
            logits = model(x)
            _, pred = torch.max(logits, 1)
            
            for i in y.cpu().numpy():
                gt_list.append(i)
            for j in pred.cpu().numpy():
                pred_list.append(j)
    
    f1 = f1_score(gt_list, pred_list, average='micro')
    del gt_list, pred_list
    
    print(f'Validation f1_score : {f1}')
    return f1

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.3)

In [12]:
VIS = 600

for epoch in range(NUM_EPOCHS):
    running_loss = 0.0
    model.train()
    for i, (inputs, labels) in enumerate(tqdm(train_dataloader)):
        optimizer.zero_grad()
        inputs = inputs['image'].to(device)
        labels = labels.to(device)
        logits = model(inputs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
            
        running_loss += loss.item()
        if i % VIS == VIS-1:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / VIS))
            running_loss = 0.0
            
    f1 = check_f1_score(valid_dataloader, model, device)
    print("-"*50)
    torch.save(model.state_dict(), f'./epoch/epoch_{epoch}_e7_{f1}.pth')
    
    if epoch < NUM_EPOCHS:
        scheduler.step()
#         torch.save(target_model.state_dict(), f'./epoch_{epoch}_vit.pth')

 28%|██▊       | 600/2127 [07:14<18:16,  1.39it/s]

[1,   600] loss: 0.379


 56%|█████▋    | 1200/2127 [14:25<11:04,  1.39it/s]

[1,  1200] loss: 0.369


 85%|████████▍ | 1800/2127 [21:35<03:53,  1.40it/s]

[1,  1800] loss: 0.377


100%|██████████| 2127/2127 [25:28<00:00,  1.39it/s]
  0%|          | 0/237 [00:00<?, ?it/s]

VALIDATION F1


100%|██████████| 237/237 [00:38<00:00,  6.22it/s]


Validation f1_score : 0.8899470899470899
--------------------------------------------------


 28%|██▊       | 600/2127 [07:11<18:13,  1.40it/s]

[2,   600] loss: 0.311


 56%|█████▋    | 1200/2127 [14:21<11:06,  1.39it/s]

[2,  1200] loss: 0.329


 85%|████████▍ | 1800/2127 [21:32<03:57,  1.38it/s]

[2,  1800] loss: 0.305


100%|██████████| 2127/2127 [25:28<00:00,  1.39it/s]
  0%|          | 0/237 [00:00<?, ?it/s]

VALIDATION F1


100%|██████████| 237/237 [00:38<00:00,  6.22it/s]


Validation f1_score : 0.8920634920634921
--------------------------------------------------


 28%|██▊       | 600/2127 [07:12<18:13,  1.40it/s]

[3,   600] loss: 0.149


 56%|█████▋    | 1200/2127 [14:24<11:04,  1.39it/s]

[3,  1200] loss: 0.115


 85%|████████▍ | 1800/2127 [21:35<03:54,  1.39it/s]

[3,  1800] loss: 0.106


100%|██████████| 2127/2127 [25:30<00:00,  1.39it/s]
  0%|          | 0/237 [00:00<?, ?it/s]

VALIDATION F1


100%|██████████| 237/237 [00:38<00:00,  6.21it/s]


Validation f1_score : 0.964021164021164
--------------------------------------------------


 28%|██▊       | 600/2127 [07:12<18:11,  1.40it/s]

[4,   600] loss: 0.082


 56%|█████▋    | 1200/2127 [14:24<11:04,  1.39it/s]

[4,  1200] loss: 0.095


 85%|████████▍ | 1800/2127 [21:35<03:54,  1.39it/s]

[4,  1800] loss: 0.083


100%|██████████| 2127/2127 [25:30<00:00,  1.39it/s]
  0%|          | 0/237 [00:00<?, ?it/s]

VALIDATION F1


100%|██████████| 237/237 [00:38<00:00,  6.23it/s]


Validation f1_score : 0.9624338624338624
--------------------------------------------------


 28%|██▊       | 600/2127 [07:11<18:12,  1.40it/s]

[5,   600] loss: 0.049


 56%|█████▋    | 1200/2127 [14:21<11:05,  1.39it/s]

[5,  1200] loss: 0.039


 85%|████████▍ | 1800/2127 [21:32<03:55,  1.39it/s]

[5,  1800] loss: 0.035


100%|██████████| 2127/2127 [25:26<00:00,  1.39it/s]
  0%|          | 0/237 [00:00<?, ?it/s]

VALIDATION F1


100%|██████████| 237/237 [00:38<00:00,  6.22it/s]


Validation f1_score : 0.9830687830687831
--------------------------------------------------


 28%|██▊       | 600/2127 [07:10<18:13,  1.40it/s]

[6,   600] loss: 0.031


 56%|█████▋    | 1200/2127 [14:22<11:04,  1.40it/s]

[6,  1200] loss: 0.028


 85%|████████▍ | 1800/2127 [21:33<04:02,  1.35it/s]

[6,  1800] loss: 0.026


100%|██████████| 2127/2127 [25:28<00:00,  1.39it/s]
  0%|          | 0/237 [00:00<?, ?it/s]

VALIDATION F1


100%|██████████| 237/237 [00:38<00:00,  6.23it/s]


Validation f1_score : 0.9793650793650793
--------------------------------------------------


 19%|█▊        | 396/2127 [04:45<20:48,  1.39it/s]


KeyboardInterrupt: 

In [14]:
tqdm.pandas()

def make_test_full_path(s):
    path = 'input/data/eval/images/'
    return path + s

test_df = pd.read_csv('input/data/eval/info.csv')
test_df['path'] = test_df['ImageID'].progress_apply(make_test_full_path)

dataset_test = TestDataset(path=test_df['path'].values,
                         label=test_df['ans'].values,
                         transform=test_transform)

test_dataloader = DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

  from pandas import Panel
100%|██████████| 12600/12600 [00:00<00:00, 606823.18it/s]
100%|██████████| 12600/12600 [00:34<00:00, 362.99it/s]


In [19]:
model = timm.create_model('efficientnet_b7', num_classes=18, pretrained=True).to(device)
model.load_state_dict(torch.load('./epoch/epoch_5_e7_0.9793650793650793.pth'))

No pretrained weights exist for this model. Using random initialization.


<All keys matched successfully>

In [20]:
def inferences(test_dataloader, submission_df):
    pred_list = []
    for images in tqdm(test_dataloader):
        model.eval()
        with torch.no_grad():
            images = images['image'].to(device)
            pred = model(images)
            pred = pred.argmax(dim=-1)
            pred_list.extend(pred.cpu().numpy())
    submission_df['ans'] = pred_list
    return submission_df

In [21]:
submission_df = pd.read_csv('input/data/eval/info.csv')
submission = inferences(test_dataloader, submission_df)
submission.head()

100%|██████████| 1575/1575 [04:15<00:00,  6.17it/s]


Unnamed: 0,ImageID,ans
0,cbc5c6e168e63498590db46022617123f1fe1268.jpg,14
1,0e72482bf56b3581c081f7da2a6180b8792c7089.jpg,2
2,b549040c49190cedc41327748aeb197c1670f14d.jpg,17
3,4f9cb2a045c6d5b9e50ad3459ea7b791eb6e18bc.jpg,14
4,248428d9a4a5b6229a7081c32851b90cb8d38d0c.jpg,12


In [22]:
submission.to_csv('./submission_0825_e7_15.csv', index=False)

In [None]:
# fig, axes = plt.subplots(5, 5, figsize=(30, 15))