In [59]:
import json
import numpy as np
import random
from tqdm.auto import tqdm
import itertools
import os
from copy import deepcopy

In [60]:
def compare(a,b):
    a = int(a.strip("<>"))
    b = int(b.strip("<>"))
    if a<b:
        return 0
    if a==b:
        return 1
    if a>b:
        return 2
    assert False
    
def build_dicts(entities):
    entity2ind = dict()
    ind2entity = []
    for i in range(len(entities)):
        entity = entities[i]
        if not (entity in ind2entity):
            ind2entity.append(entity)
            entity2ind[entity] = len(ind2entity) - 1
    return ind2entity, entity2ind

In [61]:
vocab = []

num_entities = 1000
entities = ["<e_{}>".format(i) for i in range(num_entities)]
vocab = vocab + entities
ind2entity, entity2ind = build_dicts(entities)

num_attributes = 20
attributes = ["<attr_{}>".format(i) for i in range(num_attributes)]
vocab = vocab + attributes
ind2attribute, attribute2ind = build_dicts(attributes)

num_vals_per_attr = 21  # values range from [0, num_vals_per_attr-1]
values = ["<{}>".format(i) for i in range(num_vals_per_attr)]

vocab = vocab + values

# make controlled value
random.seed(0)
np.random.seed(0)

unseen_value = values.pop(-1)
print(f"unseen_value: {unseen_value}")

all_value_pairs = list(itertools.combinations_with_replacement(values, 2))
print(f"all_value_pairs: {len(all_value_pairs)} : {all_value_pairs}")
unseen_value_pairs = random.sample(all_value_pairs, k=int(0.1*len(all_value_pairs)))
print(f"unseen_value_pairs: {len(unseen_value_pairs)} : {unseen_value_pairs}")
seen_value_pairs = [pair for pair in all_value_pairs if pair not in unseen_value_pairs]
print(f"seen_value_pairs: {len(seen_value_pairs)} : {seen_value_pairs}")
unflip_value_pairs = []
i = 0
while len(unflip_value_pairs) < int(0.1*len(seen_value_pairs)):
    if i == len(seen_value_pairs):
        i = 0
    if seen_value_pairs[i][0] == seen_value_pairs[i][1]:
        i += 1
        continue
    if np.random.uniform() < 0.1:
        unflip_value_pairs.append(seen_value_pairs[i])
    i += 1
for i in range(len(unflip_value_pairs)):
    if np.random.uniform() < 0.5:
        unflip_value_pairs[i] = (unflip_value_pairs[i][1], unflip_value_pairs[i][0])
print(f"unflip_value_pairs: {len(unflip_value_pairs)} : {unflip_value_pairs}")

# randomly assign values
atomic_KB = np.random.randint(low=0, high=num_vals_per_attr-1, size=(num_entities, num_attributes))     #  [entity id, attribute id] -> value
# change some value to unseen_value to make ood case
change_ent_indices = np.random.randint(low=0, high=atomic_KB.shape[0], size=atomic_KB.shape[1])
print(change_ent_indices)
atomic_KB[change_ent_indices, np.arange(atomic_KB.shape[1])] = num_vals_per_attr-1

unseen_value: <20>
all_value_pairs: 210 : [('<0>', '<0>'), ('<0>', '<1>'), ('<0>', '<2>'), ('<0>', '<3>'), ('<0>', '<4>'), ('<0>', '<5>'), ('<0>', '<6>'), ('<0>', '<7>'), ('<0>', '<8>'), ('<0>', '<9>'), ('<0>', '<10>'), ('<0>', '<11>'), ('<0>', '<12>'), ('<0>', '<13>'), ('<0>', '<14>'), ('<0>', '<15>'), ('<0>', '<16>'), ('<0>', '<17>'), ('<0>', '<18>'), ('<0>', '<19>'), ('<1>', '<1>'), ('<1>', '<2>'), ('<1>', '<3>'), ('<1>', '<4>'), ('<1>', '<5>'), ('<1>', '<6>'), ('<1>', '<7>'), ('<1>', '<8>'), ('<1>', '<9>'), ('<1>', '<10>'), ('<1>', '<11>'), ('<1>', '<12>'), ('<1>', '<13>'), ('<1>', '<14>'), ('<1>', '<15>'), ('<1>', '<16>'), ('<1>', '<17>'), ('<1>', '<18>'), ('<1>', '<19>'), ('<2>', '<2>'), ('<2>', '<3>'), ('<2>', '<4>'), ('<2>', '<5>'), ('<2>', '<6>'), ('<2>', '<7>'), ('<2>', '<8>'), ('<2>', '<9>'), ('<2>', '<10>'), ('<2>', '<11>'), ('<2>', '<12>'), ('<2>', '<13>'), ('<2>', '<14>'), ('<2>', '<15>'), ('<2>', '<16>'), ('<2>', '<17>'), ('<2>', '<18>'), ('<2>', '<19>'), ('<3>', '<3>'),

In [62]:
print(atomic_KB[1,0], atomic_KB[25,0])
print(atomic_KB[1,0], atomic_KB[27,0])
print(atomic_KB[256,13], atomic_KB[41,13])
print(atomic_KB[3,0], atomic_KB[65,0])

4 7
4 7
3 12
2 18


In [63]:
def rand_flip(tup):
    tup_l = list(tup)
    random.shuffle(tup_l)
    return tuple(tup_l)
    
def choose(arr, ratio_or_count):
    if type(ratio_or_count) == float:
        num = round(ratio_or_count*len(arr))
    elif type(ratio_or_count) == int:
        num = ratio_or_count
    else:
         assert False
    if num >= len(arr):
        return arr
    rand_inds = np.random.choice(len(arr), num, replace=False).tolist()
    return [arr[i] for i in rand_inds]
    
def split(arr, ratio):
    train, test = [], []
    rand_inds = np.random.choice(len(arr), round(ratio*len(arr)), replace=False).tolist()
    for i in range(len(arr)):
        if i in rand_inds:
            train.append(arr[i])
        else:
            test.append(arr[i])
    return [train, test]

In [64]:
# special tokens
vocab = vocab + ["<mask>", "<sep>", "<a>", "</a>", "<q>", "</q>"]

print(vocab)

comp_q_tokens = attributes
comp2labels = dict()
for comp_q_token in comp_q_tokens:
    comp2labels[comp_q_token] = ["<"+comp_q_token.strip("<>")+"_{}>".format(i) for i in range(3)]
    vocab = vocab + comp2labels[comp_q_token]

['<e_0>', '<e_1>', '<e_2>', '<e_3>', '<e_4>', '<e_5>', '<e_6>', '<e_7>', '<e_8>', '<e_9>', '<e_10>', '<e_11>', '<e_12>', '<e_13>', '<e_14>', '<e_15>', '<e_16>', '<e_17>', '<e_18>', '<e_19>', '<e_20>', '<e_21>', '<e_22>', '<e_23>', '<e_24>', '<e_25>', '<e_26>', '<e_27>', '<e_28>', '<e_29>', '<e_30>', '<e_31>', '<e_32>', '<e_33>', '<e_34>', '<e_35>', '<e_36>', '<e_37>', '<e_38>', '<e_39>', '<e_40>', '<e_41>', '<e_42>', '<e_43>', '<e_44>', '<e_45>', '<e_46>', '<e_47>', '<e_48>', '<e_49>', '<e_50>', '<e_51>', '<e_52>', '<e_53>', '<e_54>', '<e_55>', '<e_56>', '<e_57>', '<e_58>', '<e_59>', '<e_60>', '<e_61>', '<e_62>', '<e_63>', '<e_64>', '<e_65>', '<e_66>', '<e_67>', '<e_68>', '<e_69>', '<e_70>', '<e_71>', '<e_72>', '<e_73>', '<e_74>', '<e_75>', '<e_76>', '<e_77>', '<e_78>', '<e_79>', '<e_80>', '<e_81>', '<e_82>', '<e_83>', '<e_84>', '<e_85>', '<e_86>', '<e_87>', '<e_88>', '<e_89>', '<e_90>', '<e_91>', '<e_92>', '<e_93>', '<e_94>', '<e_95>', '<e_96>', '<e_97>', '<e_98>', '<e_99>', '<e_100>'

In [65]:
assert len(vocab) == len(set(vocab))
print("vocab size:", len(vocab))

vocab size: 1107


In [None]:
def format_atomic(entity, attr, val, t):
    val = "<{}>".format(val)
    input_text = "".join([entity, attr])
    target_text = input_text + "".join([val, "</a>"])
    return {
        "input_text": input_text,
        "target_text": target_text,
        "type": t,
    }

def format_comp(comp_q_token, ent_1, ent_2, label, t):
    input_text = "".join([comp_q_token, "<q>", ent_1, "<mask>", ent_2])
    target_text = input_text + "".join([label, "</a>"])
    return {
        "input_text": input_text,
        "target_text": target_text,
        "type": t,
    }

num_id_entities_ratio = 0.9

def compare_ent(ent_1, ent_2, attr):
    val_1, val_2 = atomic_KB[entity2ind[ent_1], attribute2ind[attr]], atomic_KB[entity2ind[ent_2], attribute2ind[attr]]
    return compare("<{}>".format(val_1), "<{}>".format(val_2))

id_atomic_facts, ood_atomic_facts, ood_atomic_unseen_value_facts = [], [], []
train_inferred, test_inferred_id, test_inferred_id_unflip_pair, test_inferred_id_unseen_pair, test_inferred_ood, test_inferred_ood_unseen_pair, test_inferred_ood_unflip_pair, test_inferred_ood_unseen_value = [], [], [], [], [], [], [], []

for comp_q_token in tqdm(comp_q_tokens):
    copied_entities = deepcopy(entities)
    ent_for_unseen_value = copied_entities.pop(change_ent_indices[attribute2ind[comp_q_token]])
    assert len(copied_entities) == len(entities) - 1
    id_entities, ood_entities = split(copied_entities, num_id_entities_ratio)

    # add atomic facts
    for entity in id_entities:
        val = atomic_KB[entity2ind[entity], attribute2ind[comp_q_token]]
        assert val != num_vals_per_attr-1
        id_atomic_facts.append(format_atomic(entity, comp_q_token, val, t='id_atomic'))

    for entity in ood_entities:
        val = atomic_KB[entity2ind[entity], attribute2ind[comp_q_token]]
        assert val != num_vals_per_attr-1
        ood_atomic_facts.append(format_atomic(entity, comp_q_token, val, t='ood_atomic'))
    
    val = atomic_KB[entity2ind[ent_for_unseen_value], attribute2ind[comp_q_token]]
    assert val == num_vals_per_attr-1
    ood_atomic_unseen_value_facts.append(format_atomic(ent_for_unseen_value, comp_q_token, val, t='ood_atomic_unseen'))
    
    # add inferred facts
    all_pairs = list(itertools.combinations(entities, 2))
    for (ent_1, ent_2) in all_pairs:
        val_1, val_2 = atomic_KB[entity2ind[ent_1], attribute2ind[comp_q_token]], atomic_KB[entity2ind[ent_2], attribute2ind[comp_q_token]]
        if f"<{val_1}>" == unseen_value or f"{val_2}" == unseen_value:
            ty = 'test_inferred_ood_unseen_value'
            label = comp2labels[comp_q_token][compare_ent(ent_1, ent_2, comp_q_token)]
            test_inferred_ood_unseen_value.append(format_comp(comp_q_token, ent_1, ent_2, label, t=ty))
            # flip
            label = comp2labels[comp_q_token][compare_ent(ent_2, ent_1, comp_q_token)]
            test_inferred_ood_unseen_value.append(format_comp(comp_q_token, ent_2, ent_1, label, t=ty))
        else:
            value_pair = (f"<{val_1}>", f"<{val_2}>")
            sorted_value_pair = tuple(sorted(value_pair, key=lambda x: int(x.strip("><"))))
            
            if ent_1 in ood_entities and ent_2 in ood_entities:
                if sorted_value_pair in unseen_value_pairs:
                    ty = 'test_inferred_ood_unseen_pair'
                    label = comp2labels[comp_q_token][compare_ent(ent_1, ent_2, comp_q_token)]
                    test_inferred_ood_unseen_pair.append(format_comp(comp_q_token, ent_1, ent_2, label, t=ty))
                    # flip
                    label = comp2labels[comp_q_token][compare_ent(ent_2, ent_1, comp_q_token)]
                    test_inferred_ood_unseen_pair.append(format_comp(comp_q_token, ent_2, ent_1, label, t=ty))
                elif value_pair in unflip_value_pairs:
                    ty = 'test_inferred_ood_unflip_pair'
                    label = comp2labels[comp_q_token][compare_ent(ent_1, ent_2, comp_q_token)]
                    test_inferred_ood_unflip_pair.append(format_comp(comp_q_token, ent_1, ent_2, label, t=ty))
                else:
                    ty = 'test_inferred_ood'
                    label = comp2labels[comp_q_token][compare_ent(ent_1, ent_2, comp_q_token)]
                    test_inferred_ood.append(format_comp(comp_q_token, ent_1, ent_2, label, t=ty))
                    # flip
                    label = comp2labels[comp_q_token][compare_ent(ent_2, ent_1, comp_q_token)]
                    test_inferred_ood.append(format_comp(comp_q_token, ent_2, ent_1, label, t=ty))
            elif ent_1 in id_entities and ent_2 in id_entities:
                if sorted_value_pair in unseen_value_pairs:
                    ty = 'test_inferred_id_unseen_pair'
                    label = comp2labels[comp_q_token][compare_ent(ent_1, ent_2, comp_q_token)]
                    test_inferred_id_unseen_pair.append(format_comp(comp_q_token, ent_1, ent_2, label, t=ty))
                    # flip
                    label = comp2labels[comp_q_token][compare_ent(ent_2, ent_1, comp_q_token)]
                    test_inferred_id_unseen_pair.append(format_comp(comp_q_token, ent_2, ent_1, label, t=ty))
                elif value_pair in unflip_value_pairs:
                    ty = 'test_inferred_id_unflip_pair'
                    label = comp2labels[comp_q_token][compare_ent(ent_1, ent_2, comp_q_token)]
                    test_inferred_id_unflip_pair.append(format_comp(comp_q_token, ent_1, ent_2, label, t=ty))
                else:                
                    if np.random.uniform() < 0.1:
                        ty = 'test_inferred_id'
                        label = comp2labels[comp_q_token][compare_ent(ent_1, ent_2, comp_q_token)]
                        test_inferred_id.append(format_comp(comp_q_token, ent_1, ent_2, label, t=ty))
                        # flip
                        label = comp2labels[comp_q_token][compare_ent(ent_2, ent_1, comp_q_token)]
                        test_inferred_id.append(format_comp(comp_q_token, ent_2, ent_1, label, t=ty))
                    else:
                        ty = 'train_inferred'
                        label = comp2labels[comp_q_token][compare_ent(ent_1, ent_2, comp_q_token)]
                        train_inferred.append(format_comp(comp_q_token, ent_1, ent_2, label, t=ty))
                        # flip
                        label = comp2labels[comp_q_token][compare_ent(ent_2, ent_1, comp_q_token)]
                        train_inferred.append(format_comp(comp_q_token, ent_2, ent_1, label, t=ty))
            else:
                pass

 80%|████████  | 16/20 [02:25<00:36,  9.04s/it]

In [None]:
print(len(id_atomic_facts), 
      len(ood_atomic_facts), 
      len(ood_atomic_unseen_value_facts), 
      len(train_inferred), 
      len(test_inferred_id), 
      len(test_inferred_id_unflip_pair), 
      len(test_inferred_id_unseen_pair), 
      len(test_inferred_ood), 
      len(test_inferred_ood_unflip_pair), 
      len(test_inferred_ood_unseen_pair), 
      len(test_inferred_ood_unseen_value))

17980 2000 20 12408672 1381088 367092 1622096 169992 4437 19134 17790


In [None]:
test_size = 3000


test_inferred_ood_ds = choose(test_inferred_ood, test_size)

probes = []
probes = probes + test_inferred_ood_ds
probes = probes + choose(id_atomic_facts, test_size)
probes = probes + choose(ood_atomic_facts, test_size)
probes = probes + choose(ood_atomic_unseen_value_facts, test_size)
probes = probes + choose(test_inferred_id, test_size)
probes = probes + choose(test_inferred_id_unflip_pair, test_size)
probes = probes + choose(test_inferred_id_unseen_pair, test_size)
probes = probes + choose(test_inferred_ood_unflip_pair, test_size)
probes = probes + choose(test_inferred_ood_unseen_pair, test_size)
probes = probes + choose(test_inferred_ood_unseen_value, test_size)

all_atomics = id_atomic_facts + ood_atomic_facts

In [None]:
# downsampling inferred facts included in training
for inf_atom_ratio in [7.2]:
    dataset_name = "comparison.{}.{}-controlled".format(num_entities, inf_atom_ratio)
    os.makedirs("data/{}".format(dataset_name), exist_ok=True)

    train_inferred_ds = choose(train_inferred, round(inf_atom_ratio*len(id_atomic_facts)))

    probes_ = probes + choose(train_inferred_ds, test_size)

    print("train/test atomic, # train inferred:", len(id_atomic_facts), len(ood_atomic_facts), len(train_inferred_ds))
    with open("data/{}/atomic_facts.json".format(dataset_name), "w", encoding='utf-8') as f:
        json.dump(all_atomics, f)
    with open("data/{}/train.json".format(dataset_name), "w", encoding='utf-8') as f:
        json.dump(id_atomic_facts + ood_atomic_facts + train_inferred_ds, f)
    with open("data/{}/valid.json".format(dataset_name), "w", encoding='utf-8') as f:
        json.dump(test_inferred_ood_ds, f)
    with open("data/{}/test.json".format(dataset_name), "w", encoding='utf-8') as f:
        json.dump(probes_, f)
    with open("data/{}/unseen_value_pairs.json".format(dataset_name), "w", encoding='utf-8') as f:
        json.dump(unseen_value_pairs, f)
    with open("data/{}/unflip_value_pairs.json".format(dataset_name), "w", encoding='utf-8') as f:
        json.dump(unflip_value_pairs, f)
    # add vocab
    with open("data/{}/vocab.json".format(dataset_name), "w", encoding='utf-8') as f:
        json.dump(vocab, f)

train/test atomic, # train inferred: 17980 2000 129456
