# Converting MICrONS to NWB

This notebook converts the 2p data from [the MICrONS dataset](https://www.microns-explorer.org/) to NWB.

## Setup
The notebook needs to be run on https://codebook.datajoint.io/ using the "IARPA MICrONS Program" Sever Option


## Progress
- [x] Eye tracking
    - [x] Minor and major radius
    - [x] Eye position (x, y)
- [x] Treadmill velocity
- [x] Trials
    - [x] Start and stop times
    - [x] Stimulus type
    - [x] Condition hash
- [x] ROI masks
- [x] Fluorescence traces
- [x] Summary images
- [x] Microscope metadata
- [ ] Raw 2p data
- [ ] Stimulus movie (need more RAM)
- [ ] Mapping to EM data
- [ ] Subject and session metadata
    - [ ] datetime of session
    - [ ] age, sex, and ID of mouse

In [None]:
!pip install pynwb

In [None]:
import datajoint as dj
from phase3 import nda, func, utils
import matplotlib.pyplot as plt

from pynwb.file import NWBFile, Subject
from pynwb.ophys import RoiResponseSeries, Fluorescence, ImageSegmentation, OpticalChannel
from pynwb.device import Device
from pynwb.image import OpticalSeries
from hdmf.backends.hdf5.h5tools import H5DataIO
from pynwb.behavior import PupilTracking, EyeTracking, SpatialSeries
from pynwb.base import TimeSeries
import numpy as np
from uuid import uuid4
from datetime import datetime


def check_module(nwb, name, description=None):
    if name in nwb.processing:
        return nwb.processing[name]
    else:
        return nwb.create_processing_module(name, description or name)
    

def start_nwb(scan_key):
    nwb = NWBFile(
        identifier=str(uuid4()),
        session_description="unknown",
        subject=Subject(subject_id="001", species="Mus musculus", age="P75D/P81D", sex="M"),
        session_start_time=datetime(1900, 1, 1),
        session_id=str(scan_key["session"]),
    )
    return nwb

    
def add_stimulus(scan_key, nwb):
    timestamps = (nda.FrameTimes() & scan_key).fetch1('frame_times') # timestamps of stimulus images
    movie = (nda.Stimulus & scan_key).fetch1('movie')
    optical_series = OpticalSeries(
        name="visual stimulus",
        distance=np.nan,  # unknown
        field_of_view=[np.nan, np.nan],
        orientation="0 is up",
        data=H5DataIO(movie.transpose(2, 0, 1), compression=True),
        timestamps=H5DataIO(timestamps, compression=True),
        unit="n.a.",
    )
    
    nwb.add_stimulus(optical_series)
    
    
def add_eye_tracking(scan_key, nwb):
    pupil_min_r, pupil_maj_r, pupil_x, pupil_y, timestamps = (
        nda.RawManualPupil() & scan_key
    ).fetch1("pupil_min_r", "pupil_maj_r", "pupil_x", "pupil_y", "pupil_times")
    
    pupil_min_r = TimeSeries(
        name="pupil_min_r",
        description="minor axis of pupil tracking ellipse",
        data=H5DataIO(pupil_min_r, compression=True),
        timestamps=H5DataIO(timestamps, compression=True),
        unit="unknown",
    )
    
    pupil_maj_r = TimeSeries(
        name="pupil_maj_r",
        description="jajor axis of pupil tracking ellipse",
        data=H5DataIO(pupil_maj_r, compression=True),
        timestamps=pupil_min_r,
        unit="unknown",
    )
    
    pupil_tracking = PupilTracking([pupil_min_r, pupil_maj_r])
    nwb.add_acquisition(pupil_tracking)
    
    pupil_xy = SpatialSeries(
        name = "eye_position",
        description="x,y position of eye",
        data=H5DataIO(np.c_[pupil_x, pupil_y], compression=True),
        timestamps=pupil_min_r,
        unit="unknown",
        reference_frame="unknown",
    )
    
    eye_position_tracking = EyeTracking(pupil_xy)
    
    nwb.add_acquisition(eye_position_tracking)
    
    
def add_treadmill(scan_key, nwb):
    treadmill_velocity, treadmill_timestamps = (
        nda.RawTreadmill & scan_key
    ).fetch1("treadmill_velocity", "treadmill_timestamps")
    
    treadmill_velocity = TimeSeries(
        name="treadmill_velocity",
        data=H5DataIO(treadmill_velocity, compression=True),
        timestamps=H5DataIO(treadmill_timestamps, compression=True),
        description="velocity of treadmill",
        unit="unknown",
    )
    
    nwb.add_acquisition(treadmill_velocity)
    
    
def add_summary_images(field_key, nwb):
    ophys = nwb.create_processing_module("ophys", "processed 2p data")
    
    correlation, average = (
        nda.SummaryImages & field_key
    ).fetch1("correlation", "average")
    
    correlation_image = GrayscaleImage(correlation)
    average_image = GrayscaleImage(average)
    
    ophys.add(correlation_image)
    ophys.add(average_image)


def add_trials(scan_key, nwb):
    
    nwb.add_trial_column("condition_hash", "condition hash")
    nwb.add_trial_column("stimulus_type", "stimulus type")
    
    
    stimulus_types, start_times, stop_times, condition_hashes, trial_idxs = (
        nda.Trial & scan_key
    ).fetch("type", "start_frame_time", "end_frame_time", "condition_hash", "trial_idx", order_by="trial_idx")
    
    for stimulus_type, start_time, stop_time, condition_hash, trial_idx in zip(
        stimulus_types, start_times, stop_times, condition_hashes, trial_idxs
    ):
        nwb.add_trial(
            id=trial_idx,
            start_time=start_time,
            stop_time=stop_time,
            stimulus_type=stimulus_type,
            condition_hash=condition_hash,
        )

        
def add_plane_segmentation(field_key, nwb, imaging_plane, image_segmentation):

    ps = image_segmentation.create_plane_segmentation(
        name=f"PlaneSegmentation{field_key['field']}",
        description='output from segmenting my favorite imaging plane',
        imaging_plane=imaging_plane,
    )
    ps.add_column("mask_type", "type of ROI")

    image_height, image_width = (nda.Field & field_key).fetch1("px_height", "px_width")
        
    mask_pixels, mask_weights, mask_ids, mask_types = (nda.Segmentation * nda.MaskClassification & field_key).fetch(
        "pixels", "weights", "mask_id", "mask_type", order_by="mask_id"
    )

    # Reshape masks
    masks = func.reshape_masks(mask_pixels, mask_weights, image_height, image_width)
    
    for image_mask, mask_id, mask_type in zip(masks, mask_ids, mask_types):
        ps.add_roi(
            image_mask=image_mask,
            id=mask_id,
            mask_type=mask_type,
        )
        
    return ps

        
def add_roi_response_series(field_key, nwb, plane_segmentation):
    
    frame_times = (nda.FrameTimes & scan_key).fetch1('frame_times')
    
    data = np.vstack((nda.Fluorescence() & field_key).fetch("trace", order_by="mask_id")).T
    
    rt_region = plane_segmentation.create_roi_table_region(
        region=list(range(data.shape[1])),
        description=f"all rois in field {field_key['field']}"
    )
    
    roi_response_series = RoiResponseSeries(
        name=f"RioResponseSeries{field_key['field']}",
        description=f"traces for field {field_key['field']}",
        data=H5DataIO(data, compression=True),
        rois=re_region,
        timestamps=H5Data(frame_times, compression=True),
        units="n/a",
    )
    
    fluorescence = Fluorescence()
    fluorescence.add(roi_response_series)
        
        
def add_ophys(scan_key, nwb):
    device = nwb.create_device(
        name="Microscope",
        description="two-photon random access mesoscope",
    )
    ophys = check_module(nwb, "ophys")
    image_segmentation = ImageSegmentation()
    ophys.add(image_segmentation)
    all_field_data = (nda.Field & scan_key).fetch(as_dict=True)
    for field_data in all_field_data:
        optical_channel = OpticalChannel(
            name="OpticalChannel",
            description="an optical channel",
            emission_lambda=500.,
        )
        imaging_plane = nwb.create_imaging_plane(
            name=f"ImagingPlane{field_data['field']}",
            optical_channel=optical_channel,
            imaging_rate=np.nan,
            description="no description",
            device=device,
            excitation_lambda=920.,
            indicator="GCaMP6",
            location="unknown",
            grid_spacing=[field_data["um_width"]/field_data["px_width"]*1e-6, field_data["um_height"]/field_data["px_height"]] ,
            grid_spacing_unit="meters",
            origin_coords=[field_data["field_x"], field_data["field_y"], field_data["field_z"]],
            origin_coords_unit="meters",
        )
        
                
        field_key = {**scan_key, **dict(field=field_data["field"])}
        
        plane_segmentation = add_plane_segmentation(field_key, nwb, imaging_plane, image_segmentation)
        add_roi_response_series(field_key, nwb, plane_segmentation)
        add_summary_images(field_key, nwb)
        
def build_nwb(scan_key):
    nwb = start_nwb(scan_key)
    #add_stimulus(scan_key, nwb)
    add_eye_tracking(scan_key, nwb)
    add_treadmill(scan_key, nwb)
    add_trials(scan_key, nwb)
    add_ophys(scan_key, nwb)
    
    return nwb
    
        
# def full_conversion():
#     nwb = build_nwb(scan_key)
        
        

In [None]:
scan_key = {'session': 4, 'scan_idx': 7} 

nwbfile = build_nwb(scan_key)

In [None]:
scan_key = {'session': 4, 'scan_idx': 7} 


In [None]:
nda.Scan()

In [None]:
nda.Trial()

In [None]:
field_data = dict(field=1)

field_key = {**scan_key, **dict(field=field_data["field"])}

nda.Fluorescence() & field_key

In [None]:
func.reshape_masks??

In [None]:
image_height, image_width = (nda.Field & field_key).fetch1(
        "px_height", "px_width"
    )

In [None]:
image_height

In [None]:
nda.MaskClassification()

In [None]:
func.get_all_masks??

In [None]:
scan_key = {'session': 4, 'scan_idx': 7} 

(nda.Field & scan_key).fetch(as_dict=True)

In [None]:
np.c_[pupil_x, pupil_y]

In [None]:
pupil_min_r, pupil_maj_r, pupil_x, pupil_y = (
    nda.RawManualPupil() & scan_key
).fetch("pupil_min_r", "pupil_maj_r", "pupil_x", "pupil_y")

In [None]:
nda.RawManualPupil()

In [None]:
len(pupil_maj_r[0])

In [None]:
movie_times = (nda.FrameTimes() & scan_key).fetch1('frame_times') # timestamps of stimulus images
movie_times

In [None]:
plt.imshow(movie[...,3400])