In [1]:

import pandas as pd
import matplotlib
import os
# os.environ['TRANSFORMERS_OFFLINE'] = '1' # Indicating transformers for offline mode

from functools import partial

from pathlib import Path

from tokenizers import Tokenizer
from transformers import AutoTokenizer, AutoModel
import transformers
from transformers.models.auto.tokenization_auto import logger

import random

import torch
import torchmetrics
import torch.nn as nn
from torch.utils import data
from torch.utils.data import Dataset, DataLoader, random_split

import pytorch_lightning as pl





transformers.logging.set_verbosity_error()
TRANSFORMER_CACHE = Path("../resources/transformer_cache") # The cache location. Accessed when offline.

#### Model

In [2]:
class CHUNKSUMM(pl.LightningModule):
    def __init__(self, model, learning_rate=6e-5, n_classes=2, enable_chunk=False):
        super().__init__()

        self.bert = model

        # Freezing bert params
        # for param in self.bert.parameters():
        #     param.requires_grad = False
        # self.bert.eval()

        self.criterion = nn.BCEWithLogitsLoss()
        self.l1 = torch.nn.Linear(768, n_classes)
        self.learning_rate = learning_rate
        self.accuracy = torchmetrics.Accuracy()
        self.auc = torchmetrics.AUROC(num_classes=n_classes)
        self.enable_chunk = enable_chunk
       
        # self.save_hyperparameters() # Saves every *args in _init_() in checkpoint file. # Slows trainer.predict
        self.save_hyperparameters(ignore=["bert"])




    def forward(self, input_ids, attention_mask, token_type_ids, train=False):
        """Can handle more than 512 tokens"""

        embed2d = self.get_embedding(input_ids, attention_mask, token_type_ids)
        logits = self.l1(embed2d)  
        if train:
            return logits
        else: 
            return torch.softmax(logits,dim=-1)




    def get_embedding(self, input_ids, attention_mask, token_type_ids):

        if self.enable_chunk:

            batch_chunks = [
                self.chunk(batch) for batch in (input_ids, attention_mask, token_type_ids)
            ]
            handler = []
            for chunk in zip(batch_chunks[0], batch_chunks[1], batch_chunks[2]):

                chunk_hidden_states = self.bert(chunk[0], chunk[1], chunk[2], output_hidden_states=True)[2]
                chunk_embed2d = torch.stack(chunk_hidden_states)[-5:].mean(0)
                handler.append(chunk_embed2d)

            contextual_encoding = torch.cat(handler, dim=1)
            embed2d = contextual_encoding

        else:
            hidden_states = self.bert(input_ids[:, :512], attention_mask[:, :512], token_type_ids[:, :512], output_hidden_states=True)[2]
            mean_hidden_states = torch.stack(hidden_states)[-5:].mean(0)
            contextual_encoding = mean_hidden_states
            embed2d = contextual_encoding

        return embed2d

    def training_step(self, batch, batch_ids=None):



        outputs = self(batch["input_ids"], batch["attention_mask"], batch["token_type_ids"],train=True)

        labels = self.expand_targets(batch["targets"].float()) 
        labels = labels.reshape_as(outputs)


        loss = self.criterion(outputs, labels) 
        #acc = self.accuracy(outputs, labels.int())
        auc = self.auc(outputs, labels.int())


        
        self.log("Loss_train", loss, prog_bar=True, logger=True)
        self.log("Auc_train", auc, prog_bar=True, logger=True)

        return {"loss": loss, "predictions": outputs, "labels": labels}



    def validation_step(self, batch, batch_idx):

        outputs = self(
            batch["input_ids"], batch["attention_mask"], batch["token_type_ids"])
        labels = self.expand_targets(batch["targets"].float()) 
        labels = labels.reshape_as(outputs)


        loss = self.criterion(outputs,labels)   

        auc = self.auc(outputs, labels.int())

        self.log("Loss_val", loss, prog_bar=True, logger=True)
        self.log("Auc_val", auc, prog_bar=True, logger=True)

        return loss

    def test_step(self, batch, batch_idx):

        outputs = self(batch["input_ids"], batch["attention_mask"], batch["token_type_ids"])
        labels = self.expand_targets(batch["targets"].float()) 
        labels = labels.reshape_as(outputs)

        loss = self.criterion(outputs, labels)
        #acc = self.accuracy(outputs, labels.int())
        auc = self.auc(outputs, labels.int())

        self.log("Test_loss", loss, prog_bar=True, logger=True)
        self.log("Test_auc", auc, prog_bar=True, logger=True)

        return loss

    def predict_step(self, batch, batch_ids, dataloader_idx=None):

        outputs = self(
            batch["input_ids"], batch["attention_mask"], batch["token_type_ids"]
        )

        return outputs
    

    def configure_optimizers(self):

        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

    @property
    def chunk(self):
        return partial(torch.split, split_size_or_sections=512, dim=1)

    def expand_targets(self,targets):
        """This returns the Two dimentional targets given the single correct label"""
        # 0 -> [0,1]  , 1 -> [1,0] =  (IN,OUT) class.
        out = torch.stack([torch.tensor([1.,0.]) if val else torch.tensor([0.,1.])  for batch in  targets.bool() for val in batch])
      
        return out.to(self.device)

In [3]:
# This is required to initialize the backend-model (bert) which is a pretrained model.
  
pretrained_model = AutoModel.from_pretrained('../resources/checkpoints/bert-base-uncased.pt')
tokenizer = AutoTokenizer.from_pretrained('../resources/checkpoints/bert-base-uncased-tokenizer.pt')
model = CHUNKSUMM(model=pretrained_model)

  rank_zero_warn(


#### Load Checkpoints

In [4]:

#loading model checkpoint

first = torch.load('../resources/checkpoints/first.pt') # initial model
second_40k = torch.load('../resources/checkpoints/second-40k.pt') # Traning with 40k sentences

model.load_state_dict(second_40k) 


def get_token_scores(model,tokenizer,text:[str]):

    tokenized_input = tokenizer(text,return_tensors='pt')
    model.eval()

    return tokenized_input['input_ids'],model(**tokenized_input) # (IN,OUT) probabilities


### Inference

In [5]:
"This is a test"
batch_text = "This is a test"
tokens,probs = get_token_scores(model,tokenizer,batch_text)

In [6]:
tokens.shape # batch * tokens   

torch.Size([1, 6])

In [7]:
probs[:,:,0] # batch * tokens * probability

tensor([[0.0006, 0.0008, 0.0007, 0.0005, 0.0007, 0.0009]],
       grad_fn=<SelectBackward0>)

In [8]:
probs

tensor([[[5.7740e-04, 9.9942e-01],
         [7.7118e-04, 9.9923e-01],
         [6.7593e-04, 9.9932e-01],
         [4.9304e-04, 9.9951e-01],
         [6.7584e-04, 9.9932e-01],
         [9.2531e-04, 9.9907e-01]]], grad_fn=<SoftmaxBackward0>)

In [9]:
probs[:,:,0].shape

torch.Size([1, 6])