# Script to replay trials with neural activity and save as an mp4 with audio.

This notebook will generate an mp4 in `./videos_with_neurons/`. This video will
render a sequence of trials as the monkey sees it, except the monkey's gaze
position is rendered post-hoc as a green cross. It will also display the trial
phase on top and the spike times of three neurons on the bottom. Audio will
contain claps for spikes from these neurons, with each neuron having a different
pitch.

By default, this notebook runs on a session included in our OSF cache. If you
would like to run this notebook on trials from other sessions, you must first
download the behavior and spike sorting dandi files for that session by
navigating to the root directory (`../..`) and running
```
$ python3 download_dandi_data.py --subject=Perle --session=$your_session
```

In [1]:
"""Imports."""

import ast
from moviepy import CompositeVideoClip, AudioFileClip, CompositeAudioClip, ImageSequenceClip
import numpy as np
from pathlib import Path
from pynwb import NWBHDF5IO
from PIL import Image, ImageDraw, ImageFont
import task_state as task_state_lib
import renderer as renderer_lib

In [2]:
"""Units."""

BEHAVIOR_DATA_DIR = Path("../cache/dandi_data/behavior")
NEURAL_DATA_DIR = Path("../cache/dandi_data/spikesorting")
WRITE_DIR = Path("./videos_with_neurons")
WRITE_DIR.mkdir(exist_ok=True, parents=True)
SUBJECT = "Perle"
SESSION = "2022-05-31"
UNIT_IDS = [93, 131, 115]
TRIALS = [94, 187, 1287, 963, 187, 465, 1132, 1382, 698]

In [3]:
"""Load data."""

CLICK_SOUNDS = [
    AudioFileClip("./resources/audio/click_0.mp3").subclipped(0.06, 0.1),
    AudioFileClip("./resources/audio/click_1.mp3").subclipped(0.06, 0.1),
    AudioFileClip("./resources/audio/click_2.mp3").subclipped(0.06, 0.1),
]
FIRST_SPIKE_BUFFER = -2
LAST_SPIKE_BUFFER = 2

# Load behavior data
behavior_nwb_file_path = BEHAVIOR_DATA_DIR / f"sub-{SUBJECT}/sub-{SUBJECT}_ses-{SESSION}_behavior+task.nwb"
neural_nwb_file_path = NEURAL_DATA_DIR / f"sub-{SUBJECT}/sub-{SUBJECT}_ses-{SESSION}_spikesorting.nwb"

# Load neural data
spikesorting_read_io = NWBHDF5IO(neural_nwb_file_path, mode="r", load_namespaces=True)
spikesorting_nwbfile = spikesorting_read_io.read()
ecephys = spikesorting_nwbfile.processing["ecephys"]
units = ecephys.data_interfaces["units"]
electrodes_df = spikesorting_nwbfile.electrodes.to_dataframe()

# Read behavior data from nwb file
with NWBHDF5IO(behavior_nwb_file_path, "r") as io:
    print(f"Processing {behavior_nwb_file_path}")
    nwbfile = io.read()
    all_trials_df = nwbfile.trials.to_dataframe()
    all_display_df = nwbfile.intervals['display'].to_dataframe()

# Resolve identities
all_trials_df["stimulus_object_identities"] = all_trials_df.stimulus_object_identities.apply(
    lambda x: ast.literal_eval(x))
all_trials_df["num_objects"] = all_trials_df.stimulus_object_identities.apply(lambda x: len(x))

# Filter 3-object correct trials
all_trials_df = all_trials_df[
    (all_trials_df.num_objects == 3)
    & (all_trials_df.broke_fixation == False)
    & (all_trials_df.reward_duration > 0)
]

# Reset index
all_trials_df["trial_id"] = all_trials_df.index
all_trials_df = all_trials_df.reset_index(drop=True)
trial_ids = all_trials_df.trial_id.values

Processing ../cache/dandi_data/behavior/sub-Perle/sub-Perle_ses-2022-05-31_behavior+task.nwb


In [4]:
"""Make video."""

# Constants
pixels_per_unit = 64
top_pad = 16
left_pad = 16
start_t = -0.3
end_t = 0.7
start_pixel = 256
end_pixel = 1024
spike_y_pad = 8
spike_height = 40
spike_width = 4
past_color = (128, 64, 0, 255)
future_color = (255, 128, 0, 255)
frames_between_trials = 20
_SPEED = 0.5
IMAGE_SIZE = 1024

# def _render_task_phase(phase, height=64):
def _render_task_phase(phase, height=128):
    # Render text in buffer
    # font = ImageFont.load_default(size=42)
    font = ImageFont.load_default(size=84)
    buffer = Image.new('RGBA', (IMAGE_SIZE, height), (0, 0, 0, 255))
    draw = ImageDraw.Draw(buffer)
    # capitalize phase
    text = phase.capitalize()
    # Draw text centered
    text_width = draw.textlength(text, font=font)
    text_height = font.size
    left_pad = (IMAGE_SIZE - text_width) // 2
    y = (height - text_height) // 2
    # Draw text
    draw.text((left_pad, y), text, fill=(255, 255, 255, 255), font=font)
    # Convert to numpy array
    buffer = np.array(buffer)[:, :, :3]
    return buffer

def _render_trial(trial_df, display_df):
    """Render trial."""
    # Get task state and renderer
    stim_positions = ast.literal_eval(trial_df.stimulus_object_positions.values[0])
    stim_identities = trial_df.stimulus_object_identities.values[0]
    stim_target_str = trial_df.stimulus_object_target.values[0]
    stim_target = stim_target_str[1:-1].split(", ")
    stim_target = [x == 'true' for x in stim_target]
    task_state = task_state_lib.get_task_state(
        stim_positions, stim_identities, stim_target)
    for p in task_state['prey']:
        p.opacity = 255
    task_state['eye'][0].opacity = 255
    background_indices = (0, 0)
    renderer = renderer_lib.Renderer(image_size=(IMAGE_SIZE, IMAGE_SIZE))

    # Iterate through timesteps and render frames
    frames = []
    num_timesteps = len(display_df)
    for timestep in range(num_timesteps):
        fixation_cross_scale = display_df.loc[timestep, 'fixation_cross_scale']
        eye_position = display_df.loc[timestep, 'closed_loop_eye_position']
        task_phase = display_df.loc[timestep, 'task_phase']

        # Modify state according to timestep
        task_state['fixation'][0].scale = fixation_cross_scale
        task_state['eye'][0].position = eye_position
        task_state['fw'][0].position = eye_position
        blank = False
        if task_phase == 'fixation':
            for p in task_state['prey']:
                p.opacity = 0
        elif task_phase == 'stimulus':
            for p in task_state['prey']:
                p.opacity = 255
        elif task_phase == 'delay':
            for p in task_state['prey']:
                p.opacity = 0
            blank = trial_df.delay_object_blanks.values[0]
        elif task_phase == 'cue':
            task_state['fixation'][0].opacity = 0
            for p in task_state['prey']:
                p.opacity = 0
            task_state['cue'][0].opacity = 255
            blank = trial_df.delay_object_blanks.values[0]
        elif task_phase == 'response':
            task_state['fixation'][0].opacity = 0
            for p in task_state['prey']:
                p.opacity = 0
            task_state['cue'][0].opacity = 150
            task_state['fw'][0].opacity = 255
            blank = trial_df.delay_object_blanks.values[0]
        elif task_phase == 'reveal':
            task_state['fixation'][0].opacity = 0
            for p in task_state['prey']:
                p.opacity = 145
            task_state['cue'][0].opacity = 150
            task_state['fw'][0].opacity = 255
        else:
            raise ValueError(f'Unknown task_phase: {task_phase}')
        
        image = renderer(task_state, background_indices, blank=blank)
        
        # Add task phase text
        task_phase_image = _render_task_phase(task_phase)
        image = np.concatenate((task_phase_image, image), axis=0)
        
        # Append frame
        frames.append(image)
        
    return frames

# Get videos per trial
video_per_trial = []
for trial_index in TRIALS:
    # Compute trial and display dataframes for the selected trial
    trial_df = all_trials_df[all_trials_df.trial_id == trial_index]
    trial_id = trial_df["trial_id"]
    display_df = all_display_df[
        (all_display_df['start_time'] > trial_df['phase_stimulus_time'].values[0] - 0.25) &
        (all_display_df['stop_time'] < trial_df['stop_time'].values[0])
    ].reset_index(drop=True)
    
    # Get spike times per unit
    first_frame_time = display_df.start_time.values[0]
    last_frame_time = display_df.stop_time.values[-1]
    spike_times_per_unit = []
    for unit_id in UNIT_IDS:
        spike_times = units.spike_times_index[unit_id]
        spike_times_in_trial = spike_times[
            (spike_times >= first_frame_time + FIRST_SPIKE_BUFFER) &
            (spike_times <= last_frame_time + LAST_SPIKE_BUFFER)
        ] - first_frame_time
        frame_times = display_df.start_time.values - first_frame_time
        spike_times_per_unit.append(spike_times_in_trial)
        
    # Get frames
    frames = _render_trial(trial_df, display_df)
    
    # Render text in buffer
    num_units = len(UNIT_IDS)
    y_values = (top_pad + pixels_per_unit * np.arange(num_units)).astype(int)
    font = ImageFont.load_default(size=40)
    height = top_pad + pixels_per_unit * num_units
    buffer = Image.new('RGBA', (frames[0].shape[1], height), (0, 0, 0, 255))
    draw = ImageDraw.Draw(buffer)
    for i, y in enumerate(y_values):
        text = f"Neuron {UNIT_IDS[i]}"
        draw.text((left_pad, y), text, fill=(255, 255, 255, 255), font=font)
    buffer_array = np.array(buffer)

    frames_with_spikes = []
    for i, row in display_df.iterrows():
        frame_time = row['start_time'] - first_frame_time
        
        buffer = np.copy(buffer_array)
        
        # Render present line
        present_value = -start_t / (end_t - start_t)
        present_pixel = int(start_pixel + (end_pixel - start_pixel) * present_value)
        buffer[:, present_pixel:present_pixel + 1] = (128, 128, 128, 255)
        
        # Render spikes
        for y, spike_times in zip(y_values, spike_times_per_unit):
            spikes = spike_times[
                (spike_times >= frame_time + start_t) &
                (spike_times <= frame_time + end_t)
            ]
            past_labels = spikes < frame_time
            spikes = (spikes - (frame_time + start_t)) / (end_t - start_t)
            spike_pixels = (
                start_pixel + (end_pixel - start_pixel) * spikes
            ).astype(int)
            for spike, past in zip(spike_pixels, past_labels):
                color = past_color if past else future_color
                buffer[y + spike_y_pad: y + spike_y_pad + spike_height, spike - spike_width: spike] = color
                
        # Append buffer to frame
        frame = np.concatenate((frames[i], buffer[:, :, :3]), axis=0)
        frames_with_spikes.append(frame)
    
    # Add more buffer
    empty_frame = np.zeros_like(frames[0], dtype=np.uint8)
    frame_with_buffer = np.concatenate([empty_frame, buffer_array[:, :, :3]], axis=0)
    for _ in range(frames_between_trials):
        frames_with_spikes.append(frame_with_buffer)
        
    # Make video
    video_clip = ImageSequenceClip(frames_with_spikes, fps=int(_SPEED * 60))
    audio_clips = []
    for clip, spike_times in zip(CLICK_SOUNDS, spike_times_per_unit):
        for t in spike_times:
            if t < 0:
                continue
            if t > last_frame_time - first_frame_time:
                continue
            audio_clips.append(clip.with_start(t / _SPEED))
    composite_audio = CompositeAudioClip(audio_clips)
    video_clip.audio = composite_audio
    
    # Append video
    video_per_trial.append(video_clip)
    
# Write the result
clips = [video_per_trial[0]]
for v in video_per_trial[1:]:
    clips.append(v.with_start(clips[-1].end))

composite = CompositeVideoClip(clips)
write_path = WRITE_DIR / f"{SUBJECT}_{SESSION}.mp4"
composite.write_videofile(write_path)

MoviePy - Building video videos_with_neurons/Perle_2022-05-31.mp4.
MoviePy - Writing audio in Perle_2022-05-31TEMP_MPY_wvf_snd.mp3


                                                                      

MoviePy - Done.
MoviePy - Writing video videos_with_neurons/Perle_2022-05-31.mp4



                                                                          

MoviePy - Done !
MoviePy - video ready videos_with_neurons/Perle_2022-05-31.mp4
