# Import

In [1]:
import warnings
warnings.filterwarnings('ignore')

import logging
import logging.config
import random
import os
import time
import datetime
import math

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.nn.modules.transformer import _get_activation_fn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, Sampler
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision.ops import FeaturePyramidNetwork, DeformConv2d

import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
from RandAugment import RandAugment

import pandas as pd
import numpy as np
from glob import glob
from matplotlib import pyplot as plt
from sklearn.metrics import f1_score
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import train_test_split, KFold
from sklearn import preprocessing
from sklearn.impute import SimpleImputer

import torch_optimizer as optim
import ttach as tta

from tqdm.notebook import tqdm
from pprint import pprint

from robust_loss_pytorch import AdaptiveLossFunction
import robust_loss_pytorch

import neptune.new as neptune

# Hyperparameter

In [2]:
neptune_run = neptune.init(
    project="djlee/dacon-growing",
    api_token=os.environ["NEPTUNE_API_TOKEN"],
)

params = {
    "model_name": "tf_efficientnetv2_b0",
    "optimizer": "sgd_sam",
    "criterion": "l1",
    "huber_delta": 0,
    "scheduler": "cosineannealinglr",
    "num_class": 1,
    "epochs": 25,
    "batch": 64,
    "learning_rate": 16e-2,
    "weight_decay": 5e-2,
    "drop_out_rate": 0.4,
    "drop_path_rate": 0.15,
    "max_norm": 1,
    "num_workers": 10,
    "train_image_size": (224, 224),
    "test_image_size": (224, 224),
    "ib_start_epoch": 9999,
    "cutmix": False,
    "mixup": False,
    "mixup_alpha": 0.2,
    "mix_end_epoch": -1,
    "seed": 42,
    "scaler": "RobustScaler",
    
    # LSTM
    # "lstm_input_size": 18,
    # "lstm_hidden_size": 64,
    # "lstm_num_layers": 2,
    # "lstm_bidirectional": True,
    # "lstm_batch_first": True,
}

neptune_run["parameters"] = params

TRAIN_PATH = "data/train/"
TEST_PATH = "data/test/"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

SAVE_PATH = f"models/{params['model_name']}_{datetime.datetime.now().strftime('%y-%m-%d-%H:%M:%S')}/"
os.makedirs(SAVE_PATH)

https://app.neptune.ai/djlee/dacon-growing/e/DAC-409
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#.stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


# Logger

In [3]:
config = {
    "version": 1,
    "formatters": {
        "simple": {"format": "[%(asctime)s] %(message)s", "datefmt": "%Y-%m-%d %H:%M:%S"},
    },
    "handlers": {
        "console": {
            "class": "logging.StreamHandler",
            "formatter": "simple",
            "level": "INFO",
        },
        "file": {
            "class": "logging.FileHandler",
            "filename": f"{SAVE_PATH}/train.log",
            "formatter": "simple",
            "level": "INFO",
        },
    },
    "root": {"handlers": ["console", "file"], "level": "INFO"},
    "loggers": {"parent": {"level": "INFO"}, "parent.child": {"level": "DEBUG"},},
}

logging.config.dictConfig(config)
logger = logging.getLogger()
logger.info(params)

[2022-05-20 16:46:15] {'model_name': 'tf_efficientnetv2_b0', 'optimizer': 'sgd_sam', 'criterion': 'l1', 'huber_delta': 0, 'scheduler': 'cosineannealinglr', 'num_class': 1, 'epochs': 25, 'batch': 64, 'learning_rate': 0.16, 'weight_decay': 0.05, 'drop_out_rate': 0.4, 'drop_path_rate': 0.15, 'max_norm': 1, 'num_workers': 10, 'train_image_size': (224, 224), 'test_image_size': (224, 224), 'ib_start_epoch': 9999, 'cutmix': False, 'mixup': False, 'mixup_alpha': 0.2, 'mix_end_epoch': -1, 'seed': 42, 'scaler': 'RobustScaler'}


# Fix Seed

In [4]:
random.seed(params["seed"])
np.random.seed(params["seed"])
os.environ["PYTHONHASHSEED"] = str(params["seed"])
torch.manual_seed(params["seed"])
torch.cuda.manual_seed(params["seed"])
torch.backends.cudnn.deterministic = False  # True 할 시 연산속도 감소. 마지막에 고정시킬 때 사용 권장.
torch.backends.cudnn.benchmark = True

# Class

## Mean, Std

In [5]:
class OnlineMeanStd:
    def __init__(self):
        pass

    def __call__(self, dataset, batch_size, method='strong', mode="train"):
        """
        Calculate mean and std of a dataset in lazy mode (online)
        On mode strong, batch size will be discarded because we use batch_size=1 to minimize leaps.
        :param dataset: Dataset object corresponding to your dataset
        :param batch_size: higher size, more accurate approximation
        :param method: weak: fast but less accurate, strong: slow but very accurate - recommended = strong
        :return: A tuple of (mean, std) with size of (3,)
        """

        if method == 'weak':
            loader = DataLoader(dataset=dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=1,
                                pin_memory=0)
            mean = 0.
            std = 0.
            nb_samples = 0.
            for data in loader:
                data = data[0]
                batch_samples = data.size(0)
                data = data.view(batch_samples, data.size(1), -1)
                mean += data.mean(2).sum(0)
                std += data.std(2).sum(0)
                nb_samples += batch_samples

            mean /= nb_samples
            std /= nb_samples

            return mean, std

        elif method == 'strong':
            loader = DataLoader(dataset=dataset,
                                batch_size=1,
                                shuffle=False,
                                num_workers=1,
                                pin_memory=0)
            cnt = 0
            fst_moment = torch.empty(3)
            snd_moment = torch.empty(3)

            for data in loader:
                if mode == "train":
                    data = data[0] # train -> data[0] / test -> data
                elif mode == "test":
                    data = data
                b, c, h, w = data.shape
                nb_pixels = b * h * w
                sum_ = torch.sum(data, dim=[0, 2, 3])
                sum_of_square = torch.sum(data ** 2, dim=[0, 2, 3])
                fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
                snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)

                cnt += nb_pixels

            return fst_moment, torch.sqrt(snd_moment - fst_moment ** 2)

## Dataset

In [6]:
class CustomDataset(Dataset):
    def __init__(self,
                 df,
                 mode,
                 transform,
                 scaler,
                 imputer,
                 mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225]):
        self.mean = mean
        self.std = std
        self.mode = mode
        self.transform = transform
        self.scaler = scaler
        self.imputer = imputer
        self.data_list = df
            
    def __len__(self):
        if self.mode == "train":
            return len(self.data_list)
        elif self.mode == "test":
            return len(self.data_list[0])

    def __getitem__(self, index):
        if self.mode == "train":
            image = cv2.imread(self.data_list[index][0])
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            label = np.array(self.data_list[index][2]).astype(np.float32)
            image = self.transform(image)["image"]
            meta = pd.read_csv(self.data_list[index][1])
            meta.drop(["시간"], axis=1, inplace=True)
            meta.interpolate(inplace=True)
            meta = self.imputer.transform(meta)
            meta = self.scaler.transform(meta)
            # meta = meta.mean(axis=1)
            return image, torch.from_numpy(meta), torch.from_numpy(label)

        elif self.mode == "test":
            image = cv2.imread(self.data_list[0][index])
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = self.transform(image)["image"]
            meta = pd.read_csv(self.data_list[1][index])
            meta.drop(["시간"], axis=1, inplace=True)
            meta.interpolate(inplace=True)
            meta = self.imputer.transform(meta)
            meta = self.scaler.transform(meta)
            # meta = meta.mean(axis=1)
            return image, torch.from_numpy(meta)

## Sampler

In [7]:
class RandomCycleIter:
    
    def __init__ (self, data, test_mode=False):
        self.data_list = list(data)
        self.length = len(self.data_list)
        self.i = self.length - 1
        self.test_mode = test_mode
        
    def __iter__ (self):
        return self
    
    def __next__ (self):
        self.i += 1
        
        if self.i == self.length:
            self.i = 0
            if not self.test_mode:
                random.shuffle(self.data_list)
            
        return self.data_list[self.i]
    
def class_aware_sample_generator (cls_iter, data_iter_list, n, num_samples_cls=1, is_infinite=False):

    i = 0
    j = 0
    while i < n or is_infinite:
        
        if j >= num_samples_cls:
            j = 0
    
        if j == 0:
            temp_tuple = next(zip(*[data_iter_list[next(cls_iter)]]*num_samples_cls))
            yield temp_tuple[j]
        else:
            yield temp_tuple[j]
        
        i += 1
        j += 1

class ClassAwareSampler (Sampler):
    
    def __init__(self, data_source, num_samples_cls=1, is_infinite=False):
        num_classes = len(np.unique(data_source.labels))
        self.class_iter = RandomCycleIter(range(num_classes))
        cls_data_list = [list() for _ in range(num_classes)]
        for i, label in enumerate(data_source.labels):
            cls_data_list[label].append(i)
        self.data_iter_list = [RandomCycleIter(x) for x in cls_data_list]
        self.num_samples = max([len(x) for x in cls_data_list]) * len(cls_data_list)
        self.num_samples_cls = num_samples_cls

        self.is_infinite = is_infinite
        
    def __iter__ (self):
        return class_aware_sample_generator(self.class_iter, self.data_iter_list,
                                            self.num_samples, self.num_samples_cls, self.is_infinite)
    
    def __len__ (self):
        return self.num_samples
    
def get_sampler():
    return ClassAwareSampler

## Loss

### Cosine CrossEntropy Loss

In [8]:
class CosineCrossEntropyLoss(nn.Module):
    def __init__(self, xent=.1, reduction="mean", weight=None):
        super(CosineCrossEntropyLoss, self).__init__()
        self.xent = xent
        self.reduction = reduction
        self.weight = weight
        self.y = torch.Tensor([1]).cuda()
        
    def forward(self, input, target):
        cosine_loss = F.cosine_embedding_loss(input, F.one_hot(target, num_classes=input.size(-1)), self.y, reduction=self.reduction)
        cent_loss = F.cross_entropy(F.normalize(input), target, reduction=self.reduction, weight=self.weight)
        
        return cosine_loss + self.xent * cent_loss

### Cosine Focal Loss

In [9]:
class FocalCosineLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, xent=.1, reduction="mean", weight=None):
        super(FocalCosineLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.xent = xent
        self.y = torch.Tensor([1]).cuda()
        self.reduction = reduction
        self.weight = weight

    def forward(self, input, target, reduction="mean"):
        cosine_loss = F.cosine_embedding_loss(input, F.one_hot(target, num_classes=input.size(-1)), self.y, reduction=self.reduction)

        cent_loss = F.cross_entropy(F.normalize(input), target, reduction=self.reduction, weight=self.weight)
        pt = torch.exp(-cent_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * cent_loss

        if self.reduction == "mean":
            focal_loss = torch.mean(focal_loss)

        return cosine_loss + self.xent * focal_loss

### Seesaw Loss

In [10]:
class SeesawLoss(torch.nn.Module):
    """
    Implementation of seesaw loss.
    Refers to `Seesaw Loss for Long-Tailed Instance Segmentation (CVPR 2021)
    <https://arxiv.org/abs/2008.10032>
    Args:
        num_classes (int): The number of classes.
                Default to 1000 for the ImageNet dataset.
        p (float): The ``p`` in the mitigation factor.
                Defaults to 0.8.
        q (float): The ``q`` in the compensation factor.
                Defaults to 2.0.
        eps (float): The min divisor to smooth the computation of compensation factor.
                Default to 1e-2.
    """

    def __init__(self, num_classes=params["num_class"],
                 p=0.8, q=1.0, eps=1e-2):
        super().__init__()
        self.num_classes = num_classes
        self.p = p
        self.q = q
        self.eps = eps

        # cumulative samples for each category
        self.register_buffer('accumulated',
                             torch.zeros(self.num_classes, dtype=torch.float))

    def forward(self, outputs, targets):
        # accumulate the samples for each category
        for unique in targets.unique():
            self.accumulated[unique] += (targets == unique.item()).sum()

        onehot_targets = F.one_hot(targets, self.num_classes)
        seesaw_weights = outputs.new_ones(onehot_targets.size())

        # mitigation factor
        if self.p > 0:
            matrix = self.accumulated[None, :].clamp(min=1) / self.accumulated[:, None].clamp(min=1)
            index = (matrix < 1.0).float()
            sample_weights = matrix.pow(self.p) * index + (1 - index)
            mitigation_factor = sample_weights[targets.long(), :]
            seesaw_weights = seesaw_weights * mitigation_factor

        # compensation factor
        if self.q > 0:
            scores = F.softmax(outputs.detach(), dim=1)
            self_scores = scores[torch.arange(0, len(scores)).to(scores.device).long(), targets.long()]
            score_matrix = scores / self_scores[:, None].clamp(min=self.eps)
            index = (score_matrix > 1.0).float()
            compensation_factor = score_matrix.pow(self.q) * index + (1 - index)
            seesaw_weights = seesaw_weights * compensation_factor

        outputs = outputs + (seesaw_weights.log() * (1 - onehot_targets))
        return F.cross_entropy(outputs, targets)

### Focal Loss

In [11]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, label_smoothing=0.1, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.label_smoothing = label_smoothing
        self.reduce = reduce

    def forward(self, inputs, targets):
        ce_loss = nn.CrossEntropyLoss(label_smoothing=self.label_smoothing)(inputs, targets)
        pt = torch.exp(-ce_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * ce_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

### CDB Loss

In [12]:
def sigmoid(x):
    return (1/(1+np.exp(-x)))

class CDB_loss(nn.Module):
  
    def __init__(self, class_difficulty, tau='dynamic', reduction='none'):
        
        super(CDB_loss, self).__init__()
        self.class_difficulty = class_difficulty
        if tau == 'dynamic':
            bias = (1 - np.min(class_difficulty))/(1 - np.max(class_difficulty) + 0.01)
            tau = sigmoid(bias)
        else:
            tau = float(tau) 
        self.weights = self.class_difficulty ** tau
        self.weights = self.weights / self.weights.sum() * len(self.weights)
        self.reduction = reduction
        self.loss = nn.CrossEntropyLoss(weight=torch.FloatTensor(self.weights), reduction=self.reduction).cuda()
        
    def forward(self, input, target):
        return self.loss(input, target)

### CB Loss

In [13]:
def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gamma, device):
    effective_num = 1.0 - np.power(beta, samples_per_cls)
    weights = (1.0 - beta) / np.array(effective_num)
    weights = weights / np.sum(weights) * no_of_classes

    labels_one_hot = F.one_hot(labels, no_of_classes).float()
    labels_one_hot = labels_one_hot.to(device)

    weights = torch.tensor(weights).float()
    weights = weights.unsqueeze(0)
    weights = weights.repeat(labels_one_hot.shape[0],1).to(device)
    weights = weights* labels_one_hot
    weights = weights.sum(1)
    weights = weights.unsqueeze(1)
    weights = weights.repeat(1,no_of_classes)

    if loss_type == "focal":
        cb_loss = focal_loss(labels_one_hot, logits, weights, gamma)
    elif loss_type == "sigmoid":
        cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weights = weights)
    elif loss_type == "softmax":
        pred = logits.softmax(dim = 1)
        cb_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights)
    return cb_loss

### Equalized Focal Loss

In [14]:
def equalized_focal_loss(logits,
                         targets,
                         gamma_b=2,
                         scale_factor=8,
                         reduction="mean"):
    """ EFL loss"""
    ce_loss = F.cross_entropy(logits, targets, reduction="none", label_smoothing=0.1)
    outputs = F.cross_entropy(logits, targets, label_smoothing=0.1)
    log_pt = -ce_loss
    pt = torch.exp(log_pt)

    targets = targets.view(-1, 1)
    grad_i = torch.autograd.grad(outputs=-outputs, inputs=logits)[0]
    grad_i = grad_i.gather(1, targets)
    pos_grad_i = F.relu(grad_i).sum()
    neg_grad_i = F.relu(-grad_i).sum()
    neg_grad_i += 1e-9
    grad_i = pos_grad_i / neg_grad_i
    grad_i = torch.clamp(grad_i, min=0, max=1)

    dy_gamma = gamma_b + scale_factor * (1 - grad_i)
    dy_gamma = dy_gamma.view(-1)
    # weighting factor
    wf = dy_gamma / gamma_b
    weights = wf * (1 - pt) ** dy_gamma

    efl = weights * ce_loss

    if reduction == "sum":
        efl = efl.sum()
    elif reduction == "mean":
        efl = efl.mean()
    else:
        raise ValueError(f"reduction '{reduction}' is not valid")
    return efl


def balanced_equalized_focal_loss(logits,
                                  targets,
                                  alpha_t=0.25,
                                  gamma_b=2,
                                  scale_factor=8,
                                  reduction="mean"):
    """balanced EFL loss"""
    return alpha_t * equalized_focal_loss(logits, targets, gamma_b,
                                          scale_factor, reduction)

### IB Loss

In [15]:
def ib_loss(input_values, ib):
    """Computes the focal loss"""
    loss = input_values * ib
    return loss.mean()

class IBLoss(nn.Module):
    def __init__(self, weight=None, alpha=10000.):
        super(IBLoss, self).__init__()
        assert alpha > 0
        self.alpha = alpha
        self.epsilon = 0.001
        self.weight = weight

    def forward(self, outputs, target):
        input, features = outputs
        grads = torch.sum(torch.abs(F.softmax(input, dim=1) - F.one_hot(target, NUM_CLASS)),1) # N * 1
        ib = grads*features.reshape(-1)
        ib = self.alpha / (ib + self.epsilon)
        return ib_loss(F.cross_entropy(input, target, reduction='none', weight=self.weight), ib)


def ib_focal_loss(input_values, ib, gamma):
    """Computes the ib focal loss"""
    p = torch.exp(-input_values)
    loss = (1 - p) ** gamma * input_values * ib
    return loss.mean()

class IB_FocalLoss(nn.Module):
    def __init__(self, weight=None, alpha=10000., gamma=0.):
        super(IB_FocalLoss, self).__init__()
        assert alpha > 0
        self.alpha = alpha
        self.epsilon = 0.001
        self.weight = weight
        self.gamma = gamma

    def forward(self, outputs, target):
        input, features = outputs
        grads = torch.sum(torch.abs(F.softmax(input, dim=1) - F.one_hot(target, NUM_CLASS)),1) # N * 1
        ib = grads*(features.reshape(-1))
        ib = self.alpha / (ib + self.epsilon)
        return ib_focal_loss(F.cross_entropy(input, target, reduction='none', weight=self.weight), ib, self.gamma) # weight=self.weight
    
    
class IB_CosineFocalLoss(nn.Module):
    def __init__(self, weight=None, alpha=10000., gamma=0.):
        super(IB_CosineFocalLoss, self).__init__()
        assert alpha > 0
        self.alpha = alpha
        self.epsilon = 0.001
        self.weight = weight
        self.gamma = gamma

    def forward(self, outputs, target):
        input, features = outputs
        grads = torch.sum(torch.abs(F.softmax(input, dim=1) - F.one_hot(target, NUM_CLASS)),1) # N * 1
        ib = grads*(features.reshape(-1))
        ib = self.alpha / (ib + self.epsilon)
        cosine_ce = CosineCrossEntropyLoss(reduction='none', weight=self.weight).cuda()
        return ib_focal_loss(cosine_ce(input, target), ib, self.gamma)

### Smooth Crossentropy Loss

In [16]:
class SmoothCrossentropy(nn.Module):
    def __init__(self, smoothing=0.1):
        super(SmoothCrossentropy, self).__init__()
        self.smoothing=smoothing

    def forward(self, pred, gold):
        n_class = pred.size(1)

        one_hot = torch.full_like(pred, fill_value=self.smoothing / (n_class - 1))
        one_hot.scatter_(dim=1, index=gold.unsqueeze(1), value=1.0 - self.smoothing)
        log_prob = F.log_softmax(pred, dim=1)

        return F.kl_div(input=log_prob, target=one_hot, reduction='none').sum(-1).mean()

### Blanced Softmax Loss

In [17]:
class BalancedSoftmax(nn.Module):
    """
    Balanced Softmax Loss
    """
    def __init__(self):
        super(BalancedSoftmax, self).__init__()

    def forward(self, input, label, reduction='mean'):
        return balanced_softmax_loss(input, label, reduction)


def balanced_softmax_loss(logits, labels, reduction="mean", weight=None):
    spc = torch.cuda.FloatTensor(SAMPLES_PER_CLS)
    spc = spc.unsqueeze(0).expand(logits.shape[0], -1)
    logits = logits + spc.log()
    loss = F.cross_entropy(input=logits, target=labels, reduction=reduction, label_smoothing=0.1)
    return loss

### ASL loss

In [18]:
class ASLSingleLabel(nn.Module):
    '''
    This loss is intended for single-label classification problems
    '''
    def __init__(self, gamma_pos=0, gamma_neg=4, eps=0.1, reduction='mean'):
        super(ASLSingleLabel, self).__init__()

        self.eps = eps
        self.logsoftmax = nn.LogSoftmax(dim=-1)
        self.targets_classes = []
        self.gamma_pos = gamma_pos
        self.gamma_neg = gamma_neg
        self.reduction = reduction

    def forward(self, inputs, target):
        '''
        "input" dimensions: - (batch_size,number_classes)
        "target" dimensions: - (batch_size)
        '''
        num_classes = inputs.size()[-1]
        log_preds = self.logsoftmax(inputs)
        self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)

        # ASL weights
        targets = self.targets_classes
        anti_targets = 1 - targets
        xs_pos = torch.exp(log_preds)
        xs_neg = 1 - xs_pos
        xs_pos = xs_pos * targets
        xs_neg = xs_neg * anti_targets
        asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
                                 self.gamma_pos * targets + self.gamma_neg * anti_targets)
        log_preds = log_preds * asymmetric_w

        if self.eps > 0:  # label smoothing
            self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes)

        # loss calculation
        loss = - self.targets_classes.mul(log_preds)

        loss = loss.sum(dim=-1)
        if self.reduction == 'mean':
            loss = loss.mean()

        return loss

### Ldam Loss

In [19]:
class LDAMLoss(nn.Module):
    
    def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30):
        super(LDAMLoss, self).__init__()
        m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list))
        m_list = m_list * (max_m / np.max(m_list))
        m_list = torch.cuda.FloatTensor(m_list)
        self.m_list = m_list
        assert s > 0
        self.s = s
        self.weight = weight

    def forward(self, x, target):
        index = torch.zeros_like(x, dtype=torch.uint8)
        index.scatter_(1, target.data.view(-1, 1), 1)
        
        index_float = index.type(torch.cuda.FloatTensor)
        batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0,1))
        batch_m = batch_m.view((-1, 1))
        x_m = x - batch_m
    
        output = torch.where(index, x_m, x)
        return F.cross_entropy(self.s*output, target, weight=self.weight)

### Diverse Expert loss

In [20]:
# class DiverseExpertLoss(nn.Module):
#     def __init__(self, cls_num_list=None, tau=5):
#         super().__init__()
#         self.base_loss = F.cross_entropy
#         prior = np.array(cls_num_list) / np.sum(cls_num_list)
#         self.prior = torch.tensor(prior).float().cuda()
        
#         self.tau = tau 
#         self.cls_num_list = cls_num_list

#     def inverse_prior(self, prior): 
#         value, idx0 = torch.sort(prior)
#         _, idx1 = torch.sort(idx0)
#         idx2 = prior.shape[0]-1-idx1 # reverse the order
#         inverse_prior = value.index_select(0,idx2)
        
#         return inverse_prior

#     def forward(self, output_logits, target):
#         loss = 0
        
#         # Obtain logits from each expert  
#         expert1_logits = output_logits["base"]
#         # expert2_logits = output_logits["balance"]
#         expert3_logits = output_logits["inverse"]
 
#         # Softmax loss for expert 1 
#         # loss += LDAMLoss(cls_num_list=self.cls_num_list, weight=prior)(expert1_logits, target)
#         loss += self.base_loss(expert1_logits, target)
        
#         # Balanced Softmax loss for expert 2 
#         # expert2_logits = expert2_logits + torch.log(self.prior + 1e-9) 
#         # loss += LDAMLoss(cls_num_list=self.cls_num_list)(expert2_logits, target)
#         # loss += self.base_loss(expert2_logits, target)
        
#         # Inverse Softmax loss for expert 3
#         inverse_prior = self.inverse_prior(self.prior)
#         expert3_logits = expert3_logits + torch.log(self.prior + 1e-9) - self.tau * torch.log(inverse_prior+ 1e-9)
#         # loss += LDAMLoss(cls_num_list=self.cls_num_list)(expert3_logits, target)
#         loss += self.base_loss(expert3_logits, target)
   
#         return loss

### Wing Loss

In [21]:
class WingLoss(nn.Module):
    def __init__(self, omega=10, epsilon=2):
        super(WingLoss, self).__init__()
        self.omega = omega
        self.epsilon = epsilon

    def forward(self, pred, target):
        y = target
        y_hat = pred
        delta_y = (y - y_hat).abs()
        delta_y1 = delta_y[delta_y < self.omega]
        delta_y2 = delta_y[delta_y >= self.omega]
        loss1 = self.omega * torch.log(1 + delta_y1 / self.epsilon)
        C = self.omega - self.omega * math.log(1 + self.omega / self.epsilon)
        loss2 = delta_y2 - C
        return (loss1.sum() + loss2.sum()) / (len(loss1) + len(loss2))
    

class AdaptiveWingLoss(nn.Module):
    def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1):
        super(AdaptiveWingLoss, self).__init__()
        self.omega = omega
        self.theta = theta
        self.epsilon = epsilon
        self.alpha = alpha

    def forward(self, pred, target):
        '''
        :param pred: BxNxHxH
        :param target: BxNxHxH
        :return:
        '''

        y = target
        y_hat = pred
        delta_y = (y - y_hat).abs()
        delta_y1 = delta_y[delta_y < self.theta]
        delta_y2 = delta_y[delta_y >= self.theta]
        y1 = y[delta_y < self.theta]
        y2 = y[delta_y >= self.theta]
        loss1 = self.omega * torch.log(1 + torch.pow(delta_y1 / self.omega, self.alpha - y1))
        A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))) * (self.alpha - y2) * (
            torch.pow(self.theta / self.epsilon, self.alpha - y2 - 1)) * (1 / self.epsilon)
        C = self.theta * A - self.omega * torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))
        loss2 = A * delta_y2 - C
        return (loss1.sum() + loss2.sum()) / (len(loss1) + len(loss2))


class DiverseExpertLoss(nn.Module):
    def __init__(self, cls_num_list=None, tau=5):
        super().__init__()
        self.base_loss = WingLoss().cuda()

    def forward(self, output_logits, target):
        loss = 0
        
        # Obtain logits from each expert  
        expert1_logits = output_logits["image"].squeeze(1)
        expert2_logits = output_logits["meta"].squeeze(1)
 
        # wing loss for expert 1
        loss += self.base_loss(expert1_logits, target)
        
        # wing loss for expert 2
        loss += self.base_loss(expert2_logits, target)
        return loss

## Optimizer

### SAM

In [22]:
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, scaler=None, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        if scaler:
            scaler.step(self.base_optimizer)
            scaler.update()
        else:
            self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm
    
    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups
        

def disable_running_stats(model):
    def _disable(module):
        if isinstance(module, nn.BatchNorm2d):
            module.backup_momentum = module.momentum
            module.momentum = 0

    model.apply(_disable)

def enable_running_stats(model):
    def _enable(module):
        if isinstance(module, nn.BatchNorm2d) and hasattr(module, "backup_momentum"):
            module.momentum = module.backup_momentum

    model.apply(_enable)

### CosineAnnealingWarmUpRestarts

In [23]:
class CosineAnnealingWarmUpRestarts(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1):
        if T_0 <= 0 or not isinstance(T_0, int):
            raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
        if T_mult < 1 or not isinstance(T_mult, int):
            raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
        if T_up < 0 or not isinstance(T_up, int):
            raise ValueError("Expected positive integer T_up, but got {}".format(T_up))
        self.T_0 = T_0
        self.T_mult = T_mult
        self.base_eta_max = eta_max
        self.eta_max = eta_max
        self.T_up = T_up
        self.T_i = T_0
        self.gamma = gamma
        self.cycle = 0
        self.T_cur = last_epoch
        super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch)
    
    def get_lr(self):
        if self.T_cur == -1:
            return self.base_lrs
        elif self.T_cur < self.T_up:
            return [(self.eta_max - base_lr)*self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs]
        else:
            return [base_lr + (self.eta_max - base_lr) * (1 + math.cos(math.pi * (self.T_cur-self.T_up) / (self.T_i - self.T_up))) / 2
                    for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
            self.T_cur = self.T_cur + 1
            if self.T_cur >= self.T_i:
                self.cycle += 1
                self.T_cur = self.T_cur - self.T_i
                self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up
        else:
            if epoch >= self.T_0:
                if self.T_mult == 1:
                    self.T_cur = epoch % self.T_0
                    self.cycle = epoch // self.T_0
                else:
                    n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
                    self.cycle = n
                    self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
                    self.T_i = self.T_0 * self.T_mult ** (n)
            else:
                self.T_i = self.T_0
                self.T_cur = epoch
                
        self.eta_max = self.base_eta_max * (self.gamma**self.cycle)
        self.last_epoch = math.floor(epoch)
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr
        print("learning rate:", lr)
        neptune_run["train/learning_rate"].log(lr)

## AverageMeter

In [24]:
class AverageMeter(object):
    
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

# Function

In [25]:
def score_function(real, pred):
    score = f1_score(real, pred, average="macro")
    return score
    
def calc_loss(z, j):
    squared_error = torch.sum(z**2, (1, 2, 3)) / 2
    jacob = torch.sum(j, (1, 2, 3))
    return squared_error - jacob

def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = LR/2 * 0.01
    if epoch >= int(EPOCHS*0.9):
        lr = LR/2 * 0.000001
    elif epoch >= int(EPOCHS*0.8):
        lr = LR/2 * 0.0001
    logger.info("learing rate:", lr)
    neptune_run["train/learning_rate"].log(lr)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        
def cdb_adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    if epoch >= int(EPOCHS*0.9):
        lr = LR/3 * 0.0001
    elif epoch >= int(EPOCHS*0.8):
        lr = LR/3 * 0.01
    logger.info("learing rate:", lr)
    neptune_run["train/learning_rate"].log(lr)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        
def mixup_data(x, y, alpha=1.0):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).cuda()
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mix_criterion(cr, pred, y_a, y_b, lam):
    if cr == "ib":
        return lam * criterion_ib(pred, y_a) + (1 - lam) * criterion_ib(pred, y_b)
    elif cr == "cr":
        return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def get_transform():
    transform = A.Compose([
        A.Crop(x_min=408, y_min=0, x_max=2872, y_max=2464, p=1),
        A.Resize(params["train_image_size"][0], params["train_image_size"][1], interpolation=cv2.INTER_AREA),
        # A.Resize(params["train_image_size"][0]+int(params["train_image_size"][0]*0.2), params["train_image_size"][1]+int(params["train_image_size"][1]*0.2), interpolation=cv2.INTER_AREA),
        # A.RandomCrop(params["train_image_size"][0], params["train_image_size"][1], p=1),
        # A.RandomRotate90(p=0.5),
        # A.HorizontalFlip(p=0.5),
        # A.VerticalFlip(p=0.5),
        # A.Affine(rotate=(-45, 45), translate_percent=(-0.1, 0.1), p=0.5),
        # A.Cutout(num_holes=16, max_h_size=16, max_w_size=16, p=0.5),
        # A.GaussNoise(p=0.5),
        # A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
        A.Normalize(
            # mean = [0.4849, 0.5310, 0.5719],
            # std = [0.5511, 0.4834, 0.5929],
            mean = [0.5, 0.5, 0.5],
            std = [0.5, 0.5, 0.5],
        ),
        ToTensorV2()
    ])
    return transform

def train_transform(image):
    transform = get_transform()
    return transform(image=image)

def test_transform(image):
    transform = A.Compose([
                    A.Crop(x_min=408, y_min=0, x_max=2872, y_max=2464, p=1),
                    A.Resize(params["test_image_size"][0], params["test_image_size"][1], interpolation=cv2.INTER_AREA),
                    A.Normalize(
                        # mean = [0.4766, 0.5207, 0.6369],
                        # std = [0.4530, 0.4062, 0.4842]
                        mean = [0.5, 0.5, 0.5],
                        std = [0.5, 0.5, 0.5],
                        # mean = [0.4849, 0.5310, 0.5719],
                        # std = [0.5511, 0.4834, 0.5929],
                    ),
                    ToTensorV2(),
                ])
    return transform(image=image)

def cutmix(inputs, targets, beta=1.0, prob=1.0):
    r = np.random.rand(1)[0]
    if beta > 0 and r < prob:
        # generate mixed sample
        lam = np.random.beta(beta, beta)
        rand_index = torch.randperm(inputs.size()[0]).cuda()
        targets_a = targets
        targets_b = targets[rand_index]
        bbx1, bby1, bbx2, bby2 = rand_bbox(inputs.size(), lam)
        inputs[:, :, bbx1:bbx2, bby1:bby2] = inputs[rand_index, :, bbx1:bbx2, bby1:bby2]
        # adjust lambda to exactly match pixel ratio
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (inputs.size()[-1] * inputs.size()[-2]))
        return inputs, targets_a, targets_b, lam
        
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def get_nmae(outputs, targets):
    mae = np.mean(np.abs(targets-outputs))
    score = mae / np.mean(np.abs(targets))
    return score

# Train Data

In [26]:
def get_train_data(data_dir):
    data_list = []
    for case_name in os.listdir(data_dir):
        current_path = os.path.join(data_dir, case_name)
        label_df = pd.read_csv(current_path+'/label.csv')
        for img_name, leaf_weight in zip(label_df["img_name"], label_df["leaf_weight"]):
            data_list.append((os.path.join(current_path, "image", img_name), os.path.join(current_path, "meta", f"{os.path.splitext(img_name)[0]}.csv"), leaf_weight))
    return data_list

def get_test_data(data_dir):
    # get image path
    img_path_list = glob(os.path.join(data_dir, 'image', '*.jpg'))
    img_path_list.extend(glob(os.path.join(data_dir, 'image', '*.png')))
    img_path_list.sort(key=lambda x:int(x.split('/')[-1].split('.')[0]))
    
    # get meta path
    meta_path_list = glob(os.path.join(data_dir, 'meta', '*.csv'))
    meta_path_list.sort(key=lambda x:int(x.split('/')[-1].split('.')[0]))
    return (img_path_list, meta_path_list)

def get_scaler(data_list):
    df = pd.DataFrame()
    for data in data_list:
        df = pd.concat([df, pd.read_csv(data[1]).dropna()])
    df.drop(["시간"], axis=1, inplace=True)
    scaler = getattr(preprocessing, params["scaler"])()
    scaler = scaler.fit(df)
    imputer = SimpleImputer(strategy='median')
    imputer = imputer.fit(df)
    return scaler, imputer


train_data_list = get_train_data(TRAIN_PATH)
scaler, imputer = get_scaler(train_data_list)

# train_len = int(len(train_data_list)*0.8)
# val_data_list = train_data_list[train_len:]
# train_data_list = train_data_list[:train_len]

# train_dataset = CustomDataset(train_data_list, mode="train", transform=train_transform)
# train_loader = DataLoader(
#                     train_dataset,
#                     batch_size=params["batch"],
#                     shuffle=True,
#                     pin_memory=True,
#                     num_workers=params["num_workers"])

# val_dataset = CustomDataset(val_data_list, mode="train", transform=test_transform)
# val_loader = DataLoader(
#                     val_dataset,
#                     batch_size=params["batch"],
#                     shuffle=True,
#                     pin_memory=True,
#                     num_workers=params["num_workers"])



# Get strong mean, std
# strong_mean_std = OnlineMeanStd()
# print("train:", strong_mean_std(train_dataset, 1, "strong", "train"))
# print("test:", strong_mean_std(test_dataset, 1, "strong", "test"))
neptune_run["transform"] = str(get_transform())

# Visualize

In [27]:
# def visualize(inputs, targets):
#     # 이미지 정규화 해제하기
#     inputs = inputs.moveaxis(0, -1)
#     mean = np.array([0.4849, 0.5310, 0.5719])
#     std = np.array([0.5511, 0.4834, 0.5929])
#     inputs = inputs * std + mean
#     inputs = np.clip(inputs, 0, 1)

#     # 이미지 출력
#     fig = plt.figure()
#     plt.axis("off")
#     plt.imshow(inputs, rasterized=True)
#     plt.title(targets)
#     plt.show()
#     return fig
    
# for batch in train_loader:
#     images, metas, labels = batch[0], batch[1], batch[2]
#     for image, meta, label in zip(images, metas, labels):
#     # for b in batch:
#         # visualize(b, None)
#         fig = visualize(image, label.item())
#         neptune_run["preprocessed_image"].log(fig)
#     break

# Cutmix
# for batch in train_loader:
#     inputs, targets = batch[0], batch[2]
#     inputs, targets_a, targets_b, lam = cutmix(inputs, targets)
#     for i, a, b in zip(inputs, targets_a, targets_b):
#         visualize(i, f"a:{a} / b:{b} / l: l")
#     break

# Model

In [28]:
class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()
        self.rnn = nn.GRU(
            input_size=params["lstm_input_size"],
            hidden_size=params["lstm_hidden_size"],
            num_layers=params["lstm_num_layers"],
            batch_first=params["lstm_batch_first"],
            bidirectional=params["lstm_bidirectional"],
            dropout=params["drop_out_rate"],
        )

    def forward(self, x): 
        h0 = torch.zeros(params["lstm_num_layers"]*2 if params["lstm_bidirectional"] else params["lstm_num_layers"], x.size(0), params["lstm_hidden_size"]).to(DEVICE)
        # c0 = torch.zeros(params["lstm_num_layers"]*2 if params["lstm_bidirectional"] else params["lstm_num_layers"], x.size(0), params["lstm_hidden_size"]).to(DEVICE)
        out, _ = self.rnn(x, h0)#(h0, c0))
        return out[:, -1]

    
class Network(nn.Module):
    def __init__(self, mode="train"):
        super(Network, self).__init__()
        self.mode = mode
        if self.mode == "train":
            self.drop_out_rate = params["drop_out_rate"]
            self.drop_path_rate = params["drop_path_rate"]
        elif self.mode == "test":
            self.drop_out_rate = 0
            self.drop_path_rate = 0
        
        if "tresnet" in params["model_name"]:
            self.backbone = timm.create_model(params["model_name"],
                                              pretrained=True,
                                              num_classes=0,
                                              drop_rate=self.drop_out_rate,)
        else:
            self.backbone = timm.create_model(params["model_name"],
                                              pretrained=True,
                                              num_classes=0,
                                              drop_rate=self.drop_out_rate,
                                              drop_path_rate=self.drop_path_rate)
        if "swin" in params["model_name"] or "beit" in params["model_name"]:
            self.classifier = nn.Sequential(
                nn.Linear(self.backbone.norm.normalized_shape[0]+(params["lstm_hidden_size"]*2 if params["lstm_bidirectional"] else params["lstm_hidden_size"]), params["num_class"]),
                nn.ReLU(inplace=True)
            )
        elif "efficient" in params["model_name"]:
            self.classifier = nn.Sequential(
                    nn.Linear(self.backbone.conv_head.out_channels+1440, params["num_class"]),
                    nn.ReLU(inplace=True)
                )
        elif "tresnet" in params["model_name"]:
            self.classifier = nn.Sequential(
                nn.Linear(2432 + 1440, params["num_class"]),
                nn.ReLU(inplace=True)
            )
        # self.rnn = RNN()
        
    def forward(self, x, meta):
        if self.mode == "train":
            # meta = self.rnn(meta)
            meta = torch.mean(meta, dim=2)
            x = self.backbone(x) # torch.Size([32, 1280]) torch.Size([32, 1440])
            x = torch.cat([x, meta], 1) # 2720
            x = self.classifier(x)
            return x
        
        elif self.mode == "test":
            # meta = self.rnn(meta)
            meta = torch.mean(meta, dim=2)
            x = self.backbone(x)
            x = torch.cat([x, meta], 1)
            x = self.classifier(x)
            return x


# class NormedLinear(nn.Module):

#     def __init__(self, in_features, out_features):
#         super(NormedLinear, self).__init__()
#         self.weight = Parameter(torch.Tensor(in_features, out_features))
#         self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)

#     def forward(self, x):
#         out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0))
#         return out
    

class NormedLinear(nn.Linear):
    """Normalized Linear Layer.
    Args:
        tempeature (float, optional): Tempeature term. Default to 20.
        power (int, optional): Power term. Default to 1.0.
        eps (float, optional): The minimal value of divisor to
             keep numerical stability. Default to 1e-6.
    """

    def __init__(self, *args, tempearture=20, power=1.0, eps=1e-6, **kwargs):
        super(NormedLinear, self).__init__(*args, **kwargs)
        self.tempearture = tempearture
        self.power = power
        self.eps = eps
        self.init_weights()

    def init_weights(self):
        nn.init.normal_(self.weight, mean=0, std=0.01)
        if self.bias is not None:
            nn.init.constant_(self.bias, 0)

    def forward(self, x):
        weight_ = self.weight / (
            self.weight.norm(dim=1, keepdim=True).pow(self.power) + self.eps)
        x_ = x / (x.norm(dim=1, keepdim=True).pow(self.power) + self.eps)
        x_ = x_ * self.tempearture

        return F.linear(x_, weight_, self.bias)
    
    
# model = Network()
# model
# # logger.info(model(torch.randn(1, 3, 224, 224)))
# logger.info(model)
# model.to(DEVICE)
# logger.info(f"model {MODEL_NAME} create!")

# model
# from pprint import pprint
# model_names = timm.list_models("*hrnet*", pretrained=False)
# pprint(model_names)
# z

## Dyhead

In [29]:
# # # from pprint import pprint
# # # model_names = timm.list_models("*swin*", pretrained=True)
# # # pprint(model_names)
# from collections import OrderedDict

# class Network(nn.Module):
#     def __init__(self, mode="train"):
#         super(Network, self).__init__()
#         out_featrues = 256
#         self.mode = mode
#         # self.backbone = timm.create_model(MODEL_NAME, pretrained=True, num_classes=NUM_CLASS, drop_path_rate=DROP_PATH_RATIO)#num_classes=0, global_pool="")#, drop_path_rate=0.2)
#         self.backbone = timm.create_model(MODEL_NAME, pretrained=True, features_only=True, drop_path_rate=DROP_PATH_RATIO)
#         self.fpn = FeaturePyramidNetwork([48, 64, 160, 256], out_featrues)
#         self.concat_fpn = concat_feature_maps()
#         # self.scale_layer = Scale_Aware_Layer
#         # self.spatial_layer = Spatial_Aware_Layer
#         # self.task_layer = Task_Aware_Layer
#         # L: 4 S: 441 C: 256
#         self.dyhead = DyHead(num_blocks=6, L=4, S=784, C=256)
#         self.avg_pool = nn.AdaptiveAvgPool2d(1)
#         self.classifier = nn.Linear(256, NUM_CLASS)
#         # self.ml_decoder = MLDecoder(NUM_CLASS, initial_num_features=1280)
#         # for _, param in self.backbone.named_parameters():
#         #     param.requires_grad = False

#     def forward(self, x):
#         x = self.backbone(x)[1:]

#         x_dic = OrderedDict()
#         for i in range(len(x)):
#             x_dic[f"feat{i}"] = x[i]

#         x = self.fpn(x_dic)
#         x = self.concat_fpn(x)
#         x = self.dyhead(x)
        
#         F_tensor = x.permute(0, 3, 1, 2)
#         kernel_size = F_tensor.shape[2:] # Getting HxW of F
#         gap_output = F.avg_pool2d(F_tensor, kernel_size)

#         # Flattening gap_output from (batch_size, C, 1, 1) to (batch_size, C)
#         gap_output = gap_output.flatten(start_dim=1)
        
#         # x = x.transpose(dim0=1, dim1=-1)
#         # x = self.avg_pool(x)
#         # x = x.flatten(1)
#         x = self.classifier(gap_output)
#         return x
#         # feats = []
#         # for k, v in x.items():
#         #     v = self.avg_pool(v)
#         #     feats.append(v.flatten(1))
#         # feats = torch.cat(feats, 1)
        
#         # x = self.classifier(x)
#         # feats = self.backbone(x)
#         # x = self.classifier(feats)
#         # x = self.ml_decoder(x)
#         # return x
#         # if self.mode == "train":
#         #     return x, torch.sum(torch.abs(feats), 1).reshape(-1, 1)
#         # else:

        
# class concat_feature_maps(nn.Module):
#     def __init__(self):
#         super(concat_feature_maps, self).__init__()

#     def forward(self, fpn_output):
#         # Calculating median height to upsample or desample each fpn levels
#         heights = []
#         level_tensors = []
#         for key, values in fpn_output.items():
#             heights.append(values.shape[2])
#             level_tensors.append(values)
#         median_height = int(np.median(heights))

#         # Upsample and Desampling tensors to median height and width
#         for i in range(len(level_tensors)):
#             level = level_tensors[i]
#             # If level height is greater than median, then downsample with interpolate
#             if level.shape[2] > median_height:
#                 level = F.interpolate(input=level, size=(median_height, median_height),mode='nearest')
#             # If level height is less than median, then upsample
#             else:
#                 level = F.interpolate(input=level, size=(median_height, median_height), mode='nearest')
#             level_tensors[i] = level
        
#         # Concating all levels with dimensions (batch_size, levels, C, H, W)
#         concat_levels = torch.stack(level_tensors, dim=1)

#         # Reshaping tensor from (batch_size, levels, C, H, W) to (batch_size, levels, HxW=S, C)
#         concat_levels = concat_levels.flatten(start_dim=3).transpose(dim0=2, dim1=3)
#         return concat_levels


# class Scale_Aware_Layer(nn.Module):
#     # Constructor
#     def __init__(self, s_size):
#         super(Scale_Aware_Layer, self).__init__()

#         # Average Pooling
#         self.avg_layer = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
        
#         #1x1 Conv layer
#         self.conv = nn.Conv2d(in_channels=s_size, out_channels=1, kernel_size=1)

#         # Hard Sigmoid
#         self.hard_sigmoid = nn.Hardsigmoid()

#         # ReLU function
#         self.relu = nn.ReLU()

#     def forward(self, F):

#         # Transposing input from (batch_size, L, S, C) to (batch_size, S, L, C) so we can use convolutional layer over the level dimension L
#         x = F.transpose(dim0=2, dim1=1)

#         # Passing tensor through avg pool layer
#         x = self.avg_layer(x)

#         # Passing tensor through Conv layer
#         x = self.conv(x)
        
#         # Reshaping Tensor from (batch_size, 1, L, C) to (batch_size, L, 1, C) to then be multiplied to F
#         x = x.transpose(dim0=1, dim1=2)

#         # Passing conv output to relu
#         x = self.relu(x)

#         # Passing tensor to hard sigmoid function
#         pi_L = self.hard_sigmoid(x)

#         # pi_L: (batch_size, L, 1, C)
#         # F: (batch_size, L, S, C)
#         return pi_L * F

    
# class Spatial_Aware_Layer(nn.Module):
#     # Constructor
#     def __init__(self, L_size, kernel_height=3, kernel_width=3, padding=1, stride=1, dilation=1, groups=1):
#         super(Spatial_Aware_Layer, self).__init__()

#         self.in_channels = L_size
#         self.out_channels = L_size

#         self.kernel_size = (kernel_height, kernel_width)
#         self.padding = padding
#         self.stride = stride
#         self.dilation = dilation
#         self.K = kernel_height * kernel_width
#         self.groups = groups

#         # 3x3 Convolution with 3K out_channel output as described in Deform Conv2 paper
#         self.offset_and_mask_conv = nn.Conv2d(in_channels=self.in_channels,
#                                               out_channels=3*self.K, #3K depth
#                                               kernel_size=self.kernel_size,
#                                               stride=self.stride,
#                                               padding=self.padding,
#                                               dilation=dilation)
        
#         self.deform_conv = DeformConv2d(in_channels=self.in_channels,
#                                         out_channels=self.out_channels,
#                                         kernel_size=self.kernel_size,
#                                         stride=self.stride,
#                                         padding=self.padding,
#                                         dilation=self.dilation,
#                                         groups=self.groups)
#     def forward(self, F):
#         # Generating offesets and masks (or modulators) for convolution operation
#         offsets_and_masks = self.offset_and_mask_conv(F)

#         # Separating offsets and masks as described in Deform Conv v2 paper
#         offset = offsets_and_masks[:, :2*self.K, :, :] # First 2K channels 
#         mask = torch.sigmoid(offsets_and_masks[:, 2*self.K:, : , :]) # Last 1K channels and passing it through sigmoid

#         # Passing offsets, masks, and F into deform conv layer
#         spacial_output = self.deform_conv(F, offset, mask)
#         return spacial_output

    
# # DyReLUA technique from Dynamic ReLU paper
# class DyReLUA(nn.Module):
#     def __init__(self, channels, reduction=8, k=2, lambdas=None, init_values=None):
#         super(DyReLUA, self).__init__()

#         self.fc1 = nn.Linear(channels, channels // reduction)
#         self.fc2 = nn.Linear(channels//reduction, 2*k)
#         self.relu = nn.ReLU(inplace=True)
#         self.sigmoid = nn.Sigmoid()

#         # Defining lambdas in form of [La1, La2, Lb1, Lb2]
#         if lambdas is not None:
#             self.lambdas = lambdas
#         else:
#             # Default lambdas from DyReLU paper
#             self.lambdas = torch.cuda.HalfTensor([1.0, 1.0, 0.5, 0.5])

#         # Defining Initializing values in form of [alpha1, alpha2, Beta1, Beta2]
#         if lambdas is not None:
#             self.init_values = init_values
#         else:
#             # Default initializing values of DyReLU paper
#             self.init_values = torch.cuda.HalfTensor([1.0, 0.0, 0.0, 0.0])

#     def forward(self, F_tensor):

#         # Global Averaging F
#         kernel_size = F_tensor.shape[2:] # Getting HxW of F
#         gap_output = F.avg_pool2d(F_tensor, kernel_size)

#         # Flattening gap_output from (batch_size, C, 1, 1) to (batch_size, C)
#         gap_output = gap_output.flatten(start_dim=1)

#         # Passing Global Average output through Fully-Connected Layers
#         x = self.relu(self.fc1(gap_output))
#         x = self.fc2(x)
        
#         # Normalization between (-1, 1)
#         residuals = 2 * self.sigmoid(x) - 1

#         # Getting values of theta, and separating alphas and betas
#         theta = self.init_values + self.lambdas * residuals # Contains[alpha1(x), alpha2(x), Beta1(x), Beta2(x)]
#         alphas = theta[0, :2]
#         betas = theta[0, 2:]

#         # Performing maximum on both piecewise functions
#         output = torch.maximum((alphas[0] * F_tensor + betas[0]), (alphas[1] * F_tensor + betas[1]))

#         return output

    
# class Task_Aware_Layer(nn.Module):
#     # Defining constructor
#     def __init__(self, num_channels):
#         super(Task_Aware_Layer, self).__init__()

#         # DyReLUA relu
#         self.dynamic_relu = DyReLUA(num_channels)
    
#     def forward(self, F_tensor):
#         # Permutating F from (batch_size, L, S, C) to (batch_size, C, L, S) so we can reduce the dimensions over LxS
#         F_tensor = F_tensor.permute(0, 3, 1, 2)
#         output = self.dynamic_relu(F_tensor)
        
#         # Reversing the permutation
#         output = output.permute(0, 2, 3, 1)

#         return output


# class DyHead_Block(nn.Module):
#     def __init__(self, L, S, C):
#         super(DyHead_Block, self).__init__()
#         # Saving all dimension sizes of F
#         self.L_size = L
#         self.S_size = S
#         self.C_size = C

#         # Inititalizing all attention layers
#         self.scale_attention = Scale_Aware_Layer(s_size=self.S_size)
#         self.spatial_attention = Spatial_Aware_Layer(L_size=self.L_size)
#         self.task_attention = Task_Aware_Layer(num_channels=self.C_size)

#     def forward(self, F_tensor):
#         scale_output = self.scale_attention(F_tensor)
#         spacial_output = self.spatial_attention(scale_output)
#         task_output = self.task_attention(spacial_output)

#         return task_output

# def DyHead(num_blocks, L, S, C):
#     blocks = [('Block_{}'.format(i+1),DyHead_Block(L, S, C)) for i in range(num_blocks)]

#     return nn.Sequential(OrderedDict(blocks))


# model = Network()
# model.to(DEVICE)
# logger.info(f"model {MODEL_NAME} create!")
# # for i, (k, param) in enumerate(model.named_parameters()):
# #     if i < 644:
# #         param.requires_grad = False
# #     print(i, k, param.requires_grad)

## ML-Decoder

In [30]:
def add_ml_decoder_head(model, num_classes=-1, num_of_groups=-1, decoder_embedding=768, zsl=0):
    if num_classes == -1:
        num_classes = model.num_classes
    num_features = model.num_features
    if hasattr(model, 'global_pool') and hasattr(model, 'fc'):  # resnet50
        model.global_pool = nn.Identity()
        del model.fc
        model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features, num_of_groups=num_of_groups,
                             decoder_embedding=decoder_embedding, zsl=zsl)
    elif hasattr(model, 'head'):  # tresnet
        if hasattr(model, 'global_pool'):
            model.global_pool = nn.Identity()
        del model.head
        model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features, num_of_groups=num_of_groups,
                               decoder_embedding=decoder_embedding, zsl=zsl)
    else:
        print("model is not suited for ml-decoder")
        exit(-1)

    return model


class TransformerDecoderLayerOptimal(nn.Module):
    def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu",
                 layer_norm_eps=1e-5) -> None:
        super(TransformerDecoderLayerOptimal, self).__init__()
        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout = nn.Dropout(dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = torch.nn.functional.relu
        super(TransformerDecoderLayerOptimal, self).__setstate__(state)

    def forward(self, tgt, memory, tgt_mask = None,
                memory_mask = None,
                tgt_key_padding_mask = None,
                memory_key_padding_mask = None):
        tgt = tgt + self.dropout1(tgt)
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(tgt, memory, memory)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

class GroupFC(object):
    def __init__(self, embed_len_decoder: int):
        self.embed_len_decoder = embed_len_decoder

    def __call__(self, h, duplicate_pooling, out_extrap):
        for i in range(h.shape[1]):
            h_i = h[:, i, :]
            if len(duplicate_pooling.shape)==3:
                w_i = duplicate_pooling[i, :, :]
            else:
                w_i = duplicate_pooling
            out_extrap[:, i, :] = torch.matmul(h_i, w_i)


class MLDecoder(nn.Module):
    def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768,
                 initial_num_features=2048, zsl=0):
        super(MLDecoder, self).__init__()
        embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups
        if embed_len_decoder > num_classes:
            embed_len_decoder = num_classes

        # switching to 768 initial embeddings
        decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding
        embed_standart = nn.Linear(initial_num_features, decoder_embedding)

        # non-learnable queries
        if not zsl:
            query_embed = nn.Embedding(embed_len_decoder, decoder_embedding)
            query_embed.requires_grad_(False)
        else:
            query_embed = None

        # decoder
        decoder_dropout = 0.1
        num_layers_decoder = 1
        dim_feedforward = 2048
        layer_decode = TransformerDecoderLayerOptimal(d_model=decoder_embedding,
                                                      dim_feedforward=dim_feedforward, dropout=decoder_dropout)
        self.decoder = nn.TransformerDecoder(layer_decode, num_layers=num_layers_decoder)
        self.decoder.embed_standart = embed_standart
        self.decoder.query_embed = query_embed
        self.zsl = zsl

        if self.zsl:
            if decoder_embedding != 300:
                self.wordvec_proj = nn.Linear(300, decoder_embedding)
            else:
                self.wordvec_proj = nn.Identity()
            self.decoder.duplicate_pooling = torch.nn.Parameter(torch.Tensor(decoder_embedding, 1))
            self.decoder.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(1))
            self.decoder.duplicate_factor = 1
        else:
            # group fully-connected
            self.decoder.num_classes = num_classes
            self.decoder.duplicate_factor = int(num_classes / embed_len_decoder + 0.999)
            self.decoder.duplicate_pooling = torch.nn.Parameter(
                torch.Tensor(embed_len_decoder, decoder_embedding, self.decoder.duplicate_factor))
            self.decoder.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes))
        torch.nn.init.xavier_normal_(self.decoder.duplicate_pooling)
        torch.nn.init.constant_(self.decoder.duplicate_pooling_bias, 0)
        self.decoder.group_fc = GroupFC(embed_len_decoder)
        self.train_wordvecs = None
        self.test_wordvecs = None

    def forward(self, x):
        if len(x.shape) == 4:  # [bs,2048, 7,7]
            embedding_spatial = x.flatten(2).transpose(1, 2)
        else:  # [bs, 197,468]
            embedding_spatial = x
        embedding_spatial_786 = self.decoder.embed_standart(embedding_spatial)
        embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True)

        bs = embedding_spatial_786.shape[0]
        if self.zsl:
            query_embed = torch.nn.functional.relu(self.wordvec_proj(self.decoder.query_embed))
        else:
            query_embed = self.decoder.query_embed.weight
        # tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
        tgt = query_embed.unsqueeze(1).expand(-1, bs, -1)  # no allocation of memory with expand
        h = self.decoder(tgt, embedding_spatial_786.transpose(0, 1))  # [embed_len_decoder, batch, 768]
        h = h.transpose(0, 1)

        out_extrap = torch.zeros(h.shape[0], h.shape[1], self.decoder.duplicate_factor, device=h.device, dtype=h.dtype)
        self.decoder.group_fc(h, self.decoder.duplicate_pooling, out_extrap)
        if not self.zsl:
            h_out = out_extrap.flatten(1)[:, :self.decoder.num_classes]
        else:
            h_out = out_extrap.flatten(1)
        h_out += self.decoder.duplicate_pooling_bias
        logits = h_out
        return logits

# Optimizer & Criterion & Scheduler

In [31]:
# for i, (k, param) in enumerate(model.named_parameters()):
#     if i < 363:
#         param.requires_grad = False
#     print(i, k, param.requires_grad)
def set_optimizer():
    if "sam" in params["optimizer"]:
        if "adamw" in params["optimizer"]:
            base_optimizer = torch.optim.AdamW
        elif "sgd" in params["optimizer"]:
            base_optimizer = torch.optim.SGD
            optimizer = SAM(model.parameters(), base_optimizer, lr=params["learning_rate"], rho=2.0, adaptive=True, momentum=0.9, weight_decay=params["weight_decay"])
        elif "lamb" in params["optimizer"]:
            base_optimizer = optim.Lamb
            optimizer = SAM(model.parameters(), base_optimizer, lr=params["learning_rate"], rho=2.0, adaptive=True, betas=(0.9, 0.999), eps=1e-6, weight_decay=params["weight_decay"])
        if "cosineannealinglr" in params["scheduler"]:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer.base_optimizer, T_max=params["epochs"], verbose=True)
        elif "cosineannealingwarmuprestarts" in params["scheduler"]:
            scheduler = CosineAnnealingWarmUpRestarts(optimizer.base_optimizer, T_0=25, T_mult=1, eta_max=params["learning_rate"],  T_up=5, gamma=0.1)
    else:
        if "adamw" in params["optimizer"]:
            optimizer = torch.optim.AdamW(model.parameters(), lr=1e-30, weight_decay=params["weight_decay"])
        elif "sgd" in params["optimizer"]:
            optimizer = torch.optim.SGD(model.parameters(), lr=params["learning_rate"], weight_decay=params["weight_decay"], momentum=0.9)
        elif "lamb" in params["optimizer"]:
            optimizer = optim.Lamb(model.parameters(), lr=params["learning_rate"], betas=(0.9, 0.999), eps=1e-6, weight_decay=params["weight_decay"])
        if "cosineannealinglr" in params["scheduler"]:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=params["epochs"], verbose=True)
        elif "cosineannealingwarmuprestarts" in params["scheduler"]:
            scheduler = CosineAnnealingWarmUpRestarts(optimizer, T_0=25, T_mult=1, eta_max=params["learning_rate"],  T_up=0, gamma=0.1)

    if "ib" in params["criterion"]:
        per_cls_weights = 1.0 / np.array(SAMPLES_PER_CLS)
        per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(SAMPLES_PER_CLS)
        per_cls_weights = torch.FloatTensor(per_cls_weights).cuda()
        criterion = nn.CrossEntropyLoss().cuda()
        criterion_ib = IB_FocalLoss(weight=per_cls_weights, alpha=1000, gamma=1).cuda()
        # criterion_ib = IBLoss(weight=per_cls_weights, alpha=1000).cuda()
    elif "ce" in params["criterion"]:
        criterion = nn.CrossEntropyLoss().cuda()
    elif "sc" in params["criterion"]:
        criterion = SmoothCrossentropy().cuda()
    elif "de" in params["criterion"]:
        criterion = DiverseExpertLoss(SAMPLES_PER_CLS).cuda()
    elif "bs" in params["criterion"]:
        criterion = BalancedSoftmax().cuda()
    elif "l1" in params["criterion"]:
        criterion = nn.L1Loss().cuda()
    elif "huber" in params["criterion"]:
        criterion = nn.HuberLoss(delta=params["huber_delta"]).cuda()
    elif "cosine_similarity" in params["criterion"]:
        criterion = nn.CosineSimilarity().cuda()
    elif "wing" in params["criterion"]:
        criterion = WingLoss().cuda()
    elif "adaptivewing" in params["criterion"]:
        criterion = AdaptiveWingLoss().cuda()
    elif "expert" in params["criterion"]:
        criterion = DiverseExpertLoss().cuda()
    # elif "robust" in params["criterion"]:
    #     criterion = robust_loss_pytorch.adaptive.AdaptiveLossFunction(num_dims=1, float_dtype=np.float32, device=0)
    #     robust_params = list(model.parameters()) + list(criterion.parameters())
    #     base_optimizer = optim.Lamb
    #     optimizer = SAM(robust_params, base_optimizer, lr=params["learning_rate"], rho=2.0, adaptive=True, betas=(0.9, 0.999), eps=1e-6, weight_decay=params["weight_decay"])
    #     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer.base_optimizer, T_max=params["epochs"], verbose=True)

    grad_scaler = torch.cuda.amp.GradScaler()
    logger.info(f"model {params['model_name']} create!")
    return optimizer, criterion, scheduler, grad_scaler

# Train

## train

In [32]:
def train(model, optimizer, data_loader, epoch, epochs, criterion=None):
    start=time.time()
    batch_time = AverageMeter("Time", ":.0f")
    train_losses = AverageMeter("Loss", ":.4e")
    train_nmae = AverageMeter("Nmae", ":.5f")

    model.train()
    for batch in tqdm(data_loader): # set_postfix
        inputs = batch[0].float().cuda(non_blocking=True)
        metas = batch[1].float().cuda(non_blocking=True)
        targets = batch[2].float().cuda(non_blocking=True).squeeze()

        if params["mixup"] and epoch < params["mixup_end_epoch"]:
            inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, params["mixup_alpha"])

        elif params["cutmix"]:
            inputs, targets_a, targets_b, lam = cutmix(inputs, targets)

        if "sam" in params["optimizer"]:
            # first forward-backward step
            enable_running_stats(model)
            # with torch.cuda.amp.autocast():
            if "ib" in params["criterion"]:
                outputs, features = model(inputs)
            else:
                outputs = model(inputs, metas).squeeze()
            if (params["mixup"] and epoch < params["mixup_end_epoch"]) or params["cutmix"]:
                if "ib" in params["criterion"] and epoch > params["ib_start_epoch"]:
                    loss = mix_criterion("ib", (outputs, features), targets_a, targets_b, lam)
                else:
                    loss = mix_criterion("cr", outputs, targets_a, targets_b, lam)
            else:
                if "ib" in params["criterion"] and epoch > params["ib_start_epoch"]:
                    loss = criterion_ib((outputs, features), targets)
                else:
                    loss = criterion(outputs, targets)
                    # loss = torch.mean(criterion.lossfun((outputs - targets)[:,None]))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), params["max_norm"])
            optimizer.first_step(zero_grad=True)

            # second forward-backward step
            disable_running_stats(model)
            # with torch.cuda.amp.autocast():
            if (params["mixup"] and epoch < params["mix_end_epoch"]) or (params["cutmix"] and epoch < params["mix_end_epoch"]):
                if "ib" in params["criterion"]:
                    if epoch > params["ib_start_epoch"]:
                        second_loss = mix_criterion("ib", model(inputs), targets_a, targets_b, lam)
                    else:
                        second_loss = mix_criterion("cr", model(inputs)[0], targets_a, targets_b, lam)
                else:
                    second_loss = mix_criterion("cr", model(inputs), targets_a, targets_b, lam)
            else:
                if "ib" in params["criterion"]:
                    if epoch > params["ib_start_epoch"]:
                        second_loss = criterion_ib(model(inputs), targets)
                    else:
                        second_loss = criterion(model(inputs)[0], targets)
                else:
                    second_loss = criterion(model(inputs, metas).squeeze(1), targets)
                    # second_loss = torch.mean(criterion.lossfun((model(inputs, metas).squeeze(1) - targets)[:,None]))
            second_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), params["max_norm"])
            optimizer.second_step(zero_grad=True)

        else:
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                if "ib" in params["criterion"]:
                    outputs, features = model(inputs)
                else:
                    outputs = model(inputs, metas).squeeze()
                if (params["mixup"] and epoch < params["mix_end_epoch"]) or (params["cutmix"] and epoch < params["mix_end_epoch"]):
                    if "ib" in params["criterion"] and epoch > params["ib_start_epoch"]:
                        loss = mix_criterion("ib", (outputs, features), targets_a, targets_b, lam)
                    else:
                        loss = mix_criterion("cr", outputs, targets_a, targets_b, lam)
                else:
                    if "ib" in params["criterion"] and epoch > params["ib_start_epoch"]:
                        loss = criterion_ib((outputs, features), targets)
                    else:
                        loss = criterion(outputs, targets)
                        # loss = torch.mean(criterion.lossfun((outputs - targets)[:,None]))
            grad_scaler.scale(loss).backward()
            grad_scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), params["max_norm"])
            grad_scaler.step(optimizer)
            grad_scaler.update()

        # outputs = torch.stack([outputs["image"], outputs["meta"]], dim=1).mean(dim=1).squeeze(1)
        train_losses.update(loss.item(), inputs.size(0))
        train_nmae.update(get_nmae(outputs.detach().cpu().numpy(), targets.detach().cpu().numpy()), inputs.size(0))

    if "ib" in params["criterion"] and epoch >= params["ib_start_epoch"]:
        adjust_learning_rate(optimizer.base_optimizer, epoch)
    else:
        scheduler.step()

    batch_time.update(time.time() - start)
    train_log = f"epoch : {epoch+1}/{epochs} | time : {batch_time.val:.0f}s/{batch_time.val*(epochs-epoch-1):.0f}s | TRAIN | loss : {train_losses.avg:.5f} | nmae : {train_nmae.avg:.5f}"
    logger.info(train_log)
    return train_losses.avg, train_nmae.avg

## val

In [33]:
def val(model, optimizer, data_loader, epoch, epochs, criterion=None):
    start = time.time()
    val_batch_time = AverageMeter("Time", ":.0f")
    val_losses = AverageMeter("Loss", ":.4e")
    val_correct = AverageMeter("Acc", ":.5f")
    val_nmae = AverageMeter("Nmae", ":.5f")

    model.eval()
    with torch.no_grad():
        for batch in tqdm(data_loader):                              
            inputs = batch[0].float().cuda(non_blocking=True)
            metas = batch[1].float().cuda(non_blocking=True)
            targets = batch[2].long().cuda(non_blocking=True).squeeze()
            # with torch.cuda.amp.autocast():
            if "ib" in params["criterion"]:
                outputs, features = model(inputs)
            else:
                outputs = model(inputs, metas).squeeze()
            if "ib" in params["criterion"] and epoch > params["ib_start_epoch"]:
                loss = criterion_ib((outputs, features), targets)
            else:
                loss = criterion(outputs, targets)
                # loss = torch.mean(criterion.lossfun((outputs - targets)[:,None]))

            val_losses.update(loss.item(), inputs.size(0))
            val_nmae.update(get_nmae(outputs.detach().cpu().numpy(), targets.detach().cpu().numpy()), inputs.size(0))
    # torch.save(model.state_dict(), SAVE_PATH+params["model_name"]+"_last.pt")

    val_batch_time.update(time.time() - start)
    val_log = f"epoch : {epoch+1}/{epochs} | time : {val_batch_time.val:.0f}s/{val_batch_time.val*(epochs-epoch-1):.0f}s | VAL | loss : {val_losses.avg:.5f} | nmae : {val_nmae.avg:.5f}"
    logger.info(val_log)
    return val_losses.avg, val_nmae.avg

## Start Train

In [None]:
cv = KFold(n_splits=10, shuffle=True, random_state=params["seed"])

for idx, (train_idx, val_idx) in enumerate(cv.split(train_data_list)):
    logger.info(f"KFold5 Start!! Current -> {idx}")
    best_nmae = 9999
    train_dataset = np.array(train_data_list)[train_idx]
    val_dataset = np.array(train_data_list)[val_idx]

    train_dataset = CustomDataset(train_dataset, mode="train", transform=train_transform, scaler=scaler, imputer=imputer)
    train_loader = DataLoader(
                        train_dataset,
                        batch_size=params["batch"],
                        shuffle=True,
                        pin_memory=True,
                        num_workers=params["num_workers"])

    val_dataset = CustomDataset(val_dataset, mode="train", transform=test_transform, scaler=scaler, imputer=imputer)
    val_loader = DataLoader(
                        val_dataset,
                        batch_size=params["batch"],
                        shuffle=True,
                        pin_memory=True,
                        num_workers=params["num_workers"])
    
    model = Network()
    if idx == 0:
        neptune_run["model_classifier"].log(model.classifier)
    model.to(DEVICE)
    optimizer, criterion, scheduler, grad_scaler = set_optimizer()
    for epoch in range(params["epochs"]):
        train_losses, train_nmae = train(model=model, optimizer=optimizer, data_loader=train_loader, epoch=epoch, epochs=params["epochs"], criterion=criterion)
        val_losses, val_nmae = val(model=model, optimizer=optimizer, data_loader=val_loader, epoch=epoch, epochs=params["epochs"], criterion=criterion)
        
        neptune_run[f"KFold{idx}_train/loss"].log(train_losses)
        neptune_run[f"KFold{idx}_train/nmae"].log(train_nmae)
        neptune_run[f"KFold{idx}_val/loss"].log(val_losses)
        neptune_run[f"KFold{idx}_val/nmae"].log(val_nmae)
        if val_nmae <= best_nmae:
            best_nmae = val_nmae
            torch.save(model.state_dict(), SAVE_PATH+params["model_name"]+f"_best{idx}.pt")
            if best_nmae == 0:
                logger.info("best_loss is 0!!")
                break
        torch.save(model.state_dict(), SAVE_PATH+params["model_name"]+f"_last{idx}.pt")
    break

[2022-05-20 16:46:58] KFold5 Start!! Current -> 0
[2022-05-20 16:47:00] model tf_efficientnetv2_b0 create!


Adjusting learning rate of group 0 to 1.6000e-01.


  0%|          | 0/22 [00:00<?, ?it/s]

[2022-05-20 16:47:26] epoch : 1/25 | time : 26s/623s | TRAIN | loss : 73.73714 | nmae : 0.90999


Adjusting learning rate of group 0 to 1.5937e-01.


  0%|          | 0/3 [00:00<?, ?it/s]

[2022-05-20 16:47:33] epoch : 1/25 | time : 7s/159s | VAL | loss : 67.74374 | nmae : 0.90993


  0%|          | 0/22 [00:00<?, ?it/s]

[2022-05-20 16:47:58] epoch : 2/25 | time : 26s/588s | TRAIN | loss : 73.30623 | nmae : 0.90559


Adjusting learning rate of group 0 to 1.5749e-01.


  0%|          | 0/3 [00:00<?, ?it/s]

[2022-05-20 16:48:05] epoch : 2/25 | time : 6s/147s | VAL | loss : 63.73503 | nmae : 0.85627


  0%|          | 0/22 [00:00<?, ?it/s]

[2022-05-20 16:48:30] epoch : 3/25 | time : 26s/561s | TRAIN | loss : 67.15659 | nmae : 0.83030


Adjusting learning rate of group 0 to 1.5438e-01.


  0%|          | 0/3 [00:00<?, ?it/s]

[2022-05-20 16:48:37] epoch : 3/25 | time : 6s/143s | VAL | loss : 64.31225 | nmae : 0.86704


  0%|          | 0/22 [00:00<?, ?it/s]

[2022-05-20 16:49:02] epoch : 4/25 | time : 25s/535s | TRAIN | loss : 73.26136 | nmae : 0.90608


Adjusting learning rate of group 0 to 1.5010e-01.


  0%|          | 0/3 [00:00<?, ?it/s]

[2022-05-20 16:49:09] epoch : 4/25 | time : 6s/134s | VAL | loss : 62.61366 | nmae : 0.83400


  0%|          | 0/22 [00:00<?, ?it/s]

In [None]:
# z

# Test

## Test data

In [None]:
test_data_list = get_test_data(TEST_PATH)
test_dataset = CustomDataset(test_data_list, mode="test", transform=test_transform, scaler=scaler, imputer=imputer)
test_loader = DataLoader(
                    test_dataset,
                    batch_size=params["batch"]//2,
                    shuffle=False,
                    pin_memory=True,
                    num_workers=params["num_workers"])

## TTA

In [None]:
torch.cuda.empty_cache()

# tta_transform = tta.Compose(
#     [
#         # tta.HorizontalFlip(),
#         # tta.VerticalFlip(),
#         tta.Rotate90(angles=[0, 90, 180, 270]),
#         # tta.Scale(scales=[0.6, 0.8, 1]),
#         # tta.Multiply(factors=[0.9, 1, 1.1]),        
#     ]
# )

## Test predict

In [None]:
pred_ensemble = []
for i in range(1):
    model = Network(mode="test")
    model.load_state_dict(torch.load(SAVE_PATH+params["model_name"]+f"_best{i}.pt"))
    model.to(DEVICE)
    # tta_model = tta.ClassificationTTAWrapper(model, tta_transform)

    model.eval()
    f_pred = []

    with torch.no_grad():
        for batch in tqdm(test_loader):
            inputs = batch[0].float().cuda(non_blocking=True)
            metas = batch[1].float().cuda(non_blocking=True)
            with torch.cuda.amp.autocast():
                # outputs = tta_model(x)
                outputs = model(inputs, metas).squeeze()
            f_pred.extend(outputs.detach().cpu().numpy().tolist())

    # Pred Ensemble
    pred_ensemble.append(f_pred)

f_result = np.mean(pred_ensemble, axis=0)

# Submission
submission = pd.read_csv("data/sample_submission.csv")
submission["leaf_weight"] = f_result
submission.to_csv(SAVE_PATH+"submission.csv", index = False)
neptune_run.stop()