In [None]:
import numpy as np
import torch

import sys
sys.path.append('..')
from bliss.data import Language

import matplotlib.pyplot as plt
%matplotlib inline
import gudhi
import time
DATA_PATH='../muse_data' #Data path to the wiki word vectors.

In [None]:
# Given the embeddings, compute it's distance matrix
# Inputs:
# embed: N x M numpy array containing the M dimensional embeddings of N most frequent words.
# n: The distance matrix is computed for the n most frequent word's embeddings.
# Return:
# n x n numpy array containing the pairwise euclidean distances between the word embeddings.

def distance_matrix(embed, n=5000):
    embed = embed[: n]
    dist = torch.sqrt(2 - 2 * torch.clamp(torch.mm(embed, torch.t(embed)), -1., 1.))
    return dist.cpu().numpy()

In [None]:
# Given the distance matrix and the homology dimension, calculate the persistence diagram.
# Inputs: 
# x: distance matrix containing the pairwise distances between word embeddings.
# homo_dim: homology dimension.
# Return:
# list of persistence diagrams for dimensions upto homo_dim.

def compute_diagram(x, homo_dim=1):
    rips_tree = gudhi.RipsComplex(x).create_simplex_tree(max_dimension=homo_dim)
    rips_diag = rips_tree.persistence()
    return [rips_tree.persistence_intervals_in_dimension(w) for w in range(homo_dim)]

In [None]:
def compute_distance(x, y, homo_dim=1):
    start_time = time.time()
    diag_x = compute_diagram(x, homo_dim=homo_dim)
    diag_y = compute_diagram(y, homo_dim=homo_dim)
    print("Filteration graph: %.3f" % (time.time() - start_time))
    return min([gudhi.bottleneck_distance(x, y, e=0) for (x, y) in zip(diag_x, diag_y)])

In [None]:
langs = ['en', 'es', 'et', 'fi', 'el', 'hu', 'pl', 'tr', 'et']
l = {}
for i in langs:
    l[i] = Language(name=i, gpu=True, mode='rand', mean_center=True, unit_norm=True)
    l[i].load('wiki.%s.vec' % i, DATA_PATH)

In [None]:
d = {}
r = [5000, 7000] #Number of points to consider while computing bottleneck distance.
n = len(langs)
for k in r:
    matrices = {}
    for i in langs:
        matrices[i] = distance_matrix(l[i].embeddings, n=k)
    d[k] = matrices

In [None]:
pairs = {('en', 'es'), ('en', 'et'), ('en', 'fi'), ('en', 'el'), ('en', 'hu'), ('en', 'pl'), ('en', 'tr'), ('et', 'fi')}
n = len(langs)
for k in r:
    for src, tgt in pairs:
        print('%s-%s for %d points: %.4f' % (src, tgt, k, compute_distance(d[k][src], d[k][tgt])))