In [1]:
#https://github.com/amazon-science/efficient-longdoc-classification
import sys
sys.path.append('../')
from functools import partial
import nltk
import pickle as pk
import torch
from context_enforcement.models.context_enforcer import compute_context_boundary
from context_enforcement.trainers.train_bart3 import model_init
from context_enforcement.data.common import create_text_tokenizer, SmartCollator
from context_enforcement.trainers.common import get_dataset_specified_tasks
from pytorch_lightning import seed_everything

import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
seed_everything(1376)

  from .autonotebook import tqdm as notebook_tqdm
Global seed set to 1376
[nltk_data] Downloading package punkt to /home/nlplab/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Global seed set to 1376


1376

In [2]:
configs= pk.load(open("../trained_models/xsum/bart-base-xsum-context-enforcer-baseline/train_args.ap",'rb'))
context_max_len=configs.context_max_len
context_max_len_list = [context_max_len]#,300,450]
context_sampling_bounds=(0.1, 0.45)

In [3]:
tokenizer = create_text_tokenizer(configs.model_base)

task_dataset_gen = get_dataset_specified_tasks(configs.task_type)

train_dataset = None
eval_dataset = None
test_dataset = None
if task_dataset_gen is not None:
    raw_dataset = task_dataset_gen(tokenizer=tokenizer, )
    train_dataset = raw_dataset['train']
    eval_dataset = raw_dataset['validation']
    test_dataset = raw_dataset['test']

model_builder = model_init(
        vocab_size=len(train_dataset.tokenizer),
        model_base=configs.model_base,
        context_max_len = context_max_len,
        context_sampling_bounds = context_sampling_bounds,
        context_max_len_list= context_max_len_list,#is_baseline=True
        is_enforcement_baseline=configs.is_enforcement_baseline
    )

Found cached dataset xsum (/home/nlplab/.cache/huggingface/datasets/xsum/default/1.2.0/082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71)
100%|██████████| 3/3 [00:00<00:00, 469.09it/s]


In [4]:
generator = model_builder()
train_model_path = "../trained_models/xsum/bart-base-xsum-context-enforcer-baseline/checkpoint-12753/pytorch_model.bin"
state_dict = torch.load(train_model_path)
generator.load_state_dict(state_dict)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [25]:
te1= test_dataset[690]
b_input_ids = te1.input_ids.view(1, -1).to(device)
b_input_mask = te1.attention_mask.view(1, -1).to(device)

In [8]:
example= {'document':"""
           Have you been rocking "The Rachel" for the last twenty years? If so, then you're overdue for a trendier haircut that can make the most of your youthful face and features. 
           Check out some stylish magazines or even some celebrity gossip magazines and see what hairstyles are popular these days. 
           You don't have to go for something ultra-trendy if that's not your thing, but getting a haircut that suits you better than your old one can make you look a decade younger.
           """}
te1=test_dataset._process_data(example)
b_input_ids = te1.input_ids.view(1, -1).to(device)
b_input_mask = te1.attention_mask.view(1, -1).to(device)

In [44]:
#seed_everything(1376)
boundary_sample =  (0.15, 0.35)
seq_len = te1.input_ids[:1024].shape[0]

boundary_width = 40
# int(0.95*seq_len) 
context_boundary = compute_context_boundary(seq_len,
                                            context_sampling_bounds=boundary_sample,
                                            context_max_len=boundary_width)
context_boundary,seq_len,context_max_len

((39, 79), 241, 200)

In [45]:
generator.eval()
with torch.no_grad():
    bb=generator.generate(input_ids=b_input_ids[:,:1024],
                attention_mask=b_input_mask[:,:1024],
                context_boundary=context_boundary,
                eos_token_id=test_dataset.tokenizer.eos_token_id,
        max_length=189,
        early_stopping=True,
        use_cache=True,
        num_beams=10,
        )
test_dataset.tokenizer.batch_decode(bb,clean_up_tokenization_spaces=True,skip_special_tokens=True)

["Kenya's authorities have deported Rumba singer Yolanda Olomide from the Kenyan capital, Nairobi."]

In [46]:
test_dataset.tokenizer.decode(te1.labels,clean_up_tokenization_spaces=True,skip_special_tokens=True)
context_sampling_bounds

(0.1, 0.45)

In [47]:
from torch.utils.data import DataLoader,SequentialSampler
import tqdm
from context_enforcement.data.common import write_to_file
import evaluate
metrics = evaluate.combine(['bleu','meteor',"rouge"])
def generate(context_max_len:int):
    test_data_loader = DataLoader(test_dataset,batch_size=12,
                                sampler= SequentialSampler(test_dataset),
                                collate_fn= SmartCollator(
                pad_token_id=train_dataset.tokenizer.pad_token_id,
                max_len=configs.max_seq_len,
                context_max_len=context_max_len,
                context_sampling_bounds=context_sampling_bounds,
            
            ))

    output_summaries =[]
    for batch in tqdm.tqdm(test_data_loader):
        b_input_ids = batch['input_ids'].to(device)
        b_input_mask = batch['attention_mask'].to(device)
        
        seq_len = b_input_ids.shape[1]

        
        
        boundary_mask =  batch["boundary"]
        bb=generator.generate(input_ids=b_input_ids,
                attention_mask=b_input_mask,
                context_boundary=boundary_mask,
                num_beams=10,
                do_sample=False,
                early_stopping=True,
                use_cache=True,
                num_return_sequences=1,
                max_length=160)
        sentences = test_dataset.tokenizer.batch_decode(bb,clean_up_tokenization_spaces=True,skip_special_tokens=True)
        output_summaries+=sentences
    return output_summaries

[nltk_data] Downloading package wordnet to /home/nlplab/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /home/nlplab/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/nlplab/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [48]:
targets = [tokenizer.decode(c.labels,clean_up_tokenization_spaces=True,skip_special_tokens=True) for c in test_dataset]

Token indices sequence length is longer than the specified maximum sequence length for this model (1327 > 1024). Running this sequence through the model will result in indexing errors


In [49]:
os.makedirs("outputs/xsum/context-enforcer-baseline/",exist_ok=True)

In [None]:
context_lens = [70,120,220,320,420,520,620,720]
outputs = {}
results = {}
for cl in context_lens:
    print(f'Generating for the context length: {cl}')
    rbase_output = generate(context_max_len=cl)
    outputs[cl] = rbase_output
    
    write_to_file(rbase_output[:len(test_dataset)], 
              f"outputs/xsum/context-enforcer-baseline//best_base_final_{cl}")
    
    scores = metrics.compute(predictions=rbase_output,references=targets)
    print(scores)
    
    results[cl]= scores
    
    
    