In [None]:
#%% Transitive Inference - Single model analysis  #########################

from __future__ import division
import numpy as np

import task_train_test      as tk
import nn_models            as nm
import nn_analyses          as nna
import ti_functions         as ti
import postproc             as pp

datadir             = 'model_files/'    #  Directory for storing trained models
Location            = 2                 #  1: local, 2: batch

In [None]:
#%%##### Train (or load) model  #######

if Location == 1:       # local 

    # file settings #
    Load_model              = False     # if file is available, load
    Save_model              = True      # save model to file after training

    # optional settings #
    jobname                 = None
    TID                     = -1        # -1 is dummy value

elif Location == 2:     # batch 

    # file settings #
    # (pre-trained models available at datadryad.org)
    if 1:                # f-RNN, simple
        jobname                 = 'ti350/'    
        TID                     = 205        # Fig. 2, performance
        # TID                   = 249        # Fig. 5 trajectory
    elif 0:             # f-RNN, complex 
        jobname                 = 'ti350/'     
        TID                     = 987        # Fig. 2, performance
    elif 1:               # r-RNN
        if 1:
            jobname             = 'ti351/'   # Fig. 2. (performance)
            TID                 = 10
        elif 1:
            jobname             = 'ti352/'  # Fig. S6 (trajectory)   only ff untrained
            TID                 = 32
    elif 0:             # r-RNN, complex 
        jobname                 = 'ti351/'          
        TID                     = 447       # Fig. 2 (performance)
    elif 0:      
        jobname                 = 'ti253/'  # LR    
        TID                     = 0
    elif 0:      
        jobname                 = 'ti254/'  # MLP    
        TID                     = 0

    Load_model              = True      # if file available, load
    Save_model              = False     # save model to file

currentdir              = %pwd
localdir                = '%s/%s' % (currentdir,datadir)

p, F, filename, Dataset, fps, Fps  = tk.get_model( Location, localdir, Load_model, Save_model, jobname, TID)
 
if not hasattr(p,'Task_version'):
    p.Task_version = 0


In [None]:
#%% Model summary ######################################

Model_select            = -4        # Specify which model to plot:  -1: model with lowest loss, -4: earliest trained model showing full generalization 

# select model #
model, model_epoch, model_ind, model_title  = tk.select_model( p, F, Model_select )
model_string                                = '%s \n %s' % (filename,model_title)       

# plot basic summary #
fig, axes = tk.plot_training_losses( F, p, filename, model_epoch ) # Plot Loss trajectory ####
if 'performance' not in F.keys() and p.Z == 3:
    F   = pp.run_Performance(p, F)       # updates F to have 'performance' 
tk.plot_training_performance( F, p, fig, axes, model_epoch)                 

In [None]:
#%% Report model behavior #############################
tk.print_model_performance( p, model, F, TT = 0, silent = False )   
ti.plot_model_summary( p, F, model, TID, model_epoch, model_string )     # plot performance on grid

In [None]:
#%% (if Feedforward model) Plot unit responses  ################
if p.T == 1:
    Out, Hid = nna.Unit_responses(p,F,model)
    if p.Model == 8:
        nna.Plot_output_unit_responses(p, Out)
    elif p.Model == 7 or p.Model == 9:
        nna.Plot_hidden_unit_responses(p, Hid, model)
        # nna.Plot_hidden_pca_responses(p, Hid)

In [None]:
#%% Get neural activity (data) from model (and calculate PCA) ###############
if p.T > 1:
    # where to calculate PCA from #
    PCA_option = 0   # 0: all conditions and times, 1: XCM-subtracted (cross-condition mean), 3: Spatial PCA, 1st item timestep 
    Data       = tk.Simulate_model( p, F, model, TT_run=range(17) )
    Data       = tk.Project_Data_PCA(p, F, Data, PCA_option, 0)
    nna.PCA_plot(Data)  # plots PCA var explained


In [None]:
#%% Readout plot #################################################

YSCALE          = 10        # scale of y-axis
TT_plot         = [0]       # timepoints to plot (0 is default trial)

if p.T > 1:
    ti.Readout_plot( p, Data, TT_plot, YSCALE ,  3, model_string )

In [None]:
#%% PC 1D #########################################################

YSCALE  = 5 #10
TT_plot  = [0]

ti.PCA_1D_plot( p, Data, TT_plot, YSCALE , model_string )


#%% PC 2D  ########################################################

groups_to_plot          = [2,1,0]  # (trial types) 0: Train, 1: Edges (contains end items), 2: Probes (inner trial types)
pcplist                 = np.array([[0,1],[0,2],[1,2]])   # List of PC pairs to plot
SCALE                   = 10
TT                      = 0

ti.PCA_2D(p, Data, SCALE, groups_to_plot, TT, pcplist, None)


#%% PC 3D  ########################################################

AXSCALE         = [6]      
groups_toplot   = [0,1,2]    # (trial types) 0: Train, 1: Edges (contains end items), 2: Probes (inner trial types)
pcs             = [0,1,2]    # PC axes to plot
COLOR_EMBED     = 6          # color of final timestep (choice)
TT_plot         = 0          # timepoints to plot (0 is default trial)

ax              = ti.PCA_3D(p, Data, AXSCALE, groups_toplot, TT_plot, pcs, COLOR_EMBED)


In [None]:
#%% Infer Linear Dynamics (OLS) from delay period activity ############
if p.T > 1:
    
    # Infer dynamics #
    LD =  ti.Linear_Dynamics_Delay_Period(p, F, Data, TT = 0, topdims = 10 )

    # Plot projected data on oscillatory basis #
    nna.Project_Simulated_Data( Data, LD['oscbasis'], ('data_emp','DATA_emp','basis_emp') )

    # Plot eigenvalues of linear dynamics #
    ti.Plot_Eigenvalues_Comparison( p, None , [LD['A']] )   # plot

    # Plot activity trajectory in oscillatory mode ##
    groups_to_plot      = [0,1,2] 
    scales              = [2, 0.8, .32]  # general
    ti.Plot_Oscillatory_2D(p, Data, None, scales, groups_to_plot, 0, None, None, Sys_to_plot=[2])

In [None]:
#%% Fixed point analyses ##############################################

# (if not calculated in file) Find Fixed points (FPF) #
if fps == None:

    # 1.  Find Fixed/Slow points #

        # search parms #
    add_rand_seeds     = False #True    # (optional) get random trials for seeds
    rand_numbatch      = 2       #  " " no. of batches to run  
    rand_batchsize     = 50      #  " " no. of seeds / batch
    numeps             = 10000  # no. of optimization epochs per batch
    numeps_stop        = 4000   # no. optimization epochs for improvement before stopping
        # detection parms #
    tol_q              = 4   #5.95      # speed in -log10
    tol_unique         = 0.000001
    detect_baseline    = True   # (optional) detect FP nearest the baseline, i.e. time of item 1
    detect_delay       = False   # (optional) detect FP nearest the baseline, i.e. time of item 1

    # 2. Run FPF #
    fps   = nna.FPF_search(model, p, F['X_input'],numepochs=numeps,numepochs_stop=numeps_stop,\
                            add_rand_seeds=add_rand_seeds,rand_numbatch=rand_numbatch,rand_batchsize=rand_batchsize)
    Fps   = nna.FPF_detect(fps, model, p, F['X_input'], pca = None, tol_q=tol_q, tol_unique=tol_unique, \
                            detect_baseline = detect_baseline, detect_delay = detect_delay)
    nna.plot_histogram_fp_speeds(fps,Fps,3,10)      # plot speeds
    nna.plot_eigenvalues(p,Fps)                       # plot spectrum

    # 3.   For each FP, identify unstable modes #
    Fps         = nna.find_unstable_modes(Fps)      # identify unstable modes + update Fps object
    fps, Fps    = nna.pca_project_fp(fps,Fps,pca)   # update Fps object w/ PCA projected 


# Define Fixed Point criteria #
tol_q              = 5  #5.95      # speed in -log10
tol_unique         = 1e-5
detect_baseline    = True   # (optional) detect FP nearest the baseline, i.e. time of item 1
detect_delay       = False   # (optional) detect FP nearest the baseline, i.e. time of item 1
Fps   = nna.FPF_detect( fps, model, p, F['X_input'], Data['pca'], tol_q=tol_q, tol_unique=tol_unique, \
                                detect_baseline = detect_baseline, detect_delay = detect_delay)
nna.plot_histogram_fp_speeds( fps, Fps, 3, 10)      # plot speeds
nna.plot_eigenvalues(p, Fps )                       # plot spectrum


In [None]:
#%%  PC 3D  ####### with fixed points plotted ##################################################

TT              = 0
qmin,qmax       = (3,9)
plot_fps        = 0
plot_FP         = 1
plot_FP_modes   = 1

groups_toplot   = [0,1,2]    # Trial type selection (0: Train, 1: Edges, 2: Probes)
pcp             = [0,1,2]    # PC axes to plot
COLOR_EMBED_CHOICE = 6       # 6: symbolic distance

if 0:  # f-RNN, Fig. 4a
    plot_FP         = 1
    plot_FP_modes   = 1
    AXSCALE         = [10]  # f-RNN, Fig. 4a        
    FP_to_plot_mode = [2]   
elif 0:  # r-RNN, Fig. S6a
    plot_FP         = 1
    plot_FP_modes   = 0
    AXSCALE         = [6]         
    FP_to_plot_mode = [2]   
elif 1:
    plot_FP         = 1
    plot_FP_modes   = 1
    AXSCALE         = [10]  # f-RNN, Fig. 4a        
    FP_to_plot_mode = [3]   

ax = ti.PCA_3D(p, Data, AXSCALE, groups_toplot, TT, pcp, COLOR_EMBED_CHOICE)

if plot_fps:        # plot fp candidates
    nna.plot_fps_3D(ax, qmin,qmax, fps, Data['pca'], pcp)
if plot_FP:         # plot fp detected
    nna.plot_FP_3D( ax, Fps , Data['pca'], pcp)
if plot_FP_modes:   # plot fp detected, their modes
    nna.plot_FP_modes_3D(ax,Fps,FP_to_plot_mode,Data['pca'],pcp)
        

In [None]:
#%%  Construct Linearized RNN  #################

# Define linearized system parameters #
if 1:         # Fast Oscillation + 2D filtered  (Fig. 6d, top row)
    FP_indices              = [0]       # Index of FP to plot     
    FP_linoption            = [3]       # # 0: blank, 1: unstable, 2: osci only, 3: fastest osci, 4: largest osci, largest real
    FP_osc                  = 0         # for 2D linearbasis
    Filter_mode             = 1         # 0: unfiltered, 1: osci-filtered
elif 0:         # Fast Oscillation + ND unfiltered (Fig. 6d, bottom row)
    FP_indices              = [0]                
    FP_linoption            = [3]       
    FP_osc                  = 0       
    Filter_mode             = 0          
elif 0:           # Switching LTVDS
    FP_indices              = [0,3,1,2]    
    FP_linoption            = [1,0,0,0]    
else:           # All fixed points
    FP_indices              = range(Fps.num)                
    FP_linoption            = np.zeros(Fps.num)

# Get Linearized system parameters #
Lsp             = nna.Define_linearized_system( p, Fps, FP_indices, FP_linoption, plot_eigspectra=False )
# Make Oscillatory basis #
xstar, v1, v2   = nna.Get_Oscillatory_Mode( Lsp, FP_osc )   # identify FP and real osci axes
basis_lz        = nna.Linearbasis_2D( xstar, v1, v2 )
# Make Linearized model
model_lin       = nm.Model11( p, Lsp, basis_lz, Data['pca'], model , Filter_mode )   

# Run Linearized Model #
TT_run           =  [0]       
Data_lz          =  tk.Simulate_model( p, F, model_lin, TT_run=TT_run )
    # project existing data into oscillatory (linearized) basis too
nna.Project_Simulated_Data( Data,    basis_lz, ('data_lz','DATA_lz','basis_lz') )
nna.Project_Simulated_Data( Data_lz, basis_lz, ('data_lz','DATA_lz','basis_lz') )
LD =  ti.Linear_Dynamics_Delay_Period(p, F, Data, TT = 0, topdims = 10 )
nna.Project_Simulated_Data( Data_lz, LD['oscbasis'], ('data_emp','DATA_emp','basis_emp') )


In [None]:
# %%  Plot eigenspectra from Linear Dynamics vs. Linearized FP ###############
A = []              # dynamics matrices
ti = reload(ti)
for System in [0,1]:  # 0: non-linear, 1: linearized
    if System == 0:
        Data_sys = Data
    else:
        Data_sys = Data_lz
    ld =  ti.Linear_Dynamics_Delay_Period(p, F, Data_sys, TT=0, topdims=10 )
    A.append(  ld['A']  )
ti.Plot_Eigenvalues_Comparison( p, Fps.Jac[FP_osc,:,:].squeeze() , A )   # plot 
