<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2021/blob/main/AutoEncoderClusterAnalysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [94]:
import numpy as np
import jax.numpy as jnp
import jax
from jax.experimental import stax, optimizers
from jax.experimental.stax import Dense, Dropout, Selu, Tanh, serial
import itertools as it
from plotly.subplots import make_subplots
import plotly.io as pio
from scipy.signal import find_peaks
pio.templates.default= 'plotly_dark'

In [95]:
def equal_area_histogram(data, nbins=None):
    data_sorted = np.sort(data)
    n=data_sorted.size
    nbins = int(2*n**(2/5)) if nbins is None else nbins
    nperbin = int(n//nbins)
    x = np.arange(0,n)
    edges = np.linspace(0,n,nbins+1)
    offsets = np.arange(np.ceil(-nperbin/4),np.floor(nperbin/4))
    indices=(edges[1:-1,None] + offsets[None,:]).astype(np.int64)
    Y=np.take(data_sorted,indices)
    X=np.stack((indices,np.ones_like(indices)),axis=2)
    XT = np.moveaxis(X,(0,1,2), (0,2,1))
    pinv=np.linalg.inv(XT@X) @ XT 
    mb=np.einsum('ijk,ik->ij',pinv,Y)
    smoothed_edges = np.r_[(np.min(data), mb[:,0]*edges[1:-1] + mb[:,1], np.max(data))]
    height = 1/nbins/(smoothed_edges[1:]-smoothed_edges[:-1])
    bin_edges=np.repeat(smoothed_edges,2)
    bin_heights=np.r_[0.,np.repeat(height,2),0.]

    return bin_edges, bin_heights

In [96]:
np.random.seed(1234)
c = np.array([[-3,-3, -3], [3,3,3], [-3,3,1], [3,-3,2]])
n = np.random.randint(15,20,c.size)
p=[]
for i in range(c.shape[0]):
    p.append(np.random.normal(loc=c[i,:],scale=1,size=(n[i],3)))
p=jnp.concatenate(p)

In [97]:
encoder = stax.serial(Dense(5), Selu,
                    Dense(1), Tanh)
decoder = stax.serial(Dense(5), Selu,
                      Dense(3))

init_model, model = stax.serial(encoder,decoder)

In [98]:
key = jax.random.PRNGKey(1234)
_, init_model_params = init_model(key,(-1,3))

opt_init, opt_update, get_params = optimizers.adam(0.1)
opt_state = opt_init(init_model_params)
n_step = it.count()

In [99]:
def loss(params, batch):
    return jnp.sum( (model(params,batch)-batch)**2 )

In [100]:
loss_grad=jax.jit(jax.value_and_grad(loss))

In [101]:
for i in range(1000):
    params=get_params(opt_state)
    value,grad = loss_grad(params, p)
    opt_state = opt_update(next(n_step), grad, opt_state)
    if i%100 == 0:
        print(i,value)

0 2192.3467
100 290.24484
200 187.74948
300 154.93742
400 155.48294
500 155.28333
600 168.55212
700 149.25386
800 201.08876
900 139.98083


In [102]:
encoded = jnp.ravel(encoder[1](params[0][:4],p))
x,y=equal_area_histogram(encoded)
fig2=make_subplots()
fig2.add_scatter(x=x, y=y, mode='lines')
fig2.update_layout(width=800,height=600)

In [103]:
idx, desc = find_peaks(y,0.1)
encoded_peaks=x[idx]
cluster=np.argmin(np.abs(encoded[:,None]-encoded_peaks[None,:]),axis=1)

In [104]:
colors=np.array(['blue','green','white','yellow'])
fig3=make_subplots()
fig3.add_scatter3d(x=p[:,0],y=p[:,1],z=p[:,2],text=np.arange(p.shape[0]),mode='text', textfont_color=colors[cluster], textposition='middle center')
fig3.update_layout(width=800,height=800, showlegend=False)