## Import Library

In [None]:
import numpy as np 
import pandas as pd 

import os 
import cv2 
import timm 

import albumentations 
from albumentations.pytorch.transforms import ToTensorV2

import torch 
import torch.nn.functional as F 
from torch import nn 
from torch.optim import Adam

import math
import neptune
from sklearn.model_selection import StratifiedKFold

from tqdm.notebook import tqdm 
from sklearn.preprocessing import LabelEncoder

from torch.optim.lr_scheduler import StepLR, ExponentialLR, OneCycleLR, _LRScheduler, ReduceLROnPlateau

### Config

In [None]:
class Config:
    
    DATA_DIR = '/mnt/hdd1/wearly/compatibility_rec/data/images/'
    TRAIN_CSV = '/mnt/hdd1/wearly/deep_rec/separ_train.csv'
    SEED = 225
    SAVE_NAME = 'separ'

    IMG_SIZE = 512#224
    MEAN = [0.485, 0.456, 0.406]
    STD = [0.229, 0.224, 0.225]

    EPOCHS = 50 #15  # Try 15 epochs
    BATCH_SIZE = 16#32#64
    N_FOLDS = 10
    
    NUM_WORKERS = 4
    DEVICE = 'cuda:1'#torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    #HEIGHT=512 #for augmentation
    #WIDTH=512
    
    CLASSES = 352#1272
    SCALE = 30 
    MARGIN = 0.5

    MODEL_NAME =  'tf_efficientnet_b4'
    FC_DIM = 512
    
    #LR
    LR_START = 1e-5
    
    weight_decay = 0.0
    optimizer_name = 'adam'

In [None]:
def seed_setting():
    seed = Config.SEED
    torch.cuda.set_device(Config.DEVICE)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

In [None]:
def stratify_df(df):
    
    train_df = df.copy()
    
    train_df['fold'] = -1
    
    n_folds = Config.N_FOLDS
    
    strat_kfold = StratifiedKFold(n_splits=n_folds, random_state = Config.SEED, shuffle=True)

    for i, (_, train_index) in enumerate(strat_kfold.split(train_df.index, train_df['label_group'])):
        train_df.iloc[train_index,-1] = i

    train_df['fold'] = train_df['fold'].astype('int')
    
    if n_folds == 10:
        train = train_df[train_df.fold != 0].reset_index(drop=True)
        valid = train_df[train_df.fold == 0].reset_index(drop=True)
        
        return train,valid

### Neptune 

In [None]:
# neptune.init(project_qualified_name = "younghoon/Deep-rec",
#               api_token = "eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJmMTAxYTc0NS1kMWFlLTQwNjEtYWQ2OS04ODM3ZGI1YTA2ZjUifQ==",
#               )
# #pass parameters to create experiment
# neptune.create_experiment(params=  None , name= Config.MODEL_NAME, description = f'train {Config.EPOCHS}'
#                           , tags=['efficientnet_b4','30epochs','one-cycle-lr'] )

### Dataset

In [None]:
class KfashionDataset(torch.utils.data.Dataset):
    
    def __init__(self,df,transform=None):
        self.df = df
        self.root_dir = Config.DATA_DIR
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,idx):
        
        row = self.df.iloc[idx]
        img_path = os.path.join(self.root_dir, row.image_name)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        label = row.label_group
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        
        return {
            'image' : image,
            'label' : torch.tensor(label).long()
        }

### Augmentation

In [None]:
def get_train_transforms():
    return albumentations.Compose(
        [
            albumentations.Resize(Config.IMG_SIZE, Config.IMG_SIZE, always_apply=True),
            albumentations.HorizontalFlip(p=0.5),
            albumentations.VerticalFlip(p=0.5),
            albumentations.Rotate(limit=120, p=0.8),
            albumentations.RandomBrightness(limit=(0.09, 0.6), p=0.5),
            albumentations.Normalize(mean = Config.MEAN, std = Config.STD),
            ToTensorV2(p=1.0),
        ])

In [None]:
def get_valid_transforms():

    return albumentations.Compose(
        [
            albumentations.Resize(Config.IMG_SIZE, Config.IMG_SIZE,always_apply=True),
            albumentations.Normalize(),
        ToTensorV2(p=1.0)
        ]
    )

### Scheduler

### Optimizer

### Activation

In [None]:
#credit : https://github.com/tyunist/memory_efficient_mish_swish/blob/master/mish.py

''' I just wanted to understand and implement custom backward activation in PyTorch so I choose this.
    You can also simply use this function below too.

class Mish(nn.Module):
    def __init__(self):
        super(Mish, self).__init__()

    def forward(self, input):
        return input * (torch.tanh(F.softplus(input)))
'''

class Mish_func(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, i):
        result = i * torch.tanh(F.softplus(i))
        ctx.save_for_backward(i)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_tensors[0]
        v = 1. + i.exp()
        h = v.log() 
        grad_gh = 1./h.cosh().pow_(2) 

        # Note that grad_hv * grad_vx = sigmoid(x)
        #grad_hv = 1./v  
        #grad_vx = i.exp()
        
        grad_hx = i.sigmoid()

        grad_gx = grad_gh *  grad_hx #grad_hv * grad_vx 
        
        grad_f =  torch.tanh(F.softplus(i)) + i * grad_gx 
        
        return grad_output * grad_f 


class Mish(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        print("Mish initialized")
        pass
    def forward(self, input_tensor):
        return Mish_func.apply(input_tensor)

In [None]:
def replace_activations(model, existing_layer, new_layer):
    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            model._modules[name] = replace_activations(module, existing_layer, new_layer)

        if type(module) == existing_layer:
            layer_old = module
            layer_new = new_layer
            model._modules[name] = layer_new
    return model

### Modeling

In [None]:
class ArcMarginProduct(nn.Module):
    def __init__(self, in_features, out_features, scale=30.0, margin=0.50, easy_margin=False, ls_eps=0.0):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.scale = scale
        self.margin = margin
        self.ls_eps = ls_eps  # label smoothing
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(margin)
        self.sin_m = math.sin(margin)
        self.th = math.cos(math.pi - margin)
        self.mm = math.sin(math.pi - margin) * margin

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device=Config.DEVICE)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.scale
        return output, nn.CrossEntropyLoss()(output,label)

class KfashionModel(nn.Module):

    def __init__(
        self,
        n_classes = Config.CLASSES,
        model_name = Config.MODEL_NAME,
        fc_dim = Config.FC_DIM,
        margin = Config.MARGIN,
        scale = Config.SCALE,
        use_fc = True,
        pretrained = True):


        super(KfashionModel,self).__init__()
        print('Building Model Backbone for {} model'.format(model_name))

        self.backbone = timm.create_model(model_name, pretrained=pretrained)

        if model_name == 'resnext50_32x4d':
            final_in_features = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()
            self.backbone.global_pool = nn.Identity()

        elif 'efficientnet' in model_name:
            final_in_features = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Identity()
            self.backbone.global_pool = nn.Identity()
        
        elif 'nfnet' in model_name:
            final_in_features = self.backbone.head.fc.in_features
            self.backbone.head.fc = nn.Identity()
            self.backbone.head.global_pool = nn.Identity()

        self.pooling =  nn.AdaptiveAvgPool2d(1)

        self.use_fc = use_fc

        if use_fc:
            self.dropout = nn.Dropout(p=0.0)
            self.fc = nn.Linear(final_in_features, fc_dim)
            self.bn = nn.BatchNorm1d(fc_dim)
            self._init_params()
            final_in_features = fc_dim

        self.final = ArcMarginProduct(
            final_in_features,
            n_classes,
            scale = scale,
            margin = margin,
            easy_margin = False,
            ls_eps = 0.0
        )

    def _init_params(self):
        nn.init.xavier_normal_(self.fc.weight)
        nn.init.constant_(self.fc.bias, 0)
        nn.init.constant_(self.bn.weight, 1)
        nn.init.constant_(self.bn.bias, 0)

    def forward(self, image, label):
        feature = self.extract_feat(image)
        logits = self.final(feature,label)
        return logits

    def extract_feat(self, x):
        batch_size = x.shape[0]
        x = self.backbone(x)
        x = self.pooling(x).view(batch_size, -1)

        if self.use_fc:
            x = self.dropout(x)
            x = self.fc(x)
            x = self.bn(x)
        return x

### Training function

In [None]:
def train_fn(model, data_loader, optimizer, scheduler, i):
    model.train()
    fin_loss = 0.0
    tk = tqdm(data_loader, desc = "Epoch" + " [TRAIN] " + str(i+1))

    for t,data in enumerate(tk):
        for k,v in data.items():
            data[k] = v.to(Config.DEVICE)
        optimizer.zero_grad()
        _, loss = model(**data)
        loss.mean().backward()
        optimizer.step() 
        fin_loss += loss.mean().item() 

        tk.set_postfix({'loss' : '%.6f' %float(fin_loss/(t+1)), 'LR' : optimizer.param_groups[0]['lr']})
        
        #neptune.log_metric('loss_tr',float(fin_loss/(t+1)))
        #neptune.log_metric('train_lr',optimizer.param_groups[0]['lr'])
        
    #scheduler.step()

    return fin_loss / len(data_loader)

def eval_fn(model, data_loader, i):
    model.eval()
    fin_loss = 0.0
    tk = tqdm(data_loader, desc = "Epoch" + " [VALID] " + str(i+1))

    with torch.no_grad():
        for t,data in enumerate(tk):
            for k,v in data.items():
                data[k] = v.to(Config.DEVICE)
            _, loss = model(**data)
            fin_loss += loss.mean().item() 

            tk.set_postfix({'loss' : '%.6f' %float(fin_loss/(t+1))})
            
            #neptune.log_metric('loss_valid', float(fin_loss/(t+1)))
            
        return fin_loss / len(data_loader)

In [None]:
def run_training():
    
    seed_setting()
    
    df = pd.read_csv(Config.TRAIN_CSV, index_col=0)
    df = df.reset_index(drop=True)
    
    train,valid = stratify_df(df)
    print(f"train shape : {train.shape}")
    print(f"validation shape : {valid.shape}")
    
    print(train.label_group.nunique())
    print(valid.label_group.nunique())
    
    #train
    train_dataset = KfashionDataset(train, transform = get_train_transforms())    
    trainloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size = Config.BATCH_SIZE,
        pin_memory = True,
        num_workers = Config.NUM_WORKERS,
        shuffle = True,
        drop_last = True
    )
    
    #valid 추가
    valid_dataset = KfashionDataset(valid, transform = get_valid_transforms())
    validloader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size = Config.BATCH_SIZE,
        num_workers = Config.NUM_WORKERS,
        shuffle = False,
        pin_memory = True,
        drop_last = False
        )
    
    
    model = KfashionModel()
    #model = nn.DataParallel(model)
    model.to(Config.DEVICE)
    
    #existing_layer = torch.nn.SiLU
    #new_layer = Mish()
    #model = replace_activations(model, existing_layer, new_layer) # in eca_nfnet_l0 SiLU() is used, but it will be replace by Mish()
    
    optimizer = Adam(model.parameters(), lr = Config.LR_START)
    #scheduler = OneCycleLR(optimizer, max_lr = 2e-3, steps_per_epoch = len(trainloader),epochs=Config.EPOCHS)
    
    save_dir = f'./{Config.SAVE_NAME}_{Config.MODEL_NAME}_{Config.EPOCHS}_{Config.optimizer_name}_Weights'
    
    if os.path.exists(save_dir) == False :
        print('Making Weights Folder')
        os.mkdir(save_dir)
    
    
    for i in range(Config.EPOCHS):
        avg_loss_train = train_fn(model, trainloader, optimizer, None, i)
        avg_loss_valid = eval_fn(model, validloader,i)
        torch.save(model.state_dict(),f'{save_dir}/best_{i}EpochStep.pt')

In [None]:
run_training()