In [321]:
import argparse
import bisect
import json
import os
import re
import sys
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from zoneinfo import ZoneInfo
from decimal import Decimal
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
from aind_data_schema.core.session import (
    DetectorConfig,
    FieldOfView,
    LaserConfig,
    Modality,
    Session,
    Stream, 
    TriggerType,
    RewardDeliveryConfig,
    RewardSolution,
    RewardSpoutConfig,
    SpoutSide,
    LaserConfig,
    Stack,
    StackChannel,
    StimulusEpoch,
    StimulusModality
)
from aind_data_schema.imaging.tile import Channel
from aind_data_schema.models.devices import Calibration, Software
from aind_data_schema.models.stimulus import (
    PhotoStimulation,
    PhotoStimulationGroup,#    StimulusEpoch,
)
from aind_data_schema.models.coordinates import Translation3dTransform, Rotation3dTransform, RelativePosition, Axis, AxisName
from aind_data_schema.models.units import PowerUnit, SizeUnit, TimeUnit
from pydantic import Field
from pydantic_settings import BaseSettings
from ScanImageTiffReader import ScanImageTiffReader




In [322]:
# For information that can't be parsed from the tif files, it needs to be a field here.
# I'm not sure whether things can be set as default values. Please update it if the
# defaults I'm using don't make send. Just remove the defaults.
# 
class JobSettings(BaseSettings):
    """Data that needs to be input by user. Can be pulled from env vars with
    BERGAMO prefix or set explicitly."""

    input_source: Path = Field(
        ..., description="Directory of files that need to be parsed."
    )
    output_directory: Optional[Path] = Field(
        default=None,
        description=(
            "Directory where to save the json file to. If None, then json"
            " contents will be returned in the Response message."
        ),
    )
    # mandatory fields:
    experimenter_full_name: List[str]
    subject_id: str
    imaging_laser_wavelength: int # user defined
    fov_imaging_depth: int
    fov_targeted_structure: str
    notes: str
    
    
    # fields with default values
    mouse_platform_name: str = 'tube' # should match rig json
    active_mouse_platform : bool = False
    session_type: str = "BCI"
    iacuc_protocol: str = "2109"
    rig_id: str = "Bergamo photostim rig" # should match rig json
    behavior_camera_names: List[str] = ["Side Face Camera", "Bottom Face Camera"] # should match rig json
    imaging_laser_name: str = "Chameleon tunable pulsing laser" # should match rig json
    
    photostim_laser_name: str = "Monaco 1040nm pulsing laser" # should match rig json
    photostim_laser_wavelength  :int = 1040
    fov_coordinate_ml: Decimal = Decimal('1.5')
    fov_coordinate_ap: float = Decimal('1.5')
    fov_reference: str = "Bregma"
    
    starting_lickport_position : list[float] = [0, -6, 0] # in mm from face of the mouse
    behavior_task_name : str = 'single neuron BCI conditioning'
    timezone: ZoneInfo = ZoneInfo('US/Pacific')

    class Config:
        """Config to set env var prefix to BERGAMO"""

        env_prefix = "BERGAMO_"
        


In [323]:
# This class makes it easier to flag which tif files are which expected type

class TifFileGroup(str, Enum):
    BEHAVIOR = "behavior"
    PHOTOSTIM = "photostim"
    SPONTANEOUS = "spontaneous"
    STACK = "stack"
    
# This class will hold the metadata information pulled from the tif files with minimal parsing.

@dataclass(frozen=True)
class RawImageInfo:
    """Raw metadata from a tif file"""

    reader_metadata_header: dict
    reader_metadata_json: dict
    # The reader descriptions for the last tif file
    reader_descriptions: List[dict]
    # Looks like [620, 800, 800]
    # [num_of_frames, pixel_width, pixel_height]?
    reader_shape: List[int]
    
# This class is a container to hold only the tif file metadata information needed to build the 
# Session.json file
# More stuff can be added if necessary.

@dataclass(frozen=True)
class ParsedMetadataInfo:
    """Tif file metadata that's needed downstream"""

    tif_file_group: TifFileGroup
    number_of_tif_files: int  # This should correspond to the number of trials
    h_photostim: dict
    h_roi_manager: dict
    h_beams: dict
    h_fast_z: dict
    imaging_roi_group: dict
    photostim_roi_groups: List[dict]
    reader_description_last: dict
    reader_shape: List[int]
    
# The following functions will be used to translate the tif file information into information
# needed to build the Session.json file. They need to be re-usable and preferrably modular.
# I can bundle them into a class to ensure they all process the same job_settings class.
# I'll keep them independent for purposes of sharing a jupyter notebook.

def get_tif_file_locations(job_settings: JobSettings) -> Dict[str, List[Path]]:
    """Scans the input source directory and returns a dictionary of file
    groups in an ordered list. For example, if the directory had
    [neuron2_00001.tif, neuron2_00002.tif, stackPost_00001.tif,
    stackPost_00002.tif, stackPost_00003.tif], then it will return
    { "neuron2": [neuron2_00001.tif, neuron2_00002.tif],
     "stackPost":
       [stackPost_00001.tif, stackPost_00002.tif, stackPost_00003.tif]
    }
    """
    compiled_regex = re.compile(r"^(.*)_.*?(\d+).tif+$")
    tif_file_map = {}
    for root, dirs, files in os.walk(job_settings.input_source):
        for name in files:
            matched = re.match(compiled_regex, name)
            if matched:
                groups = matched.groups()
                file_stem = groups[0]
                # tif_number = groups[1]
                tif_filepath = Path(os.path.join(root, name))
                if tif_file_map.get(file_stem) is None:
                    tif_file_map[file_stem] = [tif_filepath]
                else:
                    bisect.insort(tif_file_map[file_stem], tif_filepath)

        # Only scan the top level files
        break
    return tif_file_map

def flat_dict_to_nested(flat: dict, key_delim: str = ".") -> dict:
    """
    Utility method to convert a flat dictionary into a nested dictionary.
    Modified from https://stackoverflow.com/a/50607551
    Parameters
    ----------
    flat : dict
      Example {"a.b.c": 1, "a.b.d": 2, "e.f": 3}
    key_delim : str
      Delimiter on dictionary keys. Default is '.'.

    Returns
    -------
    dict
      A nested dictionary like {"a": {"b": {"c":1, "d":2}, "e": {"f":3}}
    """

    def __nest_dict_rec(k, v, out) -> None:
        """Simple recursive method being called."""
        k, *rest = k.split(key_delim, 1)
        if rest:
            __nest_dict_rec(rest[0], v, out.setdefault(k, {}))
        else:
            out[k] = v

    result = {}
    for flat_key, flat_val in flat.items():
        __nest_dict_rec(flat_key, flat_val, result)
    return result

# This methods parses a single file into RawImageInfo dataclass
def extract_raw_info_from_file(file_path: Path) -> RawImageInfo:
    with ScanImageTiffReader(str(file_path)) as reader:
        reader_metadata = reader.metadata()
        reader_shape = reader.shape()
        reader_descriptions = [
            dict(
                [
                    (s.split(" = ", 1)[0], s.split(" = ", 1)[1])
                    for s in reader.description(i).strip().split("\n")
                ]
            )
            for i in range(0, len(reader))
        ]

    metadata_first_part = reader_metadata.split("\n\n")[0]
    flat_metadata_header_dict = dict(
        [
            (s.split(" = ", 1)[0], s.split(" = ", 1)[1])
            for s in metadata_first_part.split("\n")
        ]
    )
    metadata_dict = flat_dict_to_nested(flat_metadata_header_dict)
    reader_metadata_json = json.loads(reader_metadata.split("\n\n")[1])
    # Move SI dictionary up one level
    if "SI" in metadata_dict.keys():
        si_contents = metadata_dict.pop("SI")
        metadata_dict.update(si_contents)
    return RawImageInfo(
        reader_shape=reader_shape,
        reader_metadata_header=metadata_dict,
        reader_metadata_json=reader_metadata_json,
        reader_descriptions=reader_descriptions,
    )
# vvvvvvvvvvvv MARTON HAS REMOVED THIS FUNCTION AS IT GENERATED NEW CONVENTIONS THAT MAKE IT HARDER TO FOLLOW vvvvvvvvvvvvvv

# This method transforms a RawImageInfo class into an intermediate ParsedMetadatInfo class.
# It should be easier to deal with the more focused ParsedMetadataInfo class than the RawImageInfo.
# If other fields from the RawImageInfo are needed downstream, we can update things here.

# def parse_raw_metadata(
#     raw_image_info: RawImageInfo, number_of_files: int
# ) -> ParsedMetadataInfo:
#     h_roi_manager = raw_image_info.reader_metadata_header.get(
#         "hRoiManager", {}
#     )
#     h_beams = raw_image_info.reader_metadata_header.get("hBeams", {})
#     h_fast_z = raw_image_info.reader_metadata_header.get("hFastZ", {})
#     h_photostim = raw_image_info.reader_metadata_header.get(
#         "hPhotostim", {}
#     )
#     roi_groups = raw_image_info.reader_metadata_json.get("RoiGroups", {})
#     imaging_roi_group = roi_groups.get("imagingRoiGroup", {})
#     photostim_roi_groups = roi_groups.get("photostimRoiGroups", [])

#     reader_description_last = raw_image_info.reader_descriptions[-1]

#     tif_file_group = map_raw_image_info_to_tif_file_group(
#         raw_image_info=raw_image_info
#     )

#     return ParsedMetadataInfo(
#         tif_file_group=tif_file_group,
#         number_of_tif_files=number_of_files,
#         h_photostim=h_photostim,
#         h_roi_manager=h_roi_manager,
#         h_beams=h_beams,
#         h_fast_z=h_fast_z,
#         imaging_roi_group=imaging_roi_group,
#         photostim_roi_groups=photostim_roi_groups,
#         reader_description_last=reader_description_last,
#         reader_shape=raw_image_info.reader_shape,
#     )

# This method maps a RawImageInfo dataclass into a TifFileGroup type
# ^^^^^^^^^MARTON HAS REMOVED THIS FUNCTION AS IT GENERATED NEW CONVENTIONS THAT MAKE IT HARDER TO FOLLOW ^^^^^^^^^^^^^^^^^^


def map_raw_image_info_to_tif_file_group(
    raw_image_info: RawImageInfo,
) -> TifFileGroup:
    header = raw_image_info.reader_metadata_header
    if header.get("hPhotostim", {}).get("status") in [
        "'Running'",
        "Running",
    ]:
        return TifFileGroup.PHOTOSTIM
    elif (
        header.get("hIntegrationRoiManager", {}).get("enable") == "true"
        and header.get("hIntegrationRoiManager", {}).get(
            "outputChannelsEnabled"
        )
        == "true"
        and header.get("extTrigEnable", {}) == "1"
    ):
        return TifFileGroup.BEHAVIOR
    elif header.get("hStackManager", {}).get("enable") == "true":
        return TifFileGroup.STACK
    else:
        return TifFileGroup.SPONTANEOUS
# Loops through tif file locations and transforms them into a dictionary of ParsedMetadataInfo

def extract_parsed_metadata_info_from_files(
    tif_file_locations: Dict[str, List[Path]]
) -> Dict[Tuple[str, TifFileGroup], ParsedMetadataInfo]:
    parsed_map = {}
    for file_stem, files in tif_file_locations.items():
        number_of_files = len(files)
        last_file = files[-1]
        raw_info = extract_raw_info_from_file(last_file)
        raw_info_first_file = extract_raw_info_from_file(files[0])
        # parsed_info = parse_raw_metadata(
        #     raw_image_info=raw_info, number_of_files=number_of_files
        # )
        tif_file_group = map_raw_image_info_to_tif_file_group(
        raw_image_info=raw_info
    )
        parsed_map[(file_stem, tif_file_group)] = [raw_info,[files],raw_info_first_file]
    return parsed_map

In [325]:
tif_file_path = Path("/home/jupyter/bucket/Data/Calcium_imaging/raw/Bergamo-2P-Photostim/BCI_29/052422")
job_settings = JobSettings(input_source=tif_file_path,
                           experimenter_full_name=["John Apple"],
                           subject_id="061022",
                           imaging_laser_wavelength = 920, #nm
                           fov_imaging_depth  =200, #microns
                           fov_targeted_structure = 'Primary Motor Cortex',
                           notes = 'test upload'
                          )
#separate files by basenames
tif_file_locations = get_tif_file_locations(job_settings=job_settings)
#parse metadata
parsed_metadata = extract_parsed_metadata_info_from_files(tif_file_locations=tif_file_locations)


In [326]:
stack_file_info = [
    (k, v)
    for k, v in parsed_metadata.items()
    if k[1] == TifFileGroup.STACK
]
spont_file_info = [
    (k, v)
    for k, v in parsed_metadata.items()
    if k[1] == TifFileGroup.SPONTANEOUS
]
behavior_file_info = [
    (k, v)
    for k, v in parsed_metadata.items()
    if k[1] == TifFileGroup.BEHAVIOR
]
photo_stim_file_info = [
    (k, v)
    for k, v in parsed_metadata.items()
    if k[1] == TifFileGroup.PHOTOSTIM
]
first_tiff_metadata_header = parsed_metadata[list(parsed_metadata.keys())[0]][0].reader_metadata_header # to get scanimage version


In [328]:

# here key is channel in scanimage
channel_dict = {1:{'channel_name':'Ch1', 
                   'light_source_name' : job_settings.imaging_laser_name,
                   'filter_names' : [], # FROM RIG JSON
                   'detector_name': '',# FROM RIG JSON
                   'excitation_wavelength':job_settings.imaging_laser_wavelength,
                   'daq_name':'',# FROM RIG JSON
                  },
               2:{'channel_name':'Ch2',
                   'light_source_name' : job_settings.imaging_laser_name,
                   'filter_names' : [], # FROM RIG JSON
                   'detector_name': '',# FROM RIG JSON
                  'excitation_wavelength':job_settings.imaging_laser_wavelength,
                  'daq_name':'',# FROM RIG JSON
                  }}
laser_dict = {'imaging_laser':{'power_index':0},
             'photostim_laser':{'power_index':1}}
FOV_1x_micron = 1000



lickportposition = RelativePosition(device_position_transformations = [Translation3dTransform(translation = job_settings.starting_lickport_position), # this is the standard position for BCI task
                                                                      Rotation3dTransform(rotation = [0]*9)], # this is the standard position for BCI task
                                    device_origin = 'tip of the lickspout',
                                    device_axes = [Axis(name = AxisName.X,direction = 'lateral motion'),
                                                  Axis(name = AxisName.Y,direction = 'rostro-caudal motion positive is towards mouse, negative is away'),
                                                  Axis(name = AxisName.Z,direction = 'up/down')] )

reward_spout_config = RewardSpoutConfig(side=SpoutSide.CENTER,
                                        starting_position = lickportposition,
                                        variable_position = True,
                                       )
reward_delivery = RewardDeliveryConfig(reward_solution = RewardSolution.WATER,
                                      reward_spouts = [reward_spout_config])
behavior_software = Software(name = 'pyBpod',
                             version = '1.8.2',#hard coded
                             url = 'https://github.com/pybpod/pybpod') 
pybpod_script = Software(name = 'pybpod_basic.py',#file name
                         version = '1',#commit#
                         url = 'https://github.com/rozmar/BCI-motor-control/blob/main/BCI-pybpod-protocols/bci_basic.py',
                         parameters = {}) #  can I do this?
photostim_software  = Software(name = 'ScanImage',
                            version = '{}.{}.{}'.format(first_tiff_metadata_header['VERSION_MAJOR'],
                                                        first_tiff_metadata_header['VERSION_MINOR'],
                                                        first_tiff_metadata_header['VERSION_UPDATE']),#hard coded
                            url = 'https://www.mbfbioscience.com/products/scanimage/') #hard coded

all_stream_start_times = []
all_stream_end_times = []
streams = []
stim_epochs = []
for stack_file_info_now in stack_file_info: # ONLY 2P STREAM DURING STACKS
    tiff_header = stack_file_info_now[1][0].reader_metadata_header
    last_frame_description = stack_file_info_now[1][0].reader_descriptions[-1]
    ###########################################THIS THING REPEATS FOR EVERY STREAM########################################
    z_list = np.asarray(tiff_header['hStackManager']['zs'].strip('[]').split(' '),float)
    z_start = np.min(z_list)-np.median(z_list) + job_settings.fov_imaging_depth
    z_end = np.max(z_list)-np.median(z_list) + job_settings.fov_imaging_depth
    z_step = float(tiff_header['hStackManager']['stackZStepSize'])
    channel_nums = np.asarray(tiff_header['hChannels']['channelSave'].strip('[]').split(' ') ,int)
    daq_names = []
    for channel_num in channel_nums: daq_names.append(channel_dict[channel_num]['daq_name'])
    channels = []
    start_time_corrected = last_frame_description['epoch'].strip('[]').replace('  ',' 0').split(' ')
    start_time_corrected = ' '.join(start_time_corrected[:-1] + [str(int(np.floor(float(start_time_corrected[-1])))).zfill(2),str(int(1000000*(float(start_time_corrected[-1])%1))).zfill(6)])
    stream_start_time = datetime.strptime(start_time_corrected,'%Y %m %d %H %M %S %f').replace(tzinfo=job_settings.timezone) 
    stream_start_time = stream_start_time.replace(tzinfo=job_settings.timezone)
    stream_end_time = stream_start_time+timedelta(seconds = float(last_frame_description['frameTimestamps_sec']))
    ###########################################THIS THING REPEATS FOR EVERY STREAM########################################
    all_stream_start_times.append(stream_start_time)
    all_stream_end_times.append(stream_end_time)
    for channel_num in channel_nums:
        channels.append(StackChannel(start_depth = z_start,
                                    end_depth = z_end, 
                                    channel_name = channel_dict[channel_num]['channel_name'], 
                                    light_source_name = channel_dict[channel_num]['light_source_name'],
                                    filter_names = channel_dict[channel_num]['filter_names'],
                                    detector_name = channel_dict[channel_num]['detector_name'],
                                    excitation_wavelength = channel_dict[channel_num]['excitation_wavelength'],
                                    excitation_power = np.asarray(tiff_header['hBeams']['powers'].strip('[]').split(' '),float)[laser_dict['imaging_laser']['power_index']],# from tiff header,
                                    excitation_power_unit = PowerUnit.PERCENT,
                                    filter_wheel_index = 0))
    zstack = Stack(channels = channels,
                  number_of_planes = int(tiff_header['hStackManager']['numSlices']),
                  step_size = z_step,
                  number_of_plane_repeats_per_volume = int(tiff_header['hStackManager']['framesPerSlice']),
                  number_of_volume_repeats = int(tiff_header['hStackManager']['numVolumes']),
                  fov_coordinate_ml = job_settings.fov_coordinate_ml,
                  fov_coordinate_ap = job_settings.fov_coordinate_ap,
                  fov_reference = 'there is no reference',
                  fov_width = int(tiff_header['hRoiManager']['pixelsPerLine']),
                  fov_height = int(tiff_header['hRoiManager']['linesPerFrame']),
                  magnification = str(tiff_header['hRoiManager']['scanZoomFactor']),
                  fov_scale_factor = (FOV_1x_micron/float(tiff_header['hRoiManager']['scanZoomFactor']))/float(tiff_header['hRoiManager']['linesPerFrame']), #microns per pixel
                  frame_rate = float(tiff_header['hRoiManager']['scanFrameRate']),
                  targeted_structure = job_settings.fov_targeted_structure, 
                  )
    stream_stack = Stream(stream_start_time = stream_start_time,
                       stream_end_time = stream_end_time,
                       daq_names = daq_names,
                       light_sources = [LaserConfig(name = job_settings.imaging_laser_name,#from rig json
                                                    wavelength = job_settings.imaging_laser_wavelength,# user set value
                                                    excitation_power = np.asarray(tiff_header['hBeams']['powers'].strip('[]').split(' '),float)[laser_dict['imaging_laser']['power_index']],
                                                    excitation_power_unit = PowerUnit.PERCENT)],
                       stack_parameters = zstack,
                       stream_modalities = [Modality.POPHYS])
    streams.append(stream_stack)
    
for spont_file_info_now in spont_file_info: # ONLY 2P STREAM DURING SPONT
    tiff_header = spont_file_info_now[1][0].reader_metadata_header
    last_frame_description = spont_file_info_now[1][0].reader_descriptions[-1]
    ###########################################THIS THING REPEATS FOR EVERY STREAM########################################
    z_list = np.asarray(tiff_header['hStackManager']['zs'].strip('[]').split(' '),float)
    z_start = np.min(z_list)-np.median(z_list) + job_settings.fov_imaging_depth
    z_end = np.max(z_list)-np.median(z_list) + job_settings.fov_imaging_depth
    z_step = float(tiff_header['hStackManager']['stackZStepSize'])
    channel_nums = np.asarray(tiff_header['hChannels']['channelSave'].strip('[]').split(' ') ,int)
    daq_names = []
    for channel_num in channel_nums: daq_names.append(channel_dict[channel_num]['daq_name'])
    channels = []
    start_time_corrected = last_frame_description['epoch'].strip('[]').replace('  ',' 0').split(' ')
    start_time_corrected = ' '.join(start_time_corrected[:-1] + [str(int(np.floor(float(start_time_corrected[-1])))).zfill(2),str(int(1000000*(float(start_time_corrected[-1])%1))).zfill(6)])
    stream_start_time = datetime.strptime(start_time_corrected,'%Y %m %d %H %M %S %f') 
    stream_start_time = stream_start_time.replace(tzinfo=job_settings.timezone)
    stream_end_time = stream_start_time+timedelta(seconds = float(last_frame_description['frameTimestamps_sec']))
    ###########################################THIS THING REPEATS FOR EVERY STREAM########################################
    all_stream_start_times.append(stream_start_time)
    all_stream_end_times.append(stream_end_time)
    fov_2p = FieldOfView(index = 0,# multi-plane will have multiple - in a list
                         imaging_depth = job_settings.fov_imaging_depth, # in microns
                         fov_coordinate_ml = job_settings.fov_coordinate_ml,
                         fov_coordinate_ap = job_settings.fov_coordinate_ap,
                         fov_reference = 'there is no reference',
                         fov_width = int(tiff_header['hRoiManager']['pixelsPerLine']),
                         fov_height = int(tiff_header['hRoiManager']['linesPerFrame']),
                         magnification = str(tiff_header['hRoiManager']['scanZoomFactor']),
                         fov_scale_factor = (FOV_1x_micron/float(tiff_header['hRoiManager']['scanZoomFactor']))/float(tiff_header['hRoiManager']['linesPerFrame']), #microns per pixel
                         frame_rate = float(tiff_header['hRoiManager']['scanFrameRate']),
                         targeted_structure = job_settings.fov_targeted_structure, 
                        )
    stream_2p = Stream(stream_start_time = stream_start_time,#calculate - specify timezone # each basename is a separate stream
                       stream_end_time = stream_end_time,#calculate
                       daq_names = daq_names,# from the rig json
                       light_sources = [LaserConfig(name = job_settings.imaging_laser_name,#from rig json
                                                    wavelength = job_settings.imaging_laser_wavelength,# user set value
                                                    excitation_power = np.asarray(tiff_header['hBeams']['powers'].strip('[]').split(' '),float)[laser_dict['imaging_laser']['power_index']],# from tiff header,
                                                    excitation_power_unit = PowerUnit.PERCENT)],
                       ophys_fovs = [fov_2p], # multiple planes come here
                       stream_modalities = [Modality.POPHYS],
                      )
    streams.append(stream_2p)
    
    
    stim_epoch_spont = StimulusEpoch(stimulus_start_time = stream_start_time,#datetime#basenames are separate
                                   stimulus_end_time = stream_end_time,#datetime, 
                                   stimulus_name = 'spontaneous activity',# user defined in script
                                   stimulus_modalities = [StimulusModality.NONE],
                                   notes = 'absence of any kind of stimulus')
    stim_epochs.append(stim_epoch_spont)


    
for behavior_file_info_now in behavior_file_info: #  2P + behavior + behavior video STREAM DURING BEHAVIOR
    tiff_header = behavior_file_info_now[1][0].reader_metadata_header
    last_frame_description = behavior_file_info_now[1][0].reader_descriptions[-1]
    ###########################################THIS THING REPEATS FOR EVERY STREAM########################################
    
    z_list = np.asarray(tiff_header['hStackManager']['zs'].strip('[]').split(' '),float)
    z_start = np.min(z_list)-np.median(z_list) + job_settings.fov_imaging_depth
    z_end = np.max(z_list)-np.median(z_list) + job_settings.fov_imaging_depth
    z_step = float(tiff_header['hStackManager']['stackZStepSize'])
    channel_nums = np.asarray(tiff_header['hChannels']['channelSave'].strip('[]').split(' ') ,int)
    daq_names = []
    for channel_num in channel_nums: daq_names.append(channel_dict[channel_num]['daq_name'])
    channels = []
    start_time_corrected = last_frame_description['epoch'].strip('[]').replace('  ',' 0').split(' ')
    start_time_corrected = ' '.join(start_time_corrected[:-1] + [str(int(np.floor(float(start_time_corrected[-1])))).zfill(2),str(int(1000000*(float(start_time_corrected[-1])%1))).zfill(6)])
    stream_start_time = datetime.strptime(start_time_corrected,'%Y %m %d %H %M %S %f')
    stream_start_time = stream_start_time.replace(tzinfo=job_settings.timezone)
    stream_end_time = stream_start_time+timedelta(seconds = float(last_frame_description['frameTimestamps_sec']))
    ###########################################THIS THING REPEATS FOR EVERY STREAM########################################   
    all_stream_start_times.append(stream_start_time)
    all_stream_end_times.append(stream_end_time)
    fov_2p = FieldOfView(index = 0,# multi-plane will have multiple - in a list
                         imaging_depth = job_settings.fov_imaging_depth, # in microns
                         fov_coordinate_ml = job_settings.fov_coordinate_ml,
                         fov_coordinate_ap = job_settings.fov_coordinate_ap,
                         fov_reference = 'there is no reference',
                         fov_width = int(tiff_header['hRoiManager']['pixelsPerLine']),
                         fov_height = int(tiff_header['hRoiManager']['linesPerFrame']),
                         magnification = str(tiff_header['hRoiManager']['scanZoomFactor']),
                         fov_scale_factor = (FOV_1x_micron/float(tiff_header['hRoiManager']['scanZoomFactor']))/float(tiff_header['hRoiManager']['linesPerFrame']), #microns per pixel
                         frame_rate = float(tiff_header['hRoiManager']['scanFrameRate']),
                         targeted_structure = job_settings.fov_targeted_structure, 
                        )
    stream_2p = Stream(stream_start_time = stream_start_time,#calculate - specify timezone # each basename is a separate stream
                       stream_end_time = stream_end_time,#calculate
                       daq_names = daq_names,# from the rig json
                       light_sources = [LaserConfig(name = job_settings.imaging_laser_name,#from rig json
                                                    wavelength = job_settings.imaging_laser_wavelength,# user set value
                                                    excitation_power = np.asarray(tiff_header['hBeams']['powers'].strip('[]').split(' '),float)[laser_dict['imaging_laser']['power_index']],# from tiff header,
                                                    excitation_power_unit = PowerUnit.PERCENT)],
                       ophys_fovs = [fov_2p], # multiple planes come here
                       stream_modalities = [Modality.POPHYS,
                                           Modality.BEHAVIOR],

                      )
    streams.append(stream_2p)
    if len(job_settings.behavior_camera_names)>0:
        stream_facecameras = Stream(stream_start_time = stream_start_time,#calculate - specify timezone
                           stream_end_time = stream_end_time,#calculate
                           camera_names = job_settings.behavior_camera_names, # from rig json
                           stream_modalities = [Modality.BEHAVIOR_VIDEOS])
        streams.append(stream_facecameras)

    stim_epoch_behavior = StimulusEpoch(stimulus_start_time = stream_start_time,#datetime#basenames are separate
                                       stimulus_end_time = stream_end_time,#datetime, 
                                       stimulus_name = job_settings.behavior_task_name,# user defined in script
                                        software = [behavior_software],
                                        script = pybpod_script,
                                        stimulus_modalities = [StimulusModality.AUDITORY],#,StimulusModality.TACTILE],# tactile not in this version yet
                                        stimulus_parameters = [],# opticalBCI class to be added in future
                                        stimulus_device_names = [],#from json file, to be added (speaker, bpod ID, )
                                        output_parameters = {},#hit rate, time to reward, ...?
                                        trials_total = len(behavior_file_info_now[1][1][0]), 
                                       # trials_rewarded = ,  # not using BPOD info yet
                                       )
    stim_epochs.append(stim_epoch_behavior)


for photo_stim_file_info_now in photo_stim_file_info: #  2P + behavior + behavior video STREAM DURING BEHAVIOR
    tiff_header = photo_stim_file_info_now[1][0].reader_metadata_header
    last_frame_description = photo_stim_file_info_now[1][0].reader_descriptions[-1]
    ###########################################THIS THING REPEATS FOR EVERY STREAM########################################
    
    z_list = np.asarray(tiff_header['hStackManager']['zs'].strip('[]').split(' '),float)
    z_start = np.min(z_list)-np.median(z_list) + job_settings.fov_imaging_depth
    z_end = np.max(z_list)-np.median(z_list) + job_settings.fov_imaging_depth
    z_step = float(tiff_header['hStackManager']['stackZStepSize'])
    channel_nums = np.asarray(tiff_header['hChannels']['channelSave'].strip('[]').split(' ') ,int)
    daq_names = []
    for channel_num in channel_nums: daq_names.append(channel_dict[channel_num]['daq_name'])
    channels = []
    start_time_corrected = last_frame_description['epoch'].strip('[]').replace('  ',' 0').split(' ')
    start_time_corrected = ' '.join(start_time_corrected[:-1] + [str(int(np.floor(float(start_time_corrected[-1])))).zfill(2),str(int(1000000*(float(start_time_corrected[-1])%1))).zfill(6)])
    stream_start_time = datetime.strptime(start_time_corrected,'%Y %m %d %H %M %S %f')
    stream_start_time = stream_start_time.replace(tzinfo=job_settings.timezone)
    stream_end_time = stream_start_time+timedelta(seconds = float(last_frame_description['frameTimestamps_sec']))
    ###########################################THIS THING REPEATS FOR EVERY STREAM########################################   
    all_stream_start_times.append(stream_start_time)
    all_stream_end_times.append(stream_end_time)
    fov_2p = FieldOfView(index = 0,# multi-plane will have multiple - in a list
                         imaging_depth = job_settings.fov_imaging_depth, # in microns
                         fov_coordinate_ml = job_settings.fov_coordinate_ml,
                         fov_coordinate_ap = job_settings.fov_coordinate_ap,
                         fov_reference = 'there is no reference',
                         fov_width = int(tiff_header['hRoiManager']['pixelsPerLine']),
                         fov_height = int(tiff_header['hRoiManager']['linesPerFrame']),
                         magnification = str(tiff_header['hRoiManager']['scanZoomFactor']),
                         fov_scale_factor = (FOV_1x_micron/float(tiff_header['hRoiManager']['scanZoomFactor']))/float(tiff_header['hRoiManager']['linesPerFrame']), #microns per pixel
                         frame_rate = float(tiff_header['hRoiManager']['scanFrameRate']),
                         targeted_structure = job_settings.fov_targeted_structure, 
                        )
    stream_2p = Stream(stream_start_time = stream_start_time,#calculate - specify timezone # each basename is a separate stream
                       stream_end_time = stream_end_time,#calculate
                       daq_names = daq_names,# from the rig json
                       light_sources = [LaserConfig(name = job_settings.imaging_laser_name,#from rig json
                                                    wavelength = job_settings.imaging_laser_wavelength,# user set value
                                                    excitation_power = np.asarray(tiff_header['hBeams']['powers'].strip('[]').split(' '),float)[laser_dict['imaging_laser']['power_index']],# from tiff header,
                                                    excitation_power_unit = PowerUnit.PERCENT)],
                       ophys_fovs = [fov_2p], # multiple planes come here
                       stream_modalities = [Modality.POPHYS],
                      )
    streams.append(stream_2p)
    
    ####
    
    
    photostim_groups = []
    group_order = np.asarray(tiff_header['hPhotostim']['sequenceSelectedStimuli'].strip('[]').split(' ')*100,int)-1
    num_total_repetitions = len(photo_stim_file_info_now[1][1][0])
    group_order = group_order[:num_total_repetitions]
    group_powers = []
    for photostim_group_i, photostim_group in enumerate(photo_stim_file_info_now[1][0].reader_metadata_json['RoiGroups']['photostimRoiGroups']):
        number_of_neurons = int(np.array(photostim_group["rois"][1]["scanfields"]["slmPattern"]).shape[0])
        stimulation_laser_power = Decimal(str(photostim_group["rois"][1]["scanfields"]["powers"]))
        number_spirals = int(photostim_group["rois"][1]["scanfields"]["repetitions"])
        spiral_duration = Decimal(str(photostim_group["rois"][1]["scanfields"]["duration"]))
        inter_spiral_interval = Decimal(str(photostim_group["rois"][2]["scanfields"]["duration"]+photostim_group["rois"][0]["scanfields"]["duration"]))

        number_of_trials = sum(group_order==photostim_group_i)
        photostim_groups.append(PhotoStimulationGroup(
                                group_index=photostim_group_i+1,
                                number_of_neurons=number_of_neurons,
                                stimulation_laser_power=stimulation_laser_power,
                                stimulation_laser_power_unit=PowerUnit.PERCENT,
                                number_trials=number_of_trials,
                                number_spirals=number_spirals,
                                spiral_duration=spiral_duration,
                                inter_spiral_interval=inter_spiral_interval,
                            ))
        group_powers.append(stimulation_laser_power)
        
   
    

    
    photostim = PhotoStimulation(stimulus_name = '2p photostimulation',
                                number_groups = len(photostim_groups),#tiff header
                                groups = photostim_groups,
                                inter_trial_interval = Decimal(float(photo_stim_file_info_now[1][2].reader_descriptions[-1]['nextFileMarkerTimestamps_sec'])))# from Jon's script - seconds
      

    stim_epoch_photostim = StimulusEpoch(stimulus_start_time = stream_start_time,
                                        stimulus_end_time = stream_end_time,#datetime, 
                                        stimulus_name = '2p photostimulation',# user defined in script
                                        software = [photostim_software],
                                        stimulus_modalities = [StimulusModality.OPTOGENETICS],
                                        stimulus_parameters = [photostim],# opticalBCI class to be added in future
                                        stimulus_device_names = [],#from json file, to be added ()
                                        light_source_config = LaserConfig(name = job_settings.photostim_laser_name,#from rig json
                                                                        wavelength = job_settings.photostim_laser_wavelength,# user set value
                                                                        excitation_power = np.nanmean(group_powers),# from tiff header,
                                                                        excitation_power_unit = PowerUnit.PERCENT),
                                                                           )
    stim_epochs.append(stim_epoch_photostim)
s =Session(experimenter_full_name = job_settings.experimenter_full_name, #user added
           session_start_time = min(all_stream_start_times),
           session_end_time = max(all_stream_end_times),
           session_type = job_settings.session_type, # user added
           iacuc_protocol = job_settings.iacuc_protocol,#user added
           rig_id = job_settings.rig_id,#from rig json
           # calibrations = [Calibration(calibration_date = ,
           #                            device_name = '',#from rig json)
           #                             description = 'laser calibration',
           #                             input ={'power_percent':[]},
           #                             output = {'power_mW':[]})],
           subject_id = job_settings.subject_id,#user added
           reward_delivery = reward_delivery,
           data_streams = streams,
           mouse_platform_name = job_settings.mouse_platform_name,#from rig json
           active_mouse_platform = job_settings.active_mouse_platform,
           stimulus_epochs = stim_epochs,
           notes = job_settings.notes,#user added
          )


In [312]:
serialized = s.model_dump_json()
deserialized = Session.model_validate_json(serialized)
deserialized.write_standard_file(prefix="ophys")

In [317]:
import json
with open('/home/jupyter/temp/session.json', 'w', encoding='utf-8') as f:
    json.dump(serialized, f, indent=4)

In [318]:
s.write_standard_file('/home/jupyter/temp/')