# Converting MICrONS to NWB

This notebook converts the 2p data from [the MICrONS dataset](https://www.microns-explorer.org/) to NWB.

## Setup
The notebook needs to be run on https://codebook.datajoint.io/ using the "IARPA MICrONS Program" Sever Option


## Progress
- [x] Eye tracking
    - [x] Minor and major radius
    - [x] Eye position (x, y)
- [x] Treadmill velocity
- [x] Trials
    - [x] Start and stop times
    - [x] Stimulus type
    - [x] Condition hash
- [x] ROI masks
- [x] Fluorescence traces
- [x] Summary images
- [x] Microscope metadata
- [ ] Raw 2p data
- [x] Stimulus movie (need more RAM)
- [ ] Mapping to EM data
- [ ] Subject and session metadata
    - [ ] datetime of session
    - [ ] age, sex, and ID of mouse

In [None]:
!pip install neuroconv[ophys]

In [None]:
from copy import deepcopy
from tqdm import tqdm
from datetime import datetime
from dateutil import tz
import numpy as np

from neuroconv.utils import dict_deep_update
from neuroconv.tools.roiextractors.roiextractors import get_default_ophys_metadata
from neuroconv.tools.nwb_helpers import (
    get_module,
    make_or_load_nwbfile,
)
from neuroconv.tools.roiextractors import add_image_segmentation, add_devices

from pynwb.ophys import (
    RoiResponseSeries,
    Fluorescence,
    PlaneSegmentation,
    OpticalChannel,
    ImagingPlane,
)
from pynwb.image import OpticalSeries, GrayscaleImage
from hdmf.backends.hdf5.h5tools import H5DataIO
from pynwb.behavior import PupilTracking, EyeTracking, SpatialSeries
from pynwb.base import TimeSeries, Images

import datajoint as dj
from phase3 import nda, func, utils


def add_stimulus(scan_key, nwb):

    timestamps = (nda.FrameTimes() & scan_key).fetch1('frame_times')
    movie = (nda.Stimulus & scan_key).fetch1('movie')
    # Transpose movie to first dimension be time
    movie = movie.transpose(2, 0, 1)

    # Create optical series
    optical_series_kwargs = dict(
        name="visual_stimulus",
        data=H5DataIO(movie, compression=True),
        timestamps=H5DataIO(timestamps, compression=True),
        unit="n.a.",
        description="Stimulus movie",
        distance=np.nan,  # unknown
        orientation="0 is up",
        field_of_view=[np.nan, np.nan],
    )
    optical_series = OpticalSeries(**optical_series_kwargs)
    nwb.add_stimulus(optical_series)


def add_eye_tracking(scan_key, nwb):

    pupil_x, pupil_y, timestamps = (
        nda.RawManualPupil() & scan_key
    ).fetch1("pupil_x", "pupil_y", "pupil_times")

    spatial_series_kwargs = dict(
        name="eye_position",
        data=H5DataIO(np.c_[pupil_x, pupil_y], compression=True),
        timestamps=H5DataIO(timestamps, compression=True),
        unit="n.a.",
        description="x,y position of eye",
        reference_frame="unknown",
    )
    eye_position = SpatialSeries(**spatial_series_kwargs)

    eye_tracking = EyeTracking(eye_position)
    nwb.add_acquisition(eye_tracking)


def add_pupil_tracking(scan_key, nwb):

    pupil_min_r, pupil_maj_r, timestamps = (
            nda.RawManualPupil() & scan_key
    ).fetch1("pupil_min_r", "pupil_maj_r", "pupil_times")

    pupil_min_r = TimeSeries(
        name="pupil_min_r",
        description="minor axis of pupil tracking ellipse",
        data=H5DataIO(pupil_min_r, compression=True),
        timestamps=H5DataIO(timestamps, compression=True),
        unit="n.a.",
    )

    pupil_maj_r = TimeSeries(
        name="pupil_maj_r",
        description="major axis of pupil tracking ellipse",
        data=H5DataIO(pupil_maj_r, compression=True),
        timestamps=H5DataIO(timestamps, compression=True),
        unit="n.a.",
    )

    pupil_tracking = PupilTracking([pupil_min_r, pupil_maj_r])
    nwb.add_acquisition(pupil_tracking)


def add_treadmill(scan_key, nwb):

    treadmill_velocity, treadmill_timestamps = (
        nda.RawTreadmill & scan_key
    ).fetch1("treadmill_velocity", "treadmill_timestamps")

    time_series_kwargs = dict(
        name="treadmill_velocity",
        data=H5DataIO(treadmill_velocity, compression=True),
        timestamps=H5DataIO(treadmill_timestamps, compression=True),
        unit="n.a.",
        description="velocity of treadmiszonjaweiglll",
    )
    treadmill_velocity = TimeSeries(**time_series_kwargs)
    nwb.add_acquisition(treadmill_velocity)


def add_trials(scan_key, nwb):

    nwb.add_trial_column("condition_hash", "condition hash")
    nwb.add_trial_column("stimulus_type", "stimulus type")

    stimulus_types, start_times, stop_times, condition_hashes, trial_idxs = (
        nda.Trial & scan_key
    ).fetch("type", "start_frame_time", "end_frame_time", "condition_hash", "trial_idx", order_by="trial_idx")

    for stimulus_type, start_time, stop_time, condition_hash, trial_idx in zip(
        stimulus_types, start_times, stop_times, condition_hashes, trial_idxs
    ):
        nwb.add_trial(
            id=trial_idx,
            start_time=start_time,
            stop_time=stop_time,
            stimulus_type=stimulus_type,
            condition_hash=condition_hash,
        )


def add_summary_images(field_key, nwb):
    ophys = get_module(nwb, "ophys")

    correlation, average = (nda.SummaryImages & field_key).fetch1("correlation", "average")
    correlation_image = GrayscaleImage(
        name="correlation",
        data=H5DataIO(correlation, compression=True),
    )
    average_image = GrayscaleImage(
        name="average",
        data=H5DataIO(average, compression=True),
    )

    segmentation_images = Images(
    name=f"SegmentationImages{field_key['field']}",
    images=[correlation_image, average_image],
    description=f"Summary images for field {field_key['field']}.",
    )
    ophys.add(segmentation_images)


def get_ophys_metadata(scan_key):

    fields_for_this_scan = (nda.Field & scan_key).fetch(as_dict=True)

    metadata = get_default_ophys_metadata()

    metadata["Ophys"]["Device"][0].update(description="two-photon random access mesoscope",)

    optical_channel_metadata = dict(
        name="OpticalChannel",
        emission_lambda=500.,
        description="An optical channel.",
    )
    imaging_plane_metadata = metadata["Ophys"]["ImagingPlane"]
    plane_segmentation_metadata = metadata["Ophys"]["ImageSegmentation"]["plane_segmentations"]


    imaging_plane_metadata[0].update(
        optical_channel=[optical_channel_metadata],
        excitation_lambda=920.,
        indicator="GCaMP6",
        location="unknown",
    )

    field_name = fields_for_this_scan[0]["field"]
    plane_segmentation_metadata[0].update(description=f"Segmented ROIs for field {field_name}.")

    for num_field in range(len(fields_for_this_scan) - 1):
        imaging_plane = deepcopy(imaging_plane_metadata[0])
        imaging_plane.update(name=f"ImagingPlane{num_field + 1}")
        imaging_plane_metadata.append(imaging_plane)

        field_name = fields_for_this_scan[num_field + 1]["field"]
        plane_segmentation = deepcopy(plane_segmentation_metadata[0])
        plane_segmentation.update(
            name=f"PlaneSegmentation{num_field + 1}",
            description=f"The output from segmenting field {field_name}.",
        )
        plane_segmentation_metadata.append(plane_segmentation)

    return metadata


def add_imaging_plane(scan_key, nwb, metadata, imaging_plane_index):

    metadata_copy = deepcopy(metadata)
    imaging_plane_metadata = metadata_copy["Ophys"]["ImagingPlane"][imaging_plane_index]

    add_devices(nwbfile=nwb, metadata=metadata)
    device_name = imaging_plane_metadata["device"]


    optical_channel_metadata = imaging_plane_metadata["optical_channel"][0]

    optical_channel = OpticalChannel(**optical_channel_metadata)

    imaging_plane_name = imaging_plane_metadata["name"]
    if imaging_plane_name in nwb.imaging_planes:
        return

    field_data = (nda.Field & scan_key).fetch(as_dict=True)[imaging_plane_index]
    imaging_plane_metadata.update(
        device=nwb.devices[device_name],
        optical_channel=[optical_channel],
        grid_spacing=[field_data["um_width"]/field_data["px_width"]*1e-6,field_data["um_height"]/field_data["px_height"]] ,
        grid_spacing_unit="meters",
        origin_coords=[field_data["field_x"], field_data["field_y"], field_data["field_z"]],
        origin_coords_unit="meters",
    )

    imaging_plane = ImagingPlane(**imaging_plane_metadata)
    nwb.add_imaging_plane(imaging_plane)


def add_plane_segmentation(field_key, nwb, metadata, plane_segmentation_index):
    metadata_copy = deepcopy(metadata)
    # Create or retrieve processing module
    ophys = get_module(nwb, "ophys")

    # Create or retrieve the image segmentation where this plane segmentation will be added
    add_image_segmentation(nwbfile=nwb, metadata=metadata_copy)
    image_segmentation_metadata = metadata_copy["Ophys"]["ImageSegmentation"]
    image_segmentation_name =  image_segmentation_metadata["name"]
    image_segmentation = ophys.get_data_interface(image_segmentation_name)

    # Early return if a plane segmentation with the same name is already added to NWB
    plane_segmentation_metadata = image_segmentation_metadata["plane_segmentations"][plane_segmentation_index]
    plane_segmentation_name = plane_segmentation_metadata["name"]
    if plane_segmentation_name in image_segmentation.plane_segmentations:
        return

    # Create or retrieve the imaging plane for this plane segmentation
    add_imaging_plane(
        scan_key=scan_key,
        nwb=nwb,
        imaging_plane_index=plane_segmentation_index,
        metadata=metadata_copy,
    )

    imaging_plane_metadata = metadata_copy["Ophys"]["ImagingPlane"][plane_segmentation_index]
    imaging_plane_name = imaging_plane_metadata["name"]
    imaging_plane = nwb.imaging_planes[imaging_plane_name]

    # Query the image height and width
    image_height, image_width = (nda.Field & field_key).fetch1("px_height", "px_width")

    # Query the image masks and image types
    mask_pixels, mask_weights, mask_ids, mask_types = (
                nda.Segmentation * nda.MaskClassification & field_key).fetch(
        "pixels", "weights", "mask_id", "mask_type", order_by="mask_id"
    )

    plane_segmentation_metadata.update(
        id=mask_ids,
        imaging_plane=imaging_plane,
    )

    # Create the PlaneSegmentation object for the ROIs
    plane_segmentation = PlaneSegmentation(**plane_segmentation_metadata)

    # Reshape image masks
    masks = func.reshape_masks(mask_pixels, mask_weights, image_height, image_width)
    # Transpose image masks to have the first dimension time
    masks = masks.transpose((2, 0, 1))

    # Add image masks
    plane_segmentation.add_column(
        name="image_mask",
        description="Image masks for each ROI.",
        data=H5DataIO(masks, compression=True),
    )

    # Add type of ROIs
    plane_segmentation.add_column(
        name="mask_type",
        description="type of ROI",
        data=mask_types.astype(str),
    )

    image_segmentation.add_plane_segmentation(plane_segmentations=plane_segmentation)


def _get_fluorescence_data_interface(nwb, fluorescence_name):
    ophys = get_module(nwbfile=nwb, name="ophys")

    if fluorescence_name in ophys.data_interfaces:
        return ophys.get(fluorescence_name)

    fluorescence = Fluorescence(name=fluorescence_name)
    ophys.add(fluorescence)

    return fluorescence


def add_fluorescence_traces(field_key, nwb, metadata, timestamps, plane_segmentation_index):
    metadata_copy = deepcopy(metadata)

    add_plane_segmentation(field_key, nwb, metadata, plane_segmentation_index)

    plane_segmentation_metadata = metadata["Ophys"]["ImageSegmentation"]["plane_segmentations"][plane_segmentation_index]
    plane_segmentation_name = plane_segmentation_metadata["name"]
    image_segmentation_name = metadata_copy["Ophys"]["ImageSegmentation"]["name"]

    ophys = get_module(nwb, "ophys")
    image_segmentation = ophys.get_data_interface(image_segmentation_name)
    plane_segmentation = image_segmentation.plane_segmentations[plane_segmentation_name]

    data = np.vstack((nda.Fluorescence() & field_key).fetch("trace", order_by="mask_id")).T

    rt_region = plane_segmentation.create_roi_table_region(
        region=list(range(data.shape[1])),
        description=f"all rois in field {field_key['field']}"
    )

    roi_response_series_name = "RioResponseSeries" if plane_segmentation_index == 0 else f"RioResponseSeries{field_key['field']}"

    roi_response_series = RoiResponseSeries(
        name=roi_response_series_name,
        description=f"traces for field {field_key['field']}",
        data=H5DataIO(data, compression=True),
        rois=rt_region,
        timestamps=H5DataIO(timestamps, compression=True),
        unit="n.a.",
    )

    # Add fluorescence traces
    fluorescence_metadata = metadata_copy["Ophys"]["Fluorescence"]
    fluorescence_name = fluorescence_metadata["name"]
    fluorescence = _get_fluorescence_data_interface(nwb=nwb, fluorescence_name=fluorescence_name)

    fluorescence.add_roi_response_series(roi_response_series)


def convert_microns_to_nwb(nwbfile_path, scan_key, metadata):
    metadata_copy = deepcopy(metadata)

    ophys_metadata = get_ophys_metadata(scan_key)

    metadata = dict_deep_update(ophys_metadata, metadata_copy)
    metadata["NWBFile"].update(session_id=str(scan_key["session"]),)

    with make_or_load_nwbfile(
        nwbfile_path=nwbfile_path, nwbfile=None, metadata=metadata, overwrite=True, verbose=True) as nwbfile_out:

        # Add trials to
        add_trials(scan_key, nwbfile_out)

        # Add treadmill velocity
        add_treadmill(scan_key, nwbfile_out)

        # Add eye position
        add_eye_tracking(scan_key, nwbfile_out)

        # Add pupil tracking
        add_pupil_tracking(scan_key, nwbfile_out)

        # Add stimulus movie
        add_stimulus(scan_key, nwbfile_out)

        # Add device
        add_devices(nwbfile=nwbfile_out, metadata=metadata)

        all_field_data = (nda.Field & scan_key).fetch(as_dict=True)
        frame_times = (nda.FrameTimes & scan_key).fetch1('frame_times')

        for field_ind, field_data in tqdm(
            iterable=enumerate(all_field_data),
            desc=f"Writing segmentation data ...",
            position=0,
            total=len(all_field_data),
            mininterval=10,
        ):
            field_key = {**scan_key, **dict(field=field_data["field"])}
            # Add segmentation data (image masks, fluorescence traces)
            add_fluorescence_traces(field_key, nwbfile_out, metadata, frame_times, field_ind)
            # Add segmentation images
            add_summary_images(field_key, nwbfile_out)
