In [7]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import scipy
import matplotlib.pyplot as plt
import numpy as np
import time
import tensorflow as tf
import dsn.lib.LowRank.Fig1_Spontaneous.fct_mf as mf
from dsn.util.systems import LowRankRNN
from dsn.util.tf_DMFT_solvers import rank2_CDD_static_solve_np
from dsn.util.system import get_warm_start_dir

In [33]:
def warm_start(system):
    assert(system.name == 'LowRankRNN')
    ws_filename = get_warm_start_dir(system)
    warm_start_solve_its = 10
    rank = system.model_opts['rank']
    behavior_type = system.behavior['type']
    min_bound = np.min(system.a)
    max_bound = np.max(system.b)
    step = system.warm_start_grid_step
    grid_vals = np.arange(min_bound, max_bound+step, step)
    m = grid_vals.shape[0]**system.D
    if (rank == 2 and behavior_type == "CDD"):
        print('Rank2 CDD static warm start, %d total queries' % m)
        grid_vals_list = []
        for param in system.all_params:
            if (param in system.free_params):
                vals = grid_vals.copy()
            else:
                vals = np.array([system.fixed_params[param]])
            grid_vals_list.append(vals)
    grid = np.array(np.meshgrid(*grid_vals_list))
    grid = np.reshape(grid, (len(system.all_params), m))
    
    cA = np.zeros((m,))
    cB = np.zeros((m,))
    
    kappa1_init = -5.0*np.ones((m,))
    kappa2_init = -5.0*np.ones((m,))
    delta0_init = 5.0*np.ones((m,))
    kappa1, kappa2, delta_0, z = rank2_CDD_static_solve_np(kappa1_init, 
                                                           kappa2_init,
                                                           delta0_init,
                                                           cA,
                                                           cB,
                                                           grid[0],
                                                           grid[1],
                                                           grid[2],
                                                           grid[3],
                                                           grid[4],
                                                           grid[5],
                                                           grid[6],
                                                           warm_start_solve_its,
                                                           system.solve_eps,
                                                           db=False)
    
    solution_grid = np.stack((kappa1, kappa2, delta_0), axis=1)
                                                           
    np.savez(ws_filename, param_grid=grid, solution_grid=solution_grid)
    return kappa1

In [34]:
# create an instance of the V1_circuit system class
fixed_params = {'g':0.8, 'gammaLO':-0.14, 'gammaHI':0.08}

behavior_type = "CDD"
means = np.array([0.3])
variances = np.array([0.0001])
behavior = {"type": behavior_type, "means": means, "variances": variances}

# set model options
model_opts = {"rank": 2, "input_type": "input"}

solve_its = 500
solve_eps = 0.2
system = LowRankRNN(
    fixed_params, behavior, model_opts=model_opts, solve_its=solve_its, solve_eps=solve_eps
)


In [37]:
x = warm_start(system)
print(x)

Rank2 CDD static warm start, 81 total queries
[-1.20038031 -0.72459362 -0.34695227 -0.75773954 -0.75773954 -0.75773954
 -0.34695227 -0.72459362 -1.20038031 -0.97059225 -0.53687091 -0.2337172
 -0.53687091 -0.53687091 -0.53687091 -0.2337172  -0.53687091 -0.97059225
 -0.7555112  -0.38304103 -0.15665055 -0.36508035 -0.36508035 -0.36508035
 -0.15665055 -0.38304103 -0.7555112  -1.04141191 -0.51728886 -0.18415687
 -0.51229822 -0.51229822 -0.51229822 -0.18415687 -0.51728886 -1.04141191
 -1.05804576 -0.53687091 -0.20639246 -0.53687091 -0.53687091 -0.53687091
 -0.20639246 -0.53687091 -1.05804576 -1.07465208 -0.5564253  -0.22860277
 -0.56140679 -0.56140679 -0.56140679 -0.22860277 -0.5564253  -1.07465208
 -0.72569087 -0.35049938 -0.12070013 -0.32764435 -0.32764435 -0.32764435
 -0.12070013 -0.35049938 -0.72569087 -0.96829095 -0.53687091 -0.23468875
 -0.53687091 -0.53687091 -0.53687091 -0.23468875 -0.53687091 -0.96829095
 -1.22342323 -0.75640527 -0.38638002 -0.79539232 -0.79539232 -0.79539232
 -0.38