In [2]:
import pickle
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from os.path import join as joinpath

# %matplotlib inline
%matplotlib qt
pred_path='./Data/gqa/bg_pred.pickle'
with open(pred_path, 'rb') as inp:
    bg_pred_ls = pickle.load(inp)


In [3]:
bg_pred_ls
pred2idx = dict((pred, idx) for idx, pred in enumerate(bg_pred_ls))

In [4]:
def plot_dist_mat(embs, emb_name):
    cnt_emb, cnt_feat = embs.shape
    dist = torch.zeros((cnt_emb, cnt_emb))
    assert len(emb_name) == cnt_emb
    
    for i, x in enumerate(embs):
        for j, y in enumerate(embs):
            dist[i, j] = torch.linalg.vector_norm(x - y).data.item()
    
    # print(dist)
    # dist = nn.functional.normalize(dist, p=2, dim=1)
    # dist=nn.functional.softmax(dist, dim=1)
    # dist -= dist.min(1, keepdim=True)[0]
    # dist /= dist.max(1, keepdim=True)[0]
    # dist = nn.functional.softmax(dist, dim=1)
    # print(dist)
    rank_score, rank_id = torch.sort(dist, descending=True)
    order = torch.zeros((cnt_emb, cnt_emb), dtype=int)
    for row, rank in enumerate(rank_id):
        for ord, pos in enumerate(rank):
            # print(row, order, pos)
            order[row, pos.data.item()] = ord
    # rank_score = torch.sort(dist, descending=True)[0]
    # print(rank_id)
    # print(rank_score)
    # print(order)
    fig = plt.figure()
    ax=fig.add_subplot(111)
    cax = ax.matshow(order)
    fig.colorbar(cax)
    
    plt.xticks(range(cnt_emb), emb_name)
    plt.yticks(range(cnt_emb), emb_name)
    plt.show()
#     plt.show()

In [5]:
def get_closest_pairs(embs, emb_name, cnt, show_dist=False, dist_mode='cosine', norm='none'):
    cnt_emb, cnt_feat = embs.shape
    dist = []
    
    dist_val = torch.zeros((cnt_emb, cnt_emb))
    for i, x in enumerate(embs):
        for j, y in enumerate(embs):
            if dist_mode=='cosine':
                cos = nn.CosineSimilarity(dim=0)
                dist_val[i, j] = cos(x, y)
            else:
                dist_val[i, j] = torch.linalg.vector_norm(x - y).data.item()
    
    # print(dist)
    if norm == 'l2':
        dist_val = nn.functional.normalize(dist_val, p=2, dim=1)
    # dist=nn.functional.softmax(dist, dim=1)
    
    for i, x in enumerate(embs):
        for j, y in enumerate(embs):
            if i <= j:
                continue
            dist.append((dist_val[i][j].data.item(), (emb_name[i], emb_name[j])))
    dist.sort(key=lambda x: x[0], reverse=True)
    # rank_id = torch.sort(dist.view(-1), descending=True)[1]
    # pair_ls = []
    # for pos in rank_id[cnt_emb: cnt_emb+cnt]:
    #     row = int(pos / cnt_emb)
    #     column = pos % cnt_emb
    #     pair_ls.append((emb_name[row], emb_name[column]))

    cnt = min(cnt, len(dist))
    if show_dist:
        print(dist[: cnt])
    else:
        print([x[1] for x in dist[:cnt]])

In [6]:
nlp_model = GPT2LMHeadModel.from_pretrained('gpt2')  # or any other checkpoint
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

temp_index = tokenizer.encode('car', add_prefix_space=True)
temp_feat = nlp_model.transformer.wte.weight[temp_index,:].shape[1]

wn_pred_ls = []  # list of predicates aviliable in wordnet
wn_pred_id = []
wn_embs = []
skip_pred_ls = []
for i, pred in enumerate(bg_pred_ls):
    # TODO: embedding pooling
    wn_idx = tokenizer.encode(pred, add_prefix_space=True)
    if len(wn_idx) > 1:
        skip_pred_ls.append(pred)
        continue

    wn_pred_ls.append(pred)
    wn_pred_id.append(i)
    
    t_emb = nlp_model.transformer.wte.weight[wn_idx,:]
    wn_embs.append(t_emb)
wn_embs = torch.cat(wn_embs, dim=0)
# temp_predicates = PCA(n_components = self.num_feat-2*int(args.add_p0)).fit_transform(temp_predicates.detach().numpy())

In [7]:
print(len(wn_embs))
print(wn_embs[0].shape)
print(skip_pred_ls)

191
torch.Size([768])
['cell_phone', 'covered_by', 'donut', 'faucet', 'giraffe', 'hanging_on', 'in_front_of', 'kite', 'license_plate', 'lying_on', 'mane', 'napkin', 'sitting_on', 'skateboard', 'skis', 'standing_in', 'standing_on', 'street_light', 'surfboard', 't-shirt', 'vase', 'walking_on', 'zebra']


In [8]:
# print(wn_embs.shape)
tot=191
plot_dist_mat(wn_embs[:tot,:].detach(), bg_pred_ls[:tot])


In [9]:
def load_model(model_dir, pred_name, use_gpu):
    path = joinpath(model_dir, pred_name)
    if use_gpu:
        model = torch.load(path)
        model.cuda()
        model.args.use_gpu = True
    else:
        model = torch.load(path, map_location=torch.device('cpu'))
        model.args.use_gpu = False
    return model

In [10]:
model_dir = './Data/gqa/modelts4_es4_template102_md3_recnone_iterPerRound5_numRound3000_mtmax_filterConstants80_feat30_gpuTrue_lrbg0.001_lri0.01_lrr0.01_tgls0_filterIndirectFalse_embWN_randomIPP5_splitD2'
model = load_model(model_dir, 'person', False)
model_embs = model.embeddings_bgs[wn_pred_id, :]
plot_dist_mat(model_embs[:tot,:].detach(), bg_pred_ls[:tot])

In [11]:
cnt=20
tot=10
tot = min(wn_embs.shape[0], tot)
show_dist=True
get_closest_pairs(wn_embs[:tot,:].detach(), bg_pred_ls[:tot], cnt, show_dist=show_dist)
get_closest_pairs(model_embs[:tot,:].detach(), bg_pred_ls[:tot], cnt, show_dist=show_dist)
show_dist=False
get_closest_pairs(wn_embs[:tot,:].detach(), bg_pred_ls[:tot], cnt, show_dist=show_dist)
get_closest_pairs(model_embs[:tot,:].detach(), bg_pred_ls[:tot], cnt, show_dist=show_dist)

[(0.5161124467849731, ('bag', 'backpack')), (0.438531756401062, ('arrow', 'apple')), (0.40217262506484985, ('airplane', 'air')), (0.3558805286884308, ('backpack', 'airplane')), (0.3506677746772766, ('apple', 'airplane')), (0.34332045912742615, ('animal', 'airplane')), (0.34238290786743164, ('arrow', 'animal')), (0.3387574851512909, ('at', 'above')), (0.3335779309272766, ('apple', 'animal')), (0.3329431712627411, ('bag', 'airplane')), (0.3198569715023041, ('arrow', 'airplane')), (0.2955162525177002, ('arm', 'airplane')), (0.27776288986206055, ('arm', 'apple')), (0.27376842498779297, ('animal', 'air')), (0.27369213104248047, ('arrow', 'arm')), (0.26726627349853516, ('backpack', 'arm')), (0.2658688724040985, ('bag', 'air')), (0.2648478150367737, ('arrow', 'above')), (0.2648349106311798, ('bag', 'arrow')), (0.261461079120636, ('bag', 'animal'))]
[(0.8734731674194336, ('bag', 'backpack')), (0.8344952464103699, ('bag', 'airplane')), (0.8223341107368469, ('arrow', 'apple')), (0.81214743852615