# Basic setup

In [None]:
pip install timm

In [1]:
import random
import pandas as pd
import numpy as np
import os
import glob
import cv2

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt

from tqdm import tqdm

from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

import timm

import warnings
warnings.filterwarnings(action='ignore') 

In [2]:
debug = False
CFG = {
    'CLF_LR': 0.0001,
    'BATCH_SIZE': 64,
    'SEED': 42,
    'EPOCHS': 100,
}

In [3]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [4]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Prepare Dataset

In [5]:
train_data_path = glob.glob('./train/SEM/*/*/*.png')

df_train_SEM = pd.DataFrame({'path':train_data_path})
df_train_SEM['case'] = df_train_SEM['path'].apply(lambda x: int(x[18:21])//10%10)

In [6]:
df_train_SEM.head()

Unnamed: 0,path,case
0,./train/SEM\Depth_110\site_00000\SEM_043510.png,1
1,./train/SEM\Depth_110\site_00000\SEM_043987.png,1
2,./train/SEM\Depth_110\site_00000\SEM_045397.png,1
3,./train/SEM\Depth_110\site_00000\SEM_046894.png,1
4,./train/SEM\Depth_110\site_00000\SEM_049394.png,1


# Case Classifier Dataset

In [7]:
class Classifier_Dataset(Dataset):
  def __init__(self, df):
    self.df = df
  def __len__(self):
    return len(self.df)
  def __getitem__(self, idx):
    img = cv2.imread(self.df.iloc[idx, 0], cv2.IMREAD_GRAYSCALE)
    img = img / 255
    img = torch.Tensor(img)[None, :]

    case = self.df.iloc[idx, 1]
    case = torch.eye(4)[case-1]
    return img, case

# CNN Classifier Model

In [8]:
class Case_Classifier(nn.Module):
    def __init__(self):
        super(Case_Classifier, self).__init__()
        #tf_efficientnet_b0_ns
        self.model = timm.create_model('tf_efficientnet_b0_ns', pretrained = True, num_classes = 4, in_chans=1)
        self.softmax = nn.Softmax()
        
    def forward(self, x):
        x = self.model(x) #1층
        return self.softmax(x)

# Training Classifier

In [None]:
from sklearn.model_selection import StratifiedKFold

seed_everything(CFG['SEED']) # Seed 고정

kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for i, [train_idx, val_idx] in enumerate(kf.split(df_train_SEM, df_train_SEM['case'])):
    df_train = df_train_SEM.iloc[train_idx]
    df_val = df_train_SEM.iloc[val_idx]

    cls_set = Classifier_Dataset(df_train)
    cls_val_set = Classifier_Dataset(df_val)
    cls_loader = DataLoader(cls_set, batch_size=CFG['BATCH_SIZE'], shuffle=True)
    cls_val_loader = DataLoader(cls_val_set, batch_size=CFG['BATCH_SIZE'], shuffle=True)
    classifier = Case_Classifier()
    classifier.to(device)

    optimizer = torch.optim.AdamW(params=classifier.parameters(), lr=CFG['CLF_LR'])
    criterion = nn.CrossEntropyLoss()

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=8, factor=0.5)
    
    best_acc = 0
    np.set_printoptions(precision=6, suppress=True)
    for epoch in range(CFG['EPOCHS']):
      train_losses = []
      val_losses = []
      accuracy = 0
      
      classifier.train()
      for sem, case in tqdm(cls_loader):
        sem = sem.to(device)
        case = case.to(device)

        optimizer.zero_grad()
        pred = classifier(sem)
        loss = criterion(case, pred)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
      
      classifier.eval()
      cm = np.zeros((4,4))

      with torch.no_grad():
        for sem, case in tqdm(cls_val_loader):
          sem = sem.to(device)
          case = case.to(device)

          pred = classifier(sem)
          loss = criterion(case, pred)
          val_losses.append(loss.item())

          case = case.argmax(dim=1)
          pred = pred.argmax(dim=1)

          for l, p in zip(case, pred):
            cm[l][p] += 1
          acc = (case==pred).count_nonzero()
          accuracy += acc.item() / len(cls_val_set)
      if best_acc < accuracy:
        torch.save(classifier.state_dict(), './cnn_classifier.pth')
        print('##########Model Saved!##########')
        best_acc = accuracy


      if scheduler is not None:
        scheduler.step(accuracy)

      train_losses = np.mean(train_losses)
      val_losses = np.mean(val_losses)

      print(f'[EPOCH:{epoch+1}/{CFG["EPOCHS"]}] [Train Loss:{train_losses}] [Val Loss:{val_losses}] [Val Accuracy:{accuracy}]')
    
    break