In [None]:
!pip install transformers &> /dev/null
!pip install datasets &> /dev/null
!pip install stanfordnlp &> /dev/null
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [None]:
import stanfordnlp
stanford_nlp = stanfordnlp.Pipeline(processors='tokenize', lang='en')

def stanford_tokenizer(text):
    doc = stanford_nlp(text)
    sentences = []
    for i, sentence in enumerate(doc.sentences):
        sent = [word.text for word in sentence.words]
        sentences.append(sent)
        
    return sentences

Use device: gpu
---
Loading: tokenize
With settings: 
{'model_path': '/root/stanfordnlp_resources/en_ewt_models/en_ewt_tokenizer.pt', 'lang': 'en', 'shorthand': 'en_ewt', 'mode': 'predict'}
Cannot load model from /root/stanfordnlp_resources/en_ewt_models/en_ewt_tokenizer.pt


SystemExit: ignored

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


## Import libraries

In [None]:
import torch
from torch.nn.functional import softmax
from transformers import BertForNextSentencePrediction, BertTokenizer
from transformers import AdamW
from transformers import get_scheduler
import nltk
from itertools import combinations, permutations
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
import pprint
import statistics as stat
import numpy as np
import random
from collections import Counter


sent_detector = nltk.data.load('tokenizers/punkt/english.pickle') 
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


## Preparing Dataset for Fine-tuning the model

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import os
import json
path = 'drive/MyDrive/pubmed-dataset'

Mounted at /content/drive


In [None]:
file_dict ={}
path = 'drive/MyDrive/pubmed-dataset'
files = os.listdir(path)
files = [f for f in files if f != 'vocab']
#files = [f for f in files if 'test' in f]
# Import all train, val and test files
for fname in files:
    with open(os.path.join(path,fname),'r', encoding='utf-8') as f:
        file_dict[fname[:-4]] = [l.strip() for l in f]

In [None]:
def process_sc_data(txt_file):
    '''
    Return a list of abstract sentence lists.
    '''
    master_abs_list = []
    master_full_list = []
    for paper in txt_file:
        abstract_list = [i.replace('<S>','').replace('</S>','').strip() for i in json.loads(paper)['abstract_text']]
        #abstract = ''.join(abstract).replace('<S>','').replace('</S>','').strip()

        #full = ''.join(json.loads(paper)['article_text']).replace('\n','').strip()
        full_list = json.loads(paper)['article_text']

        # abstract list 
        #abstract_list = sent_detector.tokenize(abstract)
        #full_list = sent_detector.tokenize(full)
        master_abs_list.append(abstract_list)
        master_full_list.append(full_list) 

    return master_abs_list,master_full_list


In [None]:
def pre_sentpairs(abs,full):
    sampled_full = []
    assert len(abs) == len(full)
    # Length of abstract and full text
    doc_abs_sents = [len(i) for i in abs] 
    doc_full_sents = [len(i) for i in full]
    # loop through all of them
    for i in range(len(abs)):
        # Length of abstract and doc
        abs_len = doc_abs_sents[i]
        full_len = doc_full_sents[i]
        to_sample = full_len-abs_len
        # 
        if to_sample > 0:
            start = random.sample(range(to_sample),1)[0]
            end = start+abs_len
            sampled_full.append(full[i][start:end])
        elif len(full[i])>1:
            sampled_full.append(full[i])
        
    doc_sentences = abs+sampled_full
    
    return doc_sentences

In [None]:
def extract_neg_intersample(docs_sents_list,current_doc_index):

    '''
    Given a doc_sents_list and an index, we extract one negative sample
    (i.e., a sentence) from other docs. 
    '''
    wk_list = docs_sents_list.copy()
    # random sample one sent from current index
    intra_sent_index = random.sample(range(len(wk_list[current_doc_index])),1)[0]
    intra_sent = wk_list[current_doc_index][intra_sent_index]
    
    # random sample one index and ensure it is not the same index as current index
    inter_doc_index = random.sample(range(len(wk_list)),1)[0]
    while inter_doc_index == current_doc_index:
        inter_doc_index = random.sample(range(len(wk_list)),1)[0]
    inter_doc = wk_list[inter_doc_index]
    # Random select one sent_index
    inter_sent_index = random.sample(range(len(inter_doc)),1)[0]
    inter_sent = inter_doc[inter_sent_index]

    info = {
        'sent1-index':intra_sent_index,
        'sent1-docindex':current_doc_index,
        'sent2-index': inter_sent_index,
        'sent2-docindex':inter_doc_index,
        'label':1}

    return intra_sent, inter_sent, 1, info

In [None]:
def process_sentpairs(docs_sents_list, interdoc_bias=0.5,negative_sample_amount = 1):
    '''
    Given a list of sentence list, find next sentence with label = 0 and 
    create one negative samples of label = 1 (i.e., a balanced dataset).
    '''
    datasets = []
    for doc_index in range(len(docs_sents_list)):
        sents = docs_sents_list[doc_index]
        doc_length = len(sents)
        sent_pairs_index = list(permutations(range(doc_length),2))
        ##  Create data from same doc
        pos_samples = []
        neg_intrasamples = []
        for t in sent_pairs_index:
            if (t[0]+1) == t[1]:
                info = {'sent1-index':t[0],
                        'sent1-docindex':doc_index,
                        'sent2-index':t[1],
                        'sent2-docindex':doc_index,
                        'label':0}
                data = (sents[t[0]],sents[t[1]],0,info) # sent1, sent2, label, infos
                pos_samples.append(data)
            else:
                info = {'sent1-index':t[0],
                        'sent1-docindex':doc_index,
                        'sent2-index':t[1],
                        'sent2-docindex':doc_index,
                        'label':1}
                data = (sents[t[0]],sents[t[1]],1,info)
                neg_intrasamples.append(data)
        # Create negative samples from same doc and other docs
        desired_neg_amount = negative_sample_amount*len(pos_samples)
        neg_samples = []
        while len(neg_samples) < desired_neg_amount:
            # If larger than interdoc_bias then we sample from intradoc
            if random.uniform(0, 1) > interdoc_bias:
                random.shuffle(neg_intrasamples)
                intra_sample = neg_intrasamples.pop()
                neg_samples.append(intra_sample)
            # If smaller, we extract from interdocs (i.e., other docs)
            else:
                inter_sample = extract_neg_intersample(docs_sents_list,doc_index)
                neg_samples.append(inter_sample)

        datasets += pos_samples
        datasets += neg_samples
    # Unzip and assign variables for readability and usability
    previous_sents, later_sents, labels, infos = list(map(list, zip(*datasets)))

    return previous_sents, later_sents, labels, infos

In [None]:
test_abs, test_full = process_sc_data(file_dict['test'])
test_data = process_sentpairs(pre_sentpairs(test_abs,test_full))

val_abs, val_full = process_sc_data(file_dict['val'])
val_data = process_sentpairs(pre_sentpairs(val_abs,val_full))


train_abs, train_full = process_sc_data(file_dict['train'])
train_data = process_sentpairs(pre_sentpairs(train_abs,train_full))

In [None]:
class ScienceDataset(Dataset):
  def __init__(self, previous_sents, next_sents, labels, paper_index, tokenizer):
    self.previous_sents = previous_sents
    self.next_sents = next_sents
    self.labels = labels
    self.paper_index = paper_index
    self.tokenizer = tokenizer
  
  def __len__(self):
    return len(self.labels)

  def __getitem__(self, idx):
    sent1 = self.previous_sents[idx]
    sent2 = self.next_sents[idx]
    label = self.labels[idx]
    index = self.paper_index[idx]

    encoded = self.tokenizer.encode_plus(sent1, text_pair=sent2, is_split_into_words=False,
                   padding=True, truncation=True, return_tensors="pt")
    length = encoded['input_ids'].shape[1]
    input_ids = encoded['input_ids'][0]
    token_type_ids = encoded['token_type_ids'][0]
    attention_mask = encoded['attention_mask'][0]
    encoded = dict(
        input_ids=input_ids,
        token_type_ids=token_type_ids,
        attention_mask=attention_mask
    )
    return encoded, label, length, index

In [None]:
def collate_fn(data):
    """
       data: is a list of tuples with (example, label, length)
             where 'example' is a tensor of arbitrary shape
             and label/length are scalars
             https://stackoverflow.com/questions/65279115/how-to-use-collate-fn-with-dataloaders
    """
    _, labels, lengths, indexs = zip(*data)
    max_len = max(lengths)
    input_ids = torch.zeros((len(data), max_len))
    token_type_ids = torch.zeros((len(data), max_len))
    attention_mask = torch.zeros((len(data), max_len))

    for i in range(len(data)):
      k = data[i][0]['input_ids'].size(0)
      input_ids[i] = torch.cat([data[i][0]['input_ids'], torch.zeros(max_len-k)])
      token_type_ids[i] = torch.cat([data[i][0]['token_type_ids'], torch.zeros(max_len-k)])
      attention_mask[i] = torch.cat([data[i][0]['attention_mask'], torch.zeros(max_len-k)])

    encoded = dict(
        input_ids=input_ids.long(),
        token_type_ids=token_type_ids.long(),
        attention_mask=attention_mask.long()
    )
    return encoded, torch.Tensor(labels).long(), list(indexs)

In [None]:
# load pretrained model and a pretrained tokenizer
model = BertForNextSentencePrediction.from_pretrained('allenai/scibert_scivocab_uncased')
tokenizer = BertTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertForNextSentencePrediction: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
model

BertForNextSentencePrediction(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(31090, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [None]:
# load pretrained model and a pretrained tokenizer
#model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
#tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
## train
train_sc = ScienceDataset(train_data[0],train_data[1],train_data[2],train_data[3],tokenizer)
train_loader = DataLoader(train_sc, batch_size=16, shuffle=True, collate_fn=collate_fn)


## val
val_sc = ScienceDataset(val_data[0],val_data[1],val_data[2],val_data[3],tokenizer)
val_loader = DataLoader(val_sc, batch_size=16, shuffle=True, collate_fn=collate_fn)

## test
test_sc = ScienceDataset(test_data[0],test_data[1],test_data[2],test_data[3],tokenizer)
test_loader = DataLoader(test_sc, batch_size=16, shuffle=True, collate_fn=collate_fn)

### Codes for evaluating results

In [None]:
def evaluator(y_true, y_pred):
  accuracy = accuracy_score(y_true, y_pred)
  macro_f1 = f1_score(y_true, y_pred, average='macro')
  weighted_f1 = f1_score(y_true, y_pred, average='weighted')
  return dict(
      accuracy=accuracy,
      macro_f1=macro_f1,
      weighted_f1=weighted_f1
  ) 

In [None]:
def fluency(y_pred, y_keys):
  relevant_index = [(v['sent1-docindex'],k) for k,v in enumerate(y_keys) if v['label']==0]
  relevant_index_dict = {}

  for paper_index, relevant_labelindex in relevant_index:
    relevant_index_dict.setdefault(paper_index,[]).append(relevant_labelindex)
  
  fluency_scores = {}
  for paper_index, rel_index in relevant_index_dict.items():
    if len(rel_index) > 0:
      #print([y_pred[i] for i in rel_index])
      #print(np.mean(np.array([y_pred[i] for i in rel_index])==0))
      fluency_scores[paper_index] = np.mean(np.array([y_pred[i] for i in rel_index])==0)
  

  if len(fluency_scores) > 0:
    fluency_score = stat.mean(fluency_scores.values())
    return fluency_score
  else:
    return 'No relevant scores'
  

In [None]:
features, label, keys = next(iter(test_loader))
out = model(**features)[0]
pred = torch.softmax(out, dim=1)

_, class_prediction = torch.max(pred, 1)

print('Label', label)
print('Prediction',class_prediction)
print(evaluator(label, class_prediction))
print('Keys',keys)
print(fluency(class_prediction, keys))

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


## Functions to evaluate data

In [None]:
def eval_data(eval_loader):
    with torch.no_grad():
        val_label = []
        val_pred = []
        val_keys = []
        for data in tqdm(eval_loader):
            features, label, keys = data
            features = {k:v.to(device) for k,v in features.items()}
            label = label.to(device)
            outputs = model(**features)[0]
            pred = torch.softmax(outputs, dim=1)
            _, class_prediction = torch.max(pred, 1)
            val_label += label.tolist()
            val_pred += class_prediction.tolist()
            val_keys += keys

    accuracy_f1 = evaluator(val_label,val_pred)
    eval_fluency = fluency(val_pred, val_keys)
    print('Prediction Counter:', Counter(val_pred))  
    print('Accuracy and F1-score: ', accuracy_f1)
    print('Fluency of validation dataset', eval_fluency)

    return Counter(val_pred), accuracy_f1, eval_fluency
  

In [None]:
with open('drive/MyDrive/Summarization/PUBMED_SUMM/pegasus_pubmed.txt','r',encoding='utf-8') as f:
    pegasus = [sent_detector.tokenize(l.strip().replace('<n> ','')) for l in f]
    
with open('drive/MyDrive/Summarization/PUBMED_SUMM/bigbirdpegasus_pubmed.txt','r',encoding='utf-8') as f:
    bigbird = [sent_detector.tokenize(l.strip().replace('<n> ','')) for l in f]
    
with open('drive/MyDrive/Summarization/PUBMED_SUMM/hiporank_extsumm.txt','r',encoding='utf-8') as f:
    hiporank = [sent_detector.tokenize(l.strip().replace('<n> ','')) for l in f]

pegasus_data = process_sentpairs(pegasus)
bigbird_data = process_sentpairs(bigbird)
hiporank_data = process_sentpairs(hiporank)


In [None]:

## Pegasus Loader
pegasus_sc = ScienceDataset(pegasus_data[0],pegasus_data[1],pegasus_data[2],pegasus_data[3],tokenizer)
pegasus_loader = DataLoader(pegasus_sc, batch_size=16, shuffle=True, collate_fn=collate_fn)

## Bigbird Loader
bigbird_sc = ScienceDataset(bigbird_data[0],bigbird_data[1],bigbird_data[2],bigbird_data[3],tokenizer)
bigbird_loader = DataLoader(bigbird_sc, batch_size=16, shuffle=True, collate_fn=collate_fn)


## Hiporank Lodaer
hiporank_sc = ScienceDataset(hiporank_data[0],hiporank_data[1],hiporank_data[2],hiporank_data[3],tokenizer)
hiporank_loader = DataLoader(hiporank_sc, batch_size=16, shuffle=True, collate_fn=collate_fn)

## Train Model

In [None]:
# Number of parameters
sum(p.numel() for p in model.parameters() if p.requires_grad)

109920002

In [None]:
optimizer = AdamW(model.parameters(), lr=5e-5)

num_epochs = 5
num_training_steps = num_epochs * len(train_loader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

criterion = torch.nn.CrossEntropyLoss()

In [None]:
model.to(device)
print('')

In [None]:
from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))
counter = 0 
total_loss = 0

for epoch in range(num_epochs):
    print('EPOCH: ', epoch)
    # Evaluate model
    model.eval()
    print('Model performance on Validation Data Human Abstract')
    print('-'*80)
    eval_data(val_loader)
    print('Model performance on Test Data Human Abstract')
    print('-'*80)
    eval_data(test_loader)
    
    print('Model performance on Pegasus summary')
    print('-'*80)
    eval_data(pegasus_loader)
    print('Model performance on Big Bird summary')
    print('-'*80)
    eval_data(pegasus_loader) 
    print('Model performance on HipoRank summary')
    print('-'*80)
    eval_data(hiporank_loader)
    
    
    master_label = []
    master_class_pred = []      

    print("Start training for epoch:", epoch)
    model.train()
    for batch in train_loader:
        features, label, _ = batch
        features = {k:v.to(device) for k,v in features.items()}
        label = label.to(device)

        outputs = model(**features)[0]
        pred = torch.softmax(outputs, dim=1)
        loss = criterion(pred,label)
        loss.backward()

        # Optimizer - trnainig step
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        # to print loss and prediction label
        # Class prediction
        _, class_prediction = torch.max(pred, 1)
        # Store in list
        master_label += label.tolist()
        master_class_pred += class_prediction.tolist()
        # Running loss
        total_loss += loss.item()

        ## Show total loss and evaluation metric
        if counter%19999 == 0:
            print(total_loss)
            total_loss = 0
            print(evaluator(master_label, master_class_pred))
            master_label = []
            master_class_pred = []
        # counter
        counter +=1 
