In [1]:
import requests

#import trex
import json
import pprint
from IPython.display import clear_output

import pickle

import pandas as pd
import numpy as np

from datasets import Dataset 

from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import DataCollatorForSeq2Seq

import evaluate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
torch.cuda.empty_cache()
import gc
# del variables
gc.collect()
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
# device = "cpu"
device


device(type='cuda', index=0)

In [3]:
path = "facebook/bart-base"
path = "E:\Internship\EntityLinking\ELinking2\checkpoint-1500"
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForSeq2SeqLM.from_pretrained(path)
model.gradient_checkpointing_enable()

In [None]:
torch.cuda.is_available = lambda : False

In [None]:
# API request

def expand(text):
    #numberCandidates = 5
    response = requests.get("https://qanswer-core1.univ-st-etienne.fr/api/entitylinker", params={'text': text, 'language': 'en', 'knowledgebase': 'wikidata'})
    #print(response.json())

    input = text + " "
    input_entities = []
    for r in response.json():
        if 'uri' in r and 'http://www.wikidata.org/entity/' in r['uri']:
            input_entities.append("(" + r['text'] + " "+ r['uri'].replace('http://www.wikidata.org/entity/','') + ")") 
    return dict({'input': input, 'input_entities':input_entities})

expand("Rome is the capital of Italy")

In [None]:
with open('./dataset/re-nlg_0-10000.json', 'r') as f:
    data = json.load(f)

# generate the training data
training_data = []
count = 0
for d in data[0:10000]:
    clear_output(wait=True)
    count = count + 1
    print (count)
    if len(d['text'])<2500:     # that is text with less than 2500 len.
        input = expand(d['text'])
        # print(d['text'])
        output_entities = []
        for entities in d['entities']:
            if 'uri' in entities and 'http://www.wikidata.org/entity/' in entities['uri']:
                output_entities.append("("+entities['surfaceform'] + " " + entities['uri'].replace('http://www.wikidata.org/entity/','')+")")
        # print(output)
        # print(count , " -- " , len(d['text']))
        output_entities = [x for x in output_entities if x in input['input_entities']]
        input_entities = ''.join(input['input_entities'])
        output_entities = ''.join(output_entities)

        training_data.append({'text': input['input'] + " " + input_entities, 'summary': output_entities })
#pprint.pprint(training_data)

In [None]:
# with open('training_data.pickle', 'wb') as handle:
#     pickle.dump(training_data, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [4]:
with open('training_data_witoutfilter.pickle', 'rb') as f:
    training_data = pickle.load(f)
print("Length ",len(training_data))

Length  9228


In [5]:
summaryFiltered = []
for example in training_data:
    tempExample = example
    summary = []
    for entity in example["summary"]:
        if(entity.lower() in example["text"].lower()):
            entity = entity.lower()
            entity = entity.replace(' q', ' Q')  
            summary.append(entity)
    tempExample["summary"] = "".join(summary)
    summaryFiltered.append(tempExample)

In [6]:
summaryFiltered = [ x for x in summaryFiltered if(len(x["summary"]) > 0)]

In [7]:
training_data = summaryFiltered
del summaryFiltered
df = pd.DataFrame.from_records(training_data)

In [8]:
dataset = Dataset.from_pandas(df).train_test_split(test_size=.10)
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['text', 'summary'],
        num_rows: 7881
    })
    test: Dataset({
        features: ['text', 'summary'],
        num_rows: 876
    })
})


In [9]:
def preprocess_function(examples):
    inputs = [doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)

    labels = tokenizer(text_target=examples["summary"], max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [10]:
tokenized_data = dataset.map(preprocess_function, batched=True)
print(tokenized_data)

                                                               

DatasetDict({
    train: Dataset({
        features: ['text', 'summary', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 7881
    })
    test: Dataset({
        features: ['text', 'summary', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 876
    })
})




In [14]:
# define evaluation metric
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    print('Pred ', decoded_preds)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    print('Label ', decoded_labels)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

In [12]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

training_args = Seq2SeqTrainingArguments(
    output_dir="ELinking2",
    evaluation_strategy="epoch",
    learning_rate=7e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps = 2,
    weight_decay=0.01,
    save_total_limit=1,
    num_train_epochs=4,
    predict_with_generate=True,
    fp16=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    # compute_metrics=compute_metrics
)


Using cuda_amp half precision backend


In [13]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: text, summary. If text, summary are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 7881
  Num Epochs = 4
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 2
  Total optimization steps = 1972
  Number of trainable parameters = 139420416
  0%|          | 0/1972 [00:00<?, ?it/s]You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
 25%|██▌       | 493/1972 [2:22:56<6:00:17, 14.62s/it] The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration

{'eval_loss': 1.543066143989563, 'eval_runtime': 230.3041, 'eval_samples_per_second': 3.804, 'eval_steps_per_second': 0.478, 'epoch': 1.0}


 25%|██▌       | 500/1972 [2:28:46<10:09:24, 24.84s/it]Saving model checkpoint to ELinking2\checkpoint-500
Configuration saved in ELinking2\checkpoint-500\config.json
Configuration saved in ELinking2\checkpoint-500\generation_config.json


{'loss': 2.0772, 'learning_rate': 5.232251521298174e-05, 'epoch': 1.01}


Model weights saved in ELinking2\checkpoint-500\pytorch_model.bin
tokenizer config file saved in ELinking2\checkpoint-500\tokenizer_config.json
Special tokens file saved in ELinking2\checkpoint-500\special_tokens_map.json
 50%|█████     | 986/1972 [4:55:04<4:16:56, 15.63s/it] The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: text, summary. If text, summary are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 876
  Batch size = 8
                                                      
 50%|█████     | 986/1972 [4:58:56<4:16:56, 15.63s/it]

{'eval_loss': 1.3452848196029663, 'eval_runtime': 231.7395, 'eval_samples_per_second': 3.78, 'eval_steps_per_second': 0.475, 'epoch': 2.0}


 51%|█████     | 1000/1972 [5:02:54<4:45:12, 17.61s/it]Saving model checkpoint to ELinking2\checkpoint-1000
Configuration saved in ELinking2\checkpoint-1000\config.json
Configuration saved in ELinking2\checkpoint-1000\generation_config.json


{'loss': 1.5203, 'learning_rate': 3.4574036511156184e-05, 'epoch': 2.03}


Model weights saved in ELinking2\checkpoint-1000\pytorch_model.bin
tokenizer config file saved in ELinking2\checkpoint-1000\tokenizer_config.json
Special tokens file saved in ELinking2\checkpoint-1000\special_tokens_map.json
Deleting older checkpoint [ELinking2\checkpoint-500] due to args.save_total_limit
 75%|███████▌  | 1479/1972 [7:42:44<2:00:06, 14.62s/it]The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: text, summary. If text, summary are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 876
  Batch size = 8
                                                       
 75%|███████▌  | 1479/1972 [7:46:38<2:00:06, 14.62s/it]

{'eval_loss': 1.2577414512634277, 'eval_runtime': 234.3183, 'eval_samples_per_second': 3.739, 'eval_steps_per_second': 0.469, 'epoch': 3.0}


 76%|███████▌  | 1500/1972 [7:52:31<2:12:16, 16.81s/it] Saving model checkpoint to ELinking2\checkpoint-1500
Configuration saved in ELinking2\checkpoint-1500\config.json
Configuration saved in ELinking2\checkpoint-1500\generation_config.json


{'loss': 1.3022, 'learning_rate': 1.682555780933063e-05, 'epoch': 3.04}


Model weights saved in ELinking2\checkpoint-1500\pytorch_model.bin
tokenizer config file saved in ELinking2\checkpoint-1500\tokenizer_config.json
Special tokens file saved in ELinking2\checkpoint-1500\special_tokens_map.json
Deleting older checkpoint [ELinking2\checkpoint-1000] due to args.save_total_limit
100%|██████████| 1972/1972 [10:04:31<00:00, 14.50s/it] The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: text, summary. If text, summary are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 876
  Batch size = 8
                                                      
100%|██████████| 1972/1972 [10:08:24<00:00, 14.50s/it]

Training completed. Do not forget to share your model on huggingface.co/models =)


100%|██████████| 1972/1972 [10:08:24<00:00, 18.51s/it]

{'eval_loss': 1.2243396043777466, 'eval_runtime': 232.6742, 'eval_samples_per_second': 3.765, 'eval_steps_per_second': 0.473, 'epoch': 4.0}
{'train_runtime': 36504.3983, 'train_samples_per_second': 0.864, 'train_steps_per_second': 0.054, 'train_loss': 1.5278853905853829, 'epoch': 4.0}





TrainOutput(global_step=1972, training_loss=1.5278853905853829, metrics={'train_runtime': 36504.3983, 'train_samples_per_second': 0.864, 'train_steps_per_second': 0.054, 'train_loss': 1.5278853905853829, 'epoch': 4.0})

In [15]:
trainer2 = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

Using cuda_amp half precision backend


In [16]:
trainer2.evaluate()

The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: summary, text. If summary, text are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 876
  Batch size = 8
You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}

  0%|          | 0/110 [00:00<?, ?it/s]Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_

Pred  ['(delaware general corporation law Q5253299)(delaware Q1393)(', '(funk Q164444)(music genre Q188451)(soul music', '(dorado Q8837)(constellation Q8928)(dolphinfish', '(charles Q526128)(american Q30)(actor Q33999)(', '(barium Q1112)(chemical element Q11344)(atomic number Q23', '(ezra Q131635)(hebrew bible Q732870)(ne', '(armenia Q399)(landlocked country Q123480)(transc', '(apollo Q37340)(homeric greek Q35497)(ap', '(eritrea Q986)', '(hemiparesis Q2291130)(hemiplegia Q1901', '(mathematics Q395)(mersenne prime Q186875)(prime number', '(civil rights memorial Q2974874)(civil rights movement Q191600)(s', '(punjabi Q172656)(popular music Q373342)(pun', '(frederick i Q79789)(ansbach Q14859)(margrave Q', '(iona Q610)(small island Q7542813)(inner hebrides Q', '(einhard Q154526)(courtier Q615452)(ein', '(lombards Q33754)(longobardi Q130900)(italian', '(dalnet Q1151684)(internet relay chat Q73)(irc Q73', '(edward lear Q309759)(english Q1860)(illustrator Q', '(don Quixote Q480)(spanish Q1321)

100%|██████████| 110/110 [06:29<00:00,  3.54s/it]


{'eval_loss': 1.0846048593521118,
 'eval_rouge1': 0.3313,
 'eval_rouge2': 0.2462,
 'eval_rougeL': 0.3255,
 'eval_rougeLsum': 0.3259,
 'eval_gen_len': 19.1393,
 'eval_runtime': 400.1349,
 'eval_samples_per_second': 2.189,
 'eval_steps_per_second': 0.275}

In [17]:
results = []
for i in range(0, 10):
    predict_index =i
    summary_ids = model.generate(torch.IntTensor(tokenized_data["test"][predict_index]["input_ids"]).unsqueeze(0).to(device), num_beams=10, max_new_tokens=100, early_stopping=True, no_repeat_ngram_size=2)
    results.append({"text" : tokenized_data["test"][predict_index]["text"],
                    "Actual Summary": tokenized_data["test"][predict_index]["summary"],
                    "Predicted": tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
                    })

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1,
  "transformers_version": "4.26.1"
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,

In [18]:
import pprint

pprint.pprint(results)

[{'Actual Summary': '(delaware general corporation law Q5253299)(delaware '
                    'Q1393)(statute Q820655)(corporate law Q2609670)(delaware '
                    'Q1393)(fortune 500 Q76615)',
  'Predicted': '(delaware general corporation law Q5253299)(the delaware '
               'Q1393)',
  'text': 'The Delaware General Corporation Law (Title 8, Chapter 1 of the '
          'Delaware Code) is the statute governing corporate law in the U.S. '
          'state of Delaware. Over 50% of publicly traded corporations in the '
          'United States and 60% of the Fortune 500 are incorporated in the '
          'state.  (the delaware Q19084518)(the delaware Q19084523)(delaware '
          'Q1393)(delaware Q37477128)(delaware Q986183)(delaware '
          'Q82048)(delaware Q622910)(delaware Q2639827)(delaware '
          'Q2665761)(delaware Q2571194)(delaware Q3708302)(delaware '
          'Q3473057)(delaware Q5253209)(delaware Q5253214)(delaware '
          'Q730314)(delawar

In [None]:
with open('training_data.pickle', 'rb') as f:
    training_data = pickle.load(f)
print("Length ",len(training_data))

In [None]:
training_data_sum =  [x for x in training_data if len(x["summary"]) > 0]
len(training_data_sum)

In [None]:
summaryEntities = [len(x["summary"].split(')(')) for x in training_data_sum]
textEntities = [len(x["text"].split(')(')) for x in training_data_sum]

In [None]:
import seaborn as sns
sns.displot(summaryEntities,  kde=True)
sns.displot(textEntities,  kde=True)

In [None]:
training_data_sum

In [None]:
import pprint
with open('./dataset/re-nlg_0-10000.json', 'r') as f:
    data = json.load(f)

for d in data[0:3]:
    input = expand(d["text"])
    pprint.pprint(input)

In [7]:
model = AutoModelForSeq2SeqLM.from_pretrained("E:\Internship\EntityLinking\ELinking\checkpoint-2000")
model.gradient_checkpointing_enable()