# Script to replay a trial and save it as a gif.

This notebook will generate a gif in `./gifs`. This gif will render a trial as
the monkey sees it, except the monkey's gaze position is rendered post-hoc as a
green cross.

By default, this notebook runs on a session included in our OSF cache. If you
would like to run this notebook on trials from other sessions, you must first
download the behavior dandi file for that session by navigating to the root
directory (`../..`) and running
```
$ python3 download_dandi_data.py --modality=behavior --subject=Perle --session=$your_session
```

In [3]:
"""Imports."""

import ast
from pathlib import Path
from pynwb import NWBHDF5IO

import gif_writer as gif_writer_lib
import task_state as task_state_lib
import renderer as renderer_lib

In [4]:
"""Load data from NWB file."""

SUBJECT = "Perle"
SESSION = "2022-06-01"

# Load the data
nwb_file_path = Path(
    f"../cache/dandi_data/behavior/sub-{SUBJECT}/"
    f"sub-{SUBJECT}_ses-{SESSION}_behavior+task.nwb"
)

# Read data from nwb file
with NWBHDF5IO(nwb_file_path, "r") as io:
    print(f"Processing {nwb_file_path}")
    nwbfile = io.read()
    all_trials_df = nwbfile.trials.to_dataframe()
    all_display_df = nwbfile.intervals['display'].to_dataframe()


Processing ../cache/dandi_data/behavior/sub-Perle/sub-Perle_ses-2022-06-01_behavior+task.nwb


In [5]:
"""Render trial."""

TRIAL_NUMBER = 200  # Trial number to render
RENDER_GAZE = True  # Whether to render the green gaze cross

# Extract data for the selected trial
trial_df = all_trials_df.loc[TRIAL_NUMBER]
display_df = all_display_df[
    (all_display_df['start_time'] > trial_df['start_time']) &
    (all_display_df['stop_time'] < trial_df['stop_time'])
].reset_index(drop=True)
num_timesteps = display_df.shape[0]
broke_fixation = trial_df["broke_fixation"]
print(f'Number of timesteps: {num_timesteps}')
print(f'Broke fixation: {broke_fixation}')

# Get task state and renderer
stim_positions = ast.literal_eval(trial_df.stimulus_object_positions)
stim_identities = ast.literal_eval(trial_df.stimulus_object_identities)
stim_target_str = trial_df.stimulus_object_target
stim_target = stim_target_str[1:-1].split(", ")
stim_target = [x == 'true' for x in stim_target]
task_state = task_state_lib.get_task_state(
    stim_positions, stim_identities, stim_target)
for p in task_state['prey']:
    p.opacity = 255
background_indices = trial_df.background_indices
renderer = renderer_lib.Renderer()
if RENDER_GAZE:
    task_state['eye'][0].opacity = 255

# Iterate through timesteps and render frames
frames = []
for timestep in range(num_timesteps):
    fixation_cross_scale = display_df.loc[timestep, 'fixation_cross_scale']
    eye_position = display_df.loc[timestep, 'closed_loop_eye_position']
    task_phase = display_df.loc[timestep, 'task_phase']

    # Modify state according to timestep
    task_state['fixation'][0].scale = fixation_cross_scale
    task_state['eye'][0].position = eye_position
    task_state['fw'][0].position = eye_position
    blank = False
    if task_phase == 'fixation':
        for p in task_state['prey']:
            p.opacity = 0
    elif task_phase == 'stimulus':
        for p in task_state['prey']:
            p.opacity = 255
    elif task_phase == 'delay':
        for p in task_state['prey']:
            p.opacity = 0
        blank = trial_df.delay_object_blanks
    elif task_phase == 'cue':
        task_state['fixation'][0].opacity = 0
        for p in task_state['prey']:
            p.opacity = 0
        task_state['cue'][0].opacity = 255
        blank = trial_df.delay_object_blanks
    elif task_phase == 'response':
        task_state['fixation'][0].opacity = 0
        for p in task_state['prey']:
            p.opacity = 0
        task_state['cue'][0].opacity = 150
        task_state['fw'][0].opacity = 255
        blank = trial_df.delay_object_blanks
    elif task_phase == 'reveal':
        task_state['fixation'][0].opacity = 0
        for p in task_state['prey']:
            p.opacity = 145
        task_state['cue'][0].opacity = 150
        task_state['fw'][0].opacity = 255
    else:
        raise ValueError(f'Unknown task_phase: {task_phase}')
    
    # Render the current frame
    image = renderer(task_state, background_indices, blank=blank)
    frames.append(image)
    
# Handle broken fixation screen
if broke_fixation:
    task_state['broke_fixation_screen'][0].opacity = 255
    frames.append(renderer(task_state, background_indices))
    
# Save frames as gif
write_path = f'./gifs/{SUBJECT}_{SESSION}_trial_{TRIAL_NUMBER}.gif'
gif_writer = gif_writer_lib.GifWriter(gif_file=write_path, fps=20)
for frame in frames:
    gif_writer.add(frame)
gif_writer.close()

Number of timesteps: 244
Broke fixation: False
Writing gif with 244 images to file ./gifs/Perle_2022-06-01_trial_201.gif
