In [None]:
import numpy as np
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
from collections import defaultdict

viz_dict = np.load("viz_dict.npy").item()
word2idx = np.load("word2idx.npy").item()
idx2word = np.load("idx2word.npy").item()
idx2title = np.load("movid_to_name_dict.npy").item()

sortedKeys = sorted(viz_dict)
print(len(sortedKeys))
print(sortedKeys)    


In [None]:
marginal = {}
conditional = {}
embed = defaultdict(lambda: defaultdict())
mn, mx = float("-inf"), float("inf")
embed_min_max = defaultdict(lambda:(mx,mn,mx,mn)) # (xmin, xmax, ymin, ymax)
adjList = defaultdict(set)
ids = set()
epochs = -1
for k in sortedKeys:
    epoch, id1, id2 = k
    conditional[(epoch, id1, id2)] = viz_dict[k][0]
    marginal[(epoch, id1)] = viz_dict[k][1]
    embed[epoch][id1] = (viz_dict[k][2], viz_dict[k][3])
    marginal[(epoch, id2)] = viz_dict[k][4]
    embed[epoch][id2] = (viz_dict[k][5], viz_dict[k][6])
    
    embed_min_max[epoch] = (min(embed_min_max[epoch][0],viz_dict[k][2][0]),
                            max(embed_min_max[epoch][1],viz_dict[k][3][0]),
                            min(embed_min_max[epoch][2],viz_dict[k][2][1]),
                            max(embed_min_max[epoch][3],viz_dict[k][3][1]))
    
    embed_min_max[epoch] = (min(embed_min_max[epoch][0],viz_dict[k][5][0]),
                            max(embed_min_max[epoch][1],viz_dict[k][6][0]),
                            min(embed_min_max[epoch][2],viz_dict[k][5][1]),
                            max(embed_min_max[epoch][3],viz_dict[k][6][1]))
    adjList[id1].add(id2)
    ids.add(id1)
    ids.add(id2)
    epochs = max(epochs, epoch)

# Normalize the coordinates
for epoch in range(epoch+1):
    xmin, ymin = embed_min_max[epoch][0], embed_min_max[epoch][2]
    xdelta, ydelta = embed_min_max[epoch][1]-xmin, embed_min_max[epoch][3]-ymin
    for key in embed[epoch]:
        embed[epoch][key] = (((embed[epoch][key][0][0]-xmin)/(xdelta), 
                             (embed[epoch][key][0][1]-ymin)/(ydelta)),
                             ((embed[epoch][key][1][0]-xmin)/(xdelta),
                             (embed[epoch][key][1][1]-ymin)/(ydelta))
                            )
        

In [None]:
n = len(ids)
ids = list(ids)
my_cmap = plt.cm.get_cmap("Blues", n)
counter = 0
data_idx_map = {}
for i in ids:
    data_idx_map[i] = counter
    counter+=1
    
for e in range(0, epochs+1, 50):
    plt.figure(figsize = (10, 7))
    data = [[0]*n for i in range(n)]
    for i in range(n):
        index = data_idx_map[ids[i]]
        data[index][index] = 1 #marginal[(e, index)]
        for j in adjList[index]:
            if (e, index, j) in conditional:
                data[index][data_idx_map[j]] = conditional[(e, index, j)]
    idx = [idx2title[i+1] for i in ids]
    df_cm = pd.DataFrame(data, index = idx, columns = idx)
    sn.heatmap(df_cm, cmap=my_cmap, annot=True)
#     plt.title("$Prob(id_y|id_x)$, Epoch:%d  \n *Diagonal represents marginal"%e, fontsize=15)
    plt.title("$Prob(id_y|id_x)$, Epoch:%d"%e, fontsize=15)
    plt.xlabel(r"$id_x$", fontsize=15)
    plt.ylabel(r"$id_y$", fontsize=15)
    plt.tight_layout()
    plt.show()


In [None]:

my_cmap = plt.cm.get_cmap("hsv", n)
for e in range(0, epochs+1, 50):
#     plt.subplot(epochs//3+1,3,e+1) 
    plt.figure(figsize = (10, 7))
    for idx in embed[e]:
        bottom, top = embed[e][idx]
        x1,y1,x2,y2 = bottom[0], bottom[1], top[0], top[1]
#         print(x1,y1,x2,y2)
        plt.gca().add_patch(plt.Rectangle((x1, y1), x2-x1,y2-y1, fill=True, facecolor=my_cmap(idx), alpha=0.5, label=idx2title[idx+1]))
        plt.title("Epoch: %d"%e)
    plt.legend(loc=0, bbox_to_anchor=(1, 1), fontsize=15)
    plt.tight_layout()
#     plt.savefig("%d.png"%e)
    plt.show()

In [None]:
path = "path to movie train file"
grnd_cpd = {}
with open(path+"movie_train.txt", "r") as f:
    for line in f.readlines():
        line = line.split("\t")
        id1, id2, cprob = line[1], line[2], line[3]
        grnd_cpd[(word2idx[id1],word2idx[id2])] = float(cprob)


In [None]:
plt.figure(figsize = (10, 7))
data = [[0]*n for i in range(n)]

for i in range(n):
    index = ids[i]
    data[index][index] = 1 
    for j in adjList[index]:
        if (index, j) in grnd_cpd:
            data[index][data_idx_map[j]] = grnd_cpd[(index, j)]
idx = [idx2title[i+1] for i in ids]
df_cm = pd.DataFrame(data, index = idx, columns = idx)
sn.heatmap(df_cm, annot=True)
plt.title("$Prob(id_y|id_x)$,  Ground Truth", fontsize=15)
plt.xlabel(r"$id_x$", fontsize=15)
plt.ylabel(r"$id_y$", fontsize=15)
plt.tight_layout()
plt.show()

