In [1]:
from sklearn.manifold import TSNE, MDS
from matplotlib import pyplot as plt
import matplotlib
import pickle
import numpy as np
import os
import pandas as pd
import seaborn as sns
import torch
import json
import pprint
import re
from adjustText import adjust_text

In [3]:
# MILESTONE
MILESTONE_PAT_OUTDIR = '../output/run31_vbasis_nornds11_rare_l3_autosgdlrpt1_autow_pt2_tgtlr1_autoavg_preavg_maxopt_autofixed-b5-kb11/patterns'
MILESTONE_KB_OUTDIR = '../output/run31_vbasis_nornds11_rare_l3_autosgdlrpt1_autow_pt2_tgtlr1_autoavg_preavg_maxopt_autofixed-b5-kb11/kb_rels'
MILESTONE_TARGET_EMB = "../models/run31_vbasis_nornds11_rare_l3_autosgdlrpt1_autow_pt2_tgtlr1_autoavg_preavg_maxopt_autofixed-b5-kb11-20200618-214714/ep50/target_emb.pt"

# BASELINE
# BASELINE_TARGET_EMB = "../models/run31_vbasis_nornds11_rare_l3_trans_encdrop_pt3_autosgdlrpt1_autow_pt2_tgtlr1_autoavg_preavg_maxopt_autofixed-b1-kb1-20200802-171731/ep50/target_emb.pt"
# BASELINE_PAT_OUTDIR = '../output/run31_vbasis_nornds11_rare_l3_trans_encdrop_pt3_autosgdlrpt1_autow_pt2_tgtlr1_autoavg_preavg_maxopt_autofixed-b1-kb1/patterns'
# BASELINE_KB_OUTDIR = '../output/run31_vbasis_nornds11_rare_l3_trans_encdrop_pt3_autosgdlrpt1_autow_pt2_tgtlr1_autoavg_preavg_maxopt_autofixed-b1-kb1/kb_rels'

# NEW BASELINE AFTER HYPERPARAM TUNING
BASELINE_TARGET_EMB = "../models/run31_vbasis_nornds11_rare_l3_autosgdlrpt1_autow_pt2_tgtlr1_autoavg_preavg_encdrop_pt35_maxopt_autofixed-b1-kb1-20200920-003007/ep50/target_emb.pt"
BASELINE_PAT_OUTDIR = '../output/run31_vbasis_nornds11_rare_l3_autosgdlrpt1_autow_pt2_tgtlr1_autoavg_preavg_encdrop_pt35_maxopt_autofixed-b1-kb1/patterns'
BASELINE_KB_OUTDIR = '../output/run31_vbasis_nornds11_rare_l3_autosgdlrpt1_autow_pt2_tgtlr1_autoavg_preavg_encdrop_pt35_maxopt_autofixed-b1-kb1/kb_rels'

KB_BASIS_PRED_FILE = 'basis_pred.npy'
KB_EMB_FILE = 'emb.npy'
IDX2KB_FILE = 'idx2kb_dict.pkl'
PAT_BASIS_PRED_FILE = 'pat_basis_pred.pt'
IDX2PAT = 'idx2pat_dict.pkl'
ENTPAIR_NEW_VOCAB_FILE = '../data/var_basis_wo_test/entpair-new-vocab.txt'
ENTPAIR_DICT = '../data/var_basis_wo_test/entpair_dictionary_index'
FREEBASE_MAP = "../data/en-freebase_wiki_cat_title_map.txt"

In [4]:
def prepare_data(PAT_OUTDIR, KB_OUTDIR, TARGET_EMB):
    # Read basis predictions for patterns
    pat_basis_preds = torch.load(os.path.join(PAT_OUTDIR, PAT_BASIS_PRED_FILE), map_location='cpu')
    print("Pattern basis predictions:", pat_basis_preds.shape)
    # Read idx2pat map
    with open(os.path.join(PAT_OUTDIR, IDX2PAT), 'rb') as fin:
        idx2pat = pickle.load(fin)
    # Construct pat2idx map
    pat2idx = {pat:i for i, pat in enumerate(idx2pat.values())}
    print('Finished constructing pat2idx')
    # Read basis predictions for kb relations
    kb_basis_preds = np.load(os.path.join(KB_OUTDIR, KB_BASIS_PRED_FILE))
    print('KB Relation basis predictions:', kb_basis_preds.shape)
    
#     # Read embeddings for kb relations (encoder output)
#     kb_embs = np.load(os.path.join(KB_OUTDIR, KB_EMB_FILE))
    
    # Read idx2kb map
    with open(os.path.join(KB_OUTDIR, IDX2KB_FILE), 'rb') as fin:
        idx2kb = pickle.load(fin)
    # Construct kb2idx map
    kb2idx = {kb + ' <eos>':i for i, kb in enumerate(idx2kb.values())}
    print("Finished constructing kb2idx")
    # Read entity pair vocab
    entpair_vocab_map = {}
    with open(ENTPAIR_NEW_VOCAB_FILE, "r") as fin:    
        for line in fin:
            line = line.rstrip()
            index = line.find(":")
            entpair_vocab_map[line[:index]] = line[index+1:]
    print("Finished reading entity pair vocab")
#     # Read entity pair frequency map
#     target_idx2word_freq = []
#     with open(ENTPAIR_DICT, "r") as f_in:    
#         for i, line in enumerate(f_in):
#             fields = line.rstrip().split('\t')
#             if len(fields) == 3:
#                 assert len(target_idx2word_freq) == int(fields[2])
#                 target_idx2word_freq.append([fields[0],int(fields[1])])
    
    # Read freebase code to entity map
    freebase_map = {}
    with open(FREEBASE_MAP, 'r') as fin:
        for line in fin:
            parts = line.strip().split('\t')
            freebase_map[parts[1]] = parts[0]        
            assert len(parts) == 2, "Got length: {}".format(len(parts))    
    freebase_map_rev = {v:k for k, v in freebase_map.items()}
    print("Finished reading freebase code to entity map")
    
    # Construct entity pair to idx map
    target2idx = {target: int(idx[2:]) for idx, target in entpair_vocab_map.items()}
    
    # Read top 3 closest entity pairs for patterns/kb rels from json
    # json constructed manually
    def get_target_emb(target):
        return target_embs[target2idx[target]]
    with open("../output/closest_pairs.json", "r") as fin:
        closest_pairs = json.load(fin)

    def parse_targets(targets):
        targets_str = targets[targets.find(':')+1:].strip()    
        return list(filter(lambda s: s != "", re.split(r'\s*0\.\d+\s*', targets_str)))

    for run, pat2targets in closest_pairs.items():
        for pat, basis_list in pat2targets.items():
            for basis, targets in enumerate(basis_list):
                target_list = parse_targets(targets)
                target_idx_list = []
                for target in target_list:
                    target = "\t".join(list(map(lambda entity: freebase_map_rev.get(entity, entity), target.split("\t"))))
                    target_idx_list.append(target2idx.get(target, -1))
                basis_list[basis] = list(zip(target_idx_list, target_list))
    # pprint.pprint(closest_pairs)
    target_embs = torch.load(TARGET_EMB, map_location='cpu')
    print("Targets:", target_embs.shape)
   
    return { 
        'pat_basis_preds': pat_basis_preds, 
        'pat2idx': pat2idx, 
        'kb_basis_preds': kb_basis_preds, 
        'kb2idx': kb2idx, 
        'closest_pairs': closest_pairs, 
        'target_embs': target_embs,
        'target2idx': target2idx
    }

In [5]:
milestone_data = prepare_data(MILESTONE_PAT_OUTDIR, MILESTONE_KB_OUTDIR, MILESTONE_TARGET_EMB)

Pattern basis predictions: torch.Size([1261610, 11, 100])
Finished constructing pat2idx
KB Relation basis predictions: (41, 11, 100)
Finished constructing kb2idx
Finished reading entity pair vocab
Finished reading freebase code to entity map
Targets: torch.Size([549761, 100])


In [6]:
baseline_data = prepare_data(BASELINE_PAT_OUTDIR, BASELINE_KB_OUTDIR, BASELINE_TARGET_EMB)

Pattern basis predictions: torch.Size([1261610, 1, 100])
Finished constructing pat2idx
KB Relation basis predictions: (41, 1, 100)
Finished constructing kb2idx
Finished reading entity pair vocab
Finished reading freebase code to entity map
Targets: torch.Size([549761, 100])


## Headquarter example for paper

In [7]:
patterns = ['$ARG1 headoffice in $ARG2 <eos>', \
            '$ARG1 headqarters in $ARG2 <eos>', \
            'org:city_of_headquarters <eos>', \
            '$ARG1 is now at $ARG2 <eos>']

In [8]:
def get_reduced_embs(patterns, run, data, use_random=False, std=0.01, seed=20):
    np.random.seed(seed)
    pat_basis_preds = data['pat_basis_preds']
    pat2idx = data['pat2idx']
    kb_basis_preds = data['kb_basis_preds']
    kb2idx = data['kb2idx']
    closest_pairs = data['closest_pairs']
    target_embs = data['target_embs']
    
    pattern_pred_embs = np.concatenate([pat_basis_preds[pat2idx[pat]].numpy() \
                                        if pat in pat2idx else kb_basis_preds[kb2idx[pat]] \
                                        for pat in patterns], axis=0)
    if use_random:
        pattern_pred_embs += np.random.randn(*pattern_pred_embs.shape)*std
    print("Basis predictions:", pattern_pred_embs.shape)
    target_embs_for_pats = np.concatenate([target_embs[entpair[0]].detach().cpu().numpy().reshape(1, -1) \
                                           for i in range(len(patterns)) \
                                           for basis_list in closest_pairs[run][patterns[i]] \
                                           for entpair in basis_list], axis=0)
    target_embs_for_pats_uniq, target_embs_for_pats_idx = np.unique(target_embs_for_pats, \
                                                                    return_index=True, \
                                                                    axis=0)
    print("Unique entity pairs:", target_embs_for_pats_uniq.shape)
    target_embs_for_pats_labels = [entpair[1] for i in range(len(patterns)) \
                                   for basis_list in closest_pairs[run][patterns[i]] \
                                   for entpair in basis_list]
    target_embs_for_pats_labels_uniq = np.take(target_embs_for_pats_labels, target_embs_for_pats_idx)   
    
    combined = np.concatenate([pattern_pred_embs, target_embs_for_pats_uniq], axis=0)
    print("Combined:", combined.shape)
    reduced_tsne = TSNE(n_components=2, perplexity=25).fit_transform(combined)
    reduced_mds = MDS(n_components=2, random_state=seed).fit_transform(combined)
    print("Reduced:", 'tsne-:', reduced_tsne.shape, 'mds-:', reduced_mds.shape)
    return reduced_tsne, reduced_mds, target_embs_for_pats_labels, target_embs_for_pats_idx

In [9]:
# milestone_reduced_tsne, milestone_reduced_mds, milestone_labels, milestone_uniq_idx = get_reduced_embs(patterns, 'basis 5, kb basis 11', milestone_data, True)

# baseline_reduced_tsne, baseline_reduced_mds, baseline_labels, baseline_uniq_idx = get_reduced_embs(patterns, 'basis 1, kb basis 1', baseline_data, std=1)

In [10]:
def set_font_size(small, medium, large):
    SMALL_SIZE = small
    MEDIUM_SIZE = medium
    BIGGER_SIZE = large

    plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
    plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
    plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
    plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
    plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
    plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

In [11]:
def plot_basis_and_eps(patterns, \
                       is_kb, \
                       reduced, \
                       target_embs_for_pats_labels, \
                       target_embs_for_pats_idx_uniq, \
                       figsize, \
                       n_basis, \
                       n_basis_kb, \
                       colors, \
                       file, \
                       label, \
                       has_ep_labels,\
                       eps_to_label=None,\
                       markersize=60,
                       expand_points=(1.25, 1.5),\
                       force_points=(1.2, 1.2)
                      ):
    plt.figure(figsize=figsize)
    set_font_size(15, 15, 15)
    
    start = 0
    for i, pat in enumerate(patterns):    
        blen = n_basis_kb if is_kb[i] else n_basis        
        plt.scatter(x=reduced[start:start+n_basis_kb, 0][:blen], \
                    y=reduced[start:start+n_basis_kb, 1][:blen], \
                    color=colors[i], marker='o', s=markersize,\
                    label=patterns[i][:-6].replace('$', '\$'))
        start += n_basis_kb
    
    eps_to_label = set() if eps_to_label is None else set(eps_to_label)
    target_embs_for_pats_labels_uniq = np.take(target_embs_for_pats_labels, \
                                               target_embs_for_pats_idx_uniq)
    
    indices = [i for i in range(len(target_embs_for_pats_labels_uniq)) \
               if target_embs_for_pats_labels_uniq[i] in eps_to_label]
    if len(indices) == 0:
        indices = list(range(len(target_embs_for_pats_labels_uniq)))
    pruned_eps = np.array([tuple(ep) for ep in np.take(reduced[start:], indices, axis=0)], \
                          dtype=[('x', np.float), ('y', np.float)])
    pruned_labels = np.take(target_embs_for_pats_labels_uniq, indices, axis=0)
    sorted_indices = np.argsort(pruned_eps, axis=0, order=['y', 'x'])    
    plt.scatter(x=pruned_eps['x'], y=pruned_eps['y'], color='purple', \
                marker='x', s=markersize, label='Entity Pairs')
    print("Number of pruned entity pairs", len(pruned_eps))
    
    print(eps_to_label.difference(set(pruned_labels)))
    
    if has_ep_labels:
        simplifyText = lambda s: s[len("E_SLUG_"):-len("_langEN")].replace('_', ' ') \
        if s.startswith("E_SLUG_") and s.endswith("_langEN") else s
        texts = []
        for index in sorted_indices:
            point = pruned_eps[index]            
            
            text = '("' + '", "'.join(list(map(simplifyText, pruned_labels[index].split('\t')))) + '")'
            texts.append(plt.text(point[0], point[1], text))            
        adjust_text(texts, expand_points=expand_points, force_points=force_points,\
#                     expand_text=(1.2, 1.5), expand_points=(1, 1.5), \
#                     force_text=(0.5, 1.2), force_points=(1.2, 1.2), \
                    arrowprops=dict(arrowstyle="-|>", color='teal', lw=0.8),\
                    save_steps=False)

    plt.xlabel('x')
    plt.ylabel('y')
    plt.title(label)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig(file, dpi=300)
    plt.close()

In [12]:
milestone_eps_to_plot = [
    "Caltech\tPasadena , California",\
    "Soochow University\tTaipei City",\
    "Martin Schempp\tStuttgart",\
    "Paul Dini\tNew York City",\
    "E_SLUG_International_Olympic_Committee_langEN\t2007",\
    "E_SLUG_Marc_Ravalomanana_langEN\t2002",\
    "Tenzing Norgay\t1986",\
    "Coretta Scott King\tJanuary 2006",\
    "Interlake High School\tBellevue",\
    "Southern High School\tDurham",\
    "Eastern Kentucky University\tRichmond",\
    "Oakland University\tRochester , Michigan",\
    "World Tourism Organization\tMadrid",\
    "E_SLUG_Lockheed_Martin_langEN\tBethesda",\
    "Iomega\tSan Diego",\
    "McDonald's\tOak Brook",\
    "Gallas\tArsenal",\
    "Malouda\tChelsea",\
    "Kolo Toure\tArsenal",\
    "Arsenal\tEnglish Premier League",\
    "Stephen Roach\tMorgan Stanley",\
    "Commodore Voreqe Bainimarama\tE_SLUG_Fiji_langEN",\
    "Derek Abbott\tUniversity of Adelaide",\
    "Robert F. Goheen\tPrinceton University",\
    "E_SLUG_Arabinda_Rajkhowa_langEN\tE_SLUG_United_Liberation_Front_of_Assam_langEN"\
]

baseline_eps_to_plot = [
    "E_SLUG_Bowe_Bergdahl_langEN\tE_SLUG_United_States_langEN",\
    "Interlake High School\tBellevue",\
    "Southern High School\tDurham",\
    "IOC\tLausanne",\
    "Samsung\tSeoul",\
    "Peter Moore\tBritish Embassy"\
]

In [14]:
current_dir = os.getcwd()
imgdir = os.path.join(current_dir, 'figs')
if not os.path.exists(imgdir):
    os.makedirs(imgdir)
random_seeds = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384]
for random_seed in random_seeds:
    print("\nRandom seed: {}\n".format(random_seed))    
    
    _, milestone_reduced_mds, milestone_labels, milestone_uniq_idx = \
    get_reduced_embs(patterns, 'basis 5, kb basis 11', milestone_data, use_random=True, seed=random_seed)

    _, baseline_reduced_mds, baseline_labels, baseline_uniq_idx = \
    get_reduced_embs(patterns, 'basis 1, kb basis 1', baseline_data, use_random=True, seed=random_seed)
    
    milestone_file = os.path.join(imgdir, 'milestone_mds_with_eplabels_{}.pdf'.format(random_seed))
    plot_basis_and_eps(patterns, [False, False, True, False], milestone_reduced_mds, milestone_labels, \
                   milestone_uniq_idx, (18, 12), 5, 11,['r', 'g', 'b', 'orange'], \
                   milestone_file, 'Pattern Basis=5, KB Relation Basis=11', \
                   True, eps_to_label=milestone_eps_to_plot)
    baseline_file = os.path.join(imgdir, 'baseline_mds_with_eplabels_{}.pdf'.format(random_seed))
    plot_basis_and_eps(patterns, [False, False, True, False], baseline_reduced_mds, baseline_labels, \
                   baseline_uniq_idx, (12, 10), 1, 1, \
                   ['r', 'g', 'b', 'orange'], baseline_file, \
                   'Pattern Basis=1, KB Relation Basis=1', True, eps_to_label=baseline_eps_to_plot,\
                      expand_points=(1.25, 1.5), force_points=(1.2, 3))


Random seed: 32

Basis predictions: (44, 100)
Unique entity pairs: (64, 100)
Combined: (108, 100)
Reduced: tsne-: (108, 2) mds-: (108, 2)
Basis predictions: (4, 100)
Unique entity pairs: (15, 100)
Combined: (19, 100)
Reduced: tsne-: (19, 2) mds-: (19, 2)
Number of pruned entity pairs 25
set()
Number of pruned entity pairs 6
set()

Random seed: 64

Basis predictions: (44, 100)
Unique entity pairs: (64, 100)
Combined: (108, 100)
Reduced: tsne-: (108, 2) mds-: (108, 2)
Basis predictions: (4, 100)
Unique entity pairs: (15, 100)
Combined: (19, 100)
Reduced: tsne-: (19, 2) mds-: (19, 2)
Number of pruned entity pairs 25
set()
Number of pruned entity pairs 6
set()

Random seed: 128

Basis predictions: (44, 100)
Unique entity pairs: (64, 100)
Combined: (108, 100)
Reduced: tsne-: (108, 2) mds-: (108, 2)
Basis predictions: (4, 100)
Unique entity pairs: (15, 100)
Combined: (19, 100)
Reduced: tsne-: (19, 2) mds-: (19, 2)
Number of pruned entity pairs 25
set()
Number of pruned entity pairs 6
set()


In [79]:
# plot_basis_and_eps(patterns, [False, False, True, False], milestone_reduced_mds, milestone_labels, \
#                    milestone_uniq_idx, (20, 15), 5, 11,['r', 'g', 'b', 'orange'], \
#                    'milestone_mds_with_eplabels.png', 'b5kb11-mds-with-eplabels', \
#                    True, eps_to_label=milestone_eps_to_plot)

In [80]:
# plot_basis_and_eps(patterns, [False, False, True, False], milestone_reduced_mds, milestone_labels, \
#                    milestone_uniq_idx, (20, 15), 5, 11, \
#                    ['r', 'g', 'b', 'orange'], 'milestone_mds_without_eplabels.png', \
#                    'b5kb11-mds-without-eplabels', False, eps_to_label=milestone_eps_to_plot)

In [None]:
# plot_basis_and_eps(patterns, [False, False, True, False], baseline_reduced_mds, baseline_labels, \
#                    baseline_uniq_idx, (20, 15), 1, 1, \
#                    ['r', 'g', 'b', 'orange'], 'baseline_mds_with_eplabels.png', \
#                    'b1kb1-mds-with-eplabels', True, eps_to_label=baseline_eps_to_plot)

In [None]:
# plot_basis_and_eps(patterns, [False, False, True, False], milestone_reduced_tsne, milestone_labels, milestone_uniq_idx, (9, 8), 5, 11, \
#                    ['r', 'g', 'b', 'orange'], 'milestone_tsne_with_eplabels.png', \
#                    'b5kb11-tsne-with-eplabels', True, label_every=1, eps_to_label=milestone_eps_to_plot)

# plot_basis_and_eps(patterns, [False, False, True, False], milestone_reduced_tsne, milestone_labels, milestone_uniq_idx, (9, 8), 5, 11, \
#                    ['r', 'g', 'b', 'orange'], 'milestone_tsne_without_eplabels.png', \
#                    'b5kb11-tsne-without-eplabels', False, label_every=1, eps_to_label=milestone_eps_to_plot)

# plot_basis_and_eps(patterns, [False, False, True, False], baseline_reduced_tsne, baseline_labels, \
#                    baseline_uniq_idx, (18, 12), 1, 1, \
#                    ['r', 'g', 'b', 'orange'], 'baseline_tsne_with_eplabels.png', \
#                    'b1kb1-tsne-with-eplabels', True, eps_to_label=baseline_eps_to_plot)