# Heavy-tailed kernels reveal a finer cluster structure in t-SNE visualisations

#### Dmitry Kobak, George Linderman, Stefan Steinerberger, Yuval Kluger, Philipp Berens

##### http://arxiv.org/abs/1902.05804

#### This notebook written by Dmitry Kobak (Jan--Feb 2019; updated in Mar)

In [1]:
%matplotlib notebook

import numpy as np
import pylab as plt
import seaborn as sns
import pickle
import pandas as pd
import matplotlib

def sns_styleset():
    sns.set_context('paper')
    sns.set_style('ticks')
    matplotlib.rcParams['axes.linewidth']    = .75
    matplotlib.rcParams['xtick.major.width'] = .75
    matplotlib.rcParams['ytick.major.width'] = .75
    matplotlib.rcParams['xtick.major.size'] = 3
    matplotlib.rcParams['ytick.major.size'] = 3
    matplotlib.rcParams['xtick.minor.size'] = 2
    matplotlib.rcParams['ytick.minor.size'] = 2
    matplotlib.rcParams['font.size']       = 7
    matplotlib.rcParams['axes.titlesize']  = 7
    matplotlib.rcParams['axes.labelsize']  = 7
    matplotlib.rcParams['legend.fontsize'] = 7
    matplotlib.rcParams['xtick.labelsize'] = 7
    matplotlib.rcParams['ytick.labelsize'] = 7

sns_styleset()

# Version 1.1.0
import sys; sys.path.append('/home/localadmin/github/FIt-SNE')
from fast_tsne import fast_tsne

col = np.array(['#a6cee3','#1f78b4','#b2df8a','#33a02c','#fb9a99',
                '#e31a1c','#fdbf6f','#ff7f00','#cab2d6','#6a3d9a', '#ffff99'])

## Toy simulations

### Toy example 1

In [None]:
n = 100 # sample size per class
p = 10  # dimensionality
k = 10  # number of classes
d = 4   # distance between each class mean and 0

np.random.seed(42)
X = np.random.randn(k*n, p)
for i in range(k):
    X[i*n:(i+1)*n, i] += d

In [94]:
%%time

dfs = [100, 1, .5]   # alpha values to use

Z_toy1 = []
for df in dfs:
    Z = fast_tsne(X, perplexity=50, seed=42, df=df, theta=0)
    Z_toy1.append(Z)

CPU times: user 12 ms, sys: 64 ms, total: 76 ms
Wall time: 2min 41s


In [110]:
fig = plt.figure(figsize=(4.8, 2))
ax = []
ax.append(plt.axes([0,   0, .33, .9]))
ax.append(plt.axes([.33, 0, .33, .9]))
ax.append(plt.axes([.66, 0, .33, .9]))

for i,Z in enumerate(Z_toy1):
    plt.sca(ax[i])
    plt.axis('equal', adjustable='box')
    plt.scatter(Z[:,0], Z[:,1], s=.2, c=col[np.floor(np.arange(n*k)/n).astype(int)])
    plt.title(r'$\alpha={}$'.format(dfs[i]))
    plt.gca().get_xaxis().set_visible(False)
    plt.gca().get_yaxis().set_visible(False)

sns.despine(left=True, bottom=True)
plt.savefig('figures/fig-toy1.pdf')

hpos = [2, 15, 40]
hwidth = [1, 10, 10]

for i in range(3):
    plt.sca(ax[i])
    ax[i].autoscale(False)
    yl = plt.ylim()
    vpos      = yl[0] + (yl[1]-yl[0]) * .15
    vposlabel = yl[0] + (yl[1]-yl[0]) * .08
    plt.plot([hpos[i], hpos[i]+hwidth[i]], [vpos,vpos], 'k', linewidth=1)
    plt.text(hpos[i] + hwidth[i]/2, vposlabel, str(hwidth[i]), ha='center')
    
plt.text(0, 1.05, 'A', transform = plt.gcf().get_axes()[0].transAxes, fontsize=8, fontweight='bold')
plt.text(0, 1.05, 'B', transform = plt.gcf().get_axes()[1].transAxes, fontsize=8, fontweight='bold')
plt.text(0, 1.05, 'C', transform = plt.gcf().get_axes()[2].transAxes, fontsize=8, fontweight='bold')
    
sns.despine(left=True, bottom=True)
plt.savefig('figures/fig-toy1.pdf')

<IPython.core.display.Javascript object>

### Toy example 2

In [190]:
n = 100 # sample size per class
p = 20  # dimensionality
k = 10  # number of classes
d = 4   # distance between each class mean and 0
dw = 4  # distance between each dumbbell sub-clusters

np.random.seed(42)
X = np.random.randn(k*n, p)
for i in range(k):
    X[i*n:(i+1)*n, i] += d
    X[i*n:i*n+int(n/2), i+k] += dw/2
    X[i*n+int(n/2):(i+1)*n, i+k] -= dw/2

In [111]:
%%time

dfs = [100, 1, .5]   # alpha values to use

Z_toy2 = []
for df in dfs:
    Z = fast_tsne(X, perplexity=50, seed=42, df=df, theta=0)
    Z_toy2.append(Z)

CPU times: user 0 ns, sys: 60 ms, total: 60 ms
Wall time: 2min 25s


In [112]:
fig = plt.figure(figsize=(4.8, 2))
ax = []
ax.append(plt.axes([0,   0, .33, .9]))
ax.append(plt.axes([.33, 0, .33, .9]))
ax.append(plt.axes([.66, 0, .33, .9]))

for i,Z in enumerate(Z_toy2):
    plt.sca(ax[i])
    plt.axis('equal', adjustable='box')
    plt.scatter(Z[:,0], Z[:,1], s=.2, c=col[np.floor(np.arange(n*k)/n).astype(int)])
    plt.title(r'$\alpha={}$'.format(dfs[i]))
    plt.gca().get_xaxis().set_visible(False)
    plt.gca().get_yaxis().set_visible(False)

sns.despine(left=True, bottom=True)
plt.savefig('figures/fig-toy2.pdf')

hpos = [1.5, 13, 40]
hwidth = [1, 10, 10]

for i in range(3):
    plt.sca(ax[i])
    ax[i].autoscale(False)
    yl = plt.ylim()
    vpos      = yl[0] + (yl[1]-yl[0]) * .15
    vposlabel = yl[0] + (yl[1]-yl[0]) * .08
    plt.plot([hpos[i], hpos[i]+hwidth[i]], [vpos,vpos], 'k', linewidth=1)
    plt.text(hpos[i] + hwidth[i]/2, vposlabel, str(hwidth[i]), ha='center')
    
plt.text(0, 1.05, 'A', transform = plt.gcf().get_axes()[0].transAxes, fontsize=8, fontweight='bold')
plt.text(0, 1.05, 'B', transform = plt.gcf().get_axes()[1].transAxes, fontsize=8, fontweight='bold')
plt.text(0, 1.05, 'C', transform = plt.gcf().get_axes()[2].transAxes, fontsize=8, fontweight='bold')
    
plt.savefig('figures/fig-toy2.pdf')

<IPython.core.display.Javascript object>

#### Supplementary Figure -- what happens when alpha is very low

In [191]:
%%time 

Z1 = fast_tsne(X, perplexity=50, seed=42, df=.1, theta=0)
Z2 = fast_tsne(X, perplexity=50, seed=42, df=.1, theta=0, max_iter=5000)
Z_toy2_cont = [Z1,Z2,Z2,Z2]

CPU times: user 16 ms, sys: 196 ms, total: 212 ms
Wall time: 6min 36s


In [195]:
titles = [r'$\alpha=0.1$'+'\n1000 iterations (default)', r'$\alpha=0.1$'+'\n5000 iterations',
          r'$\alpha=0.1$'+'\n5000 iterations\nZoom-in without outliers',
          r'$\alpha=0.1$'+'\n5000 iterations\nZoom-in on one "dumbbell"']

zoomins = [None, None, [-120,120,-120,120], [76.1,76.7,31.35,31.95]]

fig = plt.figure(figsize=(4.8, 4.8))
for i,Z in enumerate(Z_toy2_cont):
    plt.subplot(2,2,i+1)
    plt.axis('equal', adjustable='box')
    plt.scatter(Z[:,0], Z[:,1], s=5, c=col[np.floor(np.arange(n*k)/n).astype(int)], alpha=.5, edgecolors='none')
    plt.title(titles[i])
    if zoomins[i] is not None:
        plt.axis(zoomins[i])

sns.despine()
plt.tight_layout()
plt.savefig('figures/suppl-fig-lowalpha.pdf')

<IPython.core.display.Javascript object>

### Toy example 3 (separation as a function of alpha)

In [196]:
%%time

def pdist2(A,B):
    D = np.sum(A**2,axis=1,keepdims=True) + np.sum(B**2, axis=1, keepdims=True).T - 2*A@B.T
    return D

def stnr(Z,k):
    Zmeans =  np.zeros((k, Z.shape[1]))
    Zwithin = np.zeros_like(Z)
    n = int(Z.shape[0]/k)

    for i in range(k):
        Zmeans[i,:] = Z[i*n:(i+1)*n, :].mean(axis=0)
        Zwithin[i*n:(i+1)*n, :] = Z[i*n:(i+1)*n, :] - Zmeans[i,:]
    
    D2between = np.sum(pdist2(Zmeans, Zmeans)) / (Zmeans.shape[0]**2 - Zmeans.shape[0])
    D2within  = np.sum(pdist2(Zwithin, Zwithin)) / (Zwithin.shape[0]**2 - Zwithin.shape[0])
    return np.sqrt(D2between/D2within)


n = 100 # sample size per class
p = 10  # dimensionality
k = 2   # number of classes
d = 5   # distance between each class mean and 0

np.random.seed(42)
X = np.random.randn(k*n, p)
for i in range(k):
    X[i*n:(i+1)*n, i] += d
    
dfs = np.arange(.2, 3.1, .1)
stnrs = np.zeros_like(dfs)
for i,df in enumerate(dfs):
    print('.', end='')
    Z = fast_tsne(X, perplexity=50, seed=42, df=df, theta=0)
    stnrs[i] = stnr(Z,k)
print('')

.............................
CPU times: user 164 ms, sys: 2.19 s, total: 2.35 s
Wall time: 1min 26s


In [204]:
plt.figure(figsize=(3, 2))
plt.plot(dfs,  stnrs,  'ko-', linewidth=1, markersize=2, zorder=1)
plt.xlabel(r'$\alpha$')
plt.ylabel('Between/within distance ratio')
sns.despine()
plt.tight_layout()

# plt.yscale('log')
# plt.xscale('log')

plt.savefig('figures/fig-separation.pdf')

<IPython.core.display.Javascript object>

## MNIST

In [2]:
# Load MNIST data

from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float64')
x_test = x_test.astype('float64')
x_train /= 255
x_test /= 255
X = np.concatenate((x_train, x_test))
y = np.concatenate((y_train, y_test))
print(X.shape)

U, s, V = np.linalg.svd(X - X.mean(axis=0), full_matrices=False)
X50 = np.dot(U, np.diag(s))[:,:50]
PCAinit = X50[:,:2] / np.std(X50[:,0]) * 0.0001

Using TensorFlow backend.


(70000, 784)


In [117]:
Z_mnist = []
%time Z = fast_tsne(X50, perplexity=50, learning_rate=1000, initialization=PCAinit, df=100)
Z_mnist.append(Z)
%time Z = fast_tsne(X50, perplexity=50, learning_rate=1000, initialization=PCAinit, df=1)
Z_mnist.append(Z)
%time Z = fast_tsne(X50, perplexity=50, learning_rate=1000, initialization=PCAinit, df=.5)
Z_mnist.append(Z)

CPU times: user 80 ms, sys: 60 ms, total: 140 ms
Wall time: 1min 25s
CPU times: user 80 ms, sys: 52 ms, total: 132 ms
Wall time: 1min 26s
CPU times: user 88 ms, sys: 52 ms, total: 140 ms
Wall time: 1min 35s


In [129]:
dfs = [100, 1, .5]
fig = plt.figure(figsize=(4.8, 1.5))
w = .28
ax = []
ax.append(plt.axes([0,0,w,.88]))
ax.append(plt.axes([w,0,w,.88]))
ax.append(plt.axes([2*w,0,w,.88]))

for i,Z in enumerate(Z_mnist):
    plt.sca(ax[i])
    plt.axis('equal', adjustable='box')
    plt.scatter(Z[:,0], Z[:,1], s=.5, c=col[y], rasterized=True, alpha=.5, edgecolors='none')
    plt.title(r'$\alpha={}$'.format(dfs[i]))
    plt.gca().get_xaxis().set_visible(False)
    plt.gca().get_yaxis().set_visible(False)
    
    if i==0:
        for digit in range(10):
            plt.text(np.mean(Z[y==digit,0]), np.mean(Z[y==digit,1]), digit, ha='center', va='center')

sns.despine(left=True, bottom=True)
plt.savefig('figures/fig-mnist.pdf')

hpos = [7, 40, 30]
hwidth = [1, 10, 10]

for i in range(3):
    plt.sca(ax[i])
    ax[i].autoscale(False)
    yl = plt.ylim()
    vpos      = yl[0] + (yl[1]-yl[0]) * .15
    vposlabel = yl[0] + (yl[1]-yl[0]) * .08
    plt.plot([hpos[i], hpos[i]+hwidth[i]], [vpos,vpos], 'k', linewidth=1)
    plt.text(hpos[i] + hwidth[i]/2, vposlabel, str(hwidth[i]), ha='center')
    
plt.text(0, 1.05, 'A', transform = plt.gcf().get_axes()[0].transAxes, fontsize=8, fontweight='bold')
plt.text(0, 1.05, 'B', transform = plt.gcf().get_axes()[1].transAxes, fontsize=8, fontweight='bold')
plt.text(0, 1.05, 'C', transform = plt.gcf().get_axes()[2].transAxes, fontsize=8, fontweight='bold')
plt.text(.95, 1.05, 'D', transform = plt.gcf().get_axes()[2].transAxes, fontsize=8, fontweight='bold')
    
plt.savefig('figures/fig-mnist.pdf', dpi=600)

from sklearn.cluster import DBSCAN
Z = Z_mnist[-1]
digits = [1, 4, 2]
for digitnum, digit in enumerate(digits):
    clustering = DBSCAN().fit(Z[y==digit,:])
    labels, counts = np.unique(clustering.labels_, return_counts=True)
    counts = counts[labels>=0]
    labels = labels[labels>=0]
    order = np.argsort(counts)[::-1]
    counts = counts[order]
    labels = labels[order]

    for i in np.where(counts>100)[0][:5]:
        plt.axes([.85 + .045*digitnum, .7 - .15*i, .12*2.3/7, .12])
        ind = clustering.labels_ == labels[i]
        plt.imshow(np.mean(X[y==digit,:][ind,:],axis=0).reshape(28,28), 
                   interpolation="nearest", vmin=0, vmax=1)
        plt.gray()
        plt.axis('off')

plt.savefig('figures/fig-mnist.pdf', dpi=600)

<IPython.core.display.Javascript object>

### UMAP comparison for MNIST  (Supplementary Figures)

In [206]:
import umap
print('UMAP version: ' + umap.__version__)

# UMAP code for computing $a$ and $b$ given min_dist
# Copied verbatim from the UMAP source code
from scipy.optimize import curve_fit
def find_ab_params(spread, min_dist):
    """Fit a, b params for the differentiable curve used in lower
    dimensional fuzzy simplicial complex construction. We want the
    smooth curve (from a pre-defined family with simple gradient) that
    best matches an offset exponential decay.
    """

    def curve(x, a, b):
        return 1.0 / (1.0 + a * x ** (2 * b))

    xv = np.linspace(0, spread * 3, 300)
    yv = np.zeros(xv.shape)
    yv[xv < min_dist] = 1.0
    yv[xv >= min_dist] = np.exp(-(xv[xv >= min_dist] - min_dist) / spread)
    params, covar = curve_fit(curve, xv, yv)
    return params[0], params[1]

UMAP version: 0.2.3


In [248]:
%%time

# min_dist: This controls how tightly the embedding is allowed compress points together. 
# Larger values ensure embedded points are more evenly distributed, while smaller values allow 
# the algorithm to optimise more accurately with regard to local structure. Sensible values 
# are in the range 0.001 to 0.5, with 0.1 being a reasonable default.

dists = [.5, .1, .01, np.nan]

Zumap = []
umap_b = []
umap_a = []
for d in dists:
    print('.', flush=True, end='')
    if ~np.isnan(d):
        Z = umap.UMAP(min_dist=d).fit_transform(X50)
        umap_a.append(find_ab_params(1,d)[0])    
        umap_b.append(find_ab_params(1,d)[1])    
    else:
        Z = umap.UMAP(a=1, b=0.3).fit_transform(X50)
        umap_a.append(1)    
        umap_b.append(0.3)    
    Zumap.append(Z)    
print('')

....CPU times: user 6min 41s, sys: 816 ms, total: 6min 41s
Wall time: 5min 21s


In [259]:
Zumap[0][:,0] = -Zumap[0][:,0]
Zumap[2][:,0] = -Zumap[2][:,0]

fig = plt.figure(figsize=(4.8, 4.8))
for i,Z in enumerate(Zumap):
    plt.subplot(2,2,i+1)
    plt.scatter(Z[:,0], Z[:,1], s=1, c=col[y], rasterized=True, alpha=.5, edgecolors='none')
    if ~np.isnan(dists[i]):
        plt.title('min_dist={}\na={:.2f}, b={:.2f}'.format(dists[i], umap_a[i], umap_b[i]))
    else:
        plt.title('\na={:.2f}, b={:.2f}'.format(umap_a[i], umap_b[i]))            
    plt.gca().get_xaxis().set_visible(False)
    plt.gca().get_yaxis().set_visible(False)
sns.despine(left=True, bottom=True)
plt.tight_layout()

plt.savefig('figures/suppl-fig-umap.pdf', dpi=600)

<IPython.core.display.Javascript object>

### Longer optimisation, lower alpha, and isolated digits (Supplementary Figures)

In [11]:
%%time

Z10k = fast_tsne(X50, perplexity=50, learning_rate=1000, initialization=PCAinit, df=.5, max_iter=10000)
pickle.dump(Z10k, open('pickles/Zmnist10k.pickle', 'wb'))

# final loss: 3.615

CPU times: user 456 ms, sys: 164 ms, total: 620 ms
Wall time: 4h 30min 28s


In [260]:
Z10k = pickle.load(open('pickles/Zmnist10k.pickle', 'rb'))
%time Z = fast_tsne(X50, perplexity=50, learning_rate=1000, initialization=PCAinit, df=.5)

CPU times: user 88 ms, sys: 112 ms, total: 200 ms
Wall time: 1min 38s


In [261]:
plt.figure(figsize=(4.8,4.8))

plt.subplot(221)
plt.axis('equal', adjustable='box')
plt.scatter(Z[:,0], Z[:,1], s=1, c=col[y], rasterized=True, alpha=.5, edgecolors='none')
plt.title(r'$\alpha=0.5$'+'\n1000 iterations')

plt.subplot(222)
plt.axis('equal', adjustable='box')
plt.scatter(Z10k[:,0], Z10k[:,1], s=1, c=col[y], rasterized=True, alpha=.5, edgecolors='none')
plt.title(r'$\alpha=0.5$'+'\n10000 iterations')

plt.subplot(223)
plt.axis('equal', adjustable='box')
ind = (y==4) & (Z[:,0]>-5) & (Z[:,0]<20) & (Z[:,1]>15) & (Z[:,0]<45)
plt.scatter(Z[ind,0], Z[ind,1], s=1, c=col[4], rasterized=True, alpha=.5, edgecolors='none')
plt.title(r'$\alpha=0.5$'+'\n1000 iterations\nZoom-in on digit 4')

plt.subplot(224)
plt.axis('equal', adjustable='box')
ind = (y==4) & (Z10k[:,0]>-10) & (Z10k[:,0]<50) & (Z10k[:,1]>80) & (Z10k[:,0]<200)
plt.scatter(Z10k[ind,0], Z10k[ind,1], s=1, c=col[4], rasterized=True, alpha=.5, edgecolors='none')
plt.title(r'$\alpha=0.5$'+'\n10000 iterations\nZoom-in on digit 4')

sns.despine()
plt.tight_layout()
plt.savefig('figures/suppl-fig-mnist10k.pdf', dpi=600)

<IPython.core.display.Javascript object>

In [267]:
%%time

# Isolated digit of MNIST

print('Number of images of digit 4: {}'.format(np.sum(y==4)))

Zisolated = []
Z = fast_tsne(X50[y==4], perplexity=50, df=.5, max_iter=1000, initialization=PCAinit[y==4,:])
Zisolated.append(Z)

Xext = [X50[y==4]]
for i in range(3):
    Z = np.random.randn(7000, 50)
    Z[:,i] += 10
    Xext.append(Z)
Xext = np.concatenate(Xext)
Z = fast_tsne(Xext, perplexity=50, df=.5, seed=42)
Zisolated.append(Z)

Z = fast_tsne(X50[y==4], perplexity=50, df=.5, max_iter=1000, late_exag_coeff=1.75, start_late_exag_iter=250,
             initialization=PCAinit[y==4,:])
Zisolated.append(Z)

Number of images of digit 4: 6824
CPU times: user 112 ms, sys: 256 ms, total: 368 ms
Wall time: 1min 25s


In [268]:
plt.figure(figsize=(4.8, 2))
titles = [r'$\alpha=0.5$', r'$\alpha=0.5$'+' with extra\nGaussian clusters', r'$\alpha=0.5$, exag. 1.75']
for i,Z in enumerate(Zisolated):
    plt.subplot(1,3,i+1)
    plt.axis('equal', adjustable='box')
    plt.scatter(Z[:,0], Z[:,1], s=.2, rasterized=True)
    plt.title(titles[i])
sns.despine()
plt.tight_layout()
plt.savefig('figures/suppl-fig-mnist-isolated.pdf', dpi=600)

<IPython.core.display.Javascript object>

### MNIST scan of the alpha range (animations etc.)

In [47]:
%%time 

# This takes a lot of time because it prepares three different animations:
# 1. Each alpha is initialized with PCA and optimized with default parameters
# 2. Each alpha is initialized with the previous one scaled to std=0.0001 and 
#    optimized without early exaggeration
# 3. Each alpha is initialized with the previous one without scaling and
#    optimized without exaggeration for 500 iterations. This one is the slowest
#    because the size of the embedding keeps growing and FIt-SNE slows down.

_ = fast_tsne(X50, perplexity=50, max_iter=100, load_affinities='save')
dfs = np.concatenate((np.arange(100,5,-1), np.arange(5,1,-.25), np.arange(1,.499,-.01)))

Zmovie1 = []   # each alpha initialized with PCA
kls1 = []
Zmovie2  = []  # each alpha initialized with the scaled previous one (no exaggeration)
kls2  = []
Zmovie3 = []   # each alpha initialized with the previous one (500 iter, no exaggeration)
kls3 = []
for i,df in enumerate(dfs):
    print('.', end='')
        
    Z,kl = fast_tsne(X50, perplexity=50, learning_rate=1000, initialization=PCAinit, 
                     load_affinities='load', df=df, return_loss=True)
    Zmovie1.append(Z)
    kls1.append(kl[-1])
    
    if i>0:
        Z,kl = fast_tsne(X50, perplexity=50, learning_rate=1000, 
                         initialization=Zmovie2[-1]/np.std(Zmovie2[-1][:,0])*0.0001, 
                         load_affinities='load', df=df, return_loss=True, early_exag_coeff=1)
    else:
        Z1 = fast_tsne(X50, perplexity=50, learning_rate=1000, initialization=PCAinit, 
                       load_affinities='load', df=100)
        Z,kl = fast_tsne(X50, perplexity=50, learning_rate=1000, 
                         initialization=Z1/np.std(Z1[:,0])*0.0001, 
                         load_affinities='load', df=df, return_loss=True, early_exag_coeff=1)
    Zmovie2.append(Z)
    kls2.append(kl[-1])
    
    if i>0:
        Z,kl = fast_tsne(X50, perplexity=50, learning_rate=1000, max_iter=500,
                         initialization=Zmovie3[-1], 
                         load_affinities='load', df=df, return_loss=True, early_exag_coeff=1)
    else:
        Z,kl = fast_tsne(X50, perplexity=50, learning_rate=1000, initialization=PCAinit, 
                         load_affinities='load', df=100, return_loss=True)
    Zmovie3.append(Z)
    kls3.append(kl[-1])
    
    pickle.dump([Zmovie1, Zmovie2, Zmovie3, kls1, kls2, kls3, dfs], open('pickles/Zmovie.pickle', 'wb'))
print('')

..................................................................................................................................................................
CPU times: user 5min 46s, sys: 1min 35s, total: 7min 22s
Wall time: 14h 37min 37s


In [216]:
Zmovie1, Zmovie2, Zmovie3, kls1, kls2, kls3, dfs = pickle.load(open('pickles/Zmovie.pickle', 'rb'))

In [53]:
Zmovie = Zmovie1

sns.set()
sns.set_style('white')
fig = plt.figure(figsize=(3,3))
ll = plt.text(0, .9, r'$\alpha={}$'.format(100), transform = plt.gca().transAxes)
sc = plt.scatter(Zmovie[0][:,0], Zmovie[0][:,1], c=col[y], s=1, alpha=.5)
tt = []
for digit in range(10):
    t = plt.text(np.mean(Zmovie[0][y==digit,0]), np.mean(Zmovie[0][y==digit,1]), digit, ha='center', va='center')
    tt.append(t)
sns.despine(left=True, bottom=True)
m = np.max(np.abs(Zmovie[0]))
plt.xlim([-m, m])
plt.ylim([-m, m])
plt.gca().get_xaxis().set_visible(False)
plt.gca().get_yaxis().set_visible(False)
plt.tight_layout()

pauses = [100, 1, 0.5]
pauselength = 10
frames = []
for i,df in enumerate(dfs):
    if np.round(df,3) in pauses:
        frames += [i]*pauselength
    else:
        frames += [i]

def update_plot(i):
    sc.set_offsets(Zmovie[frames[i]])
    m = np.max(np.abs(Zmovie[frames[i]]))
    plt.xlim([-m, m])
    plt.ylim([-m, m])
    for digit,t in enumerate(tt):
        t.set_position((np.mean(Zmovie[frames[i]][y==digit,0]), 
                        np.mean(Zmovie[frames[i]][y==digit,1])))
    if dfs[frames[i]]>=5:
        ll.set_text(r'$\alpha={}$'.format(int(dfs[frames[i]])))
    elif dfs[frames[i]]>=1:
        ll.set_text(r'$\alpha={:.1f}$'.format(dfs[frames[i]]))
    else:
        ll.set_text(r'$\alpha={:.2f}$'.format(dfs[frames[i]]))

from matplotlib import animation
myanim = animation.FuncAnimation(fig, update_plot, frames=len(frames), 
                                 interval=100, repeat=False)

<IPython.core.display.Javascript object>

In [49]:
myanim.save('animations/mnist.gif', dpi=150, writer='imagemagick')

# to compress gif (https://github.com/kornelski/giflossy)
# /usr/local/bin/gifsicle --lossy=200 --colors 64 -i mnist.gif -o mnist-sm.gif

# to convert to mp4
# ffmpeg -f gif -i mnist.gif mnist.mp4

# direct export to mp4 does not work very well for some reason
# myanim.save('animations/test.mp4', dpi=150, writer=animation.writers['ffmpeg'](fps=10))

# undo seaborn style set for animation
sns_styleset()

In [185]:
# Nearest neighbour preservation

from annoy import AnnoyIndex

def mnn(X, Zs, knn=10):
    ann1 = AnnoyIndex(X.shape[1], metric='euclidean')
    for i in range(X.shape[0]):
        ann1.add_item(i, X[i])
    ann1.build(10)
        
    mnn = np.zeros(len(Zs))
    for num, Z in enumerate(Zs):
        print('.', end='')
        ann2 = AnnoyIndex(Z.shape[1], metric='euclidean')
        for i in range(X.shape[0]):
            ann2.add_item(i, Z[i])
        ann2.build(10)
    
        intersections = 0.0
        for i in range(X.shape[0]):
            intersections += len(set(ann1.get_nns_by_item(i, knn+1)) & set(ann2.get_nns_by_item(i, knn+1))) - 1
        mnn[num] = intersections / X.shape[0] / knn
    print('')
    
    return mnn

%time mnns10 = mnn(X50, Zmovie1, knn=10)
%time mnns15 = mnn(X50, Zmovie1, knn=15)
%time mnns20 = mnn(X50, Zmovie1, knn=20)

pickle.dump([mnns10, mnns15, mnns20], open('pickles/mnns.pickle', 'wb'))

..................................................................................................................................................................
CPU times: user 14min 33s, sys: 36 ms, total: 14min 33s
Wall time: 14min 33s
..................................................................................................................................................................
CPU times: user 16min 13s, sys: 28 ms, total: 16min 14s
Wall time: 16min 13s
..................................................................................................................................................................
CPU times: user 18min 10s, sys: 16 ms, total: 18min 10s
Wall time: 18min 10s


In [3]:
Zmovie1, Zmovie2, Zmovie3, kls1, kls2, kls3, dfs = pickle.load(open('pickles/Zmovie.pickle', 'rb'))
[mnns10, mnns15, mnns20] = pickle.load(open('pickles/mnns.pickle', 'rb'))

In [9]:
from sklearn.neighbors import NearestNeighbors

def mnn(X, Zs, knn=10):
    nbrs1 = NearestNeighbors(n_neighbors=knn).fit(X)
    ind1 = nbrs1.kneighbors(return_distance=False)
        
    mnn = np.zeros(len(Zs))
    for num, Z in enumerate(Zs):
        print('.', end='')
        nbrs2 = NearestNeighbors(n_neighbors=knn).fit(Z)
        ind2 = nbrs2.kneighbors(return_distance=False)
    
        intersections = 0.0
        for i in range(X.shape[0]):
            intersections += len(set(ind1[i]) & set(ind2[i]))
        mnn[num] = intersections / X.shape[0] / knn
    print('')
    
    return mnn

%time mnns10 = mnn(X50, Zmovie1, knn=10)
%time mnns15 = mnn(X50, Zmovie1, knn=15)
%time mnns20 = mnn(X50, Zmovie1, knn=20)

pickle.dump([mnns10, mnns15, mnns20], open('pickles/mnns-exact.pickle', 'wb'))

..................................................................................................................................................................
CPU times: user 6min 57s, sys: 56 ms, total: 6min 57s
Wall time: 6min 56s
..................................................................................................................................................................
CPU times: user 7min 47s, sys: 32 ms, total: 7min 47s
Wall time: 7min 45s
..................................................................................................................................................................
CPU times: user 8min 27s, sys: 64 ms, total: 8min 27s
Wall time: 8min 25s


In [18]:
%time mnns50  = mnn(X50, Zmovie1, knn=50)

..................................................................................................................................................................
CPU times: user 11min 43s, sys: 2.41 s, total: 11min 46s
Wall time: 11min 42s


In [20]:
%time mnns100 = mnn(X50, Zmovie1, knn=100)

..................................................................................................................................................................
CPU times: user 16min 42s, sys: 12.7 s, total: 16min 54s
Wall time: 16min 48s


In [37]:
plt.figure(figsize=(3.5, 2.2))
plt.plot(dfs,  kls1,  'k-', linewidth=1, markersize=1)
ax1 = plt.gca()
plt.xlabel(r'$\alpha$')
plt.ylabel('KL divergence')
plt.text(60, 3.75, 'KL', color='k')
plt.plot(1.5, np.min(kls1), 'ko', markersize=3)

yl = plt.ylim()
plt.plot([1,1], yl, 'k--', linewidth=0.75, zorder=0)
plt.ylim(yl)

ax2 = ax1.twinx()
plt.plot(dfs,  mnns10, 'b-', linewidth=1, markersize=1)
# plt.plot(dfs,  mnns15, 'b-', linewidth=1, markersize=1)
# plt.plot(dfs,  mnns20, 'b-', linewidth=1, markersize=1)
plt.plot(dfs,  mnns50, 'b-', linewidth=1, markersize=1)
plt.plot(dfs,  mnns100, 'b-', linewidth=1, markersize=1)
# plt.plot(dfs,  mnns1, 'r-', linewidth=1, markersize=1)
ax2.tick_params(axis='y', labelcolor='b')
plt.ylabel('KNN preservation', color='b')
ax2.plot(dfs[np.argmax(mnns10)], np.max(mnns10), 'bo', markersize=3)
ax2.plot(dfs[np.argmax(mnns50)],  np.max(mnns50), 'bo', markersize=3)
ax2.plot(dfs[np.argmax(mnns100)],  np.max(mnns100), 'bo', markersize=3)

plt.text(50, 0.06, 'K=10', color='b')
plt.text(50, 0.13, 'K=50', color='b')
plt.text(50, 0.18,  'K=100', color='b')

plt.xscale('log')
sns.despine(ax=ax1)
sns.despine(ax=ax2, right=False, bottom=True)
ax2.spines['left'].set_color('none')
plt.tight_layout()
plt.savefig('figures/fig-kl.pdf')
# plt.savefig('figures/fig-kl-fixed.png', dpi=200)

<IPython.core.display.Javascript object>

## Tasic et al.

Download the data from here: http://celltypes.brain-map.org/rnaseq and unpack. Direct links:
 * VISp: http://celltypes.brain-map.org/api/v2/well_known_file_download/694413985
 * ALM: http://celltypes.brain-map.org/api/v2/well_known_file_download/694413179

To get the information about cluster colors and labels (`sample_heatmap_plot_data.csv`), open the interactive data browser http://celltypes.brain-map.org/rnaseq/mouse, go to "Sample Heatmaps", click "Build Plot!" and then "Download data as CSV"

In [None]:
# %%time

# # Load the Allen institute data. This takes a bit of time but saves the result 
# # to a .pickle file that can be quickly loaded afterwards

# # This function is needed because using Pandas to load these files in one go 
# # can eat up a lot of RAM. So we are doing it in chunks, and converting each
# # chunk to the sparse matrix format on the fly.
# def sparseload(filename):
#     with open(filename) as file:
#         genes = []
#         sparseblocks = []
#         for i,chunk in enumerate(pd.read_csv(filename, chunksize=1000, index_col=0)):
#             print('.', end='', flush=True)
#             if i==0:
#                 cells = np.array(chunk.columns)
#             genes.extend(list(chunk.index))
#             sparseblock = sparse.csr_matrix(chunk.values.astype(float))
#             sparseblocks.append([sparseblock])
#         counts = sparse.bmat(sparseblocks)
#         print(' done')
#     return (counts.T, np.array(genes), cells)

# filename = 'data/allen-visp-alm/mouse_VISp_2018-06-14_exon-matrix.csv'
# counts1, genes, cells1 = sparseload(filename)

# filename = 'data/allen-visp-alm/mouse_ALM_2018-06-14_exon-matrix.csv'
# counts2, genes2, cells2 = sparseload(filename)

# counts = sparse.vstack((counts1, counts2), format='csc')
# counts1, counts2 = [], []
# cells = np.concatenate((cells1, cells2))
# assert(np.all(genes==genes2))

# genesDF = pd.read_csv('data/allen-visp-alm/mouse_VISp_2018-06-14_genes-rows.csv')
# ids     = genesDF['gene_entrez_id'].tolist()
# symbols = genesDF['gene_symbol'].tolist()
# id2symbol = dict(zip(ids, symbols))
# genes = np.array([id2symbol[g] for g in genes])

# clusterInfo = pd.read_csv('data/allen-visp-alm/sample_heatmap_plot_data.csv')
# goodCells  = clusterInfo['sample_name'].values
# ids        = clusterInfo['cluster_id'].values
# labels     = clusterInfo['cluster_label'].values
# colors     = clusterInfo['cluster_color'].values

# clusterNames  = np.array([labels[ids==i+1][0] for i in range(np.max(ids))])
# clusterColors = np.array([colors[ids==i+1][0] for i in range(np.max(ids))])
# clusters   = np.copy(ids) - 1

# ind = np.array([np.where(cells==c)[0][0] for c in goodCells])
# counts = counts[ind, :]

# areas = (ind < cells1.size).astype(int)

# RNAseqData = namedtuple('RNAseqData', 'counts genes areas clusters clusterColors clusterNames')
# tasic2018 = RNAseqData(counts=counts, genes=genes, clusters=clusters, areas=areas, 
#                        clusterColors=clusterColors, clusterNames = clusterNames)

# pickle.dump(tasic2018, open('data/allen-visp-alm/tasic2018.pickle', 'wb'))

In [135]:
%%time

# The pre-processing follows Kobak & Berens (2018) preprint
# https://github.com/berenslab/rna-seq-tsne/ 
import rnaseqTools

# define data format needed to load the pickle file
# and load the data. The pickle file can be generated by running
# the script in the linked Github page
from collections import namedtuple
RNAseqData = namedtuple('RNAseqData', 'counts genes areas clusters clusterColors clusterNames')
tasic2018 = pickle.load(open('../../Dropbox (Machens Lab)/Link to rnaseq/data/allen-visp-alm/tasic2018.pickle', 'rb'))

fr = np.array([np.mean(tasic2018.areas[tasic2018.clusters==c]) for c in range(tasic2018.clusterNames.size)])
markers = ['D' if f<.05 else 's' if f>.95 else 'o' for f in fr]

importantGenesTasic2018 = rnaseqTools.geneSelection(tasic2018.counts, n=3000, threshold=32, 
                                                    decay=1.5, plot=False)
librarySizes = np.sum(tasic2018.counts, axis=1)
X = np.log2(tasic2018.counts[:, importantGenesTasic2018] / librarySizes * 1e+6 + 1)  
X = np.array(X)
X = X - X.mean(axis=0)
U,s,V = np.linalg.svd(X, full_matrices=False)
U[:,np.sum(V,axis=1)<0] *= -1
X = np.dot(U, np.diag(s))
X = X[:, np.argsort(s)[::-1]][:,:50]

PCAinit = X[:,:2]/np.std(X[:,0])*.0001

Z_tasic = []
Z = fast_tsne(X, perplexity = 50, initialization = PCAinit, df=1)
Z_tasic.append(Z)
Z = fast_tsne(X, perplexity = 50, initialization = PCAinit, df=.6)
Z_tasic.append(Z)
Z_tasic.append([])

Chosen offset: 6.56
CPU times: user 1min 47s, sys: 7.78 s, total: 1min 55s
Wall time: 2min 5s


In [136]:
ind = np.array([c.split()[0]=='Vip' for c in tasic2018.clusterNames[tasic2018.clusters]])
Z_tasic[2] = Z_tasic[1].copy()
Z_tasic[2][~ind,:] = np.nan

dfs = [1, .6, .6]
fig = plt.figure(figsize=(4.8, 2))
w = .33
ax = []
ax.append(plt.axes([0,0,w,.9]))
ax.append(plt.axes([w,0,w,.9]))
ax.append(plt.axes([2*w,0,w,.9]))

for i,Z in enumerate(Z_tasic):
    plt.sca(ax[i])
    plt.axis('equal', adjustable='box')
    plt.scatter(Z[:,0], Z[:,1], s=.2, c=tasic2018.clusterColors[tasic2018.clusters],
                rasterized=True)#, alpha=.75, edgecolors='none')
    plt.title(r'$\alpha={}$'.format(dfs[i]))
    if i==2:
        plt.title(r'$\alpha=0.6$, Vip clusters')
    plt.gca().get_xaxis().set_visible(False)
    plt.gca().get_yaxis().set_visible(False)
    
# Plot means in subplot 3
plt.sca(ax[2])
Z = Z_tasic[2]
K = tasic2018.clusterNames.size
Zmeans = np.zeros((K, 2)) * np.nan
for c in range(K):
    if not all(np.isnan(Z[tasic2018.clusters==c, 0])):
        Zmeans[c,:] = np.median(Z[tasic2018.clusters==c, :2], axis=0)
nonans = ~np.isnan(Zmeans[:,0])
plt.scatter(Zmeans[nonans,0], Zmeans[nonans,1],
            color=tasic2018.clusterColors[nonans], 
            s=20, edgecolor='k', linewidth=.5);
    
sns.despine(left=True, bottom=True)
plt.savefig('figures/fig-tasic.pdf', dpi=600)

hpos = [25, 20, -10]
hwidth = [10, 10, 1]

for i in range(3):
    plt.sca(ax[i])
    ax[i].autoscale(False)
    yl = plt.ylim()
    vpos      = yl[0] + (yl[1]-yl[0]) * .15
    vposlabel = yl[0] + (yl[1]-yl[0]) * .08
    plt.plot([hpos[i], hpos[i]+hwidth[i]], [vpos,vpos], 'k', linewidth=1)
    plt.text(hpos[i] + hwidth[i]/2, vposlabel, str(hwidth[i]), ha='center')
    
plt.text(0, 1.05, 'A', transform = plt.gcf().get_axes()[0].transAxes, fontsize=8, fontweight='bold')
plt.text(0, 1.05, 'B', transform = plt.gcf().get_axes()[1].transAxes, fontsize=8, fontweight='bold')
plt.text(0, 1.05, 'C', transform = plt.gcf().get_axes()[2].transAxes, fontsize=8, fontweight='bold')
    
plt.savefig('figures/fig-tasic.pdf', dpi=600)

<IPython.core.display.Javascript object>

In [143]:
tasic2018 = []

## HathiTrust

In [144]:
%%time

# Data files are from https://zenodo.org/record/1477018
meta = pd.read_csv('/home/localadmin/hathi/data_shuffled.tsv', sep='\t', error_bad_lines=False)
idsRus = np.array(meta['id'])[np.array(meta['language'])=='Russian']

from gensim.models.keyedvectors import KeyedVectors
model = KeyedVectors.load_word2vec_format('/home/localadmin/hathi/hathi_pca.bin', binary=True)

X = np.zeros((idsRus.size, 100))
for i,ident in enumerate(idsRus):
    X[i,:] = model[ident]
    
print(X.shape)
model = []

b'Skipping line 12799837: expected 8 fields, saw 9\n'


(408291, 100)
CPU times: user 2min 38s, sys: 6.61 s, total: 2min 44s
Wall time: 2min 51s


In [145]:
# Getting publication dates
dates = np.array(meta['date'])[np.array(meta['language'])=='Russian']
datesnum = np.zeros(dates.size)
datesnum[dates=='None'] = np.nan
datesnum[dates!='None'] = dates[dates!='None'].astype(int)
datesnum[[d>2018 if ~np.isnan(d) else False for d in datesnum]] = np.nan
datesnum[[d<1700 if ~np.isnan(d) else False for d in datesnum]] = np.nan

# Most frequent LoC categories
lcs = np.array(meta['lc1'])[np.array(meta['language'])=='Russian'].astype('str')
types, counts = np.unique(lcs, return_counts=True)
print('LIBRARY OF CONGRESS CATEGORIES')
for t,c in zip(types [np.argsort(counts)[::-1]][:10], 
               counts[np.argsort(counts)[::-1]][:10]):
    print(t + '\t {}'.format(c))
print('...\n')

# Selecting some poetry
authors = np.array(meta['first_author_name'])[np.array(meta['language'])=='Russian'].astype('str')
poets = [i for i,s in enumerate(authors) if 'Fet, A. A.' in s 
         or 'tchev, F. I.' in s or 'Mayakovsky, Vladimir' in s
         or 'Akhmatova, Anna' in s or 'Pushkin, Aleksandr' in s
         or 'Blok, Aleksandr' in s or 'Mandelʹshtam, Osip' in s
         or 'Brodsky, Joseph' in s]

print('Poetry books selected: {}'.format(len(poets)))

LIBRARY OF CONGRESS CATEGORIES
nan	 220669
PG	 30352
DK	 23876
Z	 7935
HD	 6932
QA	 6490
HC	 6116
AP	 5676
QC	 5104
D	 4181
...

Poetry books selected: 1369


In [146]:
%%time

Z_hathi = []
Z = fast_tsne(X, perplexity=50, learning_rate=10000, seed=42, load_affinities='save')
Z_hathi.append(Z)
Z = fast_tsne(X, perplexity=50, learning_rate=10000, seed=42, df=.5, load_affinities='load')
Z_hathi.append(Z)

pickle.dump([Z_hathi, datesnum, lcs, poets], open('pickles/hathi.pickle', 'wb'))

CPU times: user 1.08 s, sys: 1.28 s, total: 2.36 s
Wall time: 16min 53s


In [186]:
[Z_hathi, datesnum, lcs, poets] = pickle.load(open('pickles/hathi.pickle', 'rb'))

In [187]:
fig = plt.figure(figsize=(4.8, 2.5))
ax = []
ax.append(plt.axes([0, 0, .45, .9]))
ax.append(plt.axes([.55, 0, .45,.9]))
cax   = plt.axes([.46,.18,.02,.5])

dfs = [1, .5]
for i,Z in enumerate(Z_hathi):
    plt.sca(ax[i])
    plt.axis('equal', adjustable='box')
    plt.scatter(Z[:,0], Z[:,1], s=1, alpha=.5, c=datesnum, rasterized=True, edgecolors='none', cmap='inferno')
    plt.title(r'$\alpha={}$'.format(dfs[i]))
    plt.gca().get_xaxis().set_visible(False)
    plt.gca().get_yaxis().set_visible(False)
    
# contours
Z = Z_hathi[0]
plt.sca(ax[0])
sns.kdeplot(Z[lcs=='QA',0], Z[lcs=='QA',1], n_levels=[.002,.003,.004], linewidths=.75, bw=2, color='k')
sns.kdeplot(Z[poets,0], Z[poets,1], n_levels=[.0006,.001,.002], linewidths=.75, bw=2, color='k')

Z = Z_hathi[1]
plt.sca(ax[1])
sns.kdeplot(Z[lcs=='QA',0], Z[lcs=='QA',1], n_levels=[.004, .008, .015], linewidths=.75, bw=2, color='k')
sns.kdeplot(Z[poets,0], Z[poets,1], n_levels=[.001, .002, .005], linewidths=.75, bw=2, color='k')

# colorbar
plt.sca(ax[0])
cbar = plt.colorbar(cax=cax)
cbar.set_alpha(1)
cbar.solids.set_edgecolor('face')
cbar.solids.set_rasterized(True)
cbar.draw_all()
cbar.set_ticks([1700,1750,1800,1850,1900,1950,2000])

sns.despine(left=True, bottom=True)
plt.savefig('figures/fig-hathi.pdf', dpi=600)

hpos = [50, 20]
hwidth = [10, 10]

for i in range(2):
    plt.sca(ax[i])
    ax[i].autoscale(False)
    yl = plt.ylim()
    vpos      = yl[0] + (yl[1]-yl[0]) * .15
    vposlabel = yl[0] + (yl[1]-yl[0]) * .10
    plt.plot([hpos[i], hpos[i]+hwidth[i]], [vpos,vpos], 'k', linewidth=1)
    plt.text(hpos[i] + hwidth[i]/2, vposlabel, str(hwidth[i]), ha='center')
    
plt.text(0, 1.05, 'A', transform = plt.gcf().get_axes()[0].transAxes, fontsize=8, fontweight='bold')
plt.text(0, 1.05, 'B', transform = plt.gcf().get_axes()[1].transAxes, fontsize=8, fontweight='bold')

plt.savefig('figures/fig-hathi.pdf', dpi=600)

<IPython.core.display.Javascript object>