In [None]:
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output, display
import time
import tensorflow as tf

from cnn2snn import set_akida_version, AkidaVersion
import akida
from akida_models.tenn_spatiotemporal.eye_preprocessing import preprocess_data
from akida_models.tenn_spatiotemporal.eye_losses import process_detector_prediction 

In [None]:
def load_and_segment_npy(file_path, time_window_us=10_000, segment_duration_us=500_000):
    """
    Loads event data from a structured .npy file, splits it into 500ms segments,
    and converts each segment into model-compatible frames using preprocess_data.
    Args:
        file_path (str): Path to the .npy event file (with fields: 'p', 'x', 'y', 't').
        preprocess_data_fn (callable): preprocess_data function to apply per segment.
        time_window_us (int): Time window per frame (default: 10,000µs).
        segment_duration_us (int): Duration of each segment in microseconds (default: 500,000µs).
    Returns:
        list of tf.Tensor: List of processed frame tensors.
    """
    # Load structured event array
    data = np.load(file_path)
    # print(f"Loaded {data.shape[0]} events")

    # Convert structured fields to float32 arrays
    p = data['p'].astype('float32')
    x = data['x'].astype('float32')
    y = data['y'].astype('float32')
    t = data['t'].astype('float32')

    # Prepare stacked event tensor (4, N)
    trial = tf.stack([p, x, y, t], axis=0)

    # Time range
    t_start = t[0]
    t_end = t[-1]

    frames_list = []
    segment_list = []

    current_time = t_start
    while current_time + segment_duration_us <= t_end:
        # Get indices for the current 500ms window
        start_idx = np.searchsorted(t, current_time, side='left')
        end_idx = np.searchsorted(t, current_time + segment_duration_us, side='right')

        # Slice event segment
        segment = tf.stack([
            p[start_idx:end_idx],
            x[start_idx:end_idx],
            y[start_idx:end_idx],
            t[start_idx:end_idx]
        ], axis=0)

        # Dummy label (e.g. center)
        label = tf.convert_to_tensor([[0.5, 0.5, 0]], dtype=tf.float32)

        # Preprocess segment into frames
        frames, _ = preprocess_data(
            events=segment,
            label=label,
            train_mode=False,
            frames_per_segment=1,
            spatial_downsample=(6, 6),
            time_window=time_window_us
        )

        frames_list.append(frames)
        segment_list.append(segment)
        current_time += segment_duration_us

    print(f"Processed {len(frames_list)} segments of 500ms each.")
    return frames_list, segment_list

In [None]:
frames_all, segment_all = load_and_segment_npy("eye_tracking_event_examples.npy")

In [None]:
n_frames = len(frames_all)
N, H, W, n_ch = frames_all[0].shape
print(f"Loaded data with {n_frames} frames, {H}×{W} pixels, {n_ch} channels")

In [None]:
import akida_models
from akida_models.model_io import load_model
model = load_model("models/tenn_spatiotemporal_eye_buffer_i8_w8_a8.fbz")
print(f"Model input shape: {model.input_shape}")

In [None]:
with set_akida_version(AkidaVersion.v2):
    devices = akida.devices()
    if len(devices) > 0:
        print(f'Available devices: {[dev.desc for dev in devices]}')
        device = devices[0]
        print(device.version)
        try:
            akida_model.map(device)
            print(f"Mapping to Akida device {device.desc}.")
            mappedDevice = device.version
        except Exception as e:
            print("Model not compatible with FPGA. Running on CPU.")
            mappedDevice = "CPU"
    else:
        print("No Akida devices found, running on CPU.")
        mappedDevice = "CPU"

In [None]:
if n_ch == 2:
    # e.g. channel 0 = red, channel 1 = blue
    colors = np.array([[255, 0, 0], [0, 0, 255]], dtype=np.uint8)
else:
    # fallback: pick from matplotlib’s tab10 palette
    import matplotlib
    cmap = matplotlib.cm.get_cmap('tab10', n_ch)
    colors = (cmap(range(n_ch))[:, :3] * 255).astype(np.uint8)

# 3) Create one figure & axis to reuse
fig, ax = plt.subplots(figsize=(6, 6))

frame_number = 0

# Define the size of the cross
cross_size = 3

# 4) Loop over frames, updating in place
for f in frames_all:
    frame_number += 1
    # 4a) start from a mid-gray background
    frame_vis = np.full((H, W, 3), 128, dtype=np.uint8)

    f_np = f.numpy() if isinstance(f, tf.Tensor) else f  # Ensure f is a numpy array
    frame = f_np[0]

    # predict using the model
    pred = model.predict(f_np)

    pred = process_detector_prediction(tf.expand_dims(pred, 0))

    y_pred_x = pred[:, 1] * W
    y_pred_y = pred[:, 0] * H


    # Convert to NumPy scalars
    cx = int(y_pred_x.numpy().flatten()[0])
    cy = int(y_pred_y.numpy().flatten()[0])

    # 4b) paint each channel’s “events” on top
    for ch in range(n_ch):
        
        mask = frame[ :, :, ch] > 0   # assuming >0 marks an event

        pred_mask = np.zeros((frame.shape[0], frame.shape[1]), dtype=bool)
        # Draw a cross centered at (x, y)
        for i in range(-cross_size, cross_size + 1):
            if 0 <= cx + i < frame.shape[0]:
                pred_mask[cx + i, cy] = True
            if 0 <= cy + i < frame.shape[1]:
                pred_mask[cx, cy + i] = True

        frame_vis[mask] = colors[ch]
        frame_vis[pred_mask] = [255, 255, 0]
    
    # 4c) update the image
    ax.clear()
    ax.imshow(frame_vis)
    ax.set_title(f'Frame {frame_number}/{n_frames}')
    ax.axis('off')
    
    # 4d) redraw the same window
    clear_output(wait=True)
    display(fig)
    time.sleep(0.01)   # adjust playback speed

# 5) close when done
plt.close(fig)