In [1]:
import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from numba import jit
from scipy.spatial.distance import pdist, squareform
from scipy.optimize import minimize, root_scalar
from pynndescent import NNDescent
from scipy.sparse import csr_matrix

In [2]:
@jit(cache=False)
def entropy(d, beta):
    # remove 'cache=False' (only for testing)
    x = - d * beta
    y = np.exp(x)
    ysum = y.sum()
    if ysum < 1e-50:
        return -1.
    else:
        factor = - 1/(np.log(2.) * ysum)
        return factor * ((y * x) - (y * np.log(ysum))).sum()

In [3]:
def p_i(d, beta):
    x = - d * beta
    y = np.exp(x)
    ysum = y.sum()
    return y / ysum

In [4]:
def find_beta(d, perp, upper_bound=1e6):
    return root_scalar(
        lambda b: entropy(d, b) - np.log2(perp),
        bracket=(0.,upper_bound)
    ).root

In [51]:
# @jit
# def my_cdist(x, y):
#     dists = np.empty(len(y), dtype=np.float_)
#     for i in range(len(y)):
#         dists[i] = np.sum((x - y[i])**2)            
#     return dists

In [7]:
def my_cdist(x, y):
    return np.sum((x-y)**2, axis=1)

In [67]:
def p_ij_sym(x, perp, verbose=False):
    num_pts = x.shape[0]
    k = min(num_pts - 1, int(3 * perp))
    index = NNDescent(x)
    neighbors = np.empty((num_pts, k-1), dtype=np.int)
    p_ij = np.empty((num_pts, k-1))
    for i, xi in enumerate(x):
        if verbose:
            print('Calculating probabilities: {cur}/{tot}'.format(
                cur=i, tot=num_pts), end='\r')
        nn, dists = index.query([xi], k)
        beta = find_beta(dists[0, 1:], perp)
        neighbors[i] = nn[0, 1:]
        p_ij[i] = p_i(dists[0, 1:], beta)
    row_indices = np.repeat(np.arange(num_pts), k-1)
    p = csr_matrix((p_ij.ravel(), (row_indices, neighbors.ravel())))
    return (p + p.transpose()) / (2)

In [62]:
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

In [63]:
mnist = MNIST(
    './data',
    train=True,
    download=True,
    transform=transforms.ToTensor()
)

In [64]:
testdata = mnist.data.reshape(-1, 28*28).numpy() / 255.

In [68]:
%time foo = p_ij_sym(testdata[:10000], 100., verbose=True)

Wall time: 1min 55slities: 9999/10000
