In [None]:
import numpy as np
import pandas as pd
from scipy.optimize import fmin_ncg
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import csv
import time
import argparse
from copyfile import ThreeBitRNN, genxy

HIDDEN_SIZE = 100

In [None]:
# create special world tour trajectory and extract hidden states (condition-averaged)
model = ThreeBitRNN(hidden_size=HIDDEN_SIZE)
model.load_state_dict(torch.load('../rnndata/sifoaij27.pkl'))
n = 101
hids = np.zeros((320,HIDDEN_SIZE))
traj = np.zeros((320,3))
traj[20,0]=1
traj[40,1]=1
traj[60,0]=-1
traj[80,1]=-1
traj[100,2]=1
traj[120,1]=1
traj[140,0]=1
traj[160,1]=-1
traj[180,2]=-1
traj[200,1]=1
traj[220,2]=1
traj[240,0]=-1
traj[260,1]=-1
traj[280,2]=-1
traj[300,2]=-1
# np.savetxt("rnndata/worldtour_inputs.csv", traj, delimiter=',')
for i in range(n): # condition averaging
    trajvar = Variable(torch.Tensor(traj), requires_grad=False)
    model.set_hidden(Variable(torch.zeros(1,1,HIDDEN_SIZE)))
    hids += model.all_hiddens(trajvar).data.numpy()
hids /= n

In [None]:
W = model.rnn.state_dict()['weight_hh_l0'].numpy()
b = model.rnn.state_dict()['bias_hh_l0'].numpy()
bi = model.rnn.state_dict()['bias_ih_l0'].numpy()

# W = np.load('../rnndata/weight_hh_l0.npy')
# b = np.load('../rnndata/bias_hh_l0.npy')
# bi = np.load('../rnndata/bias_ih_l0.npy')

N = W.shape[0]

def f(x):
    r = np.tanh(W@x+b+bi)
    dx = -x+r
    return 0.5*(dx.T@dx)

## optimization functions / helpers
def f(x):
#     r = np.tanh(x)
#     dx = -x+W@r
    r = np.tanh(W@x+b+bi)
    dx = -x+r
    return 0.5*(dx.T@dx)

def grad_f(x):
#     r = np.tanh(x)
#     dx = -x+W@r
    r = np.tanh(W@x+b+bi)
    dx = -x+r
    d1 = (1-np.power(r,2)).reshape(N,1)
    h = ((W.T * (d1@np.ones([1,N]))) - np.identity(N))
    return h@dx

def hess_f(x):
#     r = np.tanh(x)
#     dx = -x+W@r
    r = np.tanh(W@x+b+bi)
    dx = -x+r
    d1 = (1-np.power(r,2)).reshape(N,1)
    h = ((W.T * (d1@np.ones([1,N]))) - np.identity(N))
    return h@h.T

In [None]:
## projection function with SVD
def PCA_project(X, modes, plot_SVs=True):
    # SVD of X
    U, s, VT = np.linalg.svd(X)
    
    # construct S matrix with S_vector -- there must be a better way to do this?
    S = np.zeros([U.shape[0],VT.shape[0]])
    for i in range(modes):
        S[i,i] = s[i]

#     print("shapes", X.shape, U.shape,S.shape,VT[:,:modes].shape, (U@S@VT[:,:modes]).shape)
    if plot_SVs:
        plt.plot(np.log(s))
        plt.show()
        
    d['U']=U
    d['S']=S
    d['VT']=VT
    return U@S@VT[:,:modes]

In [None]:
# hids = pd.read_csv('../rnndata/bounce.csv', header=None).values

d = {}
trajectory_p = PCA_project(hids,3) # hids

plt.figure(1)

# traj_projected_split = np.split(traj_projected,n_dsets)
# for i in range(len(traj_projected_split)):
#     plt.subplot(311+i)
#     for traj in traj_projected_split[i].T:
#         plt.plot(np.arange(traj.shape[0]), traj)

for tra in trajectory_p.T:
    plt.plot(np.arange(tra.shape[0]), tra)

plt.show()

In [None]:
d['VT'].shape

In [None]:
trajectory_p.shape, hids.shape

In [None]:
## use trajectories to find nearby fixed pts
# xs = traj_c
xopts = []
for x in hids:
    xopt = fmin_ncg(f, x, grad_f, fhess=hess_f, avextol=1e-10)
    xopts.append(xopt)
    
xopts = np.array(xopts)

In [None]:
np.vstack([hids,xopts]).shape

In [None]:
## concatenate trajectories and fixed points, project into 3D, and separate
divide = hids.shape[0]
# concat=np.hstack([trajectory,xopts.T])
proj_concat = PCA_project(np.vstack([hids,xopts]),3)
trajectory_p = proj_concat[:divide]
# traj_projected_split = np.split(traj_projected,3)
xopts_p = proj_concat[divide:]
print('shapes:', trajectory_p.shape, xopts_p.shape)

## alternately, SVD traj_c and project fixed points accordingly. ## not sure this is correct.
# U, s, VT = np.linalg.svd(traj_c)
# modes = 3

# S = np.zeros([U.shape[0],VT.shape[0]])
# for i in range(modes):
#     S[i,i] = s[i]

# xopts_projected2 = (U.T[:3,:]@xopts.T).T/s[:3]

In [None]:
import plotly
import plotly.plotly as py
import plotly.graph_objs as go
import pandas as pd
import numpy as np
import random
plotly.tools.set_credentials_file(username='elbertgong', api_key='••••••••••')

In [None]:


r = lambda: random.randint(0,255)
color = '#%02X%02X%02X' % (r(),r(),r())

# fixed points projected by U of trajectories
x = pd.Series(xopts_p[:,0])
y = pd.Series(xopts_p[:,1])
z = pd.Series(xopts_p[:,2])

fps = go.Scatter3d(x=x, y=y, z=z,
    mode='markers',
    marker=dict(
        size=12,
        color=z,                # set color to an array/list of desired values
        colorscale='Viridis',   # choose a colorscale
        opacity=0.3
    )
)

color = '#%02X%02X%02X' % (r(),r(),r())

x = pd.Series(trajectory_p[:,0])
y = pd.Series(trajectory_p[:,1])
z = pd.Series(trajectory_p[:,2])

trace = go.Scatter3d(x=x, y=y, z=z,
    marker=dict(
        size=4,
        color=color,
        colorscale='Viridis',
    ),
    line=dict(
        color=color,
        width=1
    )
)

data = [fps,trace]
# data = [trace]

layout = dict(
    width=800,
    height=700,
    autosize=True,
    title='3 bit flip flop',
    scene=dict(
        xaxis=dict(
            gridcolor='rgb(255, 255, 255)',
            zerolinecolor='rgb(255, 255, 255)',
            showbackground=True,
            backgroundcolor='rgb(230, 230,230)'
        ),
        yaxis=dict(
            gridcolor='rgb(255, 255, 255)',
            zerolinecolor='rgb(255, 255, 255)',
            showbackground=True,
            backgroundcolor='rgb(230, 230,230)'
        ),
        zaxis=dict(
            gridcolor='rgb(255, 255, 255)',
            zerolinecolor='rgb(255, 255, 255)',
            showbackground=True,
            backgroundcolor='rgb(230, 230,230)'
        ),
        camera=dict(
            up=dict(
                x=0,
                y=0,
                z=1
            ),
            eye=dict(
                x=-1.7428,
                y=1.0707,
                z=0.7100,
            )
        ),
        aspectratio = dict( x=1, y=1, z=0.7 ),
        aspectmode = 'manual'
    ),
)

fig = dict(data=data, layout=layout)

# plotly.offline.iplot(fig, filename='3bit_fps_plz')#, height=700, validate=False)
# plot(fig)

In [None]:
# init_notebook_mode(connected=True)
from plotly.graph_objs import *
plotly.offline.plot(fig, filename='3bit_fps_plz')#, height=700, validate=False)
# plot(fig)

In [None]:
from mpl_toolkits.mplot3d import Axes3D
# mpl.rcParams['legend.fontsize'] = 10

fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot(trajectory_p[:,0],trajectory_p[:,1],zs=trajectory_p[:,2])
# ax.set_xlim(-.2,.2)
# ax.set_ylim(-.6,0)
# ax.set_zlim(.1,.5)
ax.legend()
plt.show()

In [None]:
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster

# number of clusters
k = 7

# computes the distance between points, merges closest based on the method
Z=linkage(xopts, method='ward');

# draws the resulting dendrogram from the clustering done by linkage
dendrogram(Z);
plt.show()
# gives the cluster that each original point belongs to.
idx = fcluster(Z,k,'maxclust');

In [None]:
Z=linkage(xopts_projected, method='ward');

# draws the resulting dendrogram from the clustering done by linkage
dendrogram(Z);
plt.show()