# Open-Loop Calibration with Radial-8

## 1. Run the calibration graph

In [None]:
import json
import os
import pickle
import time
from datetime import datetime

import redis
import yaml

DURATION = None  # seconds
GRAPH = 'sim_graph_ol.yaml'
REDIS_IP = '192.168.30.6'
REDIS_PORT = 6379
test_dir = os.getcwd()

with open(GRAPH, 'r') as f:
    graph = yaml.safe_load(f)

r = redis.Redis(host=REDIS_IP, port=REDIS_PORT)

curs, start_streams = r.scan(0, _type='stream')
while curs != 0:
    curs, streams = r.scan(curs, _type='stream')
    start_streams += streams

# get the most recent ID from each stream
start_id = {}
for stream in start_streams:
    replies = r.xrevrange(stream, count=1)
    if replies:
        start_id[stream] = replies[0][0]

print(f'Starting graph from {GRAPH} as JSON')
r.xadd('supervisor_ipstream', {
    'commands': 'startGraph',
    'graph': json.dumps(graph)
})

if DURATION:
    print(f'Waiting {DURATION} seconds')
    time.sleep(DURATION)
else:
    input('Hit ENTER to stop graph...')

# Stop the graph
print('Stopping graph')
r.xadd('supervisor_ipstream', {'commands': 'stopGraph'})

curs, stop_streams = r.scan(0, _type='stream')
while curs != 0:
    curs, streams = r.scan(curs, _type='stream')
    stop_streams += streams

new_streams = [
    stream for stream in stop_streams if stream not in start_streams
]

for stream in new_streams:
    start_id[stream] = 0

# Save streams
all_data = {}
for stream in stop_streams:
    all_data[stream] = r.xrange(stream, min=start_id[stream])

date_str = datetime.now().strftime(r'%y%m%dT%H%M')
graph_name = os.path.splitext(os.path.basename(GRAPH))[0]
data_dir = os.path.join(test_dir, 'data')
os.makedirs(data_dir, exist_ok=True)
save_path = os.path.join(data_dir, f'{date_str}_{graph_name}.pkl')
with open(save_path, 'wb') as f:
    pickle.dump(all_data, f)
print(f'Saved streams: {sorted(list(all_data.keys()))}')

# Remove saved data from Redis
# delete any streams created while the graph was running
i = 0
if new_streams:
    while max([r.xlen(stream) for stream in new_streams]):
        for stream in new_streams:
            r.delete(stream)
        i += 1
r.memory_purge()
print(f'Deleted streams: {new_streams}')

## 2. Analyze the block

In [None]:
import json
import os
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import yaml
from brand.timing import timespecs_to_timestamps, timevals_to_timestamps
from scipy.signal import butter, sosfiltfilt
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm

In [None]:
# constants
test_dir = os.getcwd()
data_dir = os.path.join(test_dir, 'data')
fig_dir = os.path.join(test_dir, 'figures')
data_file = save_path

# setup
os.makedirs(fig_dir, exist_ok=True)

In [None]:
with open(os.path.join(data_dir, data_file), 'rb') as f:
    graph_data = pickle.load(f)

In [None]:
sorted(graph_data.keys())

In [None]:
# Load graph parameters
graphs = [
    json.loads(entry[b'graph']) for _, entry in graph_data[b'booter']
    if b'graph' in entry
]
graph = graphs[-1]

In [None]:
# Load info about the structure of each stream
with open('stream_spec.yaml', 'r') as f:
    stream_spec = yaml.safe_load(f)

In [None]:
# Load and parse stream data
streams = [
    b'targetData', b'cursorData', b'mouse_vel', b'binned_spikes',
    b'firing_rates', b'nsp_neural_1', b'nsp_neural_2', b'control'
]
decoded_streams = {}
for stream in streams:
    print(f'Processing {stream.decode()} stream')
    stream_data = graph_data[stream]
    out = [None] * len(stream_data)
    spec = stream_spec[stream.decode()]
    for i, (entry_id, entry_data) in tqdm(enumerate(stream_data)):
        entry_dec = {}
        for key, val in entry_data.items():
            if key.decode() in spec:
                dtype = spec[key.decode()]
                if dtype == 'str':
                    entry_dec[key.decode()] = val.decode()
                elif dtype == 'sync':
                    entry_dec[key.decode()] = json.loads(val)['nsp_idx']
                elif dtype == 'timeval':
                    entry_dec[key.decode()] = timevals_to_timestamps(val)
                elif dtype == 'timespec':
                    entry_dec[key.decode()] = timespecs_to_timestamps(val)
                else:
                    dat = np.frombuffer(val, dtype=dtype)
                    entry_dec[key.decode()] = dat[0] if dat.size == 1 else dat
        out[i] = entry_dec
    decoded_streams[stream.decode()] = out

In [None]:
# Load data at the binned spikes sample rate
# FSM
cd_df = pd.DataFrame(decoded_streams['cursorData'])
cd_df.set_index('sync', drop=False, inplace=True)
cd_df.columns = [col + '_cd' for col in cd_df.columns]

td_df = pd.DataFrame(decoded_streams['targetData'])
td_df.set_index('sync', drop=False, inplace=True)
td_df['angle'] = np.degrees(np.arctan2(td_df['Y'], td_df['X']))
td_df.columns = [col + '_td' for col in td_df.columns]

# binning
bs_df = pd.DataFrame(decoded_streams['binned_spikes'])
bs_df.set_index('sync', drop=False, inplace=True)
bs_df.columns = [col + '_bs' for col in bs_df.columns]

# autocue
ac_df = pd.DataFrame(decoded_streams['control'])
ac_df.set_index('sync', drop=False, inplace=True)
ac_df.columns = [col + '_ac' for col in ac_df.columns]

# join the dataframes
bin_df = cd_df.join(td_df).join(bs_df).join(ac_df)

In [None]:
# Check the NSP data
tslice = slice(0, 4000)  # range of samples to inspect (units: sample index)
for nsp_id in [1, 2]:
    n_channels = graph['nodes'][f'nsp_in_{nsp_id}']['parameters'][
        'chan_per_stream'][0]
    samp_per_stream = (
        graph['nodes'][f'nsp_in_{nsp_id}']['parameters']['samp_per_stream'][0])
    nsp_df = pd.DataFrame(decoded_streams[f'nsp_neural_{nsp_id}'])
    nsp_data_1 = np.hstack(nsp_df['samples'].apply(
        np.reshape, newshape=(n_channels, samp_per_stream))).T

    # Truncate data to the requested time slice
    nsp_data_1 = nsp_data_1[tslice, :]

    # Load filter parameters
    tc_params = graph['nodes'][f'thresh_cross_{nsp_id}']['parameters']

    but_low = tc_params['butter_lowercut']
    but_high = tc_params['butter_uppercut']
    but_order = tc_params['butter_order']

    if but_low and but_high:
        filt_type = 'bandpass'
        Wn = [but_low, but_high]
    elif but_high:
        filt_type = 'lowpass'
        Wn = but_high
    elif but_low:
        filt_type = 'highpass'
        Wn = but_low
    else:
        raise ValueError("Must specify 'butter_lowercut' or 'butter_uppercut'")

    fs = tc_params['input_stream']['samp_freq']
    sos = butter(but_order, Wn, btype=filt_type, output='sos', fs=fs)

    if tc_params['enable_CAR']:
        nsp_raw = nsp_data_1 - nsp_data_1.mean(0, keepdims=True)
    else:
        nsp_raw = nsp_data_1

    nsp_data_1_filt = sosfiltfilt(sos, nsp_raw, axis=0)

    # Load spike thresholds
    threshold_stream = f'thresh_cross_{nsp_id}_thresholds'
    thresholds = np.frombuffer(
        graph_data[threshold_stream.encode()][-1][1][b'thresholds'],
        dtype=stream_spec[threshold_stream]['thresholds'])

    # Check filtering
    plt_channels = 9  # number of channels to plot
    tslice = slice(0, 4000)
    ncols = np.ceil(np.sqrt(plt_channels)).astype(int)
    nrows = np.ceil(plt_channels / ncols).astype(int)
    fig, axes = plt.subplots(ncols=ncols,
                             nrows=nrows,
                             figsize=(ncols * 3, nrows * 2),
                             sharey=False)
    for iax in range(plt_channels):
        ax = axes.flat[iax]
        nsp_raw = nsp_data_1[tslice, iax]
        ax.plot(nsp_raw, alpha=0.5, label='original')
        ax.plot(nsp_data_1_filt[tslice, iax], alpha=0.5, label='filtered')
        ax.plot(thresholds[iax] * np.ones(nsp_raw.shape[0]), label='threshold')
        ax.set_title(f'Ch. {iax}')

    for iax in range(plt_channels, len(axes.flat)):
        axes.flat[iax].set_axis_off()

    axes.flat[0].legend()
    plt.tight_layout()
    plt.subplots_adjust(top=0.95)
    plt.suptitle(f'NSP {nsp_id}')
    plt.show()

## 3. Train a decoder

In [None]:
# Train a decoder
SEQ_LEN = 10  # sequence length for the Wiener filter


def get_lagged_features(data, n_history: int = 4):
    """
    Lag the data along the time axis. Stack the lagged versions of the data
    along the feature axis.

    Parameters
    ----------
    data : array of shape (n_samples, n_features)
        Data to be lagged
    n_history : int, optional
        Number of bins of history to include in the lagged data, by default 4

    Returns
    -------
    lagged_features : array of shape (n_samples, n_history * n_features)
        Lagged version of the original data
    """
    assert n_history >= 0, 'n_history must be greater than or equal to 0'
    seq_len = n_history + 1
    lags = [None] * seq_len
    for i in range(seq_len):
        lags[i] = np.zeros_like(data)
        lags[i][i:, :] = data[:-i, :] if i > 0 else data
    lagged_features = np.hstack(lags)
    return lagged_features


neural_stream = 'binned_spikes'
kin_stream = 'control'

neural_data = np.vstack(bin_df['samples_bs'])
neural_data = get_lagged_features(neural_data, n_history=SEQ_LEN - 1)
kin_data = np.vstack(bin_df['samples_ac'])[:, :2]

In [None]:
X_train, X_test, y_train, y_test = train_test_split(neural_data,
                                                    kin_data,
                                                    test_size=0.25,
                                                    shuffle=False)
# Fit the Ridge regression model
# Use k-fold cross-validation to select the weight of the L2 penalty
mdl = RidgeCV(alphas=(0.1, 1.0, 10.0), cv=3)
mdl.fit(X_train, y_train)
y_test_pred = mdl.predict(X_test)

In [None]:
mdl.score(X_test, y_test_pred)

In [None]:
# Save the trained model
file_desc = data_file.split('_')[0]

model_dir = 'models'
os.makedirs(model_dir, exist_ok=True)
model_path = os.path.join(model_dir, f'{file_desc}_wf_seq_len_{SEQ_LEN}.pkl')

with open(model_path, 'wb') as f:
    pickle.dump(mdl, f)

In [None]:
# Update the config of the closed-loop graph to load the saved model
cl_graph_path = 'sim_graph_cl.yaml'
with open(cl_graph_path, 'rb') as f:
    cl_graph = yaml.safe_load(f)

node_names = [node['name'] for node in cl_graph['nodes']]
wf_idx = node_names.index('wiener_filter')

cl_graph['nodes'][wf_idx]['parameters']['model_path'] = os.path.abspath(
    model_path)
cl_graph['nodes'][wf_idx]['parameters']['seq_len'] = SEQ_LEN

# Save the edited config
cl_graph_gen_path = list(os.path.splitext(cl_graph_path))
cl_graph_gen_path.insert(-1, '_gen')
cl_graph_gen_path = ''.join(cl_graph_gen_path)

with open(cl_graph_gen_path, 'w') as f:
    yaml.dump(cl_graph, f)