# Directory Setting

In [None]:
# ====================================================
# Directory settings
# ====================================================
import os


INPUT_DIR = "/kaggle/input/aes2-train-data/"
OOF_DIR = ""
MLM_PATH = "/kaggle/input/lal-deberta-base-mlm/deberta_v3_base_chk/checkpoint-57824/"
OUTPUT_DIR = "/kaggle/working/"
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

# CFG

In [None]:
# ====================================================
# CFG
# ====================================================


class CFG:
    debug = False
    apex = True
    print_freq = 100
    num_workers = 4
    model = "microsoft/deberta-v3-base"
    scheduler = "cosine"  # ['linear', 'cosine']
    batch_scheduler = True
    num_cycles = 0.5
    num_warmup_steps = 0
    epochs = 5
    encoder_lr = 1e-5
    decoder_lr = 1e-5
    min_lr = 1e-6
    eps = 1e-6
    betas = (0.9, 0.999)
    batch_size = 2
    fc_dropout = 0.0
    model_config = {
        "attention_dropout": 0.0,
        "attention_probs_dropout_prob": 0.0,
        "hidden_dropout": 0.0,
        "hidden_dropout_prob": 0.0,
        #'layer_norm_eps':1e-7,
    }
    target_size = 1
    target_cols = ["0", "1", "2", "3", "4", "5"]
    target_cols2 = ["score"]
    target_cols3 = ["score_s"]
    max_len = 1024
    weight_decay = 0.01
    gradient_accumulation_steps = 1
    max_grad_norm = 1000
    seed = 42
    n_fold = 6
    # trn_fold=[0, 1, 2, 3]
    trn_fold = [0]
    freeze_layer = 9
    head = "mean_pooling"  #'mean_pooling' 'attention' 'lstm'
    sl = False
    sl_rate = 0.2
    train = True
    flag = 0

# Library

In [None]:
!pip install transformers
!pip install sentencepiece

In [None]:
import ast
import copy
import gc
import itertools
import json
import math

# ====================================================
# Library
# ====================================================
import os
import pickle
import random
import re
import string
import sys
import time
import warnings

import joblib


warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import scipy as sp


pd.set_option("display.max_rows", 500)
pd.set_option("display.max_columns", 500)
pd.set_option("display.width", 1000)
# os.system('pip uninstall -y transformers')
# os.system('python -m pip install --no-index --find-links=../input/nbme-pip-wheels transformers')
import tokenizers
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from sklearn.metrics import cohen_kappa_score, f1_score, log_loss, mean_squared_error
from sklearn.model_selection import (
    GroupKFold,
    KFold,
    StratifiedGroupKFold,
    StratifiedKFold,
)
from torch.nn import Parameter
from torch.optim import SGD, Adam, AdamW
from torch.utils.checkpoint import checkpoint
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm


print(f"tokenizers.__version__: {tokenizers.__version__}")
print(f"transformers.__version__: {transformers.__version__}")
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    DataCollatorWithPadding,
    get_cosine_schedule_with_warmup,
    get_linear_schedule_with_warmup,
)


%env TOKENIZERS_PARALLELISM=true

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Utils

In [None]:
# ====================================================
# Utils
# ====================================================


def get_score(y_trues, y_preds):
    y_preds = pd.cut(
        y_preds.reshape(-1) * 5,
        [-np.inf, 0.83333333, 1.66666667, 2.5, 3.33333333, 4.16666667, np.inf],
        labels=[0, 1, 2, 3, 4, 5],
    )
    score = cohen_kappa_score(y_trues, y_preds, weights="quadratic")
    return score


def get_logger(filename=OUTPUT_DIR + "train"):
    from logging import INFO, FileHandler, Formatter, StreamHandler, getLogger

    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=f"{filename}.log")
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger


LOGGER = get_logger()


def seed_everything(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


seed_everything(seed=CFG.seed)

#Dats Loading

In [None]:
# ====================================================
# Data Loading
# ====================================================
train = pd.read_pickle(f"{INPUT_DIR}train.pkl")
if CFG.sl:
    oof = pd.read_pickle(f"{OOF_DIR}oof_df.pkl")
    train = train.merge(oof[["essay_id", "pred"]], on="essay_id", how="left")
    train[CFG.target_cols3[0]] = (
        (train[CFG.target_cols2[0]].values / 5) * (1 - CFG.sl_rate)
    ) + (train["pred"].values * CFG.sl_rate)
else:
    train[CFG.target_cols3[0]] = train[CFG.target_cols2[0]].values / 5
print(train.shape)
train.head()

# Tokenizer

In [None]:
# ====================================================
# tokenizer
# ====================================================
tokenizer = AutoTokenizer.from_pretrained(CFG.model)
tokenizer.save_pretrained(OUTPUT_DIR)
tokenizer.add_special_tokens({"additional_special_tokens": ["[BR]"]})
CFG.tokenizer = tokenizer

In [None]:
import matplotlib.pyplot as plt


plt.hist(train["length"])
plt.show()

# CV Split

In [None]:
#!pip install -q iterative-stratification
# from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
# Fold = MultilabelStratifiedKFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)
# for n, (train_index, val_index) in enumerate(Fold.split(train, train[['score', 'length']])):
#    train.loc[val_index, 'fold'] = int(n)
# train['fold'] = train['fold'].astype(int)
display(train.groupby("fold").size())

# Dataset

In [None]:
def prepare_input(cfg, text):
    inputs = cfg.tokenizer.encode_plus(
        text,
        return_tensors=None,
        add_special_tokens=True,
        max_length=CFG.max_len,
        pad_to_max_length=True,
        truncation=True,
    )
    for k, v in inputs.items():
        inputs[k] = torch.tensor(v, dtype=torch.long)
    return inputs


class TrainDataset(Dataset):
    def __init__(self, cfg, df):
        self.cfg = cfg
        self.texts = df["full_text"].values
        self.labels = df[cfg.target_cols2].values
        self.labels2 = df[cfg.target_cols3].values

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        inputs = prepare_input(self.cfg, self.texts[item])
        label = torch.tensor(self.labels[item], dtype=torch.float)
        label2 = torch.tensor(self.labels2[item], dtype=torch.float)
        return inputs, label, label2


def collate(inputs):
    mask_len = int(inputs["attention_mask"].sum(axis=1).max())
    for k, v in inputs.items():
        inputs[k] = inputs[k][:, :mask_len]
    return inputs

# Model

In [None]:
# ====================================================
# Model
# ====================================================


class MeanPooling(nn.Module):
    def __init__(self):
        super(MeanPooling, self).__init__()

    def forward(self, last_hidden_state, attention_mask):
        input_mask_expanded = (
            attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        )
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        mean_embeddings = sum_embeddings / sum_mask
        return mean_embeddings


class CustomModel_mean_pooling(nn.Module):
    def __init__(self, cfg, config_path=None, pretrained=False):
        super().__init__()
        self.cfg = cfg
        if config_path is None:
            self.config = AutoConfig.from_pretrained(cfg.model, output_hidden_states=True)
        else:
            self.config = torch.load(config_path)
        self.config.update(CFG.model_config)
        if cfg.sl:
            if pretrained:
                self.model = AutoModel.from_pretrained(cfg.model, config=self.config)
            else:
                self.model = AutoModel.from_config(self.config)
        else:
            self.model = AutoModel.from_pretrained(MLM_PATH)
        self.model.gradient_checkpointing_enable()
        self.pool = MeanPooling()
        self.fc = nn.Linear(self.config.hidden_size, self.cfg.target_size)
        self._init_weights(self.fc)
        self.layer_norm1 = nn.LayerNorm(self.config.hidden_size)
        self.model.embeddings.requires_grad_(False)
        self.model.encoder.layer[: CFG.freeze_layer].requires_grad_(False)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def feature(self, inputs):
        outputs = self.model(**inputs)
        last_hidden_states = outputs[0]
        feature = self.pool(last_hidden_states, inputs["attention_mask"])
        return feature

    def forward(self, inputs):
        feature = self.feature(inputs)
        feature = self.layer_norm1(feature)
        output = self.fc(feature)
        return output


class CustomModel_attention(nn.Module):
    def __init__(self, cfg, config_path=None, pretrained=False):
        super().__init__()
        self.cfg = cfg
        if config_path is None:
            self.config = AutoConfig.from_pretrained(cfg.model, output_hidden_states=True)
        else:
            self.config = torch.load(config_path)
        self.config.update(cfg.model_config)
        if cfg.sl:
            if pretrained:
                self.model = AutoModel.from_pretrained(cfg.model, config=self.config)
            else:
                self.model = AutoModel.from_config(self.config)
        else:
            self.model = AutoModel.from_pretrained(MLM_PATH)
        self.model.gradient_checkpointing_enable()
        self.fc = nn.Linear(self.config.hidden_size, self.cfg.target_size)
        self._init_weights(self.fc)
        self.attention = nn.Sequential(
            nn.Linear(self.config.hidden_size, 512),
            nn.Tanh(),
            nn.Linear(512, 1),
            nn.Softmax(dim=1),
        )
        self._init_weights(self.attention)
        self.model.embeddings.requires_grad_(False)
        self.model.encoder.layer[: CFG.freeze_layer].requires_grad_(False)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def feature(self, inputs):
        outputs = self.model(**inputs)
        last_hidden_states = outputs[0]
        weights = self.attention(last_hidden_states)
        feature = torch.sum(weights * last_hidden_states, dim=1)
        return feature

    def forward(self, inputs):
        feature = self.feature(inputs)
        output = self.fc(feature)
        return output


class CustomModel_lstm(nn.Module):
    def __init__(self, cfg, config_path=None, pretrained=False):
        super().__init__()
        self.cfg = cfg
        if config_path is None:
            self.config = AutoConfig.from_pretrained(cfg.model, output_hidden_states=True)
        else:
            self.config = torch.load(config_path)
        self.config.update(cfg.model_config)
        if cfg.sl:
            if pretrained:
                self.model = AutoModel.from_pretrained(cfg.model, config=self.config)
            else:
                self.model = AutoModel.from_config(self.config)
        else:
            self.model = AutoModel.from_pretrained(MLM_PATH)
        self.model.gradient_checkpointing_enable()
        self.pool = MeanPooling()
        self.fc = nn.Linear(self.config.hidden_size, self.cfg.target_size)
        self._init_weights(self.fc)
        self.layer_norm1 = nn.LayerNorm(self.config.hidden_size)
        self.lstm = nn.LSTM(
            self.config.hidden_size,
            (self.config.hidden_size) // 2,
            num_layers=2,
            dropout=self.config.hidden_dropout_prob,
            batch_first=True,
            bidirectional=True,
        )
        self._init_weights(self.lstm)
        self.model.embeddings.requires_grad_(False)
        self.model.encoder.layer[: CFG.freeze_layer].requires_grad_(False)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def feature(self, inputs):
        outputs = self.model(**inputs)
        last_hidden_states = outputs[0]
        feature, hc = self.lstm(last_hidden_states)
        feature = self.pool(feature, inputs["attention_mask"])
        return feature

    def forward(self, inputs):
        feature = self.feature(inputs)
        feature = self.layer_norm1(feature)
        output = self.fc(feature)
        return output

# Helper functions

In [None]:
# ====================================================
# Helper functions
# ====================================================


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return "%dm %ds" % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return "%s (remain %s)" % (asMinutes(s), asMinutes(rs))


def train_fn(
    fold,
    train_loader,
    valid_loader,
    valid_labels,
    valid_loader2,
    valid_labels2,
    model,
    criterion,
    optimizer,
    epoch,
    scheduler,
    device,
    best_score,
):
    model.train()
    scaler = torch.cuda.amp.GradScaler(enabled=CFG.apex)
    losses = AverageMeter()
    start = end = time.time()
    global_step = 0
    preds = []
    train_labels = []
    for step, (inputs, labels, labels2) in enumerate(train_loader):
        with torch.cuda.amp.autocast(enabled=CFG.apex):
            inputs = collate(inputs)
            for k, v in inputs.items():
                inputs[k] = v.to(device)
            labels = labels.to(device)
            labels2 = labels2.to(device)
            batch_size = labels.size(0)
            y_preds = model(inputs)
            loss = criterion(y_preds, labels2)
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        losses.update(loss.item(), batch_size)
        preds.append(y_preds.sigmoid().detach().to("cpu").numpy())
        train_labels.append(labels.detach().to("cpu").numpy())
        scaler.scale(loss).backward()
        # awp.attack_backward(inputs, labels, epoch)
        if (step + 1) % CFG.gradient_accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            global_step += 1
            if CFG.batch_scheduler:
                scheduler.step()
        end = time.time()
        if step % (CFG.print_freq * CFG.gradient_accumulation_steps) == 0 or step == (
            len(train_loader) - 1
        ):
            print(
                "Epoch: [{0}][{1}/{2}] "
                "Elapsed {remain:s} "
                "Loss: {loss.val:.4f}({loss.avg:.4f}) "
                "LR: {lr:.8f}  ".format(
                    epoch + 1,
                    step,
                    len(train_loader),
                    remain=timeSince(start, float(step + 1) / len(train_loader)),
                    loss=losses,
                    lr=scheduler.get_lr()[0],
                )
            )

            if step > len(train_loader) - 2:
                predictions = np.concatenate(preds)
                train_labels = np.concatenate(train_labels)
                train_score = get_score(train_labels, predictions)
                avg_val_loss, predictions, predictions2 = valid_fn(
                    valid_loader, valid_loader2, model, criterion, device
                )
                score = get_score(valid_labels, predictions)
                score2 = get_score(valid_labels2, predictions2)
                elapsed = time.time() - start
                LOGGER.info(
                    f"Epoch_Step {epoch+1}_{step} - avg_train_loss: {losses.avg:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s"
                )
                LOGGER.info(
                    f"Epoch {epoch+1} - Train Score: {train_score:.4f} Val Score: {score:.4f} Val Score2: {score2:.4f}"
                )

                if best_score < score2:
                    best_score = score2
                    LOGGER.info(
                        f"Epoch_Step {epoch+1}_{step} - Save Best Score: {best_score:.4f} Model\n"
                    )
                    torch.save(
                        {"model": model.state_dict(), "predictions": predictions},
                        OUTPUT_DIR + f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth",
                    )

    return best_score


def valid_fn(valid_loader, valid_loader2, model, criterion, device):
    losses = AverageMeter()
    model.eval()
    preds = []
    start = end = time.time()
    for step, (inputs, labels, labels2) in enumerate(valid_loader):
        with torch.no_grad():
            inputs = collate(inputs)
            for k, v in inputs.items():
                inputs[k] = v.to(device)
            labels2 = labels2.to(device)
            y_preds = model(inputs)
            loss = criterion(y_preds, labels2)
        batch_size = labels2.size(0)
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        losses.update(loss.item(), batch_size)
        preds.append(y_preds.sigmoid().to("cpu").numpy())
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(valid_loader) - 1):
            print(
                "EVAL: [{0}/{1}] "
                "Elapsed {remain:s} "
                "Loss: {loss.val:.4f}({loss.avg:.4f}) ".format(
                    step,
                    len(valid_loader),
                    loss=losses,
                    remain=timeSince(start, float(step + 1) / len(valid_loader)),
                )
            )
    predictions = np.concatenate(preds)

    preds = []
    for step, (inputs, labels, labels2) in enumerate(valid_loader2):
        with torch.no_grad():
            inputs = collate(inputs)
            for k, v in inputs.items():
                inputs[k] = v.to(device)
            labels2 = labels2.to(device)
            y_preds = model(inputs)
        batch_size = labels2.size(0)
        preds.append(y_preds.sigmoid().to("cpu").numpy())
    predictions2 = np.concatenate(preds)

    return losses.avg, predictions, predictions2

In [None]:
# ====================================================
# train loop
# ====================================================


def train_loop(folds, fold):
    LOGGER.info(f"========== fold: {fold} training ==========")

    # ====================================================
    # loader
    # ====================================================
    if CFG.flag == 0:
        train_folds = folds[(folds["fold"] != fold) & ((folds["flag"] != 2))].reset_index(
            drop=True
        )
        valid_folds = folds[(folds["fold"] == fold) & (folds["flag"] != 2)].reset_index(
            drop=True
        )
        valid_folds2 = folds[(folds["fold"] == fold) & (folds["flag"] == 1)].reset_index(
            drop=True
        )
    else:
        train_folds = folds[(folds["fold"] != fold) & ((folds["flag"] == 1))].reset_index(
            drop=True
        )
        valid_folds = folds[(folds["fold"] == fold) & (folds["flag"] != 2)].reset_index(
            drop=True
        )
        valid_folds2 = folds[(folds["fold"] == fold) & (folds["flag"] == 0)].reset_index(
            drop=True
        )

    valid_folds = valid_folds.sort_values(["length", "essay_id"]).reset_index(drop=True)
    valid_labels = valid_folds[CFG.target_cols2].values
    valid_folds2 = valid_folds2.sort_values(["length", "essay_id"]).reset_index(drop=True)
    valid_labels2 = valid_folds2[CFG.target_cols2].values

    train_dataset = TrainDataset(CFG, train_folds)
    valid_dataset = TrainDataset(CFG, valid_folds)
    valid_dataset2 = TrainDataset(CFG, valid_folds2)

    train_loader = DataLoader(
        train_dataset,
        batch_size=CFG.batch_size,
        shuffle=True,
        num_workers=CFG.num_workers,
        pin_memory=True,
        drop_last=True,
    )
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=CFG.batch_size,
        shuffle=False,
        num_workers=CFG.num_workers,
        pin_memory=True,
        drop_last=False,
    )
    valid_loader2 = DataLoader(
        valid_dataset2,
        batch_size=CFG.batch_size,
        shuffle=False,
        num_workers=CFG.num_workers,
        pin_memory=True,
        drop_last=False,
    )

    # ====================================================
    # model & optimizer
    # ====================================================
    if CFG.sl:
        if CFG.head == "mean_pooling":
            model = CustomModel_mean_pooling(CFG, config_path=None, pretrained=False)
        elif CFG.head == "attention":
            model = CustomModel_attention(CFG, config_path=None, pretrained=False)
        elif CFG.head == "lstm":
            model = CustomModel_lstm(CFG, config_path=None, pretrained=False)
        state = torch.load(
            "/content/" + f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth",
            map_location=torch.device("cpu"),
        )
        model.load_state_dict(state["model"])
    else:
        if CFG.head == "mean_pooling":
            model = CustomModel_mean_pooling(CFG, config_path=None, pretrained=True)
        elif CFG.head == "attention":
            model = CustomModel_attention(CFG, config_path=None, pretrained=True)
        elif CFG.head == "lstm":
            model = CustomModel_lstm(CFG, config_path=None, pretrained=True)
    torch.save(model.config, OUTPUT_DIR + "config.pth")
    model.to(device)

    def get_optimizer_params(model, encoder_lr, decoder_lr, weight_decay=0.0):
        param_optimizer = list(model.named_parameters())
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_parameters = [
            {
                "params": [
                    p
                    for n, p in model.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "lr": encoder_lr,
                "weight_decay": weight_decay,
            },
            {
                "params": [
                    p
                    for n, p in model.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "lr": encoder_lr,
                "weight_decay": 0.0,
            },
            {
                "params": [p for n, p in model.named_parameters() if "model" not in n],
                "lr": decoder_lr,
                "weight_decay": 0.0,
            },
        ]
        return optimizer_parameters

    optimizer_parameters = get_optimizer_params(
        model,
        encoder_lr=CFG.encoder_lr,
        decoder_lr=CFG.decoder_lr,
        weight_decay=CFG.weight_decay,
    )
    optimizer = AdamW(
        optimizer_parameters, lr=CFG.encoder_lr, eps=CFG.eps, betas=CFG.betas
    )

    # ====================================================
    # scheduler
    # ====================================================
    def get_scheduler(cfg, optimizer, num_train_steps):
        if cfg.scheduler == "linear":
            scheduler = get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=cfg.num_warmup_steps,
                num_training_steps=num_train_steps,
            )
        elif cfg.scheduler == "cosine":
            scheduler = get_cosine_schedule_with_warmup(
                optimizer,
                num_warmup_steps=cfg.num_warmup_steps,
                num_training_steps=num_train_steps,
                num_cycles=cfg.num_cycles,
            )
        return scheduler

    num_train_steps = int(len(train_folds) / CFG.batch_size * CFG.epochs)
    scheduler = get_scheduler(CFG, optimizer, num_train_steps)

    # ====================================================
    # loop
    # ====================================================
    criterion = nn.BCEWithLogitsLoss(reduction="mean")

    best_score = -np.inf

    for epoch in range(CFG.epochs):
        if epoch < 3:
            # train
            best_score = train_fn(
                fold,
                train_loader,
                valid_loader,
                valid_labels,
                valid_loader2,
                valid_labels2,
                model,
                criterion,
                optimizer,
                epoch,
                scheduler,
                device,
                best_score,
            )

    predictions = torch.load(
        OUTPUT_DIR + f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth",
        map_location=torch.device("cpu"),
    )["predictions"]
    valid_folds["pred"] = predictions

    torch.cuda.empty_cache()
    gc.collect()

    return valid_folds

In [None]:
if __name__ == "__main__":

    def get_result(oof_df):
        labels = oof_df[CFG.target_cols2].values
        preds = oof_df["pred"].values
        score = get_score(labels, preds)
        labels = oof_df.loc[oof_df.flag == 1, CFG.target_cols2].values
        preds = oof_df.loc[oof_df.flag == 1, "pred"].values
        score2 = get_score(labels, preds)
        LOGGER.info(f"Score: {score:<.4f} Score2: {score2:<.4f}")

    if CFG.train:
        oof_df = pd.DataFrame()
        for fold in range(CFG.n_fold):
            if fold in CFG.trn_fold:
                _oof_df = train_loop(train, fold)
                oof_df = pd.concat([oof_df, _oof_df])
                LOGGER.info(f"========== fold: {fold} result ==========")
                get_result(_oof_df)
        oof_df = oof_df.reset_index(drop=True)
        LOGGER.info(f"========== CV ==========")
        get_result(oof_df)
        oof_df.to_pickle(OUTPUT_DIR + "oof_df.pkl")