https://www.kaggle.com/code/columbia2131/training-inference-code-xlm-roberta-base

### import

In [None]:
import numpy as np 
import pandas as pd 
import sys 
import os 
import logzero 
import wandb 
import pickle 
from tqdm.auto import tqdm
import matplotlib.pyplot as plt 
import seaborn as sns 
sns.set()

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torch.cuda.amp import autocast, GradScaler
from pytorch_lightning.utilities.seed import seed_everything


### config

In [None]:

class Config():
    # common
    version = '008'
    comment = 'test'
    input_dir = '/home/user/work/input/we-are-all-alike-on-the-inside'
    output_dir = f'/home/user/work/output/{version}' 
    seed = 42
    debug = False 
    target_col = None 

    # wandb
    wandb_init = {
        "project": "debug",
        "entity": "kuto5046",
        "group": f"exp{version}",
        "dir": output_dir,
        "tags": [],
        "mode": "disabled", 
    }

    # cv 
    n_splits = 5
    use_fold = [0]  # fold1つで終える場合[0], 全てのfoldを実行する場合[0,1,2,3,4]

    # dataloader
    loader_params = {
        "train": {'batch_size': 32, 'shuffle': True, 'num_workers': 4},
        "valid": {'batch_size': 32, 'shuffle': False, 'num_workers': 4},
        "test": {'batch_size': 32, 'shuffle': False, 'num_workers': 4} 
        }

    # model
    # res
    resume_checkpoint_path = None #f"{output_dir}/model_fold0_epoch=0.ckpt"  # resume用
    # pretrained_model_path = f"{output_dir}/model_fold0_epoch=0.ckpt"  # 予測のみ用 
    n_epochs = 1
    model_name = 'xlm-roberta-base'
    max_len = 128
    weight_decay = 1e-3
    beta = (0.9, 0.98)
    lr = 3e-5
    num_warmup_steps_rate = 0
    gradient_accumulation_steps = 1  # 1なら累積しない

c = Config()
DEBUG = c.debug 
# c = HydraConfig.get_cnf(config_path='/home/user/work/configs/', config_name='config.yaml')
os.makedirs(c.output_dir, exist_ok=True)
logger = logzero.setup_logger(name='main', logfile=f'{c.output_dir}/result.log', level=10)

In [None]:
c.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
c.device

In [None]:
seed_everything(c.seed)

### read data

In [None]:
train = pd.read_csv(f'{c.input_dir}/train.csv')
test = pd.read_csv(f'{c.input_dir}/test.csv')
train.shape, test.shape

In [None]:
train.head()

### preprocess

In [None]:
from src.features.base import get_categorical_col, get_numerical_col
from src.features.encoder import pp_for_categorical_encoding

In [None]:
whole = pd.concat([train, test]).reset_index(drop=True)

In [None]:
c.target_col = 'category'
c.n_class = 3
c.target_map = {'association': 0, 'disagreement': 1, 'unbiased': 2}
c.target_map_rev = {0: 'association', 1: 'disagreement', 2: 'unbiased'}

In [None]:
get_categorical_col(whole)

In [None]:
get_numerical_col(whole)

In [None]:
# import ast 
# def fix_s1s2(data):
#     new_s1 = []
#     new_s2 = []
#     for idx, row in tqdm(data.iterrows(), total=len(data)):
#         if row["s1"].startswith("["):
#             try:
#                 temp_s1 = " ".join(ast.literal_eval(row["s1"]))
#             except SyntaxError:
#                 temp_s1 = row["s1"][1:-1]
#         else:
#             temp_s1 = row["s1"]

#         if row["s2"].startswith("["):
#             try:
#                 temp_s2 = " ".join(ast.literal_eval(row["s2"]))
#             except SyntaxError:
#                 temp_s2 = row["s2"][1:-1]
#         else:
#             temp_s2 = row["s2"]

#         new_s1.append(temp_s1)
#         new_s2.append(temp_s2)
#     data["s1"] = new_s1
#     data["s2"] = new_s2
#     return data

In [None]:
whole['s1'] = whole['s1'].map(
    lambda x: x\
        .replace("['", '')\
        .replace("']", '')\
        .replace('["', '')\
        .replace('"]', '')\
        .replace('[«', '«')
        .replace('»]', '»')\
        .replace('[', '')\
        .replace(']', '')\
        .split("', '")
)


whole['s1'] = whole['s1'].map(lambda x: ' '.join(x))

In [None]:
train = whole[~whole[c.target_col].isna()].reset_index(drop=True)
test = whole[whole[c.target_col].isna()].reset_index(drop=True)

In [None]:
# labelを数値に変換
train[c.target_col] = train[c.target_col].map(c.target_map)

### model

In [None]:
import transformers
from transformers import AutoModel, AutoTokenizer
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW 

In [None]:
class CustomDataset(Dataset):
    def __init__(self, df: pd.DataFrame, config: Config, phase: str='train'):
        assert phase in ['train', 'valid', 'test']
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        self.phase = phase
        self.s1 = df['s1'].to_numpy()
        self.s2 = df['s2'].to_numpy()
        self.y = np.full(len(df), np.nan)
        if self.phase in ['train', 'valid']:
            self.y = df[config.target_col].to_numpy()
    
    def __len__(self):
        return self.s1.shape[0]

    def __getitem__(self, idx):
        # GET TEXT AND WORD LABELS 
        inputs1 = self.tokenizer.encode_plus(
            self.s1[idx],
            self.s2[idx],
            add_special_tokens=True,
            max_length=self.config.max_len, 
            padding='max_length',
            truncation=True, 
            # return_attention_mask=True,
        )
        
        x = {
            'token1': torch.tensor(inputs1['input_ids'], dtype=torch.long),
            'mask1': torch.tensor(inputs1['attention_mask'], dtype=torch.long),
        }
        return x, self.y[idx]

In [None]:

class CustomModel(nn.Module):
    def __init__(self, model_name, n_class):
        super(CustomModel, self).__init__()
        self.backbone = AutoModel.from_pretrained(model_name)
        self.ln = nn.LayerNorm(768)
        self.linear1 = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, n_class) 
        )

    # @torch.autocast()
    def forward(self, x):
        output = self.backbone(x['token1'], attention_mask=x['mask1'])["last_hidden_state"][:, 0, :]
        output = self.ln(output)
        output = self.linear1(output)        
        return output

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from torchmetrics import F1Score

class CustomTask(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.criterion = self.get_criterion(config)
        self.optimizer = self.get_optimizer(config)
        self.scheduler = self.get_scheduler(config)
        self.metric = self.get_metric(config)


    def _calculate_loss(self, batch, mode="train"):
        x, y = batch
        output = self.model(x)
        loss = self.criterion(output, y)
        score = self.metric(output, y) 

        self.log(f'Loss/{mode}', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Score/{mode}', score, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return loss 


    def training_step(self, batch, batch_idx):
        return self._calculate_loss(batch, mode="train")
    

    def validation_step(self, batch, batch_idx):
        return self._calculate_loss(batch, mode="valid")

        
    def configure_optimizers(self):
        return {"optimizer": self.optimizer, "lr_scheduler": self.scheduler, "monitor": "Loss/valid"}


    def get_metric(self, config):
        return F1Score(average='micro')


    def get_optimizer(self, config: dict):

        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']  # このパラメータはweight decayしない
        optimizer_grouped_parameters = [
                {
                    'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 
                    'weight_decay': config.weight_decay
                },
                {
                    'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 
                    'weight_decay': 0.0
                }
            ]

        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr = config.lr,
            betas = config.beta,
            weight_decay = config.weight_decay,
            )
        return optimizer

    def get_scheduler(self, config: dict):
        num_train_optimization_steps = int(
            config.len_loader * config.n_epochs // config.gradient_accumulation_steps
        )
        num_warmup_steps = int(num_train_optimization_steps * config.num_warmup_steps_rate)
        
        scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_train_optimization_steps
        )
        return scheduler 


    def get_criterion(self, config: dict):
        criterion = nn.CrossEntropyLoss()
        return criterion

### cv

In [None]:
from src.cv import get_kfold, get_stratifiedkfold, get_groupkfold
cv = get_stratifiedkfold(train, c.target_col, n_splits=5)
cv

### train

In [None]:
from sklearn.metrics import f1_score, roc_auc_score
def calc_score(true, pred):
    return f1_score(true, pred.argmax(axis=1), average='micro')

In [None]:
def apply_device_to_dict(_dict, device):
    for k, v in _dict.items():
        _dict[k] = v.to(device)
    return _dict 

In [None]:
def to_np(input):
    return input.detach().cpu().numpy()

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

def inference(model, loader, device):
    model.eval()
    model.to(device)
    pred = []
    with torch.no_grad():
        # https://github.com/tqdm/tqdm/issues/746
        for batch in tqdm(loader, total=len(loader)):
            with torch.autocast(device_type=device.type):
                x, y = batch
                x = apply_device_to_dict(x, device)
                output = model(x)
                pred.append(to_np(output))
    return np.concatenate(pred)

In [None]:
def train_pipeline(train, test, cv, config, target_col):
    # 関数で実行するとdebugしにくいのでそのまま実行する
    for i, (idx_train, idx_valid) in enumerate(cv):
        if i not in c.use_fold:
            continue 

        wandb.init(**config.wandb_init, name=f'exp{config.version}-fold{i}')

        _train = train.loc[idx_train].reset_index(drop=True)
        _valid = train.loc[idx_valid].reset_index(drop=True)

        loaders = {}
        loaders["train"] = DataLoader(CustomDataset(_train, config, phase="train"), **config.loader_params['train'], worker_init_fn=worker_init_fn) 
        loaders["valid"] = DataLoader(CustomDataset(_valid, config, phase="valid"), **config.loader_params['valid'], worker_init_fn=worker_init_fn)
        loaders["test"] = DataLoader(CustomDataset(test, config, phase="test"), **config.loader_params['test'], worker_init_fn=worker_init_fn)

        c.len_loader = len(loaders['train'])

        model = CustomModel(c.model_name, c.n_class)
        task = CustomTask(model, c)

        # callback 
        checkpoint_callback = ModelCheckpoint(
            monitor=f'Score/valid',
            mode='max',
            dirpath=c.output_dir,
            verbose=True,
            filename=f'model_fold{i}_' + '{epoch}')  # pl内部のepochを読む

        early_stop_callback = EarlyStopping(
            monitor='Loss/valid',
            min_delta=0.00,
            patience=3,
            verbose=True,
            mode='min')

        trainer = pl.Trainer(
            logger=[WandbLogger()], 
            callbacks=[checkpoint_callback, early_stop_callback],
            max_epochs=c.n_epochs,
            devices='auto',
            accelerator='auto',
            fast_dev_run=DEBUG,
            deterministic=True,
            precision=16,
            )

        print('start train')
        # if os.path.exists(c.pretrained_model_path):
        #     logger.info(f'load pretrained model {c.pretrained_model_path} and skip train')
        #     checkpoint = torch.load(c.pretrained_model_path)
        #     model.load_state_dict(checkpoint['state_dict'], strict=False)
        # else:
        trainer.fit(task, train_dataloaders=loaders['train'], val_dataloaders=loaders['valid'], ckpt_path=c.resume_checkpoint_path) # resumeする場合ここにcheckpointを渡す
        if not DEBUG:
            best_checkpoint = checkpoint_callback.best_model_path
            logger.info(f'load best model {best_checkpoint}')
            checkpoint = torch.load(best_checkpoint)
            model.load_state_dict(checkpoint["state_dict"], strict=False)
            config.best_checkpoint = best_checkpoint 

        print('create oof')
        pred = inference(model, loaders['valid'], c.device)
        oof = pd.DataFrame(pred, index=idx_valid)
        oof.to_csv(f"{c.output_dir}/oof_{i}.csv", index=True) # もとの並びでconcatするときにindexが必要

        # evaluate
        print('evaluate valid data')
        score = calc_score(_valid[c.target_col], pred)
        logger.info(f'fold-{i} score: {score}')
        wandb.log({'CV': score})

        # pred
        print('inference test data')
        pred_test = inference(model, loaders['test'], c.device)
        np.save(f"{c.output_dir}/pred_test_{i}", pred_test)

        # if i != c.use_fold[-1]:
        #     wandb.finish()

In [None]:
train_pipeline(train, test, cv, c, c.target_col)

In [None]:
c.best_checkpoint = f'{c.output_dir}/model_fold0_epoch=0.ckpt'
model = CustomModel(c.model_name, c.n_class)
checkpoint = torch.load(c.best_checkpoint)
model.load_state_dict(checkpoint["state_dict"], strict=False)

### inference

In [None]:
for i, (idx_train, idx_valid) in enumerate(cv):
    break 

_valid = train.loc[idx_valid].reset_index(drop=True)

loaders = {}
loaders["valid"] = DataLoader(CustomDataset(_valid, c, phase="valid"), **c.loader_params['valid'], worker_init_fn=worker_init_fn)
loaders["test"] = DataLoader(CustomDataset(test, c, phase="test"), **c.loader_params['test'], worker_init_fn=worker_init_fn)
pred_valid = inference(model, loaders['valid'], c.device)
y_valid = _valid[c.target_col].to_numpy()
calc_score(y_valid, pred_valid)

In [None]:
preds = []
for i in range(len(cv)):
    # TODO 存在していればに変更
    pred = np.load(f'{c.output_dir}/pred_test_{i}.npy')
    preds.append(pred)
pred_test = np.mean(preds, axis=0).argmax(axis=1)

In [None]:
sns.distplot(train[c.target_col], label='target')
sns.distplot(pred_test, label='test')
plt.legend();

### submission

In [None]:
sub = pd.read_csv(f'{c.input_dir}/sample_submission.csv')
sub[c.target_col] = pred_test
sub.to_csv(f'{c.output_dir}/submission_exp{c.version}.csv', index=False)