In [None]:
# %% 
# Imports

import os
import redis
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import subprocess
import time
import seaborn as sns

from brand.timing import timespec_to_timestamp, timeval_to_timestamp

In [None]:
# %% 
# Start Redis 

SAVE_DIR = '/home/mrigott/Projects/emory-cart/Data/sim/2023-12-08/RawData/'
RDB_DIR = os.path.join(SAVE_DIR,'RDB')
RDB_FILENAME = 'sim_231208T1218_test_gemini.rdb'
REDIS_IP = '127.0.0.1'
REDIS_PORT = 18000

redis_command = ['/home/snel/Projects/emory-cart/brand/bin/redis-server', '--bind', REDIS_IP, '--port', str(REDIS_PORT)]
redis_command.append('--dbfilename')
redis_command.append(RDB_FILENAME)
redis_command.append('--dir')
redis_command.append(RDB_DIR)

print('Starting redis: ' + ' '.join(redis_command))

proc = subprocess.Popen(redis_command, stdout=subprocess.PIPE)
redis_pid = proc.pid

try:
    out, _ = proc.communicate(timeout=1)
    if out:
        print(out.decode())
    if 'Address already in use' in str(out):
        print("Could not run redis-server (address already in use). Check if a Redis server is already running on that port. Aborting.")
        exit(1)
    else:
        print("Launching redis-server failed for an unknown reason, check supervisor logs. Aborting.")
        exit(1)
except subprocess.TimeoutExpired:  # no error message received
    print('Redis-server is running.')

r = redis.Redis(host=REDIS_IP, port=REDIS_PORT)

busy_loading = True
while busy_loading:
    try:
        print(f"Streams in database: {r.keys('*')}")
        busy_loading = False
    except redis.exceptions.BusyLoadingError:
        print('Redis is busy loading dataset in memory')
        time.sleep(1)

In [None]:
streams = ['nsp_neural', 'thresh_cross_1', 'sbp_1', 'binned_spikes', 'control']
msg = 'Stream Info'
print(msg + '\n' + '-' * len(msg))
for stream in streams:
    n_entries = r.xlen(stream)
    if n_entries > 0:
        entry_dict = r.xrevrange(stream, count=1)[0][1]
        has_sync = True if b'sync' in entry_dict else False
        if has_sync:
            entry_dict
    else:
        has_sync = False

    row = f'{stream :24s}: {n_entries :6d}'
    if has_sync:
        row += f"\tsync={json.loads(entry_dict[b'sync'])}"
    else:
        row += '\tsync=None'
    print(row)

In [None]:

START = 0.0
END = 0.2

crange = np.arange(90,100,1)
n_channels = crange.shape[0]

fig, axes = plt.subplots(ncols=3,
                         nrows=n_channels,
                         figsize=(30, n_channels * 2),
                         sharey='col',
                         sharex='col',
                         facecolor='w',
                         tight_layout=True)

for isig, sig in enumerate(['nsp_neural', 'thresh_cross_1', 'binned_spikes']):

    if isig == 0:
        N_SAMP = int((END-START)*1000)
        N_CHANNELS = 256
        samp_per_entry = 30
        dtype = 'int16'
        dfield = b'samples'
    elif isig == 1:
        N_SAMP = int((END-START)*1000)
        N_CHANNELS = 256
        samp_per_entry = 1
        dtype = 'int16'   
        dfield = b'crossings'    
    elif isig == 2:
        N_SAMP = int((END-START)*100)
        N_CHANNELS = 512
        samp_per_entry = 1
        dtype = 'float32' 
        dfield = b'samples'    

    stream = sig
    entries = r.xrange(stream)

    data = np.zeros((N_CHANNELS, N_SAMP*samp_per_entry), dtype=dtype)
 
    indStart = 0
    for i in range(N_SAMP):
        _, entry_data = entries[i]
        indEnd = indStart + samp_per_entry
        samples = np.reshape(np.frombuffer(entry_data[dfield], dtype=dtype),
                            (N_CHANNELS,samp_per_entry))
        data[:, indStart:indEnd] = samples
        indStart = indEnd

    t = np.linspace(START, END, int(samp_per_entry*N_SAMP)) * int(samp_per_entry*N_SAMP)

    for ich, ch in enumerate(crange):
        ax = axes[ich,isig]
        ax.plot(t, data[ch])
        if isig == 0: ax.set_ylabel(f'Ch {ch}')
        if ich == 0: ax.set_title(f'{sig}')
        if ich == len(crange)-1: ax.set_xlabel('Sample #')

In [None]:
# utility function

def scalarfrombuffer(*args, **kwargs):
    return np.frombuffer(*args, **kwargs)[0]

def get_sync_val(sync_json, field='nsp_idx_1'):
    sync_dict = json.loads(sync_json)
    return sync_dict[field]

streams = ['nsp_neural', 'thresh_cross_1', 'binned_spikes', 'control']
fields = {}
fields['sync'] = ['timestamps', 'sync', 'sync', 'sync', 'sync']
fields['ts'] = ['BRANDS_time', 'ts', 'ts', 'ts', 'ts']

dtypes = {}
dtypes['sync'] = ['uint64', 'sync', 'sync', 'sync', 'sync']
dtypes['ts'] = ['timespec', 'uint64', 'uint64', 'uint64', 'uint64']

samp_per_entry = [30, 1, 1, 1, 1]

N_CHANNELS = 256

# build dataframe with data

df_dict = {}

for s, stream in enumerate(streams):
    entries = r.xrange(stream)
    data = [None] * len(entries)
    for i, (_, entry_data) in enumerate(entries):
        data[i] = {k: entry_data[fields[k][s].encode()] for k in fields.keys()}

    if samp_per_entry[s] > 1:
        # data2 = [None] * len(entries) * samp_per_entry[s]
        # for i in range(len(data)):
        #     for j in range(samp_per_entry[s]):
        #         data2[i*samp_per_entry[s] + j] = {}
        #         for k in data[i].keys():
        #             byte_len = int(len(data[i][k])/samp_per_entry[s])
        #             data2[i*samp_per_entry[s] + j][k] = data[i][k][byte_len*(j):byte_len*(j+1)]
        data2 = [None] * len(entries)
        for i in range(len(data)):
            data2[i] = {}
            for k in data[i].keys():
                byte_len = int(len(data[i][k])/samp_per_entry[s])
                data2[i][k] = data[i][k][0:byte_len]        
        df = pd.DataFrame(data2)
    else:
        df = pd.DataFrame(data)

    for f in fields.keys():
        dtype = dtypes[f][s]
        if dtype == 'sync':
            df[f] = df[f].apply(get_sync_val).astype(np.uint64)
        elif dtype == 'timespec':
            df[f] = (df[f].apply(timespec_to_timestamp) * 1e9).astype(np.uint64)
        else:
            df[f] = df[f].apply(scalarfrombuffer, dtype=dtype)

    df_dict[stream] = df.set_index('sync', drop=True)

# add suffixes to overlapping column names
suffixes = {
    'nsp_neural': '_ca', 
    'thresh_cross_1': '_tc', 
    'sbp_1': '_sbp', 
    'binned_spikes': '_bs', 
    'control': '_co'
}

def add_column_suffix(df, suffix=''):
    mapper = {col: col + suffix for col in df.columns}
    return df.rename(mapper, axis=1)


for key in df_dict:
    if key in suffixes:
        df_dict[key] = add_column_suffix(df_dict[key], suffixes[key])
    else:
        warn(f'No suffix defined for {key} stream')

df_dict

In [None]:
df = list(df_dict.values())[0]
for df_i in list(df_dict.values())[1:]:
    df = df.join(df_i)

df

In [None]:
## Check inter-sample intervals

ts_fields = list(df.columns)

fig, axs = plt.subplots(nrows=len(ts_fields),ncols=1, figsize=(12,16), sharex=True, tight_layout=True, facecolor='w')

for i, field in enumerate(ts_fields):
    ax = axs[i]
    isi = df[field].dropna().diff() * 1e-6
    ax.plot(isi,'.')
    ax.set_title(f'{field}\n{isi.mean():2.4f} +- {isi.std():2.4f}')
    ax.set_ylabel('ISI (ms)')
axs[-1].set_xlabel('sync id')

plt.show()
        

In [None]:
nsp_step_sizes = np.unique(np.diff(df.index))
assert len(nsp_step_sizes) == 1, (
    'Multiple step sizes found in the NSP stream data')
nsp_step_size = nsp_step_sizes[0]

df_bin = df.dropna()

bdf_binsizes = np.unique(np.diff(df_bin.index))
assert len(bdf_binsizes) == 1, (
    'Multiple bin sizes found in the binned stream data')
binsize = bdf_binsizes[0]

first_bin_index = df_bin.index[0]

df_td = df.loc[first_bin_index:, :]
# make a copy of the 1kHz df that has timedelta index
df_td.index = pd.to_timedelta(df_td.index, unit='ns')

resampler = df_td.resample(f'{binsize / nsp_step_size}ms')

end_df = resampler.last()
end_df.index = (end_df.index.total_seconds() * 1e9).astype(np.uint64)

latency_df = end_df.diff(axis=1).iloc[:, 1:].dropna() * 1e-6
cs_latency_df = latency_df.cumsum(axis=1)

cs_latency_df.max()

In [None]:
# make seaborn color palette match matplotlib
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
sns_palette = sns.color_palette(colors)

In [None]:
# ts_fields = list(cs_latency_df.columns)

ts_labels = {
    'ts_ca': 'NSP Packet Reception', 
    'ts_tc': 'Feature Ext.\n& Buffering', 
    'ts_bs': 'Binning', 
    'ts_co': 'Auto-cued\nControl'
}

# Make plots
fig, axes = plt.subplots(ncols=2, nrows=1, figsize=(12, 4), facecolor='w')
labels = [ts_labels[field] for field in ts_fields[1:]]
# plot per-node latency as a histogram
step = 10e-3
# for label, field in zip(labels, ts_fields[1:]):
for label, field in zip(labels, ts_fields[1:]):
    latency = latency_df[field].values
    bins = np.arange(latency.max() + step, step=step)
    axes[0].hist(latency, bins=bins, histtype='step', label=label)
axes[0].set_xlabel('Node Latency (ms)')
axes[0].set_ylabel('Samples')
axes[0].set_yscale('log')
axes[0].legend(fontsize=8,
               ncol=1,
               frameon=False,
               loc='best')

# plot cumulative latency as a horizontal violin plot
sns.violinplot(data=latency_df.cumsum(axis=1),
               scale='width',
               linewidth=0.2,
               orient='h',
               palette=sns_palette,
               ax=axes[1])
axes[1].set_xlabel('Cumulative Latency (ms)')
axes[1].set_yticks(ticks=np.arange(latency_df.shape[1]), labels=labels)

# make the x-axes match
ncols = len(axes)
xlims = [axes[ip].get_xlim()[1] for ip in range(ncols)]
for ip in range(ncols):
    axes[ip].set_xlim(0, max(xlims))

plt.tight_layout()
plt.show()

In [None]:
r.shutdown(nosave=True)

proc.kill()