In [1]:
# MANUAL tSNE on MNIST dataset: 
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_distances

In [2]:
mnist = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())


In [None]:
# pseudocode: 

# load mnist into matrix X
x = mnist.data.numpy().reshape(-1, 28*28)
y = mnist.targets.numpy()

# sample n images 
n = 20000
np.random.seed(0)
idx = np.random.choice(len(x), n)
x = x[idx]
print(x.shape)

# prep X: 0-1 range; the 0-mean
x = x / 255
x = x - x.mean(axis=0)

# run PCA to 50 dimensions

pca = PCA(n_components=50)
x_pca = pca.fit_transform(x)

# compute pairwise distance matrix D

D = pairwise_distances(x_pca)

# compute similarities p_ij for each row i:
P = np.zeros((n, n))
for i in range(n):
    beta = 0.6
    D_i = D[i, :]
    P_i = np.zeros((1,n))
    # binary like search for beta(i)
    trials = 0
    hdiff = 1
    betamin = -np.inf
    betamax = np.inf
    while(np.abs(hdiff) > 1e-5 and trials < 50):
        P_i = np.exp(-D_i * beta)
        betac = (betamin + betamax) / 2
        h = np.sum(np.exp(-D_i * betac))
        hdiff = h - beta
        if hdiff > 0:
            betamin = betac
        else:
            betamax = betac
        trials += 1
    beta = betac
    P[i, :] = P_i  

# #make sure P is correct
# for i in range(n):
#     if np.abs(P[i].sum() - 1) > 1e-5:
#         print("Row", i, "does not sum to 1")
#         break
#     if np.abs(P[i] - P[i].T).sum() > 1e-5:
#         print("Row", i, "is not symmetric")
#         break
#     if P[i,i] > 1e-5:
#         print("Row", i, "has a diagonal element greater than 0")
#         break
# constKL = sum(P * np.log(P))

# initialize tSNE
max_iter = 400
epsilon = 500
min_gain = 0.01
Y = np.random.randn(n, 2)
dY = np.zeros((n, 2))
iY = np.zeros((n, 2))
gains = np.ones((n, 2))

# run tSNE
for i in range(max_iter):
    # compute Q
    Y2 = Y**2
    sum_Y2 = np.sum(Y2, axis=1)
    Q = np.zeros((n, n))
    for i in range(n):
        dist = sum_Y2 - 2 * np.dot(Y, Y[i]) + sum_Y2[i]
        Q[i] = 1 / (1 + dist)
    Q = Q / Q.sum()
    Q = np.maximum(Q, 1e-12)
    
    # compute dY
    PQ = P - Q
    for i in range(n):
        dY[i] = np.sum(np.tile(PQ[:, i] * Q[:, i], (2, 1)).T * (Y[i] - Y), axis=0)
    
    # perform the update
    if i == 0:
        gains = np.sign(dY) != np.sign(iY)
    else:
        gains[gains > 20] = 20
        gains[gains < 0.01] = 0.01
    iY = epsilon * iY - 0.2 * gains * dY
    Y = Y + iY
    Y = Y - np.tile(np.mean(Y, axis=0), (n, 1))
    
    # compute the KL divergence
    if (i % 10) == 0:
        C = np.sum(P * np.log(P / Q))
        print("Iteration", i, "KL divergence is", C)
    if (i % 100) == 0:
        plt.scatter(Y[:, 0], Y[:, 1], c=y[idx], cmap='tab10')
        plt.show()


(20000, 784)
