In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from tqdm import tqdm
import torch
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from dataset import CitationTextGenerationDataset
from torch.utils.data import Dataset, DataLoader

In [3]:
device = "cuda"

In [4]:
max_input_length = 16384
max_output_length = 1024

In [5]:
def process_data_to_model_inputs(batch, special_tokens=['[Dominant]', '[Reference]']):
    # tokenize the inputs and labels
    
    additional_special_tokens_lookup = {token: idx for token, idx in zip(tokenizer.additional_special_tokens, tokenizer.additional_special_tokens_ids)}
    special_token_ids = set([additional_special_tokens_lookup[token] for token in special_tokens])
    special_token_ids.add(tokenizer.mask_token_id)
    
    inputs = tokenizer(
        batch["source"],
        padding="max_length",
        truncation=True,
        max_length=max_input_length,
        add_special_tokens=True 
    )
    outputs = tokenizer(
        batch["target"],
        padding="max_length",
        truncation=True,
        max_length=max_output_length,
        add_special_tokens=True 
    )

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask

    # create 0 global_attention_mask lists
    batch["global_attention_mask"] = len(batch["input_ids"]) * [
        [0 for _ in range(len(batch["input_ids"][0]))]
    ]

    # since above lists are references, the following line changes the 0 index for all samples
    for i_batch in range(len(batch["input_ids"])):
        for i_token in range(len(batch["input_ids"][0])):
            if batch["input_ids"][i_batch][i_token] in special_token_ids:
                batch["global_attention_mask"][i_batch][i_token] = 1
            
    batch["labels"] = outputs.input_ids

    # We have to make sure that the PAD token is ignored
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in labels]
        for labels in batch["labels"]
    ]
    return batch

In [6]:
tokenizer = AutoTokenizer.from_pretrained("allenai/led-base-16384")
special_tokens = ['<doc>','</doc>', '[BOS]', '[Dominant]', '[Reference]', '[B_Dominant]',  '[E_Dominant]', '[B_Reference]', '[E_Reference]']
additional_special_tokens = {'additional_special_tokens': special_tokens}
tokenizer.add_special_tokens(additional_special_tokens)

9

In [7]:
model = AutoModelForSeq2SeqLM.from_pretrained("checkpoint-46500")

In [8]:
model = model.to(device).half()

In [9]:
val_set = CitationTextGenerationDataset(
    "/home/data/XiangciLi/CORWA/annotated_test", 
    tokenizer, 
    MAX_SENT_LEN = max_input_length,
    related_work_path='/home/data/XiangciLi/20200705v1/acl/selected_related_work.jsonl',
    cited_metadata_path='/home/data/XiangciLi/20200705v1/acl/selected_cited_metadata.jsonl',
    cited_paper_path="/home/data/XiangciLi/20200705v1/acl/selected_cited_pdf_parses.jsonl",
    citing_paper_path="/home/data/XiangciLi/20200705v1/acl/selected_pdf_parses.jsonl"
)

100%|██████████| 362/362 [00:15<00:00, 23.05it/s]


In [36]:
def run_model(batch, model):
    processed_batch = process_data_to_model_inputs(batch, special_tokens=['[Dominant]', '[Reference]'])
    processed_batch_cuda = {}
    for key in ["input_ids", "attention_mask", "global_attention_mask", "labels"]:
        processed_batch_cuda[key] = torch.tensor(processed_batch[key]).to(device)
    predicted_abstract_ids = model.generate(
        processed_batch_cuda["input_ids"], 
        attention_mask=processed_batch_cuda["attention_mask"], 
        global_attention_mask=processed_batch_cuda["global_attention_mask"]
    )
    out = tokenizer.batch_decode(predicted_abstract_ids, skip_special_tokens=True)
    target = batch["target"]
    return out, target

In [4]:
# for batch in tqdm(DataLoader(val_set, batch_size = 1, shuffle=False)):
#     if '[Dominant]' in batch["source"][0]:
#         out, target = run_model(batch, model)
#         print("Dominant: ", out)
#         batch["source"][0] = batch["source"][0].replace("[Dominant]", "[Reference]")
#         out, target = run_model(batch, model)
#         print("Reference: ", out)
#         print("Dominant label: ", target)
#         print()
#     elif '[Reference]' in batch["source"][0]:
#         out, target = run_model(batch, model)
#         print("Reference: ", out)
#         batch["source"][0] = batch["source"][0].replace("[Reference]", "[Dominant]")
#         out, target = run_model(batch, model)
#         print("Dominant: ", out)
#         print("Reference label: ", target)
#         print()

In [41]:
reference_predicted = []
reference_reference = []
dominant_predicted = []
dominant_reference = []
with open("sample_output.txt","w") as f:
    for batch in tqdm(DataLoader(val_set, batch_size = 6, shuffle=False)):
        processed_batch = process_data_to_model_inputs(batch, special_tokens=['[Dominant]', '[Reference]'])
        processed_batch_cuda = {}
        for key in ["input_ids", "attention_mask", "global_attention_mask", "labels"]:
            processed_batch_cuda[key] = torch.tensor(processed_batch[key]).to(device)
        predicted_abstract_ids = model.generate(
            processed_batch_cuda["input_ids"], 
            attention_mask=processed_batch_cuda["attention_mask"], 
            global_attention_mask=processed_batch_cuda["global_attention_mask"]
        )
        out = tokenizer.batch_decode(predicted_abstract_ids, skip_special_tokens=True)
        target = batch["target"]
        for o,t in zip(out,target):
            f.write(o+"\n")
            f.write(t+"\n")
            f.write("\n")
    #cleaned_out = []
    #cleaned_target = []
    for pred, label, citation, source in zip(out, target, batch["citations"], batch["source"]):
       for c in citation.split("#"):
           pred = pred.replace(c,"")
           label = label.replace(c,"")
       if "[Dominant]" in source:
           dominant_predicted.append(pred)
           dominant_reference.append(label)
       elif "[Reference]" in source:
           reference_predicted.append(pred)
           reference_reference.append(label)

100%|██████████| 551/551 [3:21:32<00:00, 21.95s/it]  


In [33]:
rouge = load_metric("rouge")

In [1]:
rouge.compute(
    predictions=predicted, 
    references=reference, 
    rouge_types=["rouge1","rouge2","rougeL"]
)

In [2]:
rouge.compute(
    predictions=dominant_predicted, 
    references=dominant_reference, 
    rouge_types=["rouge1","rouge2","rougeL"]
)

In [3]:
rouge.compute(
    predictions=reference_predicted, 
    references=reference_reference, 
    rouge_types=["rouge1","rouge2","rougeL"]
)