In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tf_util.systems import system_from_str
from train_dsn import train_dsn
import seaborn as sns
import pandas as pd

from util import fct_integrals as integrals
from util import tf_integrals as tf_integrals

from util import fct_mf as mf



In [5]:
TIF_flow_type = 'PlanarFlowLayer';
nlayers = 10;
flow_dict = {'latent_dynamics':None, \
             'TIF_flow_type':TIF_flow_type, \
             'repeats':nlayers, \
             'scale_layer':False};

n = 1000;
k_max = 1;
c_init = 1e4;
check_rate = 100;
max_iters = 10000;
lr_order = -3;
random_seed = 0;

def compute_bistable_mu(Sini, ics_0, ics_1):
    ### Set parameters

    Mm = 3.5      # Mean of m
    Mn = 1.       # Mean of n
    Mi = 0.       # Mean of I

    Sim = 1.      # Std of m
    Sin = 1.      # Std of n
    Sip = 1.      # Std of input orthogonal to m and n, along h (see Methods)

    g = 0.8
    tol = 1e-10;
    
    eps = 0.2;
    
    ParVec = [Mm, Mn, Mi, Sim, Sin, Sini, Sip];
    ys0, count = mf.SolveStatic(ics_0, g, ParVec, eps, tol);
    ys1, count = mf.SolveStatic(ics_1, g, ParVec, eps, tol);
    
    ss0 = ys0[-1,2];
    ss1 = ys1[-1,2];
    mu = np.array([ss0, ss1]);
    return mu;

In [None]:
system_str = 'rank1_rnn';

K = 1;
M = n;
behavior_str = 'bistable';
D = 10;
T = 10;
Sini = 0.5;
ics_0 = np.array([5., 5., 5.], np.float64);
ics_1 = np.array([-5., 5., -5.], np.float64);

Ics_0 = np.tile(np.expand_dims(np.expand_dims(ics_0, 0), 1), [K,M,1]);
Ics_1 = np.tile(np.expand_dims(np.expand_dims(ics_1, 0), 1), [K,M,1]);

system_class = system_from_str(system_str);
system = system_class(D, T, Sini, Ics_0, Ics_1, behavior_str);

mu = compute_bistable_mu(Sini, ics_0, ics_1);
Sigma = 0.01*np.ones((2,));
behavior = {'mu':mu, 'Sigma':Sigma};



cost, phi, T_x = train_dsn(system, behavior, n, flow_dict, \
                       k_max=k_max, c_init=c_init, lr_order=lr_order, check_rate=check_rate, \
                       max_iters=max_iters, random_seed=random_seed);

