In [28]:
DEBUG = False

In [29]:
EXP = 'PL11'

In [30]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl

import pandas as pd
import numpy as np
import random
import re
import itertools
import argparse

from torch.utils.data import Dataset
import spacy
import ast
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold

import pickle
import transformers
from transformers import AutoConfig, AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
import warnings
from torch.optim import Adam, SGD, AdamW

import wandb
from pytorch_lightning.loggers import WandbLogger

In [31]:
pl.seed_everything(42, workers=True)

Global seed set to 42


42

In [32]:
class MyDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=1024, stage='train', rand_prob=0.1, lowup_proba=0.0, swap_proba=0.0):
        self.df = df
        self.max_len = max_len
        self.tokenizer = tokenizer
        self.mask_token = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
        self.stage = stage
        self.rand_prob = rand_prob
        self.lowup_proba = lowup_proba
        self.swap_proba = swap_proba

        self.essay_id = df['essay_id'].values
        self.input_ids = df['input_ids'].values
        self.attention_mask = df['attention_mask'].values
        self.offset_mapping = df['offset_mapping'].values
        self.token_class_labels = df['token_class_labels'].values
        self.token_scores_labels = df['token_scores_labels'].values
        self.token_examples_mapping = df['token_examples_mapping'].values
        self.examples_scores = df['examples_scores'].values
        self.examples_classes = df['examples_classes'].values    
        
    def __getitem__(self, idx):
        essay_id = self.essay_id[idx]
        offset_mapping = self.offset_mapping[idx]

        token_examples_mapping = self.token_examples_mapping[idx]
        examples_scores = self.examples_scores[idx]
        examples_classes = self.examples_classes[idx]

        token_examples_mapping = torch.tensor(token_examples_mapping, dtype=torch.long)
        examples_scores = torch.tensor(examples_scores + [-1] * (40 - len(examples_scores)), dtype=torch.long)
        examples_classes = torch.tensor(examples_classes + [-1] * (40 - len(examples_classes)), dtype=torch.long)
        
        input_ids = self.input_ids[idx]
        attention_mask = self.attention_mask[idx]
        token_class_labels = self.token_class_labels[idx]
        token_scores_labels = self.token_scores_labels[idx]

        input_ids = torch.tensor(input_ids, dtype=torch.long)
        attention_mask = torch.tensor(attention_mask, dtype=torch.long)
        token_class_labels = torch.tensor(token_class_labels, dtype=torch.long)
        token_scores_labels = torch.tensor(token_scores_labels, dtype=torch.long)

        if self.stage == 'train':
            ix = torch.rand(size=(self.max_len,)) < self.rand_prob
            input_ids[ix] = self.mask_token
        
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_class_labels": token_class_labels,
            "token_scores_labels": token_scores_labels,
            "token_examples_mapping": token_examples_mapping,
            "examples_scores": examples_scores,
            "examples_classes": examples_classes
        }

    def __len__(self):
        return len(self.df)

In [33]:
emb_dim = 64

In [34]:
from transformers import *
import torch.nn as nn
import torch.nn.functional as F

class ResidualLSTM(nn.Module):

    def __init__(self, d_model):
        super(ResidualLSTM, self).__init__()
        self.downsample=nn.Linear(d_model,d_model//2)
        self.LSTM=nn.GRU(d_model//2, d_model//2, num_layers=2, bidirectional=False, dropout=0.2)
        self.dropout1=nn.Dropout(0.2)
        self.norm1= nn.LayerNorm(d_model//2)
        self.linear1=nn.Linear(d_model//2, d_model*4)
        self.linear2=nn.Linear(d_model*4, d_model)
        self.dropout2=nn.Dropout(0.2)
        self.norm2= nn.LayerNorm(d_model)

    def forward(self, x):
        res=x
        x=self.downsample(x)
        x, _ = self.LSTM(x)
        x=self.dropout1(x)
        x=self.norm1(x)
        x=F.relu(self.linear1(x))
        x=self.linear2(x)
        x=self.dropout2(x)
        x=res+x
        return self.norm2(x)
    

class SlidingWindowTransformerModel(nn.Module):
#     def __init__(self, DOWNLOADED_MODEL_PATH, window_size=512):
    def __init__(self, checkpoint, window_size=512):

        super(SlidingWindowTransformerModel, self).__init__()
#         config_model = AutoConfig.from_pretrained(DOWNLOADED_MODEL_PATH+'/config.json')
        config_model = AutoConfig.from_pretrained(checkpoint)

#         self.backbone=AutoModel.from_pretrained(
#                            DOWNLOADED_MODEL_PATH+'/pytorch_model.bin',config=config_model)
        self.backbone=AutoModel.from_pretrained(checkpoint)
        self.nb_features = config_model.hidden_size

#         self.lstm=ResidualLSTM(self.nb_features)
        self.lstm=nn.GRU(self.nb_features, self.nb_features//2, batch_first=True, bidirectional=True)
        
#         self.classification_head=nn.Linear(1024,15)
        self.window_size=window_size

    def forward(self,x,attention_mask):

        B,L=x.shape

        if L>self.window_size:
            x=x.reshape(B,L//self.window_size,self.window_size).reshape(-1,self.window_size)
            attention_mask=attention_mask.reshape(B,L//self.window_size,self.window_size).reshape(-1,self.window_size)
            
        x=self.backbone(input_ids=x,attention_mask=attention_mask,return_dict=False)[0]
        x=x.reshape(B,L,-1)
        x,_=self.lstm(x.permute(1,0,2))
        x=x.permute(1,0,2)
#         x=self.lstm(x.permute(1,0,2)).permute(1,0,2)
#         x=self.classification_head(x)

        return x

In [35]:
betas = (0.9, 0.999)
eps = 1e-6

In [36]:
class MyModule(pl.LightningModule):
    def __init__(self, lr, model_checkpoint, num_classes, num_classes_class, emb_dim, betas, eps):
        super().__init__()
        self.lr = lr
        self.num_classes = num_classes
        self.num_classes_class = num_classes_class
        self.emb_dim = emb_dim
        self.name = model_checkpoint
        self.pad_idx = 1 if "roberta" in self.name else 0
        config = AutoConfig.from_pretrained(model_checkpoint, output_hidden_states=True)
        self.transformer = SlidingWindowTransformerModel(model_checkpoint)
        self.nb_features = config.hidden_size
        self.logits = nn.Linear(self.nb_features, num_classes)
        self.example_logits = nn.Linear(self.nb_features + self.emb_dim, num_classes)
        self.class_logits = nn.Linear(self.nb_features, num_classes_class)  
        transformers.logging.set_verbosity_error()
        self.embedding = nn.Embedding(num_classes_class, emb_dim, max_norm=True)
        self.betas = betas
        self.eps = eps
    
    def configure_optimizers(self):
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_parameters = [
            {'params': [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)],
             'lr': self.lr, 'weight_decay': 0.01},
            {'params': [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)],
             'lr': self.lr, 'weight_decay': 0.0}
        ]
        optimizer = AdamW(optimizer_parameters, lr=self.lr, betas=self.betas, eps=self.eps)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=100,
            num_training_steps=self.trainer.estimated_stepping_batches,
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]
                
    def training_step(self, train_batch, batch_idx):
        input_ids, attention_mask, token_scores_labels, token_examples_mapping, \
        examples_scores, examples_classes, token_class_labels = \
            train_batch["input_ids"], train_batch["attention_mask"], train_batch["token_scores_labels"], \
            train_batch['token_examples_mapping'], train_batch['examples_scores'], \
            train_batch['examples_classes'], train_batch['token_class_labels']
        
        features = self.transformer(
            input_ids,
            attention_mask=attention_mask,
        )
        logits = self.logits(features)
        class_logits = self.class_logits(features)
        loss = F.cross_entropy(logits.view(-1, self.num_classes), token_scores_labels.view(-1))
        class_loss = F.cross_entropy(class_logits.view(-1, self.num_classes_class), token_class_labels.view(-1))

        # Convert to examples loss
        bs, ml, nc1 = logits.shape
        
        batch_preds = []
        batch_targs = []
        
        for i in range(bs):
            example_preds = []
            example_targs = []
            num_examples = token_examples_mapping[i].max()
            assert examples_scores[i,num_examples] >= 0 # and examples_scores[i,num_examples+1] < 0 # truncation breaks this
            for j in range(num_examples + 1):
                indices = token_examples_mapping[i] == j
                fts = features[i][indices].mean(dim=0)
                class_idx = examples_classes[i,j]
                emb = self.embedding(class_idx)
                preds = self.example_logits(torch.cat([emb,fts]))
                example_preds.append(preds)
                example_targs.append(examples_scores[i,j].view(1))
                
            example_preds = torch.cat(example_preds, dim=0).view(-1, nc1)
            example_targs = torch.cat(example_targs, dim=0)
            batch_preds.append(example_preds)
            batch_targs.append(example_targs)
        
        batch_preds = torch.cat(batch_preds, dim=0).view(-1, nc1)
        batch_targs = torch.cat(batch_targs, dim=0)
        
        example_loss = F.cross_entropy(batch_preds, batch_targs)
        
#         if self.current_epoch == 0:
    
        total_loss = loss + class_loss + example_loss

        self.log('train_scores_loss', loss)
        self.log('train_classes_loss', class_loss)
        self.log('train_examples_loss', example_loss)
        self.log('train_total_loss', total_loss)

        return total_loss
        
    def validation_step(self, val_batch, batch_idx):
        input_ids, attention_mask, token_scores_labels, token_examples_mapping, \
        examples_scores, examples_classes, token_class_labels = \
            val_batch["input_ids"], val_batch["attention_mask"], val_batch["token_scores_labels"], \
            val_batch['token_examples_mapping'], val_batch['examples_scores'], val_batch['examples_classes'], \
            val_batch['token_class_labels']
        features = self.transformer(
            input_ids,
            attention_mask=attention_mask,
        )
        logits = self.logits(features)        
        class_logits = self.class_logits(features)
        y_pred = F.log_softmax(logits, dim=-1)                                                
        loss = F.cross_entropy(logits.view(-1, self.num_classes), token_scores_labels.view(-1))
        class_loss = F.cross_entropy(class_logits.view(-1, self.num_classes_class), token_class_labels.view(-1))
        self.log('val_loss', loss)
        self.log('val_class_loss', class_loss)
        return {"preds": y_pred,
                "logits": logits,
                "features": features,
                "val_losses": loss,
                "token_examples_mapping": token_examples_mapping,
                "examples_scores": examples_scores,
                "examples_classes": examples_classes}   
    
    def validation_epoch_end(self, validation_step_outputs):

        bs, ml, nc1 = validation_step_outputs[0]["preds"].shape
        ml2 = validation_step_outputs[0]["examples_scores"].shape[-1]
        all_preds = torch.cat([x["logits"] for x in validation_step_outputs], dim=0).view(-1, ml, nc1)
        all_features = torch.cat([x["features"] for x in validation_step_outputs], dim=0).view(-1, ml, self.nb_features)
        all_mappings = torch.cat([x["token_examples_mapping"] for x in validation_step_outputs], dim=0).view(-1, ml)
        all_scores = torch.cat([x["examples_scores"] for x in validation_step_outputs], dim=0).view(-1, ml2)
        all_classes = torch.cat([x["examples_classes"] for x in validation_step_outputs], dim=0).view(-1, ml2)

        num_texts = all_scores.shape[0]        
        
        example_preds = []
        example_targs = []
        
        for i in range(num_texts):
            num_examples = all_mappings[i].max()
            assert all_scores[i,num_examples] >= 0 # and all_scores[i,num_examples+1] < 0 # truncation breaks this
            for j in range(num_examples + 1):
                indices = all_mappings[i] == j
                fts = all_features[i][indices].mean(dim=0)
                class_idx = all_classes[i,j]
                emb = self.embedding(class_idx)
                preds = self.example_logits(torch.cat([emb,fts]))
                example_preds.append(preds)
                example_targs.append(all_scores[i,j].view(1))              
                
        example_preds = torch.cat(example_preds, dim=0).view(-1, nc1)
        example_targs = torch.cat(example_targs, dim=0)
        
        if torch.isnan(example_preds).any() or torch.isinf(example_preds).any():
            print('invalid example_preds')            
        if torch.isnan(example_targs).any() or torch.isinf(example_targs).any():
            print('invalid example_targs')    
            
            
        example_loss = F.cross_entropy(example_preds, example_targs)
        self.log('example_loss', example_loss)
        print(example_loss)
        
    def predict_step(self, val_batch, batch_idx):
        input_ids, attention_mask, token_scores_labels, token_examples_mapping, examples_scores, examples_classes = \
            val_batch["input_ids"], val_batch["attention_mask"], val_batch["token_scores_labels"], \
            val_batch['token_examples_mapping'], val_batch['examples_scores'], val_batch['examples_classes']
        features = self.transformer(
            input_ids,
            attention_mask=attention_mask,
        )
        logits = self.logits(features)
        y_pred = F.softmax(logits, dim=-1)  
        
        bs, ml, nc1 = logits.shape
        ml2 = 40
        
        batch_preds = []
        batch_targs = []
        
        for i in range(bs):
            example_preds = []
            example_targs = []
            num_examples = token_examples_mapping[i].max()
            assert examples_scores[i,num_examples] >= 0 # and examples_scores[i,num_examples+1] < 0 # truncation breaks this
            for j in range(num_examples + 1):               
                indices = token_examples_mapping[i] == j
                fts = features[i][indices].mean(dim=0)
                class_idx = examples_classes[i,j]
                emb = self.embedding(class_idx)
                preds = self.example_logits(torch.cat([emb,fts]))
                example_preds.append(preds)
                example_targs.append(examples_scores[i,j].view(1))   
                
            example_preds = torch.cat(example_preds, dim=0).view(-1, nc1)
            example_targs = torch.cat(example_targs, dim=0)
            batch_preds.append(example_preds)
            batch_targs.append(example_targs)
        
        return batch_preds, batch_targs
        

In [37]:
import pickle
with open('processed-deberta-v3-large.pickle', 'rb') as handle:
    pdf = pickle.load(handle)

In [38]:
tags = ['debug'] if DEBUG else ['train']
if DEBUG: pdf = pdf.sample(n=100, random_state=42)

In [39]:
df_train = pdf[pdf.fold != 0].reset_index(drop=True)
df_valid = pdf[pdf.fold == 0].reset_index(drop=True)

In [40]:
project = 'fbck'
run = wandb.init(project=project, tags=tags)
run.log_code()

In [41]:
seed = 42
OUTPUT_DIR = '../output'
pl.seed_everything(seed)

Global seed set to 42


42

In [42]:
model_checkpoint = 'microsoft/deberta-v3-large'
max_length = 1024
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, max_length=max_length, padding='max_length')
bs = 4
lr = 5e-6
grad_acc = 1
epochs = 2 if DEBUG else 4
num_classes = 3
num_classes_class = 8
randmask_proba = 0.15

In [43]:
train_dataset = MyDataset(
    df_train,
    tokenizer,
    max_len=max_length,
    stage='train',
    rand_prob=randmask_proba
)

valid_dataset = MyDataset(
    df_valid,
    tokenizer,
    max_len=max_length,
    stage='valid',
    rand_prob=randmask_proba
)

In [44]:
train_loader = DataLoader(train_dataset,
                          batch_size=bs,
                          shuffle=True,
                          num_workers=4, pin_memory=True, drop_last=True)

val_loader = DataLoader(valid_dataset,
                          batch_size=bs,
                          shuffle=False,
                          num_workers=4, pin_memory=True, drop_last=False)

In [45]:
warnings.filterwarnings("ignore")

In [46]:
model = MyModule(lr=lr,
                 model_checkpoint=model_checkpoint, 
                 num_classes=num_classes,
                 num_classes_class=num_classes_class,
                 emb_dim=emb_dim,
                 betas=betas,
                 eps=eps,
                )

In [47]:
wandb_logger = WandbLogger(project=project)

In [48]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    save_top_k=1,
    monitor="example_loss",
    mode="min",
    dirpath=f"../output/{EXP}/",
    filename="feedback-{epoch:02d}-{example_loss:.2f}",
)

In [49]:
trainer = pl.Trainer(accelerator="gpu", 
                     precision=16, 
                     gradient_clip_val=1.,
                     devices=1, 
                     max_epochs=epochs,
                     log_every_n_steps=100, 
                     logger=wandb_logger,
                     default_root_dir=f"../output/{EXP}",
                     callbacks=[checkpoint_callback],
                     accumulate_grad_batches=grad_acc,
                     )

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [50]:
trainer.fit(model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name           | Type                          | Params
-----------------------------------------------------------------
0 | transformer    | SlidingWindowTransformerModel | 438 M 
1 | logits         | Linear                        | 3.1 K 
2 | example_logits | Linear                        | 3.3 K 
3 | class_logits   | Linear                        | 8.2 K 
4 | embedding      | Embedding                     | 512   
-----------------------------------------------------------------
438 M     Trainable params
0         Non-trainable params
438 M     Total params
877.504   Total estimated model params size (MB)


Sanity Checking DataLoader 0: 100%|████████████████████████████████████| 2/2 [00:00<00:00,  2.33it/s]tensor(1.0891, device='cuda:0')
Epoch 0:  80%|███████████████████████▏     | 838/1048 [05:18<01:19,  2.63it/s, loss=2.64, v_num=feep]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                            | 0/210 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                               | 0/210 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|▏                                      | 1/210 [00:00<00:31,  6.54it/s][A
Epoch 0:  80%|███████████████████████▏     | 839/1048 [05:19<01:19,  2.63it/s, loss=2.64, v_num=feep][A
Validation DataLoader 0:   1%|▎                                      | 2/210 [00:00<00:28,  7.28it/s][A
Epoch 0:  80%|███████████████████████▏     | 840/1048 [05:19<01:19,  2.63it/s, loss=2.64, v_num=feep][A
Validation DataLoader 0:   1%|▌                                      | 3/210 [00:00<00:27,  7.52it/s]

In [51]:
wandb.finish()

0,1
epoch,▁▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▃▆▆▆▆▆▆▆▆▆▆█████████
example_loss,█▃▃▁
train_classes_loss,█▄█▆▄▅▆▃▃▃▃▆▄▄▄▄▂█▆▆▄▃▃▃▂▃▄▅▃▆▃▂▁
train_examples_loss,▆█▃▃▅▆▂▃▄▃▆▅▅▄▃▂▂▃▄▄▃▄▃▅▂▁▂▃▃▃▁▃▇
train_scores_loss,▆█▃▄▅▅▃▃▄▃▃▅▅▆▄▂▁▃▅▅▄▆▃▅▂▁▂▅▄▂▂▁▆
train_total_loss,█▇▆▅▅▆▄▃▄▃▄▆▅▅▄▂▁▆▅▆▄▄▃▄▁▁▂▄▃▄▁▁▄
trainer/global_step,▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
val_class_loss,█▄▁▁
val_loss,█▅▃▁

0,1
epoch,3.0
example_loss,0.69697
train_classes_loss,0.54289
train_examples_loss,1.13402
train_scores_loss,1.01846
train_total_loss,2.69538
trainer/global_step,3351.0
val_class_loss,0.96168
val_loss,0.66932
