In [9]:
import json
from copy import deepcopy
import re
from rouge import Rouge
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import SmoothingFunction

result_folder = "../generated/"

class Metric:
    def __init__(self):
        self.tp = 0.
        self.gold_num = 0.
        self.pred_num = 0.

    @staticmethod
    def safe_div(a, b):
        if b == 0.:
            return 0.
        else:
            return a / b

    def compute_f1(self, prefix=''):
        tp = self.tp
        pred_num = self.pred_num
        gold_num = self.gold_num
        p, r = self.safe_div(tp, pred_num), self.safe_div(tp, gold_num)
        return {prefix + 'tp': tp,
                prefix + 'gold': gold_num,
                prefix + 'pred': pred_num,
                prefix + 'P': p * 100,
                prefix + 'R': r * 100,
                prefix + 'F1': self.safe_div(2 * p * r, p + r) * 100
                }

    def count_instance(self, gold_list, pred_list, verbose=False):
        if verbose:
            print("Gold:", gold_list)
            print("Pred:", pred_list)
        self.gold_num += len(gold_list)
        self.pred_num += len(pred_list)

        dup_gold_list = deepcopy(gold_list)
        for pred in pred_list:
            if pred in dup_gold_list:
                self.tp += 1
                dup_gold_list.remove(pred)


def load_data(file_path, data_type):
    file = json.load(open(file_path))
    single_file = [i for i in file if '<|sep|>' not in i['reference_text']]
    # gt_continuation = [i['reference_continuation_text'] for i in single_file if i['data_type']==data_type]
    # pred_continuation = [t_i['generated_result']["0"]['continuation'] for i, t_i in enumerate(single_file) if t_i['data_type']==data_type]
    gt_continuation = [i['reference_continuation_text'].replace(" 's", "'s") for i in file if i['data_type']==data_type]
    pred_continuation = [t_i['generated_result']['0']['continuation'].replace(" 's", "'s") for i, t_i in enumerate(file) if t_i['data_type']==data_type]
    return gt_continuation, pred_continuation


def compute_bleu(gt_continuation, pred_continuation):
    bleu_s = 0
    smoothie = SmoothingFunction().method2
    for ref, can in zip(pred_continuation, gt_continuation):
        bleu_s += sentence_bleu([ref.split()], can.split(), weights=(1,0,0,0), smoothing_function=smoothie)
    print(' bleu score: {}'.format(bleu_s/len(gt_continuation)))


def compute_f1(gt_continuation, pred_continuation):
    all_results = {}
    def get_tokens(out):
        # event type
        out = out + ' <'
        event_type = re.findall(r"\[(.*?)\]", out)
        # special tokens
        special_tokens = re.findall(r"<(.*?)>", out)
        # special_tokens = [special_tokens[i:i+2] for i in range(0, len(special_tokens)-1, 2)]
        # word in between
        extracted = re.findall(r"> (.*?) <", out)
        extracted = [i[:] for i in extracted if (i != '' and i != ' ')]
        return special_tokens, extracted, event_type

    def update(to_update, keys):
        name1, name2, name3, name4 = keys
        try:
            to_update[name1].append(extracted[0].split(" ")[0])
            to_update[name2].append({extracted[0].split(" ")[1]: extracted[0].split(" ")[0]})
        except:
            pass
            # print(keys, extracted)
        
        # to_update[name1].append('a')
        # to_update[name2].append('a')

        # to_update[name3].append([])
        # to_update[name4].append([])

        if 'pred' in name1:
            if len(special_tokens) != len(extracted):
                if len(to_update['format']) == 0:
                    to_update['format'].append(False)
                else:
                    to_update['format'][-1] = False
        elif 'gt' in name1:
            to_update['format'].append(True)

        # for role_pair, tokens in zip(special_tokens[1:], extracted[1:]):
        #     all_args = tokens.split('[and]')
        #     to_update[name3][-1] += all_args
        #     # if role_pair[0][1:-1] == role_pair[1][6:-1]:
        #     #     for j in all_args:
        #     #         to_update[name4][-1].append({"role": role_pair[0][1:-1], "arg": j})
        #     # else:
        #     #     for j in all_args:
        #     #         to_update[name4][-1].append({"role": role_pair[0][1:-1], "arg": j})
        #     #         to_update[name4][-1].append({"role": role_pair[1][6:-1], "arg": j})
        #     for j in all_args:
        #         to_update[name4][-1].append({"role": role_pair[:], "arg": j})
        
        for role_pair, tokens in zip(special_tokens[1:], extracted[1:]):
            all_args = tokens.split('[and]')
            to_update[name3] += all_args
            for j in all_args:
                to_update[name4].append({"role": role_pair[:], "arg": j})
                         
        return to_update

    index = 0
    more_pred = []
    less_pred = []
    for gt_i, pred_i in zip(gt_continuation, pred_continuation):
        # gt_triggers = [i['reference_continuation_text'].split("<|endoftrigger|>")[0].split(" ")[2] for i in file]
        # gt_event = [i['reference_continuation_text'].split("<|endoftrigger|>")[0].split(" ")[2] for i in file]
        all_results[index] = []
        ### for ground truth
        result_i = {'format': [],
                   'gt_triggerwords': [], 'gt_trigger_event': [], 'gt_arguments': [], 'gt_arg_role': [],
                   'pred_triggerwords': [], 'pred_trigger_event': [], 'pred_arguments': [], 'pred_arg_role': []}
        
        gt_i = gt_i.replace(" <|endoftext|>", "")
        gt_i_e = gt_i.split(' <|sep|>')
        for i in gt_i_e:
            special_tokens, extracted, _ = get_tokens(i)
            if extracted:
                result_i = update(result_i, list(result_i.keys())[1:5])

        ### for generated continuation
        pred_i = pred_i.replace(" <|endoftext|>", "")
        pred_i_e = pred_i.split(' <|sep|>')

        if len(pred_i_e) > len(gt_i_e):
            more_pred.append(index)
        if len(pred_i_e) < len(gt_i_e):
            less_pred.append(index)
        for i in pred_i_e:
            special_tokens, extracted, _ = get_tokens(i)
            if extracted:
                result_i = update(result_i, list(result_i.keys())[5:9])

        all_results[index] = result_i
        index += 1
    return all_results, more_pred, less_pred


In [10]:
filepath = '../generated/ace_all_tod_shifted_template_aw2_greedy.json'
gt_data, pred_data = load_data(filepath, 'test')

all_results, more_pred, less_pred = compute_f1(gt_data, pred_data)


In [4]:
print(more_pred)
print(50*'-')
print(less_pred)

[]
--------------------------------------------------
[]


In [13]:
file = json.load(open(filepath))

for i in range(len(file)):
    
    if " 's" in file[i]['prefix_text']:
        print(i)
        print(file[i]['reference_text'].replace(" 's", "'s"))
        print(file[i]['generated_result']['0']['full_text'].replace(" 's", "'s"))

3
<|endoftext|> He's now national director of Win Without War , and former Congressman Bob Dornan , Republican of California . <|triggerword|> former <|template|> <|Person|> [None] <|Entity|> [None] <|Place|> [None] <|endoftemplate|> <|Person|> Bob Dornan <|Entity|> California <|Place|> [None] <|endoftext|>
<|endoftext|> He's now national director of Win Without War , and former Congressman Bob Dornan , Republican of California . <|triggerword|> former <|template|> <|Person|> [None] <|Entity|> [None] <|Place|> [None] <|endoftemplate|> <|Person|> Bob Dornan <|Entity|> [None] <|Place|> [None] <|endoftext|>
5
<|endoftext|> Our president has repeatedly , for example , relied on a man whom you 're aware , Hussein Kamel , Saddam Hussein's son - in - law , leader of the Iraq arms program who defected for a time . <|triggerword|> defected <|template|> <|Person|> [None] <|Entity|> [None] <|Place|> [None] <|endoftemplate|> <|Person|> leader <|Entity|> Iraq <|Place|> [None] <|endoftext|>
<|endoft

In [40]:
all_results[7]

{'format': [True, True],
 'gt_triggerwords': ['war', 'war'],
 'gt_trigger_event': [{'[Conflict_Attack]': 'war'},
  {'[Conflict_Attack]': 'war'}],
 'gt_arguments': [],
 'gt_arg_role': [],
 'pred_triggerwords': ['war'],
 'pred_trigger_event': [{'[Conflict_Attack]': 'war'}],
 'pred_arguments': [],
 'pred_arg_role': []}

In [44]:
def compute_f1_2(gt_continuation, pred_continuation):
    all_results = {}
    def get_tokens(out):
        # event type
        out = out + ' <'
        event_type = re.findall(r"\[(.*?)\]", out)
        # special tokens
        special_tokens = re.findall(r"<(.*?)>", out)
        # special_tokens = [special_tokens[i:i+2] for i in range(0, len(special_tokens)-1, 2)]
        # word in between
        extracted = re.findall(r"> (.*?) <", out)
        extracted = [i[:] for i in extracted if (i != '' and i != ' ')]
        return special_tokens, extracted, event_type

    def update(to_update, keys):
        name1, name2, name3, name4 = keys
        try:
            to_update[name1].append(extracted[0].split(" ")[0])
            to_update[name2].append({extracted[0].split(" ")[1]: extracted[0].split(" ")[0]})
        except:
            pass
            # print(keys, extracted)
        
        # to_update[name1].append('a')
        # to_update[name2].append('a')

        to_update[name3].append([])
        to_update[name4].append([])

        if 'pred' in name1:
            if len(special_tokens) != len(extracted):
                if len(to_update['format']) == 0:
                    to_update['format'].append(False)
                else:
                    to_update['format'][-1] = False
        elif 'gt' in name1:
            to_update['format'].append(True)

        for role_pair, tokens in zip(special_tokens[1:], extracted[1:]):
            all_args = tokens.split('[and]')
            to_update[name3][-1] += all_args
            # if role_pair[0][1:-1] == role_pair[1][6:-1]:
            #     for j in all_args:
            #         to_update[name4][-1].append({"role": role_pair[0][1:-1], "arg": j})
            # else:
            #     for j in all_args:
            #         to_update[name4][-1].append({"role": role_pair[0][1:-1], "arg": j})
            #         to_update[name4][-1].append({"role": role_pair[1][6:-1], "arg": j})
            for j in all_args:
                to_update[name4][-1].append({"role": role_pair[:], "arg": j})
        
        # for role_pair, tokens in zip(special_tokens[1:], extracted[1:]):
        #     all_args = tokens.split('[and]')
        #     to_update[name3] += all_args
        #     for j in all_args:
        #         to_update[name4].append({"role": role_pair[:], "arg": j})
                         
        return to_update

    index = 0
    more_pred = []
    less_pred = []
    for gt_i, pred_i in zip(gt_continuation, pred_continuation):
        # gt_triggers = [i['reference_continuation_text'].split("<|endoftrigger|>")[0].split(" ")[2] for i in file]
        # gt_event = [i['reference_continuation_text'].split("<|endoftrigger|>")[0].split(" ")[2] for i in file]
        all_results[index] = []
        ### for ground truth
        result_i = {'format': [],
                   'gt_triggerwords': [], 'gt_trigger_event': [], 'gt_arguments': [], 'gt_arg_role': [],
                   'pred_triggerwords': [], 'pred_trigger_event': [], 'pred_arguments': [], 'pred_arg_role': []}
        
        gt_i = gt_i.replace(" <|endoftext|>", "")
        gt_i_e = gt_i.split(' <|sep|>')
        for i in gt_i_e:
            special_tokens, extracted, _ = get_tokens(i)
            if extracted:
                result_i = update(result_i, list(result_i.keys())[1:5])

        ### for generated continuation
        pred_i = pred_i.replace(" <|endoftext|>", "")
        pred_i_e = pred_i.split(' <|sep|>')

        if len(pred_i_e) > len(gt_i_e):
            more_pred.append(index)
        if len(pred_i_e) < len(gt_i_e):
            less_pred.append(index)
        for i in pred_i_e:
            special_tokens, extracted, _ = get_tokens(i)
            if extracted:
                result_i = update(result_i, list(result_i.keys())[5:9])

        all_results[index] = result_i
        index += 1
    return all_results, more_pred, less_pred

In [45]:
all_results2, more_pred2, less_pred2 = compute_f1_2(gt_data, pred_data)

In [46]:
print(more_pred2)
print(50*'-')
print(less_pred2)

[13, 16, 23, 25, 47, 60, 62, 77, 82, 83, 89, 99, 135, 137, 150, 157, 159, 164, 165, 168, 177, 181, 188, 200, 204, 205, 207, 210, 214, 217, 220, 221, 231, 233, 234, 235, 238, 240, 241, 242, 248, 263, 279, 280, 281, 282, 286, 288, 291, 297, 298, 299, 300, 302, 307]
--------------------------------------------------
[1, 5, 6, 7, 14, 19, 20, 21, 27, 28, 31, 40, 46, 58, 65, 74, 75, 76, 81, 90, 91, 94, 98, 117, 119, 121, 128, 130, 131, 134, 140, 148, 156, 171, 172, 182, 184, 194, 195, 201, 206, 212, 215, 216, 218, 219, 222, 223, 226, 228, 246, 253, 261, 264, 268, 269, 270, 272, 273, 276, 283, 284, 295]


In [36]:
all_results2[14]

{'format': [True],
 'gt_triggerwords': ['captured'],
 'gt_trigger_event': [{'[Transaction_Transfer-Ownership]': 'captured'}],
 'gt_arguments': [['forces', 'weapons']],
 'gt_arg_role': [[{'role': '|Beneficiary|', 'arg': 'forces'},
   {'role': '|Artifact|', 'arg': 'weapons'}]],
 'pred_triggerwords': ['combat', 'captured', 'rid'],
 'pred_trigger_event': [{'[Conflict_Attack]': 'combat'},
  {'[Transaction_Transfer-Ownership]': 'captured'},
  {'[Business_End-Org]': 'rid'}],
 'pred_arguments': [[], ['forces', 'weapons'], ['forces', 'weapons']],
 'pred_arg_role': [[],
  [{'role': '|Buyer|', 'arg': 'forces'},
   {'role': '|Artifact|', 'arg': 'weapons'}],
  [{'role': '|Org|', 'arg': 'forces'},
   {'role': '|Artifact|', 'arg': 'weapons'}]]}

In [34]:
for gt_i, pred_i in zip(all_results2[14]['pred_arguments'], all_results2[14]['gt_arguments']):
    print(gt_i, pred_i)
    print('-'*20)

[] ['forces', 'weapons']
--------------------


In [14]:
from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5Tokenizer,
    get_linear_schedule_with_warmup
)

In [63]:
t5_model = T5ForConditionalGeneration.from_pretrained('t5-base')


In [66]:
t5_model

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dr

In [67]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')

In [68]:
len(tokenizer)

32100

In [107]:
text = "<|endoftext|> <|content|> Orders went out today to deploy 17,000 U.S. Army soldiers in the Persian Gulf region . <|endofcontent|> <|trigger|> deploy [Movement_Transport] <|Artifact|> soldiers <|Destination|> region <|endoftext|>"
input_text = "falsify: " + text.split("<|content|> ")[1].split(" <|endofcontent|>")[0]
target_text = text.split("<|endofcontent|> ")[1].split(" <|endoftext|>")[0]
print(input_text, "\n", target_text)
tokens = tokenizer(input_text, return_tensors='pt', padding=True, max_length=256+2, truncation=True)
tokenizer(target_text, return_tensors='pt', padding=True, max_length=256+2, truncation=True)

falsify: Orders went out today to deploy 17,000 U.S. Army soldiers in the Persian Gulf region . 
 <|trigger|> deploy [Movement_Transport] <|Artifact|> soldiers <|Destination|> region


{'input_ids': tensor([[    3,     2,  9175,  1788,  6938,  9175,  3155, 17274,   784,   329,
            32,   162,   297,   834, 18474,  1493,   908,     3,     2,  9175,
          7754,    23,  8717,  9175,  3155, 10838,     3,     2,  9175,   308,
           222,    77,   257,  9175,  3155,  1719,     1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [106]:
len(tokenizer)

32100

In [105]:
tokenizer(" <|Artifact|> soldiers <|Destination|> region <|Origin|> [None] <|Vehicle|> [None] <|Agent|> [None] <|Place|> [None] ")

{'input_ids': [3, 2, 9175, 7754, 23, 8717, 9175, 3155, 10838, 3, 2, 9175, 308, 222, 77, 257, 9175, 3155, 1719, 3, 2, 9175, 667, 3380, 77, 9175, 3155, 784, 567, 782, 908, 3, 2, 9175, 553, 15, 107, 23, 2482, 9175, 3155, 784, 567, 782, 908, 3, 2, 9175, 188, 5560, 9175, 3155, 784, 567, 782, 908, 3, 2, 9175, 345, 11706, 9175, 3155, 784, 567, 782, 908, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [71]:
import torch

In [109]:
gen_text = t5_model.generate(input_ids= torch.LongTensor(tokens["input_ids"]).view(1, -1),
                attention_mask = tokens["attention_mask"].view(1, -1),
                max_length=512,
                num_beams=4,
                forced_bos_token_id=None)
tokenizer.decode(gen_text[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)


'........................................................'

In [111]:
t5_model.config.vocab_size

32128

In [112]:
tokenizer.decode([32128])

'<extra_id_-29>'

In [18]:
special_tokens = []
special_tokens += ['[PER]', '[ORG]', '[FAC]', '[LOC]', '[WEA]', '[GPE]', '[VEH]']
special_tokens += ['[\\PER]', '[\\ORG]', '[\\FAC]', '[\\LOC]', '[\\WEA]', '[\\GPE]', '[\\VEH]']

tokenizer.add_tokens(special_tokens)

14

In [27]:
tokenizer.encode('<unk>')

[2, 1]

In [36]:
from transformers import AutoTokenizer
gpt2_tokenizer = AutoTokenizer.from_pretrained('gpt2')

In [38]:
print(gpt2_tokenizer)
gpt2_tokenizer.add_tokens(special_tokens)
print(gpt2_tokenizer)


PreTrainedTokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_len=1024, is_fast=True, padding_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'})
PreTrainedTokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_len=1024, is_fast=True, padding_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'})


In [47]:
len(gpt2_tokenizer)

50271

In [35]:
tokenizer.decode([32113])

'[\\VEH]'

In [48]:
tokenizer

PreTrainedTokenizer(name_or_path='t5-base', vocab_size=32100, model_max_len=512, is_fast=False, padding_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_id_22>', '<extra_id_23>', '<extra_id_24>', '<extra_id_25>', '<extra_id_26>', '<extra_id_27>', '<extra_id_28>', '<extra_id_29>', '<extra_id_30>', '<extra_id_31>', '<extra_id_32>', '<extra_id_33>', '<extra_id_34>', '<extra_id_35>', '<extra_id_36>', '<extra_id_37>', '<extra_id_38>', '<extra_id_39>', '<extra_id_40>', '<extra_id_41>', '<extra_id_42>', '<extra_id_43>', '<extra_id_44>', '<extra_id_45>',