In [15]:
# pretraining
DEBUG = False

In [16]:
EXP = 'PLpret1'

In [17]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl

import pandas as pd
import numpy as np
import random
import re
import itertools
import argparse

from torch.utils.data import Dataset
import spacy
import ast
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold

import pickle
import transformers
from transformers import AutoConfig, AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
import warnings
from torch.optim import Adam, SGD, AdamW

import wandb
from pytorch_lightning.loggers import WandbLogger

In [18]:
pl.seed_everything(42, workers=True)

Global seed set to 42


42

In [19]:
class MyDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=1024, stage='train', rand_prob=0.1, lowup_proba=0.0, swap_proba=0.0):
        self.df = df
        self.max_len = max_len
        self.tokenizer = tokenizer
        self.mask_token = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
        self.stage = stage
        self.rand_prob = rand_prob
        self.lowup_proba = lowup_proba
        self.swap_proba = swap_proba

        self.essay_id = df['essay_id'].values
        self.input_ids = df['input_ids'].values
        self.attention_mask = df['attention_mask'].values
        self.token_class_labels = df['token_class_labels'].values
        self.token_scores_labels = df['token_scores_labels'].values
        self.token_examples_mapping = df['token_examples_mapping'].values
        self.examples_scores = df['examples_scores'].values
        self.examples_classes = df['examples_classes'].values    
        
    def __getitem__(self, idx):
        essay_id = self.essay_id[idx]

        token_examples_mapping = self.token_examples_mapping[idx]
        examples_scores = self.examples_scores[idx]
        examples_classes = self.examples_classes[idx]

#         token_examples_mapping = torch.tensor(token_examples_mapping, dtype=torch.long)
#         examples_scores = torch.tensor(examples_scores + [-1] * (40 - len(examples_scores)), dtype=torch.long)
#         examples_classes = torch.tensor(examples_classes + [-1] * (40 - len(examples_classes)), dtype=torch.long)
        
        input_ids = self.input_ids[idx]
        attention_mask = self.attention_mask[idx]
        token_class_labels = self.token_class_labels[idx]
        token_scores_labels = self.token_scores_labels[idx]

#         input_ids = torch.tensor(input_ids, dtype=torch.long)
#         attention_mask = torch.tensor(attention_mask, dtype=torch.long)
#         token_class_labels = torch.tensor(token_class_labels, dtype=torch.long)
#         token_scores_labels = torch.tensor(token_scores_labels, dtype=torch.long)

#         if self.stage == 'train':
#             ix = torch.rand(size=(len(input_ids),)) < self.rand_prob
#             input_ids[ix] = self.mask_token
        
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_class_labels": token_class_labels,
            "token_scores_labels": token_scores_labels,
            "token_examples_mapping": token_examples_mapping,
            "examples_scores": examples_scores,
            "examples_classes": examples_classes
        }

    def __len__(self):
        return len(self.df)

In [20]:
emb_dim = 32
pooler_dropout = 0.1

In [21]:
from transformers.activations import GELUActivation

class Pooler(nn.Module):
    def __init__(self, hidden_size, dropout):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.act = GELUActivation()
        self.dropout = nn.Dropout(dropout)

    def forward(self, features):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.

        features = self.dropout(features)
        features = self.dense(features)
        features = self.act(features)
        return features

In [22]:
class MyModule(pl.LightningModule):
    def __init__(self, lr, model_checkpoint, num_classes, num_classes_class, emb_dim):
        super().__init__()
        self.lr = lr
        self.num_classes = num_classes
        self.num_classes_class = num_classes_class
        self.emb_dim = emb_dim
        self.name = model_checkpoint
        self.pad_idx = 1 if "roberta" in self.name else 0
        config = AutoConfig.from_pretrained(model_checkpoint, output_hidden_states=True)
        self.longformer = AutoModel.from_pretrained(model_checkpoint, config=config)
        self.nb_features = config.hidden_size
        self.logits = nn.Linear(self.nb_features, num_classes)
        self.pooler_class = Pooler(self.nb_features, pooler_dropout)
        self.pooler_scores = Pooler(self.nb_features, pooler_dropout)
        self.pooler = Pooler(self.nb_features + self.emb_dim, pooler_dropout)
        self.example_logits = nn.Linear(self.nb_features + self.emb_dim, num_classes)
        self.class_logits = nn.Linear(self.nb_features, num_classes_class)  
        transformers.logging.set_verbosity_error()
        self.embedding = nn.Embedding(num_classes_class, emb_dim, max_norm=True)
    
    def load_model(self, path):
        self.load_state_dict(torch.load(path, map_location='cuda:0'), strict=False)
        print('Model Loaded!')
    
    def configure_optimizers(self):
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_parameters = [
            {'params': [p for n, p in self.named_parameters() if (not any(nd in n for nd in no_decay)) and \
                       ('deberta' in n)],
             'lr': self.lr, 'weight_decay': 0.01},
            {'params': [p for n, p in self.named_parameters() if (any(nd in n for nd in no_decay)) and \
                       ('deberta' in n)],
             'lr': self.lr, 'weight_decay': 0.0},
            {'params': [p for n, p in self.named_parameters() if (not any(nd in n for nd in no_decay)) and \
                       ('deberta' not in n)],
             'lr': self.lr * 10, 'weight_decay': 0.01},
            {'params': [p for n, p in self.named_parameters() if (any(nd in n for nd in no_decay)) and \
                       ('deberta' not in n)],
             'lr': self.lr * 10, 'weight_decay': 0.0},
        ]
        optimizer = AdamW(optimizer_parameters, lr=self.lr)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=100,
            num_training_steps=self.trainer.estimated_stepping_batches,
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]
                
    def training_step(self, train_batch, batch_idx):
        input_ids, attention_mask, token_scores_labels, token_examples_mapping, \
        examples_scores, examples_classes, token_class_labels = \
            train_batch["input_ids"], train_batch["attention_mask"], train_batch["token_scores_labels"], \
            train_batch['token_examples_mapping'], train_batch['examples_scores'], \
            train_batch['examples_classes'], train_batch['token_class_labels']
        
        hidden_states = self.longformer(
            input_ids,
            attention_mask=attention_mask,
        )[-1]
        features = hidden_states[-1]
        class_logits = self.class_logits(self.pooler_class(features))
        class_loss = F.cross_entropy(class_logits.view(-1, self.num_classes_class), token_class_labels.view(-1))
        self.log('train_classes_loss', class_loss)
        return class_loss
        
    def validation_step(self, val_batch, batch_idx):
        input_ids, attention_mask, token_scores_labels, token_examples_mapping, \
        examples_scores, examples_classes, token_class_labels = \
            val_batch["input_ids"], val_batch["attention_mask"], val_batch["token_scores_labels"], \
            val_batch['token_examples_mapping'], val_batch['examples_scores'], val_batch['examples_classes'], \
            val_batch['token_class_labels']
        hidden_states = self.longformer(
            input_ids,
            attention_mask=attention_mask,
        )[-1]
        features = hidden_states[-1]
        class_logits = self.class_logits(features)
        class_loss = F.cross_entropy(class_logits.view(-1, self.num_classes_class), token_class_labels.view(-1))
        self.log('val_class_loss', class_loss)
        return {"val_class_loss": class_loss}   
        

In [23]:
class Collate:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        output = dict()
        output["input_ids"] = [sample["input_ids"] for sample in batch]
        output["attention_mask"] = [sample["attention_mask"] for sample in batch]
        output["token_class_labels"] = [sample["token_class_labels"] for sample in batch]
        output["token_scores_labels"] = [sample["token_scores_labels"] for sample in batch]
        output["token_examples_mapping"] = [sample["token_examples_mapping"] for sample in batch]
        output["examples_scores"] = [sample["examples_scores"] for sample in batch]
        output["examples_classes"] = [sample["examples_classes"] for sample in batch]

        # calculate max token length of this batch
        batch_max = max([len(ids) for ids in output["input_ids"]])
        batch_max_ex = max([len(sco) for sco in output["examples_scores"]])

        # add padding
        if self.tokenizer.padding_side == "right":
            output["input_ids"] = [s + (batch_max - len(s)) * [self.tokenizer.pad_token_id] for s in output["input_ids"]]
            output["attention_mask"] = [s + (batch_max - len(s)) * [0] for s in output["attention_mask"]]
            output["token_class_labels"] = [s + (batch_max - len(s)) * [-100] for s in output["token_class_labels"]]
            output["token_scores_labels"] = [s + (batch_max - len(s)) * [-100] for s in output["token_scores_labels"]]
            output["token_examples_mapping"] = [s + (batch_max - len(s)) * [-1] for s in output["token_examples_mapping"]]
            output["examples_scores"] = [s + (batch_max_ex - len(s)) * [-1] for s in output["examples_scores"]]
            output["examples_classes"] = [s + (batch_max_ex - len(s)) * [-1] for s in output["examples_classes"]]

        else:
            output["input_ids"] = [(batch_max - len(s)) * [self.tokenizer.pad_token_id] + s for s in output["input_ids"]]
            output["attention_mask"] = [(batch_max - len(s)) * [0] + s for s in output["attention_mask"]]
            output["token_class_labels"] = [(batch_max - len(s)) * [-100] + s for s in output["token_class_labels"]]
            output["token_scores_labels"] = [(batch_max - len(s)) * [-100] + s for s in output["token_scores_labels"]]
            output["token_examples_mapping"] = [(batch_max - len(s)) * [-1] + s for s in output["token_examples_mapping"]]
            output["examples_scores"] = [(batch_max_ex - len(s)) * [-1] + s for s in output["examples_scores"]]
            output["examples_classes"] = [(batch_max_ex - len(s)) * [-1] + s for s in output["examples_classes"]]

        # convert to tensors
        output["input_ids"] = torch.tensor(output["input_ids"], dtype=torch.long)
        output["attention_mask"] = torch.tensor(output["attention_mask"], dtype=torch.long)
        output["token_class_labels"] = torch.tensor(output["token_class_labels"], dtype=torch.long)
        output["token_scores_labels"] = torch.tensor(output["token_scores_labels"], dtype=torch.long)
        output["token_examples_mapping"] = torch.tensor(output["token_examples_mapping"], dtype=torch.long)
        output["examples_scores"] = torch.tensor(output["examples_scores"], dtype=torch.long)
        output["examples_classes"] = torch.tensor(output["examples_classes"], dtype=torch.long)

        return output


In [24]:
from pytorch_lightning.callbacks import ModelCheckpoint

model_checkpoint = 'microsoft/deberta-v3-large'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
collate_fn = Collate(tokenizer)
bs = 2
project = 'fbck'
seed = 42
randmask_proba = 0.1
lr = 3e-6
epochs = 1 if DEBUG else 2
num_classes = 3
num_classes_class = 8

In [25]:
import pickle
with open('pretraining-deberta-v3-large-nbroad.pickle', 'rb') as handle:
    pdf = pickle.load(handle)

In [26]:
if DEBUG: pdf = pdf.sample(n=200, random_state=4)

In [27]:
def pretrain(fold=0):
    
    print()
    print('*' * 100)
    print(f'Training fold {fold}')
    print('*' * 100)
    
    df_train = pdf[pdf.fold == -1].reset_index(drop=True)
    df_valid = pdf[pdf.fold != -1].reset_index(drop=True)

    tags = ['debug'] if DEBUG else ['PL', 'train', f'fold_{fold}']

    run = wandb.init(project=project, 
                       name=f"{EXP}_fold_{fold}",
                       tags=tags,
                       group=f"{EXP}")
    run.log_code()

    pl.seed_everything(seed)


    train_dataset = MyDataset(
        df_train,
        tokenizer,
        stage='train',
        rand_prob=randmask_proba
    )

    valid_dataset = MyDataset(
        df_valid,
        tokenizer,
        stage='valid',
        rand_prob=randmask_proba
    )

    train_loader = DataLoader(train_dataset,
                              batch_size=bs,
                              shuffle=True,
                              collate_fn=collate_fn,
                              num_workers=4, pin_memory=True, drop_last=True)

    val_loader = DataLoader(valid_dataset,
                              batch_size=1,
                              shuffle=False,
                              collate_fn=collate_fn,
                              num_workers=4, pin_memory=True, drop_last=False)

    warnings.filterwarnings("ignore")

    model = MyModule(lr=lr,
                     model_checkpoint=model_checkpoint, 
                     num_classes=num_classes,
                     num_classes_class=num_classes_class,
                     emb_dim=emb_dim,
                    )

    wandb_logger = WandbLogger(project=project)

    checkpoint_callback = ModelCheckpoint(
        save_top_k=1,
        monitor="val_class_loss",
        mode="min",
        dirpath=f"../output/{EXP}/{fold}",
        filename="feedback-{epoch:02d}-{val_class_loss:.2f}",
    )

    trainer = pl.Trainer(precision=16, 
                         accelerator="gpu", devices=1, max_epochs=epochs,
                         log_every_n_steps=100, logger=wandb_logger,
                         default_root_dir=f"../output/{EXP}",
                         callbacks=[checkpoint_callback],
                         accumulate_grad_batches=2,
                         )

    trainer.fit(model, train_loader, val_loader)

    wandb.finish()

In [None]:
pretrain()


****************************************************************************************************
Training fold 0
****************************************************************************************************


Global seed set to 42
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name           | Type           | Params
--------------------------------------------------
0 | longformer     | DebertaV2Model | 434 M 
1 | logits         | Linear         | 3.1 K 
2 | pooler_class   | Pooler         | 1.0 M 
3 | pooler_scores  | Pooler         | 1.0 M 
4 | pooler         | Pooler         | 1.1 M 
5 | example_logits | Linear         | 3.2 K 
6 | class_logits   | Linear         | 8.2 K 
7 | embedding      | Embedding      | 256   
--------------------------------------------------
437 M     Trainable params
0         Non-trainable params
437 M     Total params
874.485   Total estimated model params size (MB)


Epoch 0:  58%|███████████████▌           | 5701/9892 [24:18<17:51,  3.91it/s, loss=0.265, v_num=nk27]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                           | 0/4191 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                              | 0/4191 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                      | 1/4191 [00:00<15:14,  4.58it/s][A
Epoch 0:  58%|███████████████▌           | 5702/9892 [24:18<17:52,  3.91it/s, loss=0.265, v_num=nk27][A
Validation DataLoader 0:   0%|                                      | 2/4191 [00:00<10:29,  6.65it/s][A
Epoch 0:  58%|███████████████▌           | 5703/9892 [24:18<17:51,  3.91it/s, loss=0.265, v_num=nk27][A
Epoch 0:  58%|███████████████▌           | 5704/9892 [24:19<17:51,  3.91it/s, loss=0.265, v_num=nk27][A
Validation DataLoader 0:   0%|                                      | 4/4191 [00:00<10:50,  6.44it/s][A
Epoch 0:  58%|██████████