In [4]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tf_util.systems import system_from_str
from train_dsn import train_dsn
import seaborn as sns
import pandas as pd

from util import fct_integrals as integrals
from util import tf_integrals as tf_integrals

from util import fct_mf as mf



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
TIF_flow_type = 'AffineFlowLayer';
nlayers = 1;
flow_dict = {'latent_dynamics':None, \
             'TIF_flow_type':TIF_flow_type, \
             'repeats':nlayers};

n = 1000;
k_max = 1;
c_init = 1e0;
check_rate = 1;
max_iters = 10000;
lr_order = -3;
random_seed = 0;

def compute_bistable_mu(Sini, ics_0, ics_1):
    ### Set parameters

    Mm = 3.5      # Mean of m
    Mn = 1.       # Mean of n
    Mi = 0.       # Mean of I

    Sim = 1.      # Std of m
    Sin = 1.      # Std of n
    Sip = 1.      # Std of input orthogonal to m and n, along h (see Methods)

    g = 0.8
    tol = 1e-10;
    
    eps = 0.2;
    
    ParVec = [Mm, Mn, Mi, Sim, Sin, Sini, Sip];
    ys0, count = mf.SolveStatic(ics_0, g, ParVec, eps, tol);
    ys1, count = mf.SolveStatic(ics_1, g, ParVec, eps, tol);
    
    ss0 = ys0[-1,2];
    ss1 = ys1[-1,2];
    mu = np.array([ss0, ss1]);
    return mu;

In [11]:
system_str = 'rank1_rnn';

K = 1;
M = n;
behavior_str = 'bistable';
D = 20;
T = 15;
Sini = 0.5;
ics_0 = np.array([5., 5., 5.], np.float64);
ics_1 = np.array([-5., 5., -5.], np.float64);

Ics_0 = np.tile(np.expand_dims(np.expand_dims(ics_0, 0), 1), [K,M,1]);
Ics_1 = np.tile(np.expand_dims(np.expand_dims(ics_1, 0), 1), [K,M,1]);

system_class = system_from_str(system_str);
system = system_class(D, T, Sini, Ics_0, Ics_1, behavior_str);

mu = compute_bistable_mu(Sini, ics_0, ics_1);
Sigma = 0.01*np.ones((2,));
behavior = {'mu':mu, 'Sigma':Sigma};



cost, phi, T_x = train_dsn(system, behavior, n, flow_dict, \
                       k_max=k_max, c_init=c_init, lr_order=lr_order, check_rate=check_rate, \
                       max_iters=max_iters, random_seed=random_seed);



c_init 1.0
results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/
0 <tf_util.flows.AffineFlowLayer object at 0x134dca550>
<tensorflow.python.ops.init_ops.VarianceScaling object at 0x134dca908>
<tensorflow.python.ops.init_ops.VarianceScaling object at 0x134dcac50>
zshapes in
connect flow
(?, ?, 20, ?)
0 AffineFlow_Layer1
(?, ?, 20, ?)
(3, ?, 1000)
(3, ?, 1000)
(?, 1000, 2, 3)
training DSN for rank1_rnn: dt=0.001, T=15
AL iteration 1
resetting optimizer
aug lag it 0
saving model at iter 0
******************************************
it = 2 
H 18.010675714871578
cost -16.24392029821312
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 3 
H 18.478431014007306
cost -16.74109200339578
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 4 
H 18.6402302536462

******************************************
******************************************
it = 38 
H 23.722898398250422
cost -22.197655916913597
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 39 
H 23.94106399256207
cost -22.418497086613733
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 40 
H 24.137817632178642
cost -22.61640181492715
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 41 
H 24.00823337579169
cost -22.488117559655496
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 42 
H 24.321831152383517
cost -22.803046

******************************************
it = 76 
H 26.692069840868086
cost -25.21202321802195
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 77 
H 26.555885012306828
cost -25.076707672170045
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 78 
H 26.714992561831323
cost -25.235911568240397
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 79 
H 26.69546209050839
cost -25.216507695094993
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 80 
H 26.82605711910464
cost -25.348191675985557
saving to results//tb/rank1_rnn_D

******************************************
it = 114 
H 28.220832189534875
cost -26.777398792478827
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 115 
H 28.341184142244064
cost -26.89976977931504
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 116 
H 28.239082500092803
cost -26.795168978940282
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 117 
H 28.4342957851806
cost -26.994320436752588
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 118 
H 28.541889405048877
cost -27.101451341357688
saving to results//tb/rank1_

******************************************
it = 152 
H 28.791876313067462
cost -27.4436677142207
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 153 
H 28.78352887341094
cost -27.43600942332567
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 154 
H 28.9587611179335
cost -27.619781352992497
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 155 
H 28.797026738300044
cost -27.47744992684861
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 156 
H 28.881949829249724
cost -27.555997889841702
saving to results//tb/rank1_rnn_

******************************************
it = 190 
H 28.30266002667473
cost -27.156340438068003
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 191 
H 28.423906735873675
cost -27.292289644372982
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 192 
H 28.38510646460362
cost -27.230025577497667
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 193 
H 28.35786661441918
cost -27.23507682185928
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 194 
H 28.482244795058484
cost -27.326060337660692
saving to results//tb/rank1_r

******************************************
it = 228 
H 28.127576909679707
cost -27.00098620998157
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 229 
H 28.419506977102767
cost -27.29497059467319
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 230 
H 28.1196394030075
cost -26.990014297407885
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 231 
H 28.211542989092866
cost -27.07999358574762
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 232 
H 28.090634460603376
cost -26.956224540596857
saving to results//tb/rank1_rn

******************************************
it = 266 
H 28.520281543018754
cost -27.39075459379322
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 267 
H 28.466747216787983
cost -27.34231657573283
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 268 
H 28.428607336782793
cost -27.304795897600403
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 269 
H 28.549225716412753
cost -27.426736068951605
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 270 
H 28.438548322627593
cost -27.31449227137956
saving to results//tb/rank1_

******************************************
it = 304 
H 28.700240440256845
cost -27.57165139465474
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 305 
H 28.3701972969775
cost -27.247413724246115
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 306 
H 28.634198619632283
cost -27.516926695630424
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 307 
H 28.725739986965568
cost -27.609640533320484
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 308 
H 28.52058053451481
cost -27.39909567277496
saving to results//tb/rank1_rn

******************************************
it = 342 
H 28.63703896096011
cost -27.518833319603196
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 343 
H 28.856026203693688
cost -27.742687671040656
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 344 
H 28.72025923324721
cost -27.596981743226234
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 345 
H 28.78546929349571
cost -27.681145698833234
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 346 
H 28.739665482806814
cost -27.62315540287498
saving to results//tb/rank1_r

******************************************
it = 380 
H 28.80828232520465
cost -27.676097018676224
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 381 
H 28.91246701365392
cost -27.79930227956073
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 382 
H 28.74168765713769
cost -27.62111130211811
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 383 
H 28.67986167624004
cost -27.552109925985107
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 384 
H 28.84967014945731
cost -27.715358134533048
saving to results//tb/rank1_rnn_

******************************************
it = 418 
H 28.92577703895547
cost -27.80455604713487
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 419 
H 28.789382836930198
cost -27.670235384756133
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 420 
H 28.90687286004106
cost -27.788342405764933
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 421 
H 28.875619359085835
cost -27.763510821433144
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 422 
H 28.77096175224014
cost -27.638922045133686
saving to results//tb/rank1_r

******************************************
it = 456 
H 29.023791905389853
cost -27.91750554388494
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 457 
H 29.009159696169498
cost -27.89074015524842
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 458 
H 29.06088496501899
cost -27.933632385795452
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 459 
H 28.903777812643487
cost -27.773046103147696
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 460 
H 28.9258240108401
cost -27.804101610213426
saving to results//tb/rank1_rn

******************************************
it = 494 
H 28.85605569149928
cost -27.736067508998584
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 495 
H 29.082776683496707
cost -27.979213257016287
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 496 
H 29.123011064104716
cost -28.010712313508403
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 497 
H 29.1653135981086
cost -28.05877554711385
saving to results//tb/rank1_rnn_D=20_T=15_flow=1A_lr_order=-3_c=0_rs=0/  ...
******************************************
******************************************
it = 498 
H 29.13960360916892
cost -28.022116016348654
saving to results//tb/rank1_rn

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2862, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-11-bbfe1718a993>", line 24, in <module>
    cost, phi, T_x = train_dsn(system, behavior, n, flow_dict,                        k_max=k_max, c_init=c_init, lr_order=lr_order, check_rate=check_rate,                        max_iters=max_iters, random_seed=random_seed);
  File "/Users/sbittner/Documents/dsn/train_dsn.py", line 230, in train_dsn
    _H, _T_x, _phi, _log_q_x = sess.run([H, T_x, phi, log_q_x], feed_dict);
  File "/Users/sbittner/Library/Python/3.6/lib/python/site-packages/tensorflow/python/client/session.py", line 900, in run
    run_metadata_ptr)
  File "/Users/sbittner/Library/Python/3.6/lib/python/site-packages/tensorflow/python/client/session.py", line 1135, in _run
    feed_dict_tensor, options, run_metadata)
  File "/Us

KeyboardInterrupt: 