In [None]:
import random
import pandas as pd
import numpy as np
import os
import cv2
import timm
import math

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import torchvision.models as models

from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score
from datetime import date, datetime, timezone, timedelta
from torch.optim.lr_scheduler import _LRScheduler

from albumentations.core.transforms_interface import ImageOnlyTransform
from copy import deepcopy
    
exp_day = str(date.today())

KST = timezone(timedelta(hours=9))
time_record = datetime.now(KST)
now = str(time_record)[5:10]+'_'+str(time_record)[11:19]

import warnings
warnings.filterwarnings(action='ignore')


model_name = f"{os.path.basename(__file__).split('.py')[0]}"
weight_save_path = f"checkpoints/{model_name}_{now}"
os.makedirs(weight_save_path, exist_ok=True)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"weight save path:{weight_save_path}")

CFG = {
    'IMG_SIZE':480,
    'EPOCHS':60,
    'LEARNING_RATE':0.000001,
    'BATCH_SIZE':32,
    'SEED':41
}



# wandb.init(
#     project='02_포디블록구조추출AI경진대회',
#     name = f'{model_name}_{now}',
#     config = CFG
# )


class CosineAnnealingWarmUpRestarts(_LRScheduler):
    def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1):
        if T_0 <= 0 or not isinstance(T_0, int):
            raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
        if T_mult < 1 or not isinstance(T_mult, int):
            raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
        if T_up < 0 or not isinstance(T_up, int):
            raise ValueError("Expected positive integer T_up, but got {}".format(T_up))
        self.T_0 = T_0
        self.T_mult = T_mult
        self.base_eta_max = eta_max
        self.eta_max = eta_max
        self.T_up = T_up
        self.T_i = T_0
        self.gamma = gamma
        self.cycle = 0
        self.T_cur = last_epoch
        super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch)
    
    def get_lr(self):
        if self.T_cur == -1:
            return self.base_lrs
        elif self.T_cur < self.T_up:
            return [(self.eta_max - base_lr)*self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs]
        else:
            return [base_lr + (self.eta_max - base_lr) * (1 + math.cos(math.pi * (self.T_cur-self.T_up) / (self.T_i - self.T_up))) / 2
                    for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
            self.T_cur = self.T_cur + 1
            if self.T_cur >= self.T_i:
                self.cycle += 1
                self.T_cur = self.T_cur - self.T_i
                self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up
        else:
            if epoch >= self.T_0:
                if self.T_mult == 1:
                    self.T_cur = epoch % self.T_0
                    self.cycle = epoch // self.T_0
                else:
                    n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
                    self.cycle = n
                    self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
                    self.T_i = self.T_0 * self.T_mult ** (n)
            else:
                self.T_i = self.T_0
                self.T_cur = epoch
                
        self.eta_max = self.base_eta_max * (self.gamma**self.cycle)
        self.last_epoch = math.floor(epoch)
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

class BaseModel(nn.Module):
    def __init__(self, num_classes=10):
        super(BaseModel, self).__init__()
        self.backbone = timm.create_model('tf_efficientnetv2_m', pretrained=True, num_classes=num_classes)
        #self.backbone = timm.create_model('tf_efficientnetv2_m', pretrained=False, num_classes=num_classes)
        #self.classifier = nn.Linear(1000, num_classes)
        nn.init.xavier_normal_(self.backbone.classifier.weight)
    def forward(self, x):
        x = F.sigmoid(self.backbone(x))
        #x = F.sigmoid(self.classifier(x))
        return x

class CustomDataset(Dataset):
    def __init__(self, img_path_list, label_list, transforms=None):
        self.img_path_list = img_path_list
        self.label_list = label_list
        self.transforms = transforms
        
    def __getitem__(self, index):
        img_path = self.img_path_list[index]
        
        image = cv2.imread(img_path)
        
        if self.transforms is not None:
            image = self.transforms(image=image)['image']
        
        if self.label_list is not None:
            label = torch.FloatTensor(self.label_list[index])
            return image, label
        else:
            return image
        
    def __len__(self):
        return len(self.img_path_list)

train_transform = A.Compose([
                            A.CenterCrop(height=384, width = 384 ,p=1),
                            A.Resize(CFG['IMG_SIZE'],CFG['IMG_SIZE']),
                            
                            A.Affine(p=1,translate_percent=[-0.1,0.1],scale = [0.9,1.1],shear = [-10,10], interpolation =cv2.INTER_LINEAR,cval = (255,255,255)),
                            A.HorizontalFlip(always_apply=False, p=0.5),
                            #A.RandomToneCurve(scale=0.1, always_apply=False, p=0.5),
                            A.ColorJitter(brightness=[0.9,1.1], contrast=0.1, saturation=0.1, hue=0.01, p=0.5),
                            A.GaussNoise(always_apply=False, p=0.5, var_limit=(0.0, 26.849998474121094)),
                            A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), max_pixel_value=255.0, always_apply=False, p=1.0),
                            ToTensorV2()
                            ])

test_transform = A.Compose([
                            A.CenterCrop(height=384, width = 384 ,p=1),
                            A.Resize(CFG['IMG_SIZE'],CFG['IMG_SIZE']),
                            A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), max_pixel_value=255.0, always_apply=False, p=1.0),
                            
                            ToTensorV2()
                            ])    

def train(model, optimizer, train_loader, val_loader, scheduler, device):
    
    if (device.type == 'cuda') and (torch.cuda.device_count() > 1):
        print('Multi GPU activate')
        model = nn.DataParallel(model, device_ids = list(range(torch.cuda.device_count())))
    model.to(device)
    
    #model.load_state_dict(torch.load("checkpoints/tf_efficientv2_m_16_01-25_03:52:29/tf_efficientv2_m_16_epoch46.pth"))
    
    criterion = nn.BCELoss().to(device)
    
    best_val_acc = 0
    best_model = None
    best_val_loss = 0.001
    best_train_loss = 0.001
    for epoch in range(1, CFG['EPOCHS']+1):
        model.train()
        train_loss = []
        for imgs, labels in tqdm(iter(train_loader), desc = f"Train end epoch{CFG['EPOCHS']}"):
            imgs = imgs.float().to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            output = model(imgs)
            loss = criterion(output, labels)
            
            loss.backward()
            optimizer.step()
            
            train_loss.append(loss.item())
        
        _val_loss, _val_acc = validation(model, criterion, val_loader, device)
        _train_loss = np.mean(train_loss)
        print(f'Epoch [{epoch}], Train Loss : [{_train_loss:.5f}] Val Loss : [{_val_loss:.5f}] Val ACC : [{_val_acc:.5f}]')
        
        if scheduler is not None:
            #scheduler.step(_val_loss)
            scheduler.step(epoch)
            
        if best_train_loss > _train_loss:
            #best_val_acc = _val_acc
            best_train_loss = _train_loss
            best_model = model
            #torch.save(model.state_dict(), os.path.join(weight_save_path,f"{model_name}_best_epoch{str(epoch).zfill(2)}.pth"))
        torch.save(model.state_dict(), os.path.join(weight_save_path,f"{model_name}_epoch{str(epoch).zfill(2)}.pth"))
            
#         metrics = {
#             "train/train_loss" : _train_loss,
#             "train/lr" : optimizer.param_groups[0]['lr'],
#             "val/val_loss" : _val_loss,
#             "val/val_acc" : _val_acc
#         }
#         wandb.log(metrics)
                   
            
#     wandb.alert(
#         title="Finish",
#         text=f"Finish training {model_name}"
#     )
    
    return best_model

def validation(model, criterion, val_loader, device):
    model.eval()
    val_loss = []
    val_acc = []
    with torch.no_grad():
        for imgs, labels in tqdm(iter(val_loader)):
            imgs = imgs.float().to(device)
            labels = labels.to(device)
            
            probs = model(imgs)
            
            loss = criterion(probs, labels)
            
            probs  = probs.cpu().detach().numpy()
            labels = labels.cpu().detach().numpy()
            preds = probs > 0.5
            batch_acc = (labels == preds).mean()
            
            val_acc.append(batch_acc)
            val_loss.append(loss.item())
        
        _val_loss = np.mean(val_loss)
        _val_acc = np.mean(val_acc)
    
    return _val_loss, _val_acc

def inference(model, test_loader, device):
    model.to(device)
    model.eval()
    predictions = []
    with torch.no_grad():
        for imgs in tqdm(iter(test_loader)):
            imgs = imgs.float().to(device)
            
            probs = model(imgs)

            probs  = probs.cpu().detach().numpy()
            preds = probs > 0.5
            preds = preds.astype(int)
            predictions += preds.tolist()
    return predictions

def get_labels(df):
    return df.iloc[:,2:].values

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

train_df = pd.read_csv('../data/kfold_csv_1차학습_7/fold1_train.csv')
train_df['img_path'] = train_df['img_path'].apply(lambda x:os.path.join("../data",x))

val_df = pd.read_csv('../data/kfold_csv_val_5/fold1_val.csv')
val_df['img_path'] = val_df['img_path'].apply(lambda x:os.path.join("../data",x))
val_df = val_df[val_df.columns[:12]]
val_df[val_df.columns[2:12]] = val_df[val_df.columns[2:12]].astype('int')

train_labels = get_labels(train_df)
val_labels = get_labels(val_df)

train_dataset = CustomDataset(train_df['img_path'].values, train_labels, train_transform)
train_loader = DataLoader(train_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=True, num_workers=8)

val_dataset = CustomDataset(val_df['img_path'].values, val_labels, test_transform)
val_loader = DataLoader(val_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=False, num_workers=8)


model = BaseModel()
optimizer = torch.optim.Adam(params = model.parameters(), lr = CFG["LEARNING_RATE"])
scheduler = CosineAnnealingWarmUpRestarts(optimizer, T_0=30, T_mult=1, eta_max=0.0001,  T_up=10, gamma=0.7)

infer_model = train(model, optimizer, train_loader, val_loader, scheduler, device)


