In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import plotly.express as px
from pandas import DataFrame

from loguru import logger
logger.remove()

from jupyter_dash import JupyterDash as Dash
from dash import html, dcc, Input, Output

from pim.simulator import SimulationExperiment
from pim.cx import fit_cpu4, to_cartesian

In [44]:
parameters = {
    "type": "simulation",
    "seed": 100,
    "T_outbound": 1500,
    "T_inbound": 3000,
    "min_homing_distance": 300,
    "motor_factor": -0.5,
    "record": ["memory", "TB1", "Pontine", "motor", "theory", "CPU4", "CPU4.old"],
    "cx": {
        "type": "weights",
        "output_layer": "motor",
        "params": {
            "noise": 0.1,
            "motor_noise": 0.1,
            "mem_initial": 0.5,
            "mem_fade": 0.15,
            "mem_gain": 0.0023,
            "cpu1_slope": 1000,#00,
            "cpu1_bias": 2,#5,
        }
    }
}

total_time = parameters["T_outbound"] + parameters["T_inbound"] - 1

experiment = SimulationExperiment(parameters)
results = experiment.run("test")

path = np.array(results.reconstruct_path())[1:,:]

#path = DataFrame(results.reconstruct_path())
#path.rename(columns={0: "x", 1: "y"}, inplace=True)

In [4]:
app = Dash(__name__)

app.layout = html.Center(
    html.Div([
        dcc.Slider(0, total_time, value=1529, step=1, marks=None, id="time"),
        html.Div(id="status"),
        dcc.Graph(id="path"),
        dcc.Graph(id="activity"),
        dcc.Graph(id="pontines"),
        dcc.Graph(id="steering"),
    ])
)

@app.callback(
    Output("status", "children"),
    Output("path", "figure"),
    Output("activity", "figure"),
    Output("pontines", "figure"),
    Input("time", "value"),
)
def set_time(time):
    path_figure = {
        "data": [
            {
                "x": path[:parameters["T_outbound"],0],
                "y": path[:parameters["T_outbound"],1],
                "type": "line",
                "name": "outbound",
            },
            {
                "x": path[parameters["T_outbound"]:,0],
                "y": path[parameters["T_outbound"]:,1],
                "type": "line",
                "name": "inbound",
            },
            {
                "x": [path[time,0]],
                "y": [path[time,1]],
                "type": "scatter",
                "name": "time",
            }
        ],
        "layout": {
            "width": 600,
            "height": 600,
            "autoscale": False,
            "yaxis": {
                "scaleanchor": "x",
                "scaleratio": 1,
            }
        }
    }
    
    activity_figure = {
        "data": [
            {
                "y": np.tile(results.recordings["TB1"]["output"][time], 2),
                "type": "bar",
                "name": "TB1",
            },
            {
                "y": results.recordings["CPU4"]["output"][time],
                "type": "bar",
                "name": "CPU4",
            },
            {
                "y": results.recordings["memory"]["internal"][time],
                "type": "bar",
                "name": "home",
            },
            {
                "y": results.recordings["memory"]["output"][time],
                "type": "bar",
                "name": "shift towards home",
            }
        ],
        "layout": {
            "title": "Activities"
        }
    }
    
    left = np.roll(results.recordings["memory"]["output"][time][:8], -1)
    right = np.roll(results.recordings["memory"]["output"][time][8:], 1)
    
    left_pontine = np.roll(results.recordings["memory"]["output"][time][:8], 3)
    right_pontine = np.roll(results.recordings["memory"]["output"][time][8:], -3)
    
    
    pontine_figure = {
        "data": [
            {
                "y": left - left_pontine,
                "type": "bar",
                "name": "left",
            },
            {
                "y": right - right_pontine,
                "type": "bar",
                "name": "right",
            }
        ],
        "layout": {
            "title": "Memory Balance"
        }
    }
    
    return (
        f"t = {time}",
        path_figure,
        activity_figure,
        pontine_figure,
    )

app.run_server(mode="jupyterlab")