In [1]:
import numpy as np
import json#, jsonlines
import matplotlib.pyplot as plt
from eval_qa import eval_file, eval_items
import os
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from tqdm import tqdm
import torch
import torch.nn.functional as F
import pandas as pd
from copy import deepcopy
import random

In [2]:
def return_rank(hd, word_embedding_, token, metric='dot', token_list=None):
    if metric == 'dot':
        word_embedding = word_embedding_
    elif metric == 'cos':
        word_embedding = F.normalize(word_embedding_, p=2, dim=1)
    else:
        assert False

    logits_ = torch.matmul(hd, word_embedding.T)

    rank = [] 
    for j in range(len(logits_)):
        log = logits_[j].cpu().numpy()
        if token_list is None:
            temp = [[i, log[i]] for i in range(len(log))]
        else:
            temp = [[i, log[i]] for i in token_list]
        temp.sort(key=lambda var: var[1], reverse=True)
        rank.append([var[0] for var in temp].index(token))
    return rank


In [3]:
dataset = r'composition1.2000.200.12.6'

# directory = os.path.join(model_dir, "{}_{}_{}".format(dataset, args.wd, args.num_layer))
directory = model_dir = r'output/3.6_0.1_2use_cot'

device = torch.device('cuda:0')

all_atomic = set()     # (h,r,t)
atomic_dict = dict()   # (h,r) -> t
with open("data/{}/train.json".format(dataset)) as f:
    train_items = json.load(f)
for item in tqdm(train_items):
    temp = item['target_text'].strip("><").split("><")
    if len(temp) != 4:
        continue
    h,r,t = temp[:3]
    atomic_dict[(h,r)] = t
    all_atomic.add((h,r,t))

id_atomic = set()
for item in tqdm(train_items):
    temp = item['target_text'].strip("><").split("><")
    if len(temp) == 4:
        continue
    h, r1, r2, b, t = temp[:5]
    b = atomic_dict[(h, r1)]
    try:
        assert atomic_dict[(b, r2)] == t
        id_atomic.add((h,r1,b))
        id_atomic.add((b,r2,t))
    except:
        continue

ood_atomic = all_atomic - id_atomic
print("# id_atomic, # ood_atomic:", len(id_atomic), len(ood_atomic))

h2rt_train = dict()
for (h,r,t) in id_atomic:
    if h not in h2rt_train:
        h2rt_train[h] = []
    h2rt_train[h].append((r,t))

with open("data/{}/test.json".format(dataset)) as f:
    pred_data = json.load(f)
d = dict()
for item in pred_data:
    t = item['type']
    if t not in d:
        d[t] = []
    d[t].append(item)

100%|██████████| 518800/518800 [00:00<00:00, 890091.50it/s]
100%|██████████| 518800/518800 [00:01<00:00, 296232.71it/s]


# id_atomic, # ood_atomic: 38000 2000


In [7]:
with open("data/{}/test_3hop_iid.json".format(dataset)) as f:
    hop_iid_3 = json.load(f)
with open("data/{}/test_3hop_ood.json".format(dataset)) as f:
    hop_ood_3 = json.load(f)
hop_ood_3[0]

{'input_text': '<e_3><r_109><r_112><r_145>',
 'target_text': '<e_3><r_109><r_112><r_145><e_951><e_1482><e_1111></a>',
 'hop1': ['<e_3><r_109><r_112>', '<e_3><r_109><r_112><e_951><e_1482></a>'],
 'hop2': ['<e_951><r_112><r_145>',
  '<e_951><r_112><r_145><e_1482><e_1111></a>'],
 'type': 'test_ood_3hop'}

In [12]:
all_checkpoints = [checkpoint for checkpoint in os.listdir(directory) if checkpoint.startswith("checkpoint")]
all_checkpoints.sort(key=lambda var: int(var.split("-")[1]))

results = []

np.random.seed(0)
target_layer = 2
# TODO: divide a 3-hop fact into two 2-hop facts, and check whether both correct
for checkpoint in tqdm(all_checkpoints[-1:]):
    print("now checkpoint", checkpoint)
    model_path = os.path.join(directory, checkpoint)

    torch.cuda.empty_cache()

    model = GPT2LMHeadModel.from_pretrained(model_path).to(device)
    word_embedding = model.lm_head.weight.data
    tokenizer = GPT2Tokenizer.from_pretrained(model_path)
    tokenizer.padding_side = "left" 
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = model.config.eos_token_id

    iid_correct = 0
    all_correct = 0
    for index, item in tqdm(enumerate(hop_ood_3)):
        # first
        q1 = item['hop1'][0]
        h,r1,r2 = q1.strip("><").split("><")
        h,r1,r2,b,t,_ = item['hop1'][1].strip("><").split("><")
        decoder_temp = tokenizer([q1], return_tensors="pt", padding=True)
        decoder_input_ids, decoder_attention_mask = decoder_temp["input_ids"], decoder_temp["attention_mask"]
        decoder_input_ids, decoder_attention_mask = decoder_input_ids.to(device), decoder_attention_mask.to(device)
        with torch.no_grad():
            outputs = model(
                    input_ids=decoder_input_ids,
                    attention_mask=decoder_attention_mask,
                    output_hidden_states=True
                )
        all_hidden_states = outputs['hidden_states']
        b_rank1 = return_rank(all_hidden_states[target_layer][0, :, :], word_embedding, tokenizer("<"+b+">")['input_ids'][0])[-1]
        # print(b_rank1)

        # print(h,r1,r2,b,t)
        decoder_temp = tokenizer(["<"+h+">"+"<"+r1+">"+"<"+r2+">"+"<"+b+">"], return_tensors="pt", padding=True)
        decoder_input_ids, decoder_attention_mask = decoder_temp["input_ids"], decoder_temp["attention_mask"]
        decoder_input_ids, decoder_attention_mask = decoder_input_ids.to(device), decoder_attention_mask.to(device)
        with torch.no_grad():
            outputs = model(
                    input_ids=decoder_input_ids,
                    attention_mask=decoder_attention_mask,
                    output_hidden_states=True
                )
        all_hidden_states = outputs['hidden_states']
        t_rank1 = return_rank(all_hidden_states[target_layer][0, :, :], word_embedding, tokenizer("<"+t+">")['input_ids'][0])[-1]
        # print(t_rank1)

        # second 
        q2 = item['hop2'][0]
        h,r1,r2 = q2.strip("><").split("><")
        h,r1,r2,b,t,_ = item['hop2'][1].strip("><").split("><")
        # decoder_temp = tokenizer([q2], return_tensors="pt", padding=True)
        # decoder_input_ids, decoder_attention_mask = decoder_temp["input_ids"], decoder_temp["attention_mask"]
        # decoder_input_ids, decoder_attention_mask = decoder_input_ids.to(device), decoder_attention_mask.to(device)
        # with torch.no_grad():
        #     outputs = model(
        #             input_ids=decoder_input_ids,
        #             attention_mask=decoder_attention_mask,
        #             output_hidden_states=True
        #         )
        # all_hidden_states = outputs['hidden_states']
        # b_rank2 = return_rank(all_hidden_states[target_layer][0, :, :], word_embedding, tokenizer("<"+b+">")['input_ids'][0])[-1]
        # print(b_rank)
        # print(h,r1,r2,b,t)
        decoder_temp = tokenizer(["<"+h+">"+"<"+r1+">"+"<"+r2+">"+"<"+b+">"], return_tensors="pt", padding=True)
        decoder_input_ids, decoder_attention_mask = decoder_temp["input_ids"], decoder_temp["attention_mask"]
        decoder_input_ids, decoder_attention_mask = decoder_input_ids.to(device), decoder_attention_mask.to(device)
        with torch.no_grad():
            outputs = model(
                    input_ids=decoder_input_ids,
                    attention_mask=decoder_attention_mask,
                    output_hidden_states=True
                )
        all_hidden_states = outputs['hidden_states']
        t_rank2 = return_rank(all_hidden_states[target_layer][0, :, :], word_embedding, tokenizer("<"+t+">")['input_ids'][0])[-1]
        # print(t_rank2)

        print(b_rank1,t_rank1,t_rank2)
        if t_rank1 == 0 and t_rank2 == 0:
            iid_correct += 1
            if b_rank1 == 0:
                all_correct += 1

        if (index+1)%100 == 0:
            break
        print(all_correct,iid_correct,len(hop_iid_3))
        print("-----------------------------")

  0%|          | 0/1 [00:00<?, ?it/s]

now checkpoint checkpoint-30000


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


0 0 0
1 1 3000
-----------------------------




0 0 0
2 2 3000
-----------------------------




0 0 0
3 3 3000
-----------------------------




0 0 0
4 4 3000
-----------------------------




0 1 0
4 4 3000
-----------------------------




0 1 0
4 4 3000
-----------------------------




0 0 0
5 5 3000
-----------------------------




0 0 0
6 6 3000
-----------------------------




0 0 0
7 7 3000
-----------------------------




0 0 0
8 8 3000
-----------------------------




0 0 0
9 9 3000
-----------------------------




0 0 0
10 10 3000
-----------------------------




0 0 0
11 11 3000
-----------------------------




0 0 0
12 12 3000
-----------------------------




0 0 0
13 13 3000
-----------------------------




0 0 0
14 14 3000
-----------------------------




0 0 0
15 15 3000
-----------------------------




0 0 0
16 16 3000
-----------------------------




0 0 0
17 17 3000
-----------------------------




0 0 0
18 18 3000
-----------------------------




0 0 0
19 19 3000
-----------------------------




0 0 0
20 20 3000
-----------------------------




0 0 0
21 21 3000
-----------------------------




0 0 0
22 22 3000
-----------------------------




0 0 4
22 22 3000
-----------------------------




0 0 0
23 23 3000
-----------------------------




0 0 0
24 24 3000
-----------------------------




0 0 0
25 25 3000
-----------------------------




0 0 0
26 26 3000
-----------------------------




0 0 0
27 27 3000
-----------------------------




0 0 0
28 28 3000
-----------------------------




0 0 0
29 29 3000
-----------------------------




0 0 0
30 30 3000
-----------------------------




0 0 0
31 31 3000
-----------------------------




0 0 0
32 32 3000
-----------------------------




0 0 0
33 33 3000
-----------------------------




0 0 0
34 34 3000
-----------------------------




0 0 0
35 35 3000
-----------------------------




0 0 0
36 36 3000
-----------------------------




0 0 0
37 37 3000
-----------------------------




0 0 0
38 38 3000
-----------------------------




0 0 0
39 39 3000
-----------------------------




0 0 0
40 40 3000
-----------------------------




0 0 0
41 41 3000
-----------------------------




0 0 0
42 42 3000
-----------------------------




0 0 0
43 43 3000
-----------------------------




0 0 0
44 44 3000
-----------------------------




0 0 0
45 45 3000
-----------------------------




0 0 0
46 46 3000
-----------------------------




0 0 0
47 47 3000
-----------------------------




0 0 0
48 48 3000
-----------------------------




0 0 0
49 49 3000
-----------------------------




0 0 0
50 50 3000
-----------------------------




0 0 0
51 51 3000
-----------------------------




0 0 0
52 52 3000
-----------------------------




0 0 0
53 53 3000
-----------------------------




0 0 0
54 54 3000
-----------------------------




0 0 0
55 55 3000
-----------------------------




0 0 0
56 56 3000
-----------------------------




0 0 0
57 57 3000
-----------------------------




0 0 0
58 58 3000
-----------------------------




0 0 0
59 59 3000
-----------------------------




0 0 0
60 60 3000
-----------------------------




0 0 0
61 61 3000
-----------------------------




0 0 0
62 62 3000
-----------------------------




0 0 0
63 63 3000
-----------------------------




0 0 0
64 64 3000
-----------------------------




0 0 0
65 65 3000
-----------------------------




0 0 0
66 66 3000
-----------------------------




0 0 0
67 67 3000
-----------------------------




0 0 0
68 68 3000
-----------------------------




0 0 0
69 69 3000
-----------------------------




0 0 0
70 70 3000
-----------------------------




0 0 0
71 71 3000
-----------------------------




0 0 0
72 72 3000
-----------------------------




0 0 0
73 73 3000
-----------------------------




0 0 0
74 74 3000
-----------------------------




0 0 0
75 75 3000
-----------------------------




0 0 0
76 76 3000
-----------------------------




0 0 0
77 77 3000
-----------------------------




0 0 0
78 78 3000
-----------------------------




0 0 0
79 79 3000
-----------------------------




0 0 0
80 80 3000
-----------------------------




0 0 0
81 81 3000
-----------------------------




0 0 0
82 82 3000
-----------------------------




0 0 0
83 83 3000
-----------------------------




0 0 0
84 84 3000
-----------------------------




0 0 0
85 85 3000
-----------------------------




0 0 0
86 86 3000
-----------------------------




0 0 0
87 87 3000
-----------------------------




0 0 0
88 88 3000
-----------------------------




0 0 0
89 89 3000
-----------------------------




0 0 0
90 90 3000
-----------------------------




0 0 0
91 91 3000
-----------------------------




0 0 0
92 92 3000
-----------------------------




0 0 0
93 93 3000
-----------------------------




0 0 2
93 93 3000
-----------------------------




0 0 0
94 94 3000
-----------------------------




0 0 0
95 95 3000
-----------------------------


99it [01:52,  1.13s/it]
100%|██████████| 1/1 [01:52<00:00, 112.66s/it]

0 0 0





In [16]:
all_correct/100

0.96