In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='1'
#import wget
import sys
sys.path.append('../')
import argparse
import json
import pandas as pd
import random
import numpy as np
import string
import nltk
from functools import partial
import re
from  tqdm import tqdm
import torch
nltk.download('punkt')
from functools import partial
import nltk
from src.dataset_processor import load_all_data
from src.utils import SmartCollator, get_args, setuptokenizer
from src.dataset_processor import (
    Multi_taskQuestionGenerationDataset as QuestionGenerationDataset,
)
from src.model_utils import CustomTrainer, get_training_arguments, model_init
from src.config import DATASET_PATH, GenerationTasks
from transformers.trainer_callback import EarlyStoppingCallback


[nltk_data] Downloading package punkt to /home/nlplab/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
model_base = 't5-base'
tokenizer = setuptokenizer(model_base=model_base,special_tokens=[
            GenerationTasks.vanilla_question_gen,
            GenerationTasks.context_question_gen,
            GenerationTasks.question_paraphrase,
            "<section>",
            "</section>",
        ],)

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [3]:
from src.dataset_processor import Multi_taskQuestionGenerationDataset
dataset = Multi_taskQuestionGenerationDataset(tokenizer,nb_records=1)
dataset.change_data_mode(1)

In [6]:
from src.model_utils import model_init
saved_model_path = '../trained_models_mtl/t5_base_model_2/checkpoint-120790/'
#'../trained_models_mtl/bart_base_model_1/checkpoint-25524//'
trained_weights = torch.load(f'{saved_model_path}/pytorch_model.bin')

generator = model_init(model_base=model_base,vocab_size=len(tokenizer))
generator.load_state_dict(trained_weights)
device = generator.device

In [7]:
from src.utils import get_default_sentence_split
default_sentence_split = get_default_sentence_split()

In [8]:
from nltk.util import ngrams
import wikipedia
def factgenerator(document,n):
    return list(ngrams(default_sentence_split(document.strip()),n))

In [9]:
sample_too = True
sampling_helper = {} if not sample_too else dict(top_k=30, top_p=0.95,)
max_length=80
length_penalty=2.6
beam_size=4
repetition_penalty=1.56
return_top_beams= beam_size if not sample_too else 10

In [10]:
DATASET_PATH = '../curated_data/'
#train_data_packet = load_all_data(DATASET_PATH, mode="train")
test_data_packet = load_all_data(DATASET_PATH, mode="dev")+load_all_data(DATASET_PATH, mode="test")
test_dataset = QuestionGenerationDataset(
        tokenizer=tokenizer, nb_records=len(test_data_packet), highlight_section=False
    )
test_dataset.change_data_mode(1)
test_dataset.set_record(test_data_packet)

processing files:  ['../curated_data/squad_dev.csv', '../curated_data/drop_dev.csv', '../curated_data/rope_dev.csv', '../curated_data/sci_dev.csv']
processing files:  ['../curated_data/sci_test.csv']


In [41]:
len(test_dataset)

43468

In [11]:
from torch.utils.data import DataLoader,SequentialSampler, Dataset


In [15]:
def writeToFile(content, filename):
    fil = filename+'.txt'
    if os.path.exists(fil):
        os.remove(fil)
    with open(fil, 'x') as fwrite:
        fwrite.writelines("%s\n" % s for s in content)
    print('Done')
    return

In [12]:
def generateOutput(dataset: Dataset,beam_size = 4,batch_size=20):
    dataset_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              collate_fn= SmartCollator(tokenizer.pad_token_id,is_inference= True,),
                              sampler=SequentialSampler(dataset))
    generated_texts = []
    for batch in tqdm(dataset_loader):
        b_input_ids = batch['input_ids'].to(device)
        b_input_mask = batch['attention_mask'].to(device)

        sample_too =  False
        sampling_helper = {} if not sample_too else dict(top_k=25, top_p=0.95,)
        return_top_beams = 1 if not sample_too else 25
        # seed_everything(2982)
        generator.eval()
        with torch.no_grad():
            sample_outputs = generator.generate(input_ids=b_input_ids,  **sampling_helper,
                                                attention_mask=b_input_mask,
                                                num_beams=beam_size,
                                                repetition_penalty=repetition_penalty,
                                                length_penalty=length_penalty,
                                                early_stopping=False,
                                                use_cache=True,
                                                max_length=max_length,
                                                no_repeat_ngram_size=2,
                                                num_return_sequences=return_top_beams,
                                                do_sample=sample_too,
                                                eos_token_id=dataset.tokenizer.eos_token_id,)
        oop = [dataset.tokenizer.decode(s,
                                        skip_special_tokens=True,
                                        clean_up_tokenization_spaces=True) for s in sample_outputs ]
        generated_texts+=oop
    return generated_texts


In [13]:
reference_sentences =[d.output_text for d in test_data_packet]
generated_text = generateOutput(test_dataset,batch_size=20) 

  0%|          | 6/2174 [00:06<37:48,  1.05s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (639 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 2174/2174 [39:23<00:00,  1.09s/it]


In [16]:
writeToFile(generated_text,saved_model_path+'beam_size4')

Done


Evaluate Performance

In [17]:
import evaluate

In [18]:
bertscore = evaluate.load('bertscore',lang='en',)

In [23]:
scorer=evaluate.combine(['bleu',
                         'meteor',])

[nltk_data] Downloading package wordnet to /home/nlplab/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /home/nlplab/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/nlplab/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [24]:
scorer.compute(predictions=generated_text,
               references=reference_sentences,lang="en")

{'bleu': 0.15058681184538658,
 'precisions': [0.35794498467860497,
  0.16658570130722253,
  0.1089699736515914,
  0.07913833443006507],
 'brevity_penalty': 1.0,
 'length_ratio': 1.3022202659090865,
 'translation_length': 677484,
 'reference_length': 520253,
 'meteor': 0.40187798663956564}

In [19]:
bscores=bertscore.compute(predictions=generated_text,
               references=reference_sentences,lang="en")

In [20]:
np.mean(bscores['f1']),bscores.keys()

(0.9056642513708126, dict_keys(['precision', 'recall', 'f1', 'hashcode']))

In [32]:
def readSentences(file,lower=False):
    with open(file,'r', encoding="utf-8") as o_file:
        sentennces = []
        for s in o_file.readlines():
            ss = s.strip() #.lower() if  lower else s.strip()
            sentennces.append(ss)
    return sentennces

In [34]:
mut_gens = readSentences(saved_model_path+'/beam_4.txt')

In [35]:
expanded_ref = []
for s in reference_sentences:
    expanded_ref +=[s]*4

In [55]:
bscores=bertscore.compute(predictions=mut_gens,
               references=expanded_ref,lang="en")

In [59]:
np.mean(bscores['recall'])

0.9111872966340173

In [39]:
len(expanded_ref),len(mut_gens)
scorer.compute(predictions=mut_gens,
               references=expanded_ref)

{'bleu': 0.14399644886631266,
 'precisions': [0.35924705974114024,
  0.16233583376790084,
  0.10248129819907306,
  0.07193742552989153],
 'brevity_penalty': 1.0,
 'length_ratio': 1.2814784345308916,
 'translation_length': 2666772,
 'reference_length': 2081012,
 'meteor': 0.3933352651188569}

In [25]:
article = wikipedia.summary("Earth crust",auto_suggest=False)

"""
Metals represent approximately 25% of the elemental makeup of the Earth's crust. 
The bulk of these metals, primarily aluminum, iron, calcium, sodium, potassium, and magnesium, are typically found in combined form. The most abundant metal is aluminum, which occurs almost exclusively as the ionic mineral bauxite. The other most common metals, including iron, sodium, potassium, magnesium, and calcium, are also found primarily as the cationic portion of an ionic compound. Very few metals actually occur naturally as pure substances.
The ones that do are often referred to as precious or semi-precious metals.
"""
#wikipedia.summary('Dragon')
n=3
facts = [' '.join(s).replace('\n','').strip() for s in factgenerator(article.replace('.T','. T'),n=n)]

In [34]:
facts[2]

'The lithosphere is broken into tectonic plates whose motion allows heat to escape from the interior of the Earth into space. The crust lies on top of the mantle, a configuration that is stable because the upper mantle is made of peridotite and is therefore significantly denser than the crust. The boundary between the crust and mantle is conventionally placed at the Mohorovičić discontinuity, a boundary defined by a contrast in seismic velocity.'

In [29]:
from src.dataset_processor import QuestionGenerationData


task_id = 0
target_fact = facts[0]
#facts[-1]
#'Table 1 Chemical combination rule for working with N2 gas: item[COA]. volume[32m3].  ratio[0.06].  '
#facts[11]

#' item[COA], volume[32m3],  ratio[0.06]  Table 1: Chemical combination rule for working with N2 gas.'
data =  QuestionGenerationData(task=GenerationTasks.vanilla_question_gen,
                                            input_text= target_fact, 
                                            output_text='',
                                            contextual_text= '')

batch = dataset.procesTexts(data)

b_input_ids = batch.input_ids.view(1, -1).to(device)
b_input_mask = batch.attention_mask.view(1, -1).to(device)

sample_too =  True
sampling_helper = {} if not sample_too else dict(top_k=25, top_p=0.95,)
return_top_beams = beam_size if not sample_too else 25
# seed_everything(2982)
generator.eval()
with torch.no_grad():
    sample_outputs = generator.generate(input_ids=b_input_ids,  **sampling_helper,
                                        attention_mask=b_input_mask,
                                        num_beams=beam_size,
                                        repetition_penalty=repetition_penalty,
                                        length_penalty=length_penalty,
                                        early_stopping=False,
                                        use_cache=True,
                                        max_length=max_length,
                                        no_repeat_ngram_size=2,
                                        num_return_sequences=return_top_beams,
                                        do_sample=sample_too,
                                        eos_token_id=dataset.tokenizer.eos_token_id,)
oop = [dataset.tokenizer.decode(sample_outputs[idx],
                                skip_special_tokens=True,
                                clean_up_tokenization_spaces=True) for idx in range(return_top_beams)]

print(f'Article Section: {data.input_text}')
print('Questions Generated')
oop = set(oop)
for q in oop:
    print(q)

Article Section: Earth's crust is Earth's thin outer shell of rock, regarding for less than 1% of Earth's radius and volume. It is the top component of the lithosphere, a division of Earth's layers that includes the crust and the upper part of the mantle. The lithosphere is broken into tectonic plates whose motion allows heat to escape from the interior of the Earth into space.
Questions Generated
What is the top component of the earth's lithosphere?
What is the top component of the lithosphere?


In [13]:
What is the earth's thin outer shell of rock?
What is the earth's thin outer shell of rock known as?
What is the top component of the lithosphere?
What is the earth's thin outer shell of rock referred to as?

Article Section: Metals represent approximately 25% of the elemental makeup of the Earth's crust. The bulk of these metals, primarily aluminum, iron, calcium, sodium, potassium, and magnesium, are typically found in combined form. The most abundant metal is aluminum, which occurs almost exclusively as the ionic mineral bauxite. The other most common metals, including iron, sodium, potassium, magnesium, and calcium, are also found primarily as the cationic portion of an ionic compound. Very few metals actually occur naturally as pure substances. The ones that do are often referred to as precious or semi-precious metals.
Questions Generated
What is the most abundant metal that occurs almost directly as the ionic mineral bauxite?
What is the most abundant metal that occurs almost exclusively as the ionic mineral bauxite?
What is the most abundant metal in the earth's crust?
What is the most abundant metal?
What is the most abundant metal of the earth's crust, which occurs almost exclusive

In [10]:
from src.dataset_processor import QuestionGenerationData


task_id = 0
target_fact = "What is the name of a reptile-like legendary creature that appears in the folklore of many cultures worldwide?"
#'Table 1 Chemical combination rule for working with N2 gas: item[COA]. volume[32m3].  ratio[0.06].  '
#facts[11]

#' item[COA], volume[32m3],  ratio[0.06]  Table 1: Chemical combination rule for working with N2 gas.'
data =  QuestionGenerationData(task=GenerationTasks.question_paraphrase,
                                            input_text= target_fact, 
                                            output_text='',
                                            contextual_text= '')

batch = dataset.procesTexts(data)

b_input_ids = batch.input_ids.view(1, -1).to(device)
b_input_mask = batch.attention_mask.view(1, -1).to(device)

sample_too = True
sampling_helper = {} if not sample_too else dict(top_k=25, top_p=0.95,)
return_top_beams = beam_size if not sample_too else 25
# seed_everything(2982)
generator.eval()
with torch.no_grad():
    sample_outputs = generator.generate(input_ids=b_input_ids,  **sampling_helper,
                                        attention_mask=b_input_mask,
                                        num_beams=beam_size,
                                        repetition_penalty=repetition_penalty,
                                        length_penalty=length_penalty,
                                        early_stopping=False,
                                        use_cache=True,
                                        max_length=max_length,
                                        no_repeat_ngram_size=2,
                                        num_return_sequences=return_top_beams,
                                        do_sample=sample_too,
                                        eos_token_id=dataset.tokenizer.eos_token_id,)
oop = [dataset.tokenizer.decode(sample_outputs[idx],
                                skip_special_tokens=True,
                                clean_up_tokenization_spaces=True) for idx in range(return_top_beams)]

print(f'Article Section: {data.input_text}')
print('Questions Generated')
oop = set(oop)
for q in oop:
    print(q)

Article Section: What is the name of a reptile-like legendary creature that appears in the folklore of many cultures worldwide?
Questions Generated
What is the name of a reptile-like legendary creature that appears in the folklore of many cultures worldwide?


In [4]:
DATASET_PATH = '../curated_data/'
train_data_packet = load_all_data(DATASET_PATH, mode="train")
test_data_packet = load_all_data(DATASET_PATH, mode="dev")

In [5]:
train_data_packet = load_all_data(DATASET_PATH, mode="train")
test_data_packet = load_all_data(DATASET_PATH, mode="dev")

processing files:  ['../curated_data/drop_train.csv', '../curated_data/squad_train.csv', '../curated_data/rope_train.csv', '../curated_data/extra_data_train.csv', '../curated_data/sci_train.csv']
processing files:  ['../curated_data/squad_dev.csv', '../curated_data/drop_dev.csv', '../curated_data/rope_dev.csv', '../curated_data/sci_dev.csv']


In [16]:
vv =[v for v in train_data_packet if 'How' in v.output_text]
len(vv)

71196

In [12]:
len(train_data_packet)

408372