In [None]:
from conf import *
from loader import *
from models import *
from trainer import *
from loss import *
from utils import *
from scheduler import *

import random

import os
import sys
import time
import numpy as np
import pandas as pd
import cv2
import PIL.Image

from tqdm.notebook import tqdm
from sklearn.metrics import roc_auc_score
import torch
from torch.utils.data import TensorDataset, DataLoader, Dataset

from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data.sampler import SubsetRandomSampler, RandomSampler, SequentialSampler

import albumentations as A
import geffnet

from sklearn.model_selection import StratifiedKFold

In [None]:
from conf import *

import torch
import random
import numpy as np
import os

from typing import Dict, Tuple, Any

from sklearn.metrics import roc_auc_score
from scipy.special import expit, softmax
from sklearn.metrics import precision_score


def set_seed(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True



def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def global_average_precision_score(y_true, y_pred, ignore_non_landmarks=False):
    indexes = np.argsort(y_pred[1])[::-1]
    queries_with_target = (y_true < args.n_classes).sum()
    correct_predictions = 0
    total_score = 0.
    i = 1
    for k in indexes:
        if ignore_non_landmarks and y_true[k] == args.n_classes:
            continue
        if y_pred[0][k] == args.n_classes:
            continue
        relevance_of_prediction_i = 0
        if y_true[k] == y_pred[0][k]:
            correct_predictions += 1
            relevance_of_prediction_i = 1
        precision_at_rank_i = correct_predictions / i
        total_score += precision_at_rank_i * relevance_of_prediction_i
        i += 1
    return 1 / queries_with_target * total_score

def comp_metric(y_true, logits, ignore_non_landmarks=False):
    
    score = global_average_precision_score(y_true, logits, ignore_non_landmarks=ignore_non_landmarks)
    return score

def cos_similarity_matrix(a, b, eps=1e-8):
    a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
    a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
    b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
    sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
    return sim_mt

def get_topk_cossim(test_emb, tr_emb, batchsize = 64, k=10, device='cuda:0',verbose=True):
    tr_emb = torch.tensor(tr_emb, dtype = torch.float32, device=torch.device(device))
    test_emb = torch.tensor(test_emb, dtype = torch.float32, device=torch.device(device))
    vals = []
    inds = []
    for test_batch in test_emb.split(batchsize):
        sim_mat = cos_similarity_matrix(test_batch, tr_emb)
        vals_batch, inds_batch = torch.topk(sim_mat, k=k, dim=1)
        vals += [vals_batch.detach().cpu()]
        inds += [inds_batch.detach().cpu()]
    vals = torch.cat(vals)
    inds = torch.cat(inds)
    return vals, inds

In [None]:
import os
import cv2
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader 
import albumentations as A


class LMDataset(Dataset): 
    def __init__(self, csv, aug=None, normalization='simple', is_test=False): 
        self.labels = csv.landmark_id.values
        self.csv = csv.filepath.values
        self.aug = aug
        self.normalization = normalization
        self.is_test = is_test

    def __getitem__(self, index):
        img_path = self.csv[index]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        if self.aug:
            img = self.augment(img)
        img = img.astype(np.float32)

        if self.normalization:
            img = self.normalize_img(img)

        tensor = self.to_torch_tensor(img)
        if self.is_test:
            feature_dict = {'idx':torch.tensor(index).long(),
                            'input':tensor}
        else:
            target = torch.tensor(self.labels[index])
            feature_dict = {'idx':torch.tensor(index).long(),
                            'input':tensor,
                            'target':target.float().long()}
        return feature_dict

    def __len__(self): 
        return len(self.csv)

    def augment(self,img):
        img_aug = self.aug(image=img)['image']
        return img_aug.astype(np.float32)

    def normalize_img(self,img):
        if self.normalization == 'imagenet':
            mean = np.array([123.675, 116.28 , 103.53 ], dtype=np.float32)
            std = np.array([58.395   , 57.120, 57.375   ], dtype=np.float32)
            img = img.astype(np.float32)
            img -= mean
            img *= np.reciprocal(std, dtype=np.float32)
        elif self.normalization == 'inception':
            mean = np.array([0.5, 0.5 , 0.5], dtype=np.float32)
            std = np.array([0.5, 0.5 , 0.5], dtype=np.float32)
            img = img.astype(np.float32)
            img = img/255.
            img = img-mean
            img = img*np.reciprocal(std, dtype=np.float32)
        else:
            pass
        return img
    
    def to_torch_tensor(self,img):
        return torch.from_numpy(img.transpose((2, 0, 1)))


def clahe(img, clip_limit=2.0, tile_grid_size=(8, 8)):
    if img.dtype != np.uint8:
        raise TypeError("clahe supports only uint8 inputs")

    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)

    if len(img.shape) == 2:
        img = clahe.apply(img)
    else:
        img = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
        img[:, :, 0] = clahe.apply(img[:, :, 0])
        img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB)

    return img

In [None]:
from conf import *

import torch
from torch import nn
import math


class FocalLoss(nn.Module):

    def __init__(self, gamma=0, eps=1e-7):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        #print(self.gamma)
        self.eps = eps
        self.ce = torch.nn.CrossEntropyLoss(reduction="none")

    def forward(self, input, target):
        logp = self.ce(input, target)
        p = torch.exp(-logp)
        loss = (1 - p) ** self.gamma * logp
        return loss.mean()


class ArcFaceLoss(nn.modules.Module):
    def __init__(self, s=45.0, m=0.1, crit="bce", weight=None, reduction="mean"):
        super().__init__()

        self.weight = weight
        self.reduction = reduction
        
        if crit == "focal":
            self.crit = FocalLoss(gamma=args.focal_loss_gamma)
        elif crit == "bce":
            self.crit = nn.CrossEntropyLoss(reduction="none")   

        if s is None:
            self.s = torch.nn.Parameter(torch.tensor([45.], requires_grad=True, device='cuda'))
        else:
            self.s = s

        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m
        
    def forward(self, logits, labels):

        # logits = logits.float()
        cosine = logits
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        
        labels2 = torch.zeros_like(cosine)
        labels2.scatter_(1, labels.view(-1, 1).long(), 1)
        output = (labels2 * phi) + ((1.0 - labels2) * cosine)

        s = self.s

        output = output * s
        loss = self.crit(output, labels)

        if self.weight is not None:
            w = self.weight[labels].to(logits.device)

            loss = loss * w
            ### human coding
            class_weights_norm = 'batch'
            if class_weights_norm == "batch":
                loss = loss.sum() / w.sum()
            if class_weights_norm == "global":
                loss = loss.mean()
            else:
                loss = loss.mean()
            
            return loss
        if self.reduction == "mean":
            loss = loss.mean()
        elif self.reduction == "sum":
            loss = loss.sum()
        return loss


def loss_fn(metric_crit, target_dict, output_dict, val=False):
    
    y_true = target_dict['target']
    y_pred = output_dict['logits']
    #ignore invalid classes for val loss
    mask = y_true < args.n_classes
    if mask.sum() == 0:
        return torch.zeros(1,  device = y_pred.device)
    loss = metric_crit(y_pred[mask], y_true[mask])

    return loss



In [None]:
from torch.optim import lr_scheduler
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import CosineAnnealingLR
from warmup_scheduler import GradualWarmupScheduler  # https://github.com/ildoonet/pytorch-gradual-warmup-lr


class GradualWarmupSchedulerV2(GradualWarmupScheduler):
    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        super(GradualWarmupSchedulerV2, self).__init__(optimizer, multiplier, total_epoch, after_scheduler)
    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_last_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]


class GradualWarmupScheduler(_LRScheduler):
    """ Gradually warm-up(increasing) learning rate in optimizer.
    Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
        total_epoch: target learning rate is reached at total_epoch, gradually
        after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
    """

    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        self.multiplier = multiplier
        if self.multiplier < 1.:
            raise ValueError('multiplier should be greater thant or equal to 1.')
        self.total_epoch = total_epoch
        self.after_scheduler = after_scheduler
        self.finished = False
        super(GradualWarmupScheduler, self).__init__(optimizer)

    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_last_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]

        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

    def step_ReduceLROnPlateau(self, metrics, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.last_epoch = epoch if epoch != 0 else 1  # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
        if self.last_epoch <= self.total_epoch:
            warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
            for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
                param_group['lr'] = lr
        else:
            if epoch is None:
                self.after_scheduler.step(metrics, None)
            else:
                self.after_scheduler.step(metrics, epoch - self.total_epoch)

    def step(self, epoch=None, metrics=None):
        if type(self.after_scheduler) != ReduceLROnPlateau:
            if self.finished and self.after_scheduler:
                if epoch is None:
                    self.after_scheduler.step(None)
                else:
                    self.after_scheduler.step(epoch - self.total_epoch)
                self._last_lr = self.after_scheduler.get_last_lr()
            else:
                return super(GradualWarmupScheduler, self).step(epoch)
        else:
            self.step_ReduceLROnPlateau(metrics, epoch)

In [None]:
from conf import *
from utils import *

#from pytorchcv.model_provider import get_model as ptcv_get_model
import timm
from torch import nn

import geffnet

import math
import torch
from torch.nn import functional as F
from torch.nn.parameter import Parameter

class ArcMarginProduct(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        # stdv = 1. / math.sqrt(self.weight.size(1))
        # self.weight.data.uniform_(-stdv, stdv)

    def forward(self, features):
        cosine = F.linear(F.normalize(features), F.normalize(self.weight))
        return cosine

def gem(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)

class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6, p_trainable=True):
        super(GeM,self).__init__()
        if p_trainable:
            self.p = Parameter(torch.ones(1)*p)
        else:
            self.p = p
        self.eps = eps

    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps)       
    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'

    
class Backbone(nn.Module):

    
    def __init__(self, name='resnet18', pretrained=True):
        super(Backbone, self).__init__()
        self.net = timm.create_model(name, pretrained=pretrained)
        
        if 'regnet' in name:
            self.out_features = self.net.head.fc.in_features
        elif 'csp' in name:
            self.out_features = self.net.head.fc.in_features
        elif 'res' in name: #works also for resnest
            self.out_features = self.net.fc.in_features
        elif 'efficientnet' in name:
            self.out_features = self.net.classifier.in_features
        elif 'densenet' in name:
            self.out_features = self.net.classifier.in_features
        elif 'senet' in name:
            self.out_features = self.net.fc.in_features
        elif 'inception' in name:
            self.out_features = self.net.last_linear.in_features
            # self.out_features = self.net.fc.in_features
        else:
            self.out_features = self.net.classifier.in_features

    def forward(self, x):
        x = self.net.forward_features(x)

        return x

    
class Net(nn.Module):
    def __init__(self, args, pretrained=True):
        super(Net, self).__init__()
        
        self.args = args
        self.backbone = Backbone(args.backbone, pretrained=pretrained)
        
        if args.pool == "gem":
            self.global_pool = GeM(p_trainable=args.p_trainable)
        elif args.pool == "identity":
            self.global_pool = torch.nn.Identity()
        else:
            self.global_pool = nn.AdaptiveAvgPool2d(1)

        self.embedding_size = args.embedding_size        
        
        # https://www.groundai.com/project/arcface-additive-angular-margin-loss-for-deep-face-recognition
        if args.neck == "option-D":
            self.neck = nn.Sequential(
                nn.Linear(self.backbone.out_features, self.embedding_size, bias=True),
                nn.BatchNorm1d(self.embedding_size),
                torch.nn.PReLU()
                # torch.nn.PReLU()
            )
        elif args.neck == "option-F":
            self.neck = nn.Sequential(
                nn.Dropout(0.3),
                nn.Linear(self.backbone.out_features, self.embedding_size, bias=True),
                nn.BatchNorm1d(self.embedding_size),
                torch.nn.PReLU()
            )
        else:
            self.neck = nn.Sequential(
                nn.Linear(self.backbone.out_features, self.embedding_size, bias=False),
                nn.BatchNorm1d(self.embedding_size),
            )
            
        self.head = ArcMarginProduct(self.embedding_size, args.n_classes)
        
        if args.pretrained_weights is not None:
            self.load_state_dict(torch.load(args.pretrained_weights, map_location='cpu'), strict=False)
            print('weights loaded from',args.pretrained_weights)

    def forward(self, input_dict, get_embeddings=False, get_attentions=False):

        x = input_dict['input']
        x = self.backbone(x)
        
        x = self.global_pool(x)
        x = x[:,:,0,0]
        
        x = self.neck(x)

        logits = self.head(x)

        # if not torch.isfinite(logits[0, 0]):
        #     print(input_dict['input'], x)
        
        if get_embeddings:
            return {'logits': logits, 'embeddings': x}
        else:
            return {'logits': logits}

In [None]:
from conf import *
from utils import *

from sklearn.utils.class_weight import compute_class_weight

import os
import sys
import time
import numpy as np
import pandas as pd
import cv2
import PIL.Image
from tqdm import tqdm
from sklearn.metrics import roc_auc_score

from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler, QuantileTransformer

import torch
from torch.utils.data import TensorDataset, DataLoader,Dataset
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import albumentations as A
import geffnet

from loss import *

def optimizer_zero_grad(epoch, batch_idx, optimizer, optimizer_idx):
    # optimizer.zero_grad()
    for param in self.model.parameters():
        param.grad = None

def train_epoch(metric_crit, epoch, model, loader, optimizer):
    criterion = nn.CrossEntropyLoss()
    model.train()
    train_loss = []
    arcface = []
    bar = tqdm(loader)
    for batch in bar:
        batch['input'] = batch['input'].to(args.device)
        batch['target'] = batch['target'].to(args.device)
        
        optimizer.zero_grad()

        logits = model(batch)
        loss = loss_fn(metric_crit, batch, logits)
        if not torch.isfinite(loss):
            print('WARNING: non-finite loss, ending training ')

            print(loss, batch, logits, batch['input'].shape, batch['target'].shape)
            exit(1)

        if args.arcface_s is None:
            s = metric_crit.s.detach().cpu().numpy()
        elif args.arcface_s == -1:
            s = 0
        else:
            s = metric_crit.s
        
        if args.distributed_backend == "ddp":
            step = epoch*args.batch_size*len(args.gpus.split(','))*args.gradient_accumulation_steps
        else:
            step = epoch*args.batch_size*args.gradient_accumulation_steps

        loss.backward()
        optimizer.step()
        loss_np = loss.detach().cpu().numpy()
        train_loss.append(loss_np)
        arcface.append(s)
        
        bar.set_description('loss: %.5f, arcface_s: %.5f' % (loss_np, s))
    
    train_loss = np.mean(train_loss)
    arcface = np.mean(arcface)

    return train_loss

def get_trans(img, I):
    if I >= 4:
        img = img.transpose(2,3)
    if I % 4 == 0:
        return img
    elif I % 4 == 1:
        return img.flip(2)
    elif I % 4 == 2:
        return img.flip(3)
    elif I % 4 == 3:
        return img.flip(2).flip(3)


def val_epoch(metric_crit_val, model, loader, n_test=1, get_output=False):
    model.eval()
    val_outputs = []
    with torch.no_grad():
        for batch in tqdm(loader):
            batch['input'] = batch['input'].to(args.device)
            batch['target'] = batch['target'].to(args.device)

            output_dict = model(batch, get_embeddings=True)
            loss = loss_fn(metric_crit_val, batch, output_dict, val=True)

            # temp_batch = batch.copy()
            # for I in range(n_test):
            #     if I == 0:
            #         output_dict = model(temp_batch, get_embeddings=True)
            #         logits = output_dict['logits']
            #         embeddings = output_dict['embeddings']
            #     else:
            #         temp_batch['input'] = get_trans(temp_batch['input'], I)
            #         output_dict2 = model(temp_batch, get_embeddings=False)
            #         logits += output_dict2['logits']
            # else:
            #     logits /= n_test
            #     output_dict['logits'] = logits[0]

            # (values, indices) = torch.topk(logits, 3, dim=1)
            # preds = indices[:, 0]
            # preds_conf = values[:, 0]

            logits = output_dict['logits']
            embeddings = output_dict['embeddings']

            preds_conf, preds = torch.max(logits.softmax(1),1)

            # allowed_classes = torch.Tensor(list(range(args.n_classes))).long().to(logits.device)

            # preds_conf_pp, preds_pp = torch.max(logits.gather(1,allowed_classes.repeat(logits.size(0),1)).softmax(1),1)
            # preds_pp = allowed_classes[preds_pp]

            targets = batch['target']

            output = dict({
                'idx':batch['idx'],
                'embeddings': embeddings,
                'val_loss': loss.view(1),
                # 'val_loss': torch.tensor([0], device='cuda:0'),
                'preds': preds,
                'preds_conf':preds_conf,
                # 'preds_pp': preds_pp,
                # 'preds_conf_pp':preds_conf_pp,
                'targets': targets,
                
            })
            val_outputs += [output] 

    return val_outputs

def val_end(val_outputs):
    out_val = {}
    for key in val_outputs[0].keys():
        out_val[key] = torch.cat([o[key] for o in val_outputs])

    device = out_val["targets"].device

    for key in out_val.keys():
            out_val[key] = out_val[key].detach().cpu().numpy().astype(np.float32)

    val_score = comp_metric(out_val["targets"], [out_val["preds"], out_val["preds_conf"]])
    val_score_landmarks = comp_metric(out_val["targets"], [out_val["preds"], out_val["preds_conf"]])

    # val_score_pp = comp_metric(out_val["targets"], [out_val["preds_pp"], out_val["preds_conf_pp"]])
    # val_score_landmarks_pp = comp_metric(out_val["targets"], [out_val["preds_pp"], out_val["preds_conf_pp"]])

    val_loss_mean = np.sum(out_val["val_loss"])
    
    results = {'val_loss': val_loss_mean,
                     'val_gap':val_score,
                     'val_gap_landmarks':val_score_landmarks,
                    #  'val_gap_pp':val_score_pp,
                    #  'val_gap_landmarks_pp':val_score_landmarks_pp,
                    }

    return results

In [None]:
set_seed(args.seed)

train = pd.read_csv('../data/public/train.csv')
skf = StratifiedKFold(n_splits=args.n_splits, shuffle=True, random_state=args.seed)
train['fold'] = 0
for idx, [trn, val] in enumerate(skf.split(train, train['landmark_id'])):
    train.loc[val, 'fold'] = idx
train['filepath'] = [os.path.join('../data/train', str(lm_id), str(id)+'.JPG') for lm_id, id in zip(train['landmark_id'], train['id'])]

if args.class_weights == "log":
    val_counts = train.landmark_id.value_counts().sort_index().values
    class_weights = 1/np.log1p(val_counts)
    class_weights = (class_weights / class_weights.sum()) * args.n_classes
    class_weights = torch.tensor(class_weights, dtype=torch.float32)
else:
    class_weights = None

trn = train.loc[train['fold']!=args.fold].reset_index(drop=True)
val = train.loc[train['fold']==args.fold].reset_index(drop=True)

print(f'trn size : {trn.landmark_id.nunique()}, last batch size : {trn.shape[0]%args.batch_size}') #: 1049
# print(len(trn)) #: 70481
# image size : (540, 960, 3)

if args.DEBUG:
    trn = trn.iloc[:2500]
    val = val.iloc[:2500]

train_dataset = LMDataset(trn, aug=args.tr_aug, normalization=args.normalization)
valid_dataset = LMDataset(val, aug=args.val_aug, normalization=args.normalization)

train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=False)

model = Net(args)
model = model.to(args.device)

# optimizer definition
metric_crit = ArcFaceLoss(args.arcface_s, args.arcface_m, crit=args.crit, weight=class_weights)
metric_crit_val = ArcFaceLoss(args.arcface_s, args.arcface_m, crit=args.crit, weight=None, reduction="sum")
if args.optim=='sgd':
    optimizer = torch.optim.SGD([{'params': model.parameters()}, {'params': metric_crit.parameters()}], lr=args.lr, momentum=0.9, nesterov=True, weight_decay=args.weight_decay)
elif args.optim=='adamw':
    optimizer = torch.optim.AdamW([{'params': model.parameters()}, {'params': metric_crit.parameters()}], lr=args.lr, weight_decay=args.weight_decay, amsgrad=False)

scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.cosine_epo)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=args.warmup_epo, after_scheduler=scheduler_cosine)

optimizer.zero_grad()
optimizer.step()

val_pp = 0.
model_file = f'../model/{args.backbone}_best_fold_{args.fold}.pth'
for epoch in range(1, args.cosine_epo+args.warmup_epo+1+49):

    scheduler_warmup.step(epoch-1)
    print(time.ctime(), 'Epoch:', epoch)

    train_loss = train_epoch(metric_crit, epoch, model, train_loader, optimizer)
    if epoch>1:
        val_outputs = val_epoch(metric_crit_val, model, valid_loader)
        np.save('../submit/val_outputs_best.npy', val_outputs)
        results = val_end(val_outputs)
        print(results)

        val_loss = results['val_loss']
        val_gap = results['val_gap']

        content = time.ctime() + ' ' + f'Fold {args.fold}, Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, train loss: {train_loss:.5f}, valid loss: {val_loss:.5f}, val_gap: {val_gap:.4f}'
        print(content)
        with open(f'../model/log_fold_{args.backbone}_{args.fold}.txt', 'a') as appender:
            appender.write(content + '\n')

        val_gap_pp = val_gap
        if val_gap_pp > val_pp:
            print('val_gap_pp_max ({:.6f} --> {:.6f}). Saving model ...'.format(val_pp, val_gap_pp))
            torch.save(model.state_dict(), model_file)
            val_pp = val_gap_pp