<a href="https://colab.research.google.com/github/profteachkids/STEMUnleashed2023/blob/main/tSNE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [61]:
import jax.numpy as jnp
import jax
import numpy as np
import pandas as pd
from plotly.subplots import make_subplots
import plotly.express as px

In [402]:
rng = np.random.RandomState(1234)
Ndim=30
Nc=15
cs = rng.uniform(low=-10,high=10,size=(Nc,Ndim))
Ns = rng.randint(low=10,high=20,size=Nc)
Ntot=np.sum(Ns)
stds=rng.uniform(1,2.,size=Nc)
dflist=[]
for c,std,N in zip(cs,stds,Ns):
    df=pd.DataFrame(np.random.normal(loc=c.reshape(Ndim,1),scale=std,size=(Ndim,N)).T)
    dflist.append(df)

In [403]:
df=pd.concat(dflist,keys=range(len(dflist)),names=['label'])
df=df.reset_index('label')
df['label']=df['label'].astype('category')

In [404]:
colors= px.colors.named_colorscales()
fig=px.scatter(df,x=0,y=1,color='label')
fig.update_layout(width=500,height=500,template='plotly_dark')

In [405]:
sigma=2.
twosigmasq = 2*sigma**2
x=df.iloc[:,1:Ndim+1].values
expnegdelta=np.exp(-np.sum((x[:,None,:]-x[None,:,:])**2,axis=-1)/twosigmasq)+1e-15
np.fill_diagonal(expnegdelta,0.)
pij=expnegdelta/np.sum(expnegdelta,axis=1)
pij=jnp.asarray((pij+pij.T)/2/Ntot)



In [406]:
def KL(y):
    qij=1/(1+jnp.sum((y[:,None,:]-y[None,:,:])**2,axis=-1))
    qij=qij.at[jnp.diag_indices(Ntot)].set(0.)
    qij = qij/jnp.sum(qij)
    return jnp.sum(pij[nodiagmask]*jnp.log(pij[nodiagmask]/qij[nodiagmask]))

KLjit = jax.jit(KL)
KLgrad=jax.jit(jax.grad(KL))

In [407]:
y0=jnp.asarray(rng.uniform(0,1,size=(Ntot,2)))
nodiagmask=np.full((Ntot,Ntot),True)
np.fill_diagonal(nodiagmask,False)
gsum=0
yorig=y0
ydefmax=1e-3

In [408]:
#Distance over Gradient - Stochastic Gradient Descent
#https://arxiv.org/pdf/2302.12022.pdf

for i in range(int(5e4)):
    g=KLgrad(y0)
    gsum+=jnp.sum(g**2)
    eta= ydefmax/jnp.sqrt(gsum)
    y1 = y0 - eta*g
    ydelta=jnp.linalg.norm(y1-yorig)
    ydefmax = ydelta if ydelta > ydefmax else ydefmax
    y0=y1
    if i % 10000 == 0:
        print(i, KL(y0), eta)

0 4.1997366 0.017376909
10000 0.5369339 354.83954
20000 0.51746815 345.79004
30000 0.496873 331.9311
40000 0.4673843 321.01035


In [409]:
df2=pd.DataFrame(y0)
df2['label']=df['label'].values

In [410]:
fig2=px.scatter(df2,x=0,y=1,color='label')
fig2.update_layout(width=500,height=500,template='plotly_dark')

In [337]:
y0

Array([[  0.8941398 , -26.122467  ],
       [ -1.6809868 , -25.629942  ],
       [ -1.6942173 , -25.267303  ],
       [ -1.7179736 , -26.5961    ],
       [ -0.459761  , -25.7367    ],
       [ -0.85274315, -25.768806  ],
       [ -1.3618202 , -25.700533  ],
       [ -2.1767547 , -27.214396  ],
       [ -1.5239995 , -26.094269  ],
       [ -0.82403564, -25.869087  ],
       [ -2.2899334 , -26.065245  ],
       [ -1.2552135 , -25.168135  ],
       [ -1.5099115 , -25.76734   ],
       [ -2.0895772 , -27.012444  ],
       [ -0.84424114, -25.240519  ],
       [ -0.6503361 , -25.911236  ],
       [  5.3151345 , -26.35796   ],
       [  5.3510585 , -25.968826  ],
       [  6.5600452 , -25.247509  ],
       [  6.9648147 , -26.317793  ],
       [  6.193748  , -25.611673  ],
       [  6.8740215 , -26.302065  ],
       [  6.2595997 , -26.274067  ],
       [  6.344212  , -26.594807  ],
       [  6.1390567 , -26.476425  ],
       [  5.310778  , -26.088179  ],
       [  6.8251815 , -25.698042  ],
 