# Notebook for visualizing streaming data and calculating seeded vs non-seeded median

In [1]:
import fastplotlib as fpl 
import numpy as np
import zmq
import tifffile
import scipy

Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.
Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.
Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.


Image(value=b'version https://git-lfs.github.com/spec/...', height='55', width='300')

Valid,Device,Type,Backend,Driver
✅,Intel(R) Arc(tm) Graphics (MTL),IntegratedGPU,Vulkan,Mesa 25.0.4
✅ (default),NVIDIA GeForce RTX 4060 Laptop GPU,DiscreteGPU,Vulkan,565.77
❗ limited,"llvmpipe (LLVM 19.1.7, 256 bits)",CPU,Vulkan,Mesa 25.0.4 (LLVM 19.1.7)
❌,Mesa Intel(R) Arc(tm) Graphics (MTL),IntegratedGPU,OpenGL,4.6 (Core Profile) Mesa 25.0.4


Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.
Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.


# Calculate seeded median

In [2]:
file_path = "/home/clewis/repos/holo-nbs/rb26_20240111/raw_voltage_chunk.tif"
data = tifffile.memmap(file_path)
data.shape

(384, 3000001)

In [3]:
# define filter functions
def butter(cutoff, fs, order=5, btype='high'):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = scipy.signal.butter(order, normal_cutoff, btype=btype, analog=False)
    return b, a


def butter_filter(data, cutoff, fs, order=5, axis=-1, btype='high'):
    b, a = butter(cutoff, fs, order=order, btype=btype)
    y = scipy.signal.filtfilt(b, a, data, axis=axis)
    return y

In [4]:
median = np.median(butter_filter(data[:, :4000], 1_000, 30_000), axis=1)

# Setup zmq subscriber

In [5]:
context = zmq.Context()
sub = context.socket(zmq.SUB)
sub.setsockopt(zmq.SUBSCRIBE, b"")

# keep only the most recent message
sub.setsockopt(zmq.CONFLATE, 1)

# address must match publisher in actor
sub.connect("tcp://127.0.0.1:5557")

<SocketContext(connect='tcp://127.0.0.1:5557')>

In [6]:
def get_buffer():
    """Gets the buffer from the publisher."""
    try:
        b = sub.recv(zmq.NOBLOCK)
    except zmq.Again:
        pass
    else:
        return b
    
    return None

# Helper functions

In [7]:
def get_spike_events(data: np.ndarray, n_deviations: int = 4, use_global: bool = True):
    """
    Calculates the median and MAD estimator. Returns a list of indices along each channel where
    threshold crossing is made (above absolute value of median + (n_deviations * MAD).
    """
    global median

    if not use_global:
        median = np.median(data, axis=1)
        
    mad = scipy.stats.median_abs_deviation(data, axis=1)

    thresh = (n_deviations * mad) + median

    abs_data = np.abs(data)

    # Find indices where threshold is crossed for each channel
    indices = [np.where(abs_data[i] > thresh[i])[0] for i in range(data.shape[0])]

    return indices

# Create figure

In [8]:
rects = [
    (0, 0, 0.5, 1),  # for image1
    (0.5, 0, 0.5, 1),  # for image2
]

figure = fpl.Figure(rects=rects, size=(1000, 900), names=["seeded median", "non-seeded median"])

for subplot in figure:
    subplot.axes.visible = False
    subplot.camera.maintain_aspect = False

RFBOutputContext()

Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.


In [9]:
def update_figure(p):
    """Update the frame using data received from the socket."""
    buff = get_buffer()
    if buff is not None:
        # Deserialize the buffer into a NumPy array
        data = np.frombuffer(buff, dtype=np.float64)

        data = data.reshape(384, 150)

        ixs = get_spike_events(data)
        ixs2 = get_spike_events(data, use_global=False)


        if len(p["seeded median"].graphics) == 0:
            lg = figure["seeded median"].add_line_stack(data, colors="gray", thickness=0.5, separation=35, name="lg")
            lg2 = figure["non-seeded median"].add_line_stack(data, colors="gray", thickness=0.5, separation=35, name="lg")
        else:
            lg = figure["seeded median"]["lg"]
            lg.colors = "gray"
            # add filtered spikes
            for i in range(lg.data[:].shape[0]):
                lg[i].data[:, 1] = data[i]

            lg2 = figure["non-seeded median"]["lg"]
            lg2.colors = "gray"
            # add filtered spikes
            for i in range(lg2.data[:].shape[0]):
                lg2[i].data[:, 1] = data[i]

        #color each spike event orange
        for i in range(len(ixs2)):
            if ixs2[i].shape[0] == 0:
                continue
            lg2[i].colors[ixs2[i]] = "orange"

                #color each spike event orange
        for i in range(len(ixs)):
            if ixs[i].shape[0] == 0:
                continue
            lg[i].colors[ixs[i]] = "orange"


        for subplot in p:
            subplot.auto_scale()

In [10]:
figure.show()

In [11]:
# Add the animation update function
figure.add_animations(update_figure)

In [12]:
arr = figure.export_numpy()

In [13]:
from PIL import Image
im = Image.fromarray(arr,  mode='RGBA')
im.save("/home/clewis/Desktop/median.png")