# Notebook for visualizing streaming data

In [1]:
import fastplotlib as fpl 
import numpy as np
import zmq
import scipy

Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.
Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.
Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.


Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01,\x00\x00\x007\x08\x06\x00\x00\x00\xb6\x1bw\x99\x…

Valid,Device,Type,Backend,Driver
✅,Intel(R) Arc(tm) Graphics (MTL),IntegratedGPU,Vulkan,Mesa 25.0.4
✅ (default),NVIDIA GeForce RTX 4060 Laptop GPU,DiscreteGPU,Vulkan,565.77
❗ limited,"llvmpipe (LLVM 19.1.7, 256 bits)",CPU,Vulkan,Mesa 25.0.4 (LLVM 19.1.7)
❌,Mesa Intel(R) Arc(tm) Graphics (MTL),IntegratedGPU,OpenGL,4.6 (Core Profile) Mesa 25.0.4


Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.
Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.


# Setup zmq subscriber

In [2]:
context = zmq.Context()
sub = context.socket(zmq.SUB)
sub.setsockopt(zmq.SUBSCRIBE, b"")

# keep only the most recent message
sub.setsockopt(zmq.CONFLATE, 1)

# address must match publisher in actor
sub.connect("tcp://127.0.0.1:5557")

<SocketContext(connect='tcp://127.0.0.1:5557')>

In [3]:
def get_buffer():
    """Gets the buffer from the publisher."""
    try:
        b = sub.recv(zmq.NOBLOCK)
    except zmq.Again:
        pass
    else:
        return b
    
    return None

# Helper functions

In [4]:
def get_spike_events(data: np.ndarray, n_deviations: int = 4):
    """
    Calculates the median and MAD estimator. Returns a list of indices along each channel where
    threshold crossing is made (above absolute value of median + (n_deviations * MAD).
    """
    median = np.median(data, axis=1)
    mad = scipy.stats.median_abs_deviation(data, axis=1)

    thresh = (n_deviations * mad) + median

    indices = [np.where(np.abs(data)[i] > thresh[i])[0] for i in range(data.shape[0])]

    return indices

In [5]:
def make_raster(ixs):
    """
    Takes a list of threshold crossings and returns a list of points (channel number, spike time) and colors.
    Used to make a raster plot.
    """
    spikes = list()

    for i, ix in enumerate(ixs):
        ys = np.full(ix.shape, i * 2)
        sp = np.vstack([ix, ys]).T
        spikes.append(sp)

    colors = list()

    for i in spikes:
        # randomly select a color
        c = [np.append(np.random.rand(3), 1)] * len(i)
        colors += c

    return spikes, np.array(colors)

# Create figure

In [6]:
rects = [
    (0, 0, 0.5, 0.7),  # for image1
    (0.5, 0, 0.5, 0.7),  # for image2
    (0, 0.7, 1, .3),  # for image1 histogram
]

figure = fpl.Figure(rects=rects, size=(1000, 900), names=["filtered spikes", "raster", "smoothed spikes"])

for subplot in figure:
    subplot.axes.visible = False
    subplot.camera.maintain_aspect = False

RFBOutputContext()

Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.


In [7]:
def update_figure(p):
    """Update the frame using data received from the socket."""
    buff = get_buffer()
    if buff is not None:
        # Deserialize the buffer into a NumPy array
        data = np.frombuffer(buff, dtype=np.float64)

        data = data.reshape(384, 150)

        # clear subplots
        # for subplot in p:
        #     p.clear()

        ixs = get_spike_events(data)

        spikes, colors = make_raster(ixs)
        spikes = np.concatenate(spikes)


        if len(p["filtered spikes"].graphics) == 0:
            lg = figure["filtered spikes"].add_line_stack(data, colors="gray", thickness=0.5, separation=35, name="lg")
        else:
            lg = figure["filtered spikes"]["lg"]
            lg.colors = "gray"
            # add filtered spikes
            for i in range(lg.data[:].shape[0]):
                lg[i].data[:, 1] = data[i]

        # color each spike event orange
        for i in range(len(ixs)):
            if ixs[i].shape[0] == 0:
                continue
            lg[i].colors[ixs[i]] = "orange"

        
        # add smoothed spikes
        for i in range(len(ixs)):
            y = np.zeros((data.shape[1],))
            x = range(len(y))
            if ixs[i].shape[0] == 0:
                y = np.zeros((data.shape[1],))
            else:
                spike_times = ixs[i]
                y[spike_times] = 1
                y = scipy.ndimage.gaussian_filter1d(y, 5)
                
            if len(p["smoothed spikes"].graphics) < data.shape[0]:
                figure["smoothed spikes"].add_line(np.vstack([x, y]).T, colors=np.append(np.random.rand(3), 1), thickness=1)
            else:
                figure["smoothed spikes"].graphics[i].data[:, 1] = y

        p["raster"].clear()
    
        # make raster 
        p["raster"].add_scatter(spikes, sizes=5, colors=colors)

        for subplot in p:
            subplot.auto_scale()

In [8]:
figure.show()

In [9]:
# Add the animation update function
figure.add_animations(update_figure)