Inspired https://www.kaggle.com/code/yasufuminakama/nbme-deberta-base-baseline-train etc... 

# Init

In [None]:
import os
from IPython.core.display import display
from pathlib import Path
import random
import re
import yaml
import numpy as np
from numpy import ndarray
import pandas as pd
import torch
from logging import Logger, getLogger, INFO, StreamHandler, FileHandler, Formatter
import wandb
from wandb.sdk.wandb_config import Config


def init_pandas() -> None:

    pd.set_option("display.max_rows", 500)
    pd.set_option("display.max_columns", 500)
    pd.set_option("display.width", 1000)


def get_logger(filename: str) -> Logger:

    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=f"{filename}.log")
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger


def seed_everything(seed: int = 42) -> None:

    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def init_wandb(wandb_key: str) -> Config:
    secret_value_0 = wandb_key
    wandb.login(key=secret_value_0)

    loader = yaml.SafeLoader
    loader.add_implicit_resolver(
        "tag:yaml.org,2002:float",
        re.compile(
            """^(?:
         [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
        |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
        |\\.[0-9_]+(?:[eE][-+][0-9]+)?
        |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
        |[-+]?\\.(?:inf|Inf|INF)
        |\\.(?:nan|NaN|NAN))$""",
            re.X,
        ),
        list("-+0123456789."),
    )
    with open(f"./config.yml") as f:
        param = yaml.load(f, Loader=loader)
    wandb.init(project=param["project"], config=param)
    wandb.config.update(param)
    print(f"run name: {wandb.run.name}")
    return wandb.config


def mk_output_dir(path: str) -> None:

    if not os.path.exists(path):
        os.makedirs(path)

In [None]:
import warnings
from getpass import getpass

wandb_key = getpass()
config = init_wandb(wandb_key=wandb_key)
mk_output_dir(path=config.output_dir)
logger = get_logger(filename=config.output_dir + "train")
seed_everything(seed=config.seed)
init_pandas()
warnings.filterwarnings("ignore")

# Helper functions for scoring

In [None]:
from sklearn.metrics import f1_score


def get_score(y_true: ndarray, y_pred: ndarray) -> float:
    score = span_micro_f1(y_true, y_pred)
    return score


def micro_f1(preds: list, truths: list) -> float:
    """
    Micro f1 on binary arrays.

    Args:
        preds (list of lists of ints): Predictions.
        truths (list of lists of ints): Ground truths.

    Returns:
        float: f1 score.
    """
    # Micro : aggregating over all instances
    preds = np.concatenate(preds)
    truths = np.concatenate(truths)

    return f1_score(truths, preds)


def spans_to_binary(spans: list, length=None):
    """
    Converts spans to a binary array indicating whether each character is in the span.

    Args:
        spans (list of lists of two ints): Spans.

    Returns:
        np array [length]: Binarized spans.
    """
    length = np.max(spans) if length is None else length
    binary = np.zeros(length)
    for start, end in spans:
        binary[start:end] = 1

    return binary


def span_micro_f1(preds, truths):
    """
    Micro f1 on spans.

    Args:
        preds (list of lists of two ints): Prediction spans.
        truths (list of lists of two ints): Ground truth spans.

    Returns:
        float: f1 score.
    """
    bin_preds = []
    bin_truths = []
    for pred, truth in zip(preds, truths):
        if not len(pred) and not len(truth):
            continue
        length = max(
            np.max(pred) if len(pred) else 0, np.max(truth) if len(truth) else 0
        )
        bin_preds.append(spans_to_binary(pred, length))
        bin_truths.append(spans_to_binary(truth, length))

    return micro_f1(bin_preds, bin_truths)

In [None]:
import ast
from numpy import ndarray
from pandas import DataFrame
from transformers.tokenization_utils import PreTrainedTokenizer


def create_labels_for_scoring(df: DataFrame):
    # example: ['0 1', '3 4'] -> ['0 1; 3 4']
    df["location_for_create_labels"] = [ast.literal_eval(f"[]")] * len(df)
    for i in range(len(df)):
        lst = df.loc[i, "location"]
        if lst:
            new_lst = ";".join(lst)
            df.loc[i, "location_for_create_labels"] = ast.literal_eval(
                f'[["{new_lst}"]]'
            )
    # create labels
    truths = []
    for location_list in df["location_for_create_labels"].values:
        truth = []
        if len(location_list) > 0:
            location = location_list[0]
            for loc in [s.split() for s in location.split(";")]:
                start, end = int(loc[0]), int(loc[1])
                truth.append([start, end])
        truths.append(truth)

    return truths


def get_char_probs(
    texts: list, predictions: ndarray, tokenizer: PreTrainedTokenizer
) -> list:
    results = [np.zeros(len(t)) for t in texts]
    for i, (text, prediction) in enumerate(zip(texts, predictions)):
        encoded = tokenizer(text, add_special_tokens=True, return_offsets_mapping=True)
        prev_pred = 0
        prev_end = -1
        for offset_mapping, pred in zip(encoded["offset_mapping"], prediction):
            start = offset_mapping[0]
            end = offset_mapping[1]
            results[i][start:end] = pred
            if start != prev_end:
                results[i][prev_end:start] = (pred + prev_pred) / 2
            prev_pred = pred
            prev_end = end

    return results


def cluster_elements(xs: list) -> list:
    clusters = [[]]

    if len(xs) == 0:
        return clusters

    prev_x = xs[0] - 1
    for x in xs:
        if x == prev_x + 1:
            clusters[-1].append(x)
        else:
            clusters.append([x])
        prev_x = x

    return clusters


def get_results(char_probs: list, pn_histories: list, th: float = 0.5) -> list:
    label_strs = []
    for char_prob, pn_history in zip(char_probs, pn_histories):
        pos_char_indices = np.where(char_prob > th)[0] + 1
        if len(pos_char_indices) > 0 and pos_char_indices[0] == 1:
            pos_char_indices = np.hstack([[0], pos_char_indices])
        clustered_pos_char_indices = cluster_elements(xs=pos_char_indices)

        for i in range(len(clustered_pos_char_indices)):
            # 1文字目がspaceの場合
            if len(clustered_pos_char_indices[i]) > 0:
                target_idx = clustered_pos_char_indices[i][0] - 1
                if target_idx > -1 and pn_history[target_idx] != " ":
                    clustered_pos_char_indices[i] = np.hstack(
                        [[target_idx], clustered_pos_char_indices[i]]
                    )
            # 1文字目が\r\nの場合
            if len(clustered_pos_char_indices[i]) > 0:
                if clustered_pos_char_indices[i][0] > 0 and clustered_pos_char_indices[
                    i
                ][0] + 2 < len(pn_history):
                    if (
                        pn_history[
                            clustered_pos_char_indices[i][
                                0
                            ] : clustered_pos_char_indices[i][0]
                            + 2
                        ]
                        == "\r\n"
                    ):
                        clustered_pos_char_indices[i] = clustered_pos_char_indices[i][
                            2:
                        ]
            # 最後の2文字が\n-の場合
            if len(clustered_pos_char_indices[i]) > 0:
                target_idx = clustered_pos_char_indices[i][-1] - 2
                if target_idx > 0 and pn_history[target_idx : target_idx + 2] == "\n-":
                    clustered_pos_char_indices[i] = clustered_pos_char_indices[i][:-2]

        pos_char_spans = []
        if len(clustered_pos_char_indices[0]) != 0:
            for x in clustered_pos_char_indices:
                if len(x) > 0:
                    pos_char_spans.append([x[0], x[-1]])
        label_strs.append(";".join([f"{x[0]} {x[1]}" for x in pos_char_spans]))

    return label_strs


def get_predictions(results: list) -> list:
    predictions = []
    for result in results:
        prediction = []
        if result != "":
            for loc in [s.split() for s in result.split(";")]:
                start, end = int(loc[0]), int(loc[1])
                prediction.append([start, end])
        predictions.append(prediction)

    return predictions


def get_result(
    df_oof: DataFrame, tokenizer: PreTrainedTokenizer, max_len: int
) -> tuple:
    labels = create_labels_for_scoring(df_oof)
    predictions = df_oof[[i for i in range(max_len)]].to_numpy()
    char_probs = get_char_probs(df_oof["pn_history"].to_numpy(), predictions, tokenizer)
    pn_histories = df_oof["pn_history"].to_list()

    score = -100
    for th in np.arange(0.3, 0.7, 0.005):
        th = np.round(th, 4)
        results = get_results(char_probs, pn_histories, th=th)
        preds = get_predictions(results)
        tmp_score = get_score(labels, preds)
        if tmp_score > score:
            best_th = th
            score = tmp_score

    return score, best_th

# Data Loading

In [None]:
from pandas import DataFrame
import pandas as pd


def preprocess_features(features: DataFrame) -> None:
    features.loc[27, "feature_text"] = "Last-Pap-smear-1-year-ago"

In [None]:
INPUT_DIR = Path("../../input/")
df_train = pd.read_csv(INPUT_DIR / "train.csv")
df_train["annotation"] = df_train["annotation"].map(lambda x: ast.literal_eval(x))
df_train["location"] = df_train["location"].map(lambda x: ast.literal_eval(x))

features = pd.read_csv(INPUT_DIR / "features.csv")
preprocess_features(features)

patient_notes = pd.read_csv(INPUT_DIR / "patient_notes.csv")

In [None]:
df_train = df_train.merge(features, on=['feature_num', 'case_num'], how='left')
df_train = df_train.merge(patient_notes, on=['pn_num', 'case_num'], how='left')
display(df_train.head())

In [None]:
from pandas import DataFrame

def correct_annotation(df_train:DataFrame) -> None:
    df_train.loc[338, 'annotation'] = ast.literal_eval('[["father heart attack"]]')
    df_train.loc[338, 'location'] = ast.literal_eval('[["764 783"]]')

    df_train.loc[621, 'annotation'] = ast.literal_eval('[["for the last 2-3 months"]]')
    df_train.loc[621, 'location'] = ast.literal_eval('[["77 100"]]')

    df_train.loc[655, 'annotation'] = ast.literal_eval('[["no heat intolerance"], ["no cold intolerance"]]')
    df_train.loc[655, 'location'] = ast.literal_eval('[["285 292;301 312"], ["285 287;296 312"]]')

    df_train.loc[1262, 'annotation'] = ast.literal_eval('[["mother thyroid problem"]]')
    df_train.loc[1262, 'location'] = ast.literal_eval('[["551 557;565 580"]]')

    df_train.loc[1265, 'annotation'] = ast.literal_eval('[[\'felt like he was going to "pass out"\']]')
    df_train.loc[1265, 'location'] = ast.literal_eval('[["131 135;181 212"]]')

    df_train.loc[1396, 'annotation'] = ast.literal_eval('[["stool , with no blood"]]')
    df_train.loc[1396, 'location'] = ast.literal_eval('[["259 280"]]')

    df_train.loc[1591, 'annotation'] = ast.literal_eval('[["diarrhoe non blooody"]]')
    df_train.loc[1591, 'location'] = ast.literal_eval('[["176 184;201 212"]]')

    df_train.loc[1615, 'annotation'] = ast.literal_eval('[["diarrhea for last 2-3 days"]]')
    df_train.loc[1615, 'location'] = ast.literal_eval('[["249 257;271 288"]]')

    df_train.loc[1664, 'annotation'] = ast.literal_eval('[["no vaginal discharge"]]')
    df_train.loc[1664, 'location'] = ast.literal_eval('[["822 824;907 924"]]')

    df_train.loc[1714, 'annotation'] = ast.literal_eval('[["started about 8-10 hours ago"]]')
    df_train.loc[1714, 'location'] = ast.literal_eval('[["101 129"]]')

    df_train.loc[1929, 'annotation'] = ast.literal_eval('[["no blood in the stool"]]')
    df_train.loc[1929, 'location'] = ast.literal_eval('[["531 539;549 561"]]')

    df_train.loc[2134, 'annotation'] = ast.literal_eval('[["last sexually active 9 months ago"]]')
    df_train.loc[2134, 'location'] = ast.literal_eval('[["540 560;581 593"]]')

    df_train.loc[2191, 'annotation'] = ast.literal_eval('[["right lower quadrant pain"]]')
    df_train.loc[2191, 'location'] = ast.literal_eval('[["32 57"]]')

    df_train.loc[2553, 'annotation'] = ast.literal_eval('[["diarrhoea no blood"]]')
    df_train.loc[2553, 'location'] = ast.literal_eval('[["308 317;376 384"]]')

    df_train.loc[3124, 'annotation'] = ast.literal_eval('[["sweating"]]')
    df_train.loc[3124, 'location'] = ast.literal_eval('[["549 557"]]')

    df_train.loc[3858, 'annotation'] = ast.literal_eval('[["previously as regular"], ["previously eveyr 28-29 days"], ["previously lasting 5 days"], ["previously regular flow"]]')
    df_train.loc[3858, 'location'] = ast.literal_eval('[["102 123"], ["102 112;125 141"], ["102 112;143 157"], ["102 112;159 171"]]')

    df_train.loc[4373, 'annotation'] = ast.literal_eval('[["for 2 months"]]')
    df_train.loc[4373, 'location'] = ast.literal_eval('[["33 45"]]')

    df_train.loc[4763, 'annotation'] = ast.literal_eval('[["35 year old"]]')
    df_train.loc[4763, 'location'] = ast.literal_eval('[["5 16"]]')

    df_train.loc[4782, 'annotation'] = ast.literal_eval('[["darker brown stools"]]')
    df_train.loc[4782, 'location'] = ast.literal_eval('[["175 194"]]')

    df_train.loc[4908, 'annotation'] = ast.literal_eval('[["uncle with peptic ulcer"]]')
    df_train.loc[4908, 'location'] = ast.literal_eval('[["700 723"]]')

    df_train.loc[6016, 'annotation'] = ast.literal_eval('[["difficulty falling asleep"]]')
    df_train.loc[6016, 'location'] = ast.literal_eval('[["225 250"]]')

    df_train.loc[6192, 'annotation'] = ast.literal_eval('[["helps to take care of aging mother and in-laws"]]')
    df_train.loc[6192, 'location'] = ast.literal_eval('[["197 218;236 260"]]')

    df_train.loc[6380, 'annotation'] = ast.literal_eval('[["No hair changes"], ["No skin changes"], ["No GI changes"], ["No palpitations"], ["No excessive sweating"]]')
    df_train.loc[6380, 'location'] = ast.literal_eval('[["480 482;507 519"], ["480 482;499 503;512 519"], ["480 482;521 531"], ["480 482;533 545"], ["480 482;564 582"]]')

    df_train.loc[6562, 'annotation'] = ast.literal_eval('[["stressed due to taking care of her mother"], ["stressed due to taking care of husbands parents"]]')
    df_train.loc[6562, 'location'] = ast.literal_eval('[["290 320;327 337"], ["290 320;342 358"]]')

    df_train.loc[6862, 'annotation'] = ast.literal_eval('[["stressor taking care of many sick family members"]]')
    df_train.loc[6862, 'location'] = ast.literal_eval('[["288 296;324 363"]]')

    df_train.loc[7022, 'annotation'] = ast.literal_eval('[["heart started racing and felt numbness for the 1st time in her finger tips"]]')
    df_train.loc[7022, 'location'] = ast.literal_eval('[["108 182"]]')

    df_train.loc[7422, 'annotation'] = ast.literal_eval('[["first started 5 yrs"]]')
    df_train.loc[7422, 'location'] = ast.literal_eval('[["102 121"]]')

    df_train.loc[8876, 'annotation'] = ast.literal_eval('[["No shortness of breath"]]')
    df_train.loc[8876, 'location'] = ast.literal_eval('[["481 483;533 552"]]')

    df_train.loc[9027, 'annotation'] = ast.literal_eval('[["recent URI"], ["nasal stuffines, rhinorrhea, for 3-4 days"]]')
    df_train.loc[9027, 'location'] = ast.literal_eval('[["92 102"], ["123 164"]]')

    df_train.loc[9938, 'annotation'] = ast.literal_eval('[["irregularity with her cycles"], ["heavier bleeding"], ["changes her pad every couple hours"]]')
    df_train.loc[9938, 'location'] = ast.literal_eval('[["89 117"], ["122 138"], ["368 402"]]')

    df_train.loc[9973, 'annotation'] = ast.literal_eval('[["gaining 10-15 lbs"]]')
    df_train.loc[9973, 'location'] = ast.literal_eval('[["344 361"]]')

    df_train.loc[10513, 'annotation'] = ast.literal_eval('[["weight gain"], ["gain of 10-16lbs"]]')
    df_train.loc[10513, 'location'] = ast.literal_eval('[["600 611"], ["607 623"]]')

    df_train.loc[11551, 'annotation'] = ast.literal_eval('[["seeing her son knows are not real"]]')
    df_train.loc[11551, 'location'] = ast.literal_eval('[["386 400;443 461"]]')

    df_train.loc[11677, 'annotation'] = ast.literal_eval('[["saw him once in the kitchen after he died"]]')
    df_train.loc[11677, 'location'] = ast.literal_eval('[["160 201"]]')

    df_train.loc[12124, 'annotation'] = ast.literal_eval('[["tried Ambien but it didnt work"]]')
    df_train.loc[12124, 'location'] = ast.literal_eval('[["325 337;349 366"]]')

    df_train.loc[12279, 'annotation'] = ast.literal_eval('[["heard what she described as a party later than evening these things did not actually happen"]]')
    df_train.loc[12279, 'location'] = ast.literal_eval('[["405 459;488 524"]]')

    df_train.loc[12289, 'annotation'] = ast.literal_eval('[["experienced seeing her son at the kitchen table these things did not actually happen"]]')
    df_train.loc[12289, 'location'] = ast.literal_eval('[["353 400;488 524"]]')

    df_train.loc[13238, 'annotation'] = ast.literal_eval('[["SCRACHY THROAT"], ["RUNNY NOSE"]]')
    df_train.loc[13238, 'location'] = ast.literal_eval('[["293 307"], ["321 331"]]')

    df_train.loc[13297, 'annotation'] = ast.literal_eval('[["without improvement when taking tylenol"], ["without improvement when taking ibuprofen"]]')
    df_train.loc[13297, 'location'] = ast.literal_eval('[["182 221"], ["182 213;225 234"]]')

    df_train.loc[13299, 'annotation'] = ast.literal_eval('[["yesterday"], ["yesterday"]]')
    df_train.loc[13299, 'location'] = ast.literal_eval('[["79 88"], ["409 418"]]')

    df_train.loc[13845, 'annotation'] = ast.literal_eval('[["headache global"], ["headache throughout her head"]]')
    df_train.loc[13845, 'location'] = ast.literal_eval('[["86 94;230 236"], ["86 94;237 256"]]')

    df_train.loc[14083, 'annotation'] = ast.literal_eval('[["headache generalized in her head"]]')
    df_train.loc[14083, 'location'] = ast.literal_eval('[["56 64;156 179"]]')

In [None]:
df_train['annotation_length'] = df_train['annotation'].map(lambda x: len(x))

# Pseudo Labeled Data

In [None]:
pl_train = pd.read_csv('./pl_train_with_folds.csv')
pl_train['location'] = pl_train['location'].fillna('[]')
pl_train['location'] = pl_train['location'].map(lambda x: ast.literal_eval(x))

# CV split

In [None]:
from sklearn.model_selection import GroupKFold

kf = GroupKFold(n_splits=config.n_folds)
groups = df_train['pn_num'].to_numpy()
df_train.loc[:, 'fold'] = -1
for n, (train_index, val_index) in enumerate(kf.split(df_train, df_train['location'], groups)):
    df_train.loc[val_index, 'fold'] = n
display(df_train.groupby('fold').size())

In [None]:
if config.debug:
    display(df_train.groupby('fold').size())
    df_train = df_train.sample(n=500, random_state=0).reset_index(drop=True)
    display(df_train.groupby('fold').size())

# tokenizer

In [None]:
from transformers import AutoTokenizer

%env TOKENIZERS_PARALLELISM=true
tokenizer = AutoTokenizer.from_pretrained(
    config.tokenizer,
    trim_offsets=False
)
tokenizer.save_pretrained(config.output_dir+'tokenizer/')

In [None]:
pn_history_lengths = []
for text in patient_notes["pn_history"].fillna("").to_list():
    length = len(tokenizer(text, add_special_tokens=False)["input_ids"])
    pn_history_lengths.append(length)
pn_history_max_len = max(pn_history_lengths)
logger.info(f"pn_history max(lengths): {pn_history_max_len}")

features_lengths = []
for text in features["feature_text"].fillna("").to_list():
    length = len(tokenizer(text, add_special_tokens=False)["input_ids"])
    features_lengths.append(length)
feature_text_max_len = max(features_lengths)
logger.info(f"feature_text max(lengths): {feature_text_max_len}")

config.max_len = pn_history_max_len + feature_text_max_len + 3
logger.info(f"max_len: {config.max_len}")

# Dataset

In [None]:
from pandas import DataFrame
import torch
from torch import Tensor
from torch.utils.data import Dataset
from transformers.tokenization_utils import PreTrainedTokenizer


class TrainDataset(Dataset):
    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        max_len: int,
        feature_text_max_len: int,
        pn_history_max_len: int,
        df: DataFrame,
    ) -> None:

        self.tokenizer = tokenizer
        self.max_len = max_len
        self.feature_text_max_len = feature_text_max_len
        self.pn_history_max_len = pn_history_max_len
        self.feature_texts = df["feature_text"].to_numpy()
        self.pn_historys = df["pn_history"].to_numpy()
        self.annotation_lengths = df["annotation_length"].to_numpy()
        self.locations = df["location"].to_numpy()

    def prepare_input_with_fixed_position(
        self, pn_history: str, feature_text: str
    ) -> dict:

        pn_history_token = self.tokenizer(
            pn_history,
            add_special_tokens=True,
            max_length=self.pn_history_max_len + 2,
            padding="max_length",
            return_offsets_mapping=False,
        )

        feature_text_token = self.tokenizer(
            feature_text,
            add_special_tokens=True,
            max_length=self.feature_text_max_len + 2,
            padding="max_length",
            return_offsets_mapping=False,
        )
        for k, v in feature_text_token.items():
            feature_text_token[k] = v[1:]

        token = {
            "input_ids": pn_history_token["input_ids"]
            + feature_text_token["input_ids"],
            "attention_mask": pn_history_token["attention_mask"]
            + feature_text_token["attention_mask"],
            #             'token_type_ids': pn_history_token['token_type_ids']+list(torch.ones_like(
            #                 torch.tensor(feature_text_token['token_type_ids'], dtype=torch.long)
            #                 ))
        }
        for k, v in token.items():
            token[k] = torch.tensor(v[: self.max_len], dtype=torch.long)
        return token

    def create_label(
        self, text: str, annotation_length: int, location_list: list
    ) -> Tensor:

        encoded = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding="max_length",
            return_offsets_mapping=True,
        )
        offset_mapping = encoded["offset_mapping"]
        ignore_idxes = np.where(np.array(encoded.sequence_ids()) != 0)[0]
        label = np.zeros(len(offset_mapping))
        label[ignore_idxes] = -1
        if annotation_length != 0:
            for location in location_list:
                for loc in [s.split() for s in location.split(";")]:
                    start_idx = -1
                    end_idx = -1
                    start, end = int(loc[0]), int(loc[1])
                    for idx in range(len(offset_mapping)):
                        if (start_idx == -1) & (start < offset_mapping[idx][0]):
                            start_idx = idx - 1
                        if (end_idx == -1) & (end <= offset_mapping[idx][1]):
                            end_idx = idx + 1
                    if start_idx == -1:
                        start_idx = end_idx
                    if (start_idx != -1) & (end_idx != -1):
                        label[start_idx:end_idx] = 1

        return torch.tensor(label[: self.max_len], dtype=torch.float)

    def __len__(self) -> int:
        return len(self.feature_texts)

    def __getitem__(self, item: int) -> tuple:

        inputs = self.prepare_input_with_fixed_position(
            self.pn_historys[item], self.feature_texts[item]
        )
        label = self.create_label(
            self.pn_historys[item], self.annotation_lengths[item], self.locations[item]
        )

        return inputs, label

# Model

In [None]:
from torch import Tensor
from torch import nn
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from transformers import AutoModel, AutoConfig


class CustomModel(Module):
    def __init__(
        self, model_name: str, config_path: str = None, pretrained: bool = False
    ) -> None:
        super().__init__()
        if config_path is None:
            self.config = AutoConfig.from_pretrained(
                model_name, output_hidden_states=True
            )
        else:
            self.config = torch.load(config_path)
        if pretrained:
            self.model = AutoModel.from_pretrained(config.model, config=self.config)
        else:
            self.model = AutoModel(self.config)
        self.initializer_range = 0.1
        self.fc = nn.Linear(self.config.hidden_size, 1)
        self._init_weights(self.fc)

    def _init_weights(self, module: Module) -> None:
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def feature(self, inputs: Tensor) -> Tensor:
        outputs = self.model(**inputs)
        last_hidden_states = outputs[0]
        return last_hidden_states

    def forward(self, inputs: Tensor) -> Tensor:
        feature = self.feature(inputs)
        output = self.fc(feature)
        return output


class AWP:
    def __init__(
        self,
        model: Module,
        criterion: _Loss,
        optimizer: Optimizer,
        adv_param: str = "weight",
        adv_lr: int = 1,
        adv_eps: float = 0.2,
        start_epoch: int = 0,
        adv_step: int = 1,
        scaler=None,
    ) -> None:
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.adv_param = adv_param
        self.adv_lr = adv_lr
        self.adv_eps = adv_eps
        self.start_epoch = start_epoch
        self.adv_step = adv_step
        self.backup = {}
        self.backup_eps = {}
        self.scaler = scaler

    def attack_backward(self, inputs: Tensor, labels: Tensor, epoch: int) -> None:
        if (self.adv_lr == 0) or (epoch < self.start_epoch):
            return None
        self._save()
        for i in range(self.adv_step):
            self._attack_step()
            with torch.cuda.amp.autocast():
                y_preds = self.model(inputs)
                adv_loss = self.criterion(y_preds.view(-1, 1), labels.view(-1, 1))
                adv_loss = torch.masked_select(
                    adv_loss, labels.view(-1, 1) != -1
                ).mean()
                adv_loss = adv_loss.mean()
            self.optimizer.zero_grad()
            self.scaler.scale(adv_loss).backward()
        self._restore()

    def _attack_step(self) -> None:
        e = 1e-6
        for name, param in self.model.named_parameters():
            if (
                param.requires_grad
                and param.grad is not None
                and self.adv_param in name
            ):
                norm1 = torch.norm(param.grad)
                norm2 = torch.norm(param.data.detach())
                if norm1 != 0 and not torch.isnan(norm1):
                    r_at = self.adv_lr * param.grad / (norm1 + e) * (norm2 + e)
                    param.data.add_(r_at)
                    param.data = torch.min(
                        torch.max(param.data, self.backup_eps[name][0]),
                        self.backup_eps[name][1],
                    )

    def _save(self) -> None:
        for name, param in self.model.named_parameters():
            if (
                param.requires_grad
                and param.grad is not None
                and self.adv_param in name
            ):
                if name not in self.backup:
                    self.backup[name] = param.data.clone()
                    grad_eps = self.adv_eps * param.abs().detach()
                    self.backup_eps[name] = (
                        self.backup[name] - grad_eps,
                        self.backup[name] + grad_eps,
                    )

    def _restore(self) -> None:
        for name, param in self.model.named_parameters():
            if name in self.backup:
                param.data = self.backup[name]
        self.backup = {}
        self.backup_eps = {}

# Logging

In [None]:
import time
from math import floor
from torch import inference_mode


class AverageMeter(object):
    def __init__(self) -> None:
        self.reset()

    def reset(self) -> None:
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val: float, n=1) -> None:
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s) -> str:
    m = floor(s / 60)
    s -= m * 60
    return "%dm %ds" % (m, s)


def timeSince(since, percent) -> str:
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return "%s (remain %s)" % (asMinutes(s), asMinutes(rs))

# Trainer

In [None]:
import time
from logging import Logger
import joblib
from torch.utils.data import DataLoader
from torch import nn
from torch.nn import Module
from torch.optim import Optimizer, AdamW
from torch.optim.lr_scheduler import _LRScheduler
from torch import cuda
from transformers import (
    get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)
from wandb.sdk.wandb_config import Config


def get_optimizer_params(
    model: Module, encoder_lr: float, decoder_lr: float, weight_decay: float = 0.0
) -> list:

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {
            "params": [
                p
                for n, p in model.model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "lr": encoder_lr,
            "weight_decay": weight_decay,
        },
        {
            "params": [
                p
                for n, p in model.model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "lr": encoder_lr,
            "weight_decay": 0.0,
        },
        {
            "params": [p for n, p in model.named_parameters() if "model" not in n],
            "lr": decoder_lr,
            "weight_decay": 0.0,
        },
    ]

    return optimizer_parameters


def get_scheduler(
    scheduler: str,
    optimizer: Optimizer,
    num_warmup_steps: int,
    num_train_steps: int,
    num_cycles: int,
) -> _LRScheduler:

    if scheduler == "linear":
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_train_steps,
        )
    elif scheduler == "cosine":
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_train_steps,
            num_cycles=num_cycles,
        )
    else:
        raise ValueError("Invalid Scheduler Name.")

    return scheduler


class Trainer:
    def __init__(
        self, config: Config, tokenizer: PreTrainedTokenizer, logger: Logger
    ) -> None:

        self.config = config
        self.tokenizer = tokenizer
        self.criterion = nn.BCEWithLogitsLoss(reduction="none")
        self.logger = logger
        self.device = torch.device("cuda" if cuda.is_available() else "cpu")

    def log(self, data: dict, prefix: str = "") -> None:

        min_str_length = min([len(k) for k in data.keys()])
        n_same_char_seqs = 0
        for i in range(min_str_length):
            s = set([k[i] for k in data.keys()])
            if len(s) == 1:
                n_same_char_seqs += 1
            else:
                break
        str_logs = [f"{k}: {v}" for k, v in data.items()]
        s = " ".join([l[n_same_char_seqs:].capitalize() for l in str_logs])
        if prefix != "":
            s = f"{prefix} - {s}"
        self.logger.info(s)
        wandb.log(data)

    def compute_loss(
        self, y_preds: Tensor, labels: Tensor, batch_size: int, loss_th: float
    ) -> tuple:

        loss = self.criterion(y_preds.squeeze(-1), labels.squeeze(-1))
        samplewise_losses = []
        for i in range(batch_size):
            samplewise_losses.append(
                torch.masked_select(loss[i], labels[i].squeeze(-1) != -1).mean()
            )
        loss = torch.stack(samplewise_losses)
        loss_filter = torch.ones(batch_size, device=self.device)
        if loss_th is not None:
            mask = loss > loss_th
            n_masked = mask.sum()
            if n_masked > 0:
                self.logger.info(f"{n_masked} sample's loss was removed.")
            loss_filter[mask] = 0.0
        else:
            n_masked = 0

        samplewise_losses = []
        if loss_th is None:
            for l in samplewise_losses:
                samplewise_losses.append(l.item())

        return (loss * loss_filter).sum() / (batch_size - n_masked), samplewise_losses

    def train_with_eval(
        self,
        model: Module,
        fold: int,
        dls: tuple,
        optimizer: Optimizer,
        epoch: int,
        scheduler: _LRScheduler,
        loss_th: float,
        valid_texts: list,
        valid_labels: ndarray,
        n_vl: int,
        best_score: float,
    ) -> tuple:

        tr_dl, vl_dl = dls
        model.train()
        scaler = cuda.amp.GradScaler(enabled=self.config.apex)
        awp = AWP(
            model,
            self.criterion,
            optimizer,
            adv_lr=self.config.adv_lr,
            adv_eps=self.config.adv_eps,
            start_epoch=self.config.adv_start_epoch,
            scaler=scaler,
        )

        am = AverageMeter()
        samplewise_losses = []
        start = end = time.time()
        global_step = 0
        for step, (inputs, labels) in enumerate(tr_dl):
            for k, v in inputs.items():
                inputs[k] = v.to(self.device)
            labels = labels.to(self.device)
            batch_size = labels.size(0)
            with cuda.amp.autocast(enabled=self.config.apex):
                y_preds = model(inputs)

            loss, sl = self.compute_loss(
                y_preds=y_preds, labels=labels, batch_size=batch_size, loss_th=loss_th
            )
            samplewise_losses += sl
            if self.config.gradient_accumulation_steps > 1:
                loss = loss / self.config.gradient_accumulation_steps
            am.update(loss.item(), batch_size)
            scaler.scale(loss).backward()
            awp.attack_backward(inputs, labels, epoch)
            grad_norm = torch.nn.utils.clip_grad_norm_(
                model.parameters(), self.config.max_grad_norm
            )
            if (step + 1) % self.config.gradient_accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                global_step += 1
                if self.config.batch_scheduler:
                    scheduler.step()
            if step % self.config.print_freq == 0 or step == (len(tr_dl) - 1):
                print(
                    "Epoch: [{0}][{1}/{2}] "
                    "Elapsed {remain:s} "
                    "Loss: {loss.val:.4f}({loss.avg:.4f}) "
                    "Grad: {grad_norm:.4f}  "
                    "LR: {lr:.8f}  ".format(
                        epoch + 1,
                        step,
                        len(tr_dl),
                        remain=timeSince(start, float(step + 1) / len(tr_dl)),
                        loss=am,
                        grad_norm=grad_norm,
                        lr=scheduler.get_lr()[0],
                    )
                )
            wandb.log(
                {
                    f"[fold{fold}] loss": am.val,
                    f"[fold{fold}] lr": scheduler.get_lr()[0],
                }
            )

            if (step + 1) % self.config.n_eval_steps == 0:
                model.eval()
                avg_vl_loss, predictions = self.infer(model, vl_dl, n_vl)
                score, best_th = self.evaluate(
                    predictions=predictions,
                    valid_texts=valid_texts,
                    valid_labels=valid_labels,
                    th_range=self.config.th_range,
                    th_step=self.config.th_step,
                )
                self.log(
                    {
                        f"[fold{fold}] epoch": epoch + 1,
                        f"[fold{fold}] step": step,
                        f"[fold{fold}] avg_val_loss": avg_vl_loss,
                        f"[fold{fold}] score": score,
                        f"[fold{fold}] best_th": best_th,
                    }
                )
                if score > best_score:
                    best_score = score
                    self.save_ckpt(fold=fold, model=model, predictions=predictions)
                model.train()

        return am.avg, best_score, samplewise_losses

    @inference_mode()
    def infer(self, model: Module, vl_dl: DataLoader, n_vl: int) -> tuple:

        model.eval()
        losses = AverageMeter()
        preds = []
        for inputs, labels in vl_dl:
            for k, v in inputs.items():
                inputs[k] = v.to(self.device)
            labels = labels.to(self.device)
            batch_size = labels.size(0)
            y_preds = model(inputs)
            loss = self.criterion(y_preds.view(-1, 1), labels.view(-1, 1))
            loss = torch.masked_select(loss, labels.view(-1, 1) != -1).mean()
            if self.config.gradient_accumulation_steps > 1:
                loss = loss / self.config.gradient_accumulation_steps
            losses.update(loss.item(), batch_size)
            preds.append(y_preds.sigmoid().to("cpu").numpy())
        predictions = np.concatenate(preds).reshape((n_vl, self.config.max_len))

        return losses.avg, predictions

    def create_dl(
        self,
        df: DataFrame,
        feature_text_max_len: int,
        pn_history_max_len: int,
        is_train: bool,
        seed: int,
    ) -> DataLoader:

        ds = TrainDataset(
            tokenizer=self.tokenizer,
            max_len=self.config.max_len,
            feature_text_max_len=feature_text_max_len,
            pn_history_max_len=pn_history_max_len,
            df=df,
        )
        g = torch.Generator()
        g.manual_seed(seed + int(is_train))

        return DataLoader(
            ds,
            batch_size=self.config.batch_size
            if is_train
            else self.config.batch_size * 2,
            shuffle=is_train,
            num_workers=self.config.num_workers,
            pin_memory=True,
            drop_last=is_train,
            generator=g,
        )

    def evaluate(
        self,
        predictions: ndarray,
        valid_texts: tuple,
        valid_labels: ndarray,
        th_range: list,
        th_step: float = 0.005,
    ) -> tuple:

        char_probs = get_char_probs(valid_texts, predictions, self.tokenizer)
        best_score = -100
        for th in np.arange(th_range[0], th_range[1], th_step):
            th = np.round(th, 4)
            results = get_results(char_probs, valid_texts, th=th)
            preds = get_predictions(results)
            score = get_score(valid_labels, preds)
            if best_score < score:
                best_th = th
                best_score = score
        return best_score, best_th

    def save_ckpt(
        self, fold: int, model: Module, predictions: ndarray, epoch: int = None
    ) -> None:

        if epoch is None:
            ep_suffix = ""
        else:
            ep_suffix = f"_epoch{epoch}"

        torch.save(
            {"model": model.state_dict(), "predictions": predictions},
            f"{self.config.output_dir}{self.config.ckpt_name}_fold{fold}{ep_suffix}_best.pth",
        )
        self.logger.info("model has been saved.")

    def compute_loss_th(self, samplewise_losses: list, fold: int, epoch: int) -> float:

        mu_loss = np.mean(samplewise_losses)
        std_loss = np.std(samplewise_losses)
        loss_th = mu_loss + std_loss * self.config.n_loss_removal_std
        joblib.dump(
            value=samplewise_losses,
            filename=f"samplewise_losses_f{fold}_e{epoch}.pkl",
            compress=3,
        )

        return loss_th

    def run(
        self,
        df: DataFrame,
        pl_df: DataFrame,
        feature_text_max_len: int,
        pn_history_max_len: int,
    ) -> None:

        oof_df = pd.DataFrame()
        for f in range(self.config.n_folds):

            self.logger.info(f"========== fold: {f} training ==========")

            model = CustomModel(
                model_name=self.config.model, config_path=None, pretrained=True
            ).to(self.device)

            tr_df = df[df["fold"] != f].reset_index(drop=True)
            if self.config.pl_frac > 0.0:
                tr_pl_df = pl_df.loc[pl_df["fold"] == f].sample(
                    frac=self.config.pl_frac, random_state=self.config.seed + 1
                )
                self.logger.info(f"{len(tr_pl_df)} pseudo labeled data was sampled.")
                tr_df = pd.concat((tr_df, tr_pl_df)).sample(
                    frac=1.0, random_state=self.config.seed
                )
            tr_dl = self.create_dl(
                df=tr_df,
                feature_text_max_len=feature_text_max_len,
                pn_history_max_len=pn_history_max_len,
                is_train=True,
                seed=self.config.seed,
            )
            num_train_steps = int(
                len(tr_df) / self.config.batch_size * self.config.epochs
            )

            vl_df = df[df["fold"] == f].reset_index(drop=True)
            vl_dl = self.create_dl(
                df=vl_df,
                feature_text_max_len=feature_text_max_len,
                pn_history_max_len=pn_history_max_len,
                is_train=False,
                seed=self.config.seed,
            )
            valid_texts = vl_df["pn_history"].to_numpy()
            valid_labels = create_labels_for_scoring(vl_df)

            optimizer_parameters = get_optimizer_params(
                model,
                encoder_lr=self.config.encoder_lr,
                decoder_lr=self.config.decoder_lr,
                weight_decay=self.config.weight_decay,
            )
            optimizer = AdamW(
                optimizer_parameters,
                lr=self.config.encoder_lr,
                eps=self.config.eps,
                betas=self.config.betas,
            )
            scheduler = get_scheduler(
                scheduler=self.config.scheduler,
                optimizer=optimizer,
                num_warmup_steps=self.config.num_warmup_steps,
                num_train_steps=num_train_steps,
                num_cycles=self.config.num_cycles,
            )

            best_score = 0
            for epoch in range(self.config.epochs):

                if epoch == self.config.loss_removal_start_ep:
                    loss_th = self.compute_loss_th(
                        samplewise_losses=samplewise_losses, fold=f, epoch=epoch
                    )
                    self.logger.info(f"Loss th: {loss_th}")

                (
                    avg_tr_loss,
                    stepwise_best_score,
                    samplewise_losses,
                ) = self.train_with_eval(
                    model,
                    f,
                    (tr_dl, vl_dl),
                    optimizer,
                    epoch,
                    scheduler,
                    loss_th if epoch >= self.config.loss_removal_start_ep else None,
                    valid_texts,
                    valid_labels,
                    len(vl_df),
                    best_score,
                )

                avg_vl_loss, predictions = self.infer(model, vl_dl, len(vl_df))
                score, best_th = self.evaluate(
                    predictions=predictions,
                    valid_texts=valid_texts,
                    valid_labels=valid_labels,
                    th_range=self.config.th_range,
                    th_step=self.config.th_step,
                )

                self.log(
                    {
                        f"[fold{f}] epoch": epoch + 1,
                        f"[fold{f}] avg_train_loss": avg_tr_loss,
                        f"[fold{f}] avg_val_loss": avg_vl_loss,
                        f"[fold{f}] score": score,
                        f"[fold{f}] best_th": best_th,
                    }
                )

                if score > stepwise_best_score:
                    best_score = score

                self.save_ckpt(
                    fold=f, model=model, predictions=predictions, epoch=epoch
                )

            predictions = torch.load(
                f"{self.config.output_dir}{self.config.ckpt_name}_fold{f}_best.pth",
                map_location=torch.device("cpu"),
            )["predictions"]
            vl_df[[i for i in range(self.config.max_len)]] = predictions
            oof_df = pd.concat([oof_df, vl_df])

            self.logger.info(f"========== fold: {f} result ==========")

            score, th = get_result(vl_df, self.tokenizer, self.config.max_len)
            self.log(
                {f"[fold{f}] overall score": score, f"[fold{f}] overall best th": th}
            )
            oof_df.to_pickle(f"{self.config.output_dir}oof_df_fold{f}.pkl")

        oof_df = oof_df.reset_index(drop=True)
        self.logger.info(f"========== CV ==========")
        score, th = get_result(oof_df, self.tokenizer, self.config.max_len)
        self.log({f"overall score": score, f"overall best th": th})
        oof_df.to_pickle(self.config.output_dir + "oof_df.pkl")

In [None]:
trainer = Trainer(
    config=config,
    tokenizer=tokenizer,
    logger=logger)

trainer.run(
    df=df_train,
    pl_df=pl_train,
    feature_text_max_len=feature_text_max_len, 
    pn_history_max_len=pn_history_max_len)

In [None]:
wandb.finish()