# 0. Initial stuff

In [87]:
%reset -f
%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [88]:
import sys
sys.path.append('../')

In [89]:
import numpy as np
import scipy.io as sio
import sklearn as sk

from sklearn import decomposition
from sklearn import metrics
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import tensorflow as tf

# For jupyter tensorboarding
from io import BytesIO
from functools import partial
import PIL.Image
from ipywidgets import FloatProgress

from IPython.display import clear_output, Image, display, HTML

import pickle
import os
import pprint as pp

from dtw import dtw
from cdtw import pydtw

from tqdm import tqdm # progress bar

In [90]:
import cyclingrnn.geometric as geo
from cyclingrnn import sigerr
from cyclingrnn.train import *

In [91]:
# sns colormaps
cmap_rdbu = sns.color_palette('RdBu',5)[:2] + sns.color_palette('RdBu',5)[-2:]
#sns.set_palette(cmap)
sns.set_context('notebook', font_scale=1.0)

In [92]:
set1 = sns.color_palette("Set1")

In [93]:
def logify(df_, logs):
  for log in logs:
    df_[log] = np.log(df_[log])
  return df_

In [140]:
def giant_regplot(df, cols=None, rows=None, logs=[]):

  if not cols:
    cols = df.columns
    for col in cols:
      try:
        if len(df[col].unique())==1: # this returns an error when df[col] is an e.g. list
          df = df.drop(col, axis=1) # drop columns with all identical entries
      except:
        df = df.drop(col, axis=1) # if above returned an error, drop
    cols = df.columns # now we have our column set
  if not rows:
    rows = df.columns
  
  isnumeric_col = [np.any([isinstance(i, (int, long, float, complex)) for i in df[c]]) for c in cols]  
  isnumeric_row = [np.any([isinstance(i, (int, long, float, complex)) for i in df[r]]) for r in rows]  
  
  fact = 2.5
  num_cols, num_rows = len(cols), len(rows)
  f, ax = plt.subplots(num_rows, num_cols, figsize=(fact*num_cols, fact*num_rows), sharey=False, sharex=False)
  for ir, row in enumerate(rows):
    for ic, col,  in enumerate(cols):
      logx_bool = col in logs
      if logx_bool:
        ax[ir, ic].set_xscale('log', basex=10)
      logy_bool = row in logs
      if logy_bool:
        ax[ir, ic].set_yscale('log', basex=10)

      if np.logical_xor(isnumeric_row[ir], isnumeric_col[ic]):
        sns.violinplot(x=col, y=row, data=df, ax=ax[ir, ic])
      elif (isnumeric_row[ir] and isnumeric_col[ic]):
        #sns.regplot(col, row, df, logx=logx_bool, ax=ax[ir, ic], truncate=False)
        #ax[ir, ic].scatter(df[col], df[row], c=df.activation=='linear', cmap=mplcmap)
        ax[ir, ic].scatter(df[col], df[row], c=set1[0])
        sns.regplot(col, row, df, logx=logx_bool, ax=ax[ir, ic], truncate=False, scatter=False, color="0.3")
        d_range = df[row].max() - df[row].min()
        #ax[ir, ic].set_ylim([df[row].min()-0.1*d_range, df[row].max()+0.1*d_range])
        #ax[ir, ic].set_xlim([df[col].min(), df[col].max()])
        
        try:
          ax[ir, ic].axhline(m1_metrics[row], color=set1[2])
          ax[ir, ic].axhline(emg_metrics[row], color=set1[1])
        except:
          pass

      if ir < num_rows-1:
        ax[ir, ic].xaxis.label.set_visible(False)
        ax[ir, ic].set_xticklabels([])
      if ic > 0:
        ax[ir, ic].yaxis.label.set_visible(False)
        ax[ir, ic].set_yticklabels([])
        
  return f, ax


In [141]:
def evaluate_run(conds, monkey, m1_or_emg):
  # Get monkey
  # build input and output data
  if monkey == 'D':
    try:
      data = sio.loadmat('./drakeFeb_processed.mat') #TODO: fix, '../' or './' depending on whether running from wrapper or not
    except:
      data = sio.loadmat('../drakeFeb_processed.mat')
  elif monkey == 'C':
    try:
      data = sio.loadmat('./cousFeb_processed.mat')
    except:
      data = sio.loadmat('../cousFeb_processed.mat')
  
  emg = data['EMG']
  m1 = data['M1']

  m1_ = m1[:, conds, :]
  emg_ = emg[:, conds, :]

  if m1_or_emg == 'm1':
    array = m1_
  elif m1_or_emg == 'emg':
    array = emg_
  
  mets = dict()
  R2 = 1.

  mets['sim_num'] = 0

  mets['tangling_90_01']  = geo.tangling_cdf( array, cutoff=0.90, alpha=0.1  )
  mets['tangling_90_001'] = geo.tangling_cdf( array, cutoff=0.90, alpha=0.01 )
  mets['tangling_95_01']  = geo.tangling_cdf( array, cutoff=0.95, alpha=0.1  )
  mets['tangling_95_001'] = geo.tangling_cdf( array, cutoff=0.95, alpha=0.01 )

  mets['path_length'] = np.sum(geo.get_path_length(array, filt_freq=0.25))
  mets['mean_curvature'], mets['mean_torsion'], _ = geo.mean_curvature(array, total_points=11, deg=4, normalize=True)
  
  mets['MSE'] = 0.
  mets['R2'] = R2

  #mets['noise_robustness'], mets['struct_robustness'] = 0., 0.
    
  return mets


# Visualize results

In [142]:
sns.set_context('paper', font_scale=1.2)

In [143]:
df_path = '../saves/170305C/df0123.pickle'
df = pickle.load(open(df_path))

conds = [0,1,2,3]
monkey = 'C'

m1_metrics = evaluate_run(conds, monkey, 'm1')
emg_metrics = evaluate_run(conds, monkey, 'emg')

In [144]:
df = df[df.activation == 'tanh']
df = df[df.tangling_95_01 < 3000]
#df = df[df.beta2 < 1e-3]
#df = df[df.mean_curvature < 400]

In [145]:
df.columns

Index([u'activation', u'beta0', u'beta1', u'beta2', u'learning_rate',
       u'monkey', u'num_neurons', u'rnn_init', u'stddev_out', u'stddev_state',
       u'MSE', u'R2', u'noise_robustness', u'noise_robustness_r2',
       u'path_length', u'percent_tangling1_001', u'percent_tangling1_01',
       u'percent_tangling2_001', u'percent_tangling2_01',
       u'percent_tangling3_001', u'percent_tangling3_01', u'sim_num',
       u'struct_robustness', u'struct_robustness_r2', u'tangling_90_001',
       u'tangling_90_01', u'tangling_95_001', u'tangling_95_01',
       u'mean_curvature', u'mean_torsion'],
      dtype='object')

In [146]:
cols = ['beta0', 'beta1', 'beta2', 'stddev_state']
rows = ['tangling_90_01', 'struct_robustness', 'noise_robustness', 'mean_curvature', 'mean_torsion', 'path_length']
logs = ['beta0', 'beta1', 'beta2', 'stddev_state', 'learning_rate', 'struct_robustness', 'noise_robustness', 'noise_robustness_r2', 'struct_robustness_r2']#, 'noise_robustness', 'struct_robustness']

In [147]:
import matplotlib as mpl
mplcmap = mpl.colors.ListedColormap(sns.color_palette().as_hex()[:2])

In [148]:
for log in logs:
  df[log] = np.log10(df[log])

In [149]:
def ax_equal(ax):
  xlo = np.zeros(ax.shape)
  xhi = np.zeros(ax.shape)
  ylo = np.zeros(ax.shape)
  yhi = np.zeros(ax.shape)
  
  for r in range(ax.shape[0]):
    for c in range(ax.shape[1]):
      xlo[r,c], xhi[r,c] = ax[r,c].get_xlim()
      ylo[r,c], yhi[r,c] = ax[r,c].get_ylim()
  
  yhi = np.min(yhi, axis=1)
  ylo = np.max(ylo, axis=1)
  
  for r in range(ax.shape[0]):
    for c in range(ax.shape[1]):
      ax[r,c].set_ylim([ylo[r], yhi[r]])

In [150]:
f, ax = giant_regplot(df, cols, rows, [])

<matplotlib.figure.Figure at 0x12302ee10>

------------

In [151]:
new_titles = {}
new_titles['tangling_90_01'] = 'Tangling (0.9 cdf)'
new_titles['struct_robustness'] = 'Structural Robustness ($\log_{10}$)'
new_titles['noise_robustness'] = 'Noise Robustness ($\log_{10}$)'
new_titles['mean_curvature'] = 'Mean Curvature'
new_titles['mean_torsion'] = 'Mean Torsion'
new_titles['path_length'] = 'Path Length'

new_titles['beta0'] = '$\lambda_x\;(\log_{10})$'
new_titles['beta1'] = '$\lambda_A\;(\log_{10})$'
new_titles['beta2'] = '$\lambda_C\;(\log_{10})$'
new_titles['stddev_state'] = '$\sigma_w\;(\log_{10})$'

see http://stackoverflow.com/questions/15538099/conversion-of-unicode-minus-sign-from-matplotlib-ticklabels

In [152]:
def fix_logs(ax, logs):
  for r in range(ax.shape[0]):
    for c in range(ax.shape[1]):
      if ax[r,c].get_ylabel() in logs:
        labels = [item.get_text() for item in ax[r,c].get_yticklabels()]
        for i,j in enumerate(labels):
          labels[i] = '$10^'+j+'$'
        ax[r,c].set_yticklabels(labels)
      if ax[r,c].get_xlabel() in logs:
        labels = [item.get_text() for item in ax[r,c].get_xticklabels()]
        for i,j in enumerate(labels):
          labels[i] = '$10^'+j+'$'
        ax[r,c].set_xticklabels(labels)

In [153]:
def fix_titles(ax, new_titles):
  for r in range(ax.shape[0]):
    for c in range(ax.shape[1]):
      for key,val in new_titles.items():
        if ax[r,c].get_ylabel() == key:
          ax[r,c,].set_ylabel(val)
        if ax[r,c].get_xlabel() == key:
          ax[r,c,].set_xlabel(val)

In [154]:
# fix y axes
ax_equal(ax)
#fix_logs(ax, logs)
fix_titles(ax, new_titles)

In [155]:
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('pdf')

In [156]:
f

<matplotlib.figure.Figure at 0x12302ee10>

In [173]:
cols = ['struct_robustness','noise_robustness']
rows = ['mean_curvature', u'tangling_90_01']

#logs = ['struct_robustness','noise_robustness', 'beta0', 'beta1', 'beta2']

In [174]:
f, ax = giant_regplot(df, cols, rows,[])
#f.suptitle('monkey D, conds 1-4')

<matplotlib.figure.Figure at 0x1388324d0>

In [175]:
# fix y axes
ax_equal(ax)
#fix_logs(ax, logs)
fix_titles(ax, new_titles)

In [176]:
f

<matplotlib.figure.Figure at 0x1388324d0>

In [None]:
# pick good RNNs
i1 = df.activation == 'tanh'
i2 = df.tangling_95_01 < m1_metrics['tangling_95_01'] + 200# and df.tangling_95_01 > (m1_metrics['tangling_95_01'] - 300)
i3 = df.tangling_95_01 > m1_metrics['tangling_95_01'] - 200

i4 = df.mean_curvature < m1_metrics['mean_curvature'] + 50# and df.tangling_95_01 > (m1_metrics['tangling_95_01'] - 300)
i5 = df.mean_curvature > m1_metrics['mean_curvature'] - 50

In [None]:
df_ = df[i1 & i2 & i3 & i4 & i5]

In [None]:
df_.sim_num