# Notebook to generate frames for task schematic.

This notebook saves PDF images of task phases to the current directory.

In [1]:
"""Imports."""

from pathlib import Path
from pynwb import NWBHDF5IO
import ast

from PIL import Image
import sys
sys.path.append('../../../task_visualization')
import task_state as task_state_lib
import renderer as renderer_lib

In [None]:
"""Load data."""

SUBJECT = "Perle"
SESSION = "2022-06-01"
TRIAL_NUMBER = 33

# Load the data
nwb_file_path = Path(
    f"../../../cache/dandi_data/behavior/sub-{SUBJECT}/"
    f"sub-{SUBJECT}_ses-{SESSION}_behavior+task.nwb"
)

# Read data from nwb file
with NWBHDF5IO(nwb_file_path, "r") as io:
    print(f"Processing {nwb_file_path}")
    nwbfile = io.read()
    all_trials_df = nwbfile.trials.to_dataframe()
    all_display_df = nwbfile.intervals['display'].to_dataframe()

# Compute trial and display dataframes for the selected trial
trial_df = all_trials_df.loc[TRIAL_NUMBER]
display_df = all_display_df[
    (all_display_df['start_time'] > trial_df['start_time']) &
    (all_display_df['stop_time'] < trial_df['stop_time'])
].reset_index(drop=True)
num_timesteps = display_df.shape[0]
broke_fixation = trial_df["broke_fixation"]

Processing ../../../cache_osf/data_for_figures/dandi_data/behavior/sub-Perle/sub-Perle_ses-2022-06-01_behavior+task.nwb


In [3]:
"""Render and save task frames."""

# Get task state and renderer
stim_positions = ast.literal_eval(trial_df.stimulus_object_positions)
stim_identities = ast.literal_eval(trial_df.stimulus_object_identities)
stim_target_str = trial_df.stimulus_object_target
stim_target = stim_target_str[1:-1].split(", ")
stim_target = [x == 'true' for x in stim_target]
task_state = task_state_lib.get_task_state(
    stim_positions, stim_identities, stim_target)
for p in task_state['prey']:
    p.opacity = 255
task_state['eye'][0].opacity = 0#255
background_indices = trial_df.background_indices
renderer = renderer_lib.Renderer()

# task_state

# Iterate through timesteps and render frames
frames = []
for timestep in range(num_timesteps):
    fixation_cross_scale = display_df.loc[timestep, 'fixation_cross_scale']
    eye_position = display_df.loc[timestep, 'closed_loop_eye_position']
    task_phase = display_df.loc[timestep, 'task_phase']

    # Modify state according to timestep
    task_state['fixation'][0].scale = fixation_cross_scale
    task_state['eye'][0].position = eye_position
    task_state['fw'][0].position = eye_position
    blank = False
    if task_phase == 'fixation':
        for p in task_state['prey']:
            p.opacity = 0
    elif task_phase == 'stimulus':
        for p in task_state['prey']:
            p.opacity = 255
    elif task_phase == 'delay':
        for p in task_state['prey']:
            p.opacity = 0
        blank = trial_df.delay_object_blanks
    elif task_phase == 'cue':
        task_state['fixation'][0].opacity = 0
        for p in task_state['prey']:
            p.opacity = 0
        task_state['cue'][0].opacity = 255
        blank = trial_df.delay_object_blanks
    elif task_phase == 'response':
        task_state['fixation'][0].opacity = 0
        for p in task_state['prey']:
            p.opacity = 0
        task_state['cue'][0].opacity = 150
        task_state['fw'][0].opacity = 255
        blank = trial_df.delay_object_blanks
    elif task_phase == 'reveal':
        task_state['fixation'][0].opacity = 0
        for p in task_state['prey']:
            p.opacity = 145
        task_state['cue'][0].opacity = 150
        task_state['fw'][0].opacity = 255
    else:
        raise ValueError(f'Unknown task_phase: {task_phase}')
    
    image = renderer(task_state, background_indices, blank=blank)
    frames.append(image)
    
# Handle broken fixation screen
if broke_fixation:
    task_state['broke_fixation_screen'][0].opacity = 255
    frames.append(renderer(task_state, background_indices))

# Save each frame
frames_to_save = [
    ("fixation", frames[0]),
    ("stimulus", frames[40]),
    ("delay", frames[100]),
    ("cue", frames[160]),
    ("response", frames[180]),
    ("reveal", frames[198]),
]
for name, frame in frames_to_save:
    Image.fromarray(frame).save(f"./{name}.pdf")