In [1]:
import argparse
import h5py
import logging
import os

import numpy as np
import pandas as pd


from temporaldata import (
    Data,
    RegularTimeSeries,
    IrregularTimeSeries,
    Interval,
    ArrayDict,
)
from brainsets.descriptions import (
    BrainsetDescription,
    SessionDescription,
    SubjectDescription,
    DeviceDescription,
)
from brainsets.taxonomy import Species, Sex, Cre_line, RecordingTech
from brainsets.taxonomy.mice import BrainRegion
from brainsets.taxonomy.allen import (
    TEMPORAL_FREQ_5_map,
    SPATIAL_FREQ_5_map,
    ORIENTATION_12_CLASSES_map,
    ORIENTATION_8_CLASSES_map,
    PHASE_4_map,
)
from brainsets import serialize_fn_map

from experanto.experiment import Experiment
from experanto.configs import DEFAULT_MODALITY_CONFIG, DEFAULT_CONFIG

logging.basicConfig(level=logging.INFO)
from datetime import datetime
from experanto.dataloaders import get_multisession_dataloader
from collections import Counter
import json



In [2]:
def get_valid_screen_intervals(valid_screen_times, to_add, start_add=0):
    start_intervals = []
    end_intervals = []
    start = valid_screen_times[0] + start_add
    for i in range(1, valid_screen_times.shape[0]):
        prev = valid_screen_times[i-1]
        if valid_screen_times[i] - valid_screen_times[i-1] > 0.33 * 1.5:
            assert valid_screen_times[i] - valid_screen_times[i - 1] > to_add, 'wrong gap between two intervals'
            start_intervals.append(start)
            start = valid_screen_times[i] + start_add
            end_intervals.append(valid_screen_times[i-1] + to_add)
    assert len(start_intervals) == len(end_intervals), 'intervals length does not match'
    start_intervals.append(start)
    end_intervals.append(valid_screen_times[-1] + to_add)
    # final_intervals = []
    # for i in range(len(end_intervals)):
    #     final_intervals.append(Interval(start_intervals[i], end_intervals[i]))
    return Interval(start=np.asarray(start_intervals), end=np.asarray(end_intervals))

In [3]:
def get_intervals_for_tier(experanto_base_folder, mouse_id_folder, calcium_traces, tier, DEFAULT_CONFIG):
    DEFAULT_CONFIG['dataset']['modality_config']['screen']['valid_condition']['tier'] = tier
    dl = get_multisession_dataloader(paths=[f'{experanto_base_folder}/{mouse_id_folder}'],
                        configs=DEFAULT_CONFIG,
                        shuffle_keys = False)
    assert len(dl.loaders) == 1, 'too long dataloader'
    kk = list(dl.loaders.keys())[0]
    assert dl.loaders[kk].dataset._valid_screen_times.min() >= calcium_traces.domain.start, "calcium traces start after the visual stimuli"
    to_add = DEFAULT_MODALITY_CONFIG['screen']['chunk_size'] / DEFAULT_MODALITY_CONFIG['screen']['sampling_rate']
    assert calcium_traces.domain.end >= (dl.loaders[kk].dataset._valid_screen_times.max() + to_add), 'calcium traces end before visual stimuli'
    # todo - this is hardcoded based on the interpolation period we had
    if tier == 'train':
        start_add = 0
    else:
        start_add = 5 / DEFAULT_MODALITY_CONFIG['screen']['sampling_rate']
    return get_valid_screen_intervals(dl.loaders[kk].dataset._valid_screen_times, to_add, start_add)
    

In [4]:
def get_fs(experiment, dev_name):
        return experiment.modality_config[dev_name].get(
            "sampling_rate", getattr(experiment.devices[dev_name], "sampling_rate", None)
        )

In [5]:
def extract_calcium_traces(experiment):
    traces = experiment.devices['responses']._data
    experiment.devices['responses'].normalize_init()
    traces = experiment.devices['responses'].normalize_data(traces)
    fs = experiment.modality_config["responses"].get("sampling_rate")
    if fs is None:
        # Fallback: check if the device object happens to store it
        fs = getattr(
            experiment.devices['responses'], "sampling_rate", 30.0
        )  # Default to 30Hz (common for Ca) or 30000Hz (ephys)

    start_time = experiment.devices['responses'].start_time
    n_samples, _ = traces.shape
    timestamps = start_time + np.arange(n_samples) / fs

    # todo - make sure it should be time by neurons
    calcium_traces = RegularTimeSeries(
        df_over_f=np.array(traces),
        sampling_rate=fs,
        domain="auto",
        domain_start=timestamps[0],
    )

    return calcium_traces

def extract_units(experanto_base_folder, mouse_id_folder):
    units = np.load(f'{experanto_base_folder}/{mouse_id_folder}/responses/meta/unit_ids.npy')
    cell_motor_coordinates = np.load(f'{experanto_base_folder}/{mouse_id_folder}/responses/meta/cell_motor_coordinates.npy')
    units = ArrayDict(
        id=units.astype(str),
        imaging_plane_xy=cell_motor_coordinates[:, :2],
        imaging_plane_height=cell_motor_coordinates[:, 2],
    )
    return units

In [6]:
def extract_running_speed(experiment):
    logging.info("Processing Treadmill...")
    running_speed = experiment.devices["treadmill"]._data
    experiment.devices["treadmill"].normalize_init()
    running_speed = experiment.devices["treadmill"].normalize_data(running_speed)
    tm_fs = get_fs(experiment, "treadmill")
    # todo - what if tm_fs is None? 
    timestamps = experiment.devices["treadmill"].start_time + np.arange(len(running_speed)) / tm_fs

    nan_mask = np.isnan(running_speed).squeeze()
    running_speed = running_speed[~nan_mask]
    timestamps = timestamps[~nan_mask]

    assert len(running_speed) == len(timestamps)

    running_speed = IrregularTimeSeries(
        timestamps=timestamps,
        running_speed=running_speed.astype(np.float32).reshape(
            -1, 1
        ),  # continues values needs to be 2 dimensional
        domain="auto",
    )
    return running_speed

In [7]:
def extract_pupil_info(experiment):
    logging.info("Processing Eye Tracker...")
    eye_data = experiment.devices["eye_tracker"]._data
    experiment.devices["eye_tracker"].normalize_init()
    eye_data = experiment.devices["eye_tracker"].normalize_data(eye_data)
    eye_fs = get_fs(experiment, "eye_tracker")
    eye_time = experiment.devices["eye_tracker"].start_time + np.arange(len(eye_data)) / eye_fs

    # as we give beh together - filter out any times where at least one variable is a nan
    nan_mask = np.isnan(experiment.devices["eye_tracker"]._data).any(axis=1)
    eye_time = eye_time[~nan_mask]
    eye_data = eye_data[~nan_mask]

    assert len(eye_data) == len(eye_time)
    if len(eye_data) == 0:
        return None
    
    # '0': radius
    # '1': radius_derivative
    # '2': x
    # '3': y
    pupil = IrregularTimeSeries(
        timestamps=eye_time,
        location=eye_data[:, 2:].astype(np.float32),
        size=eye_data[:, :2].astype(np.float32),
        domain="auto",
    )
    return pupil
    

In [8]:
# experanto_base_folder = '/mnt/vast-react/projects/neural_foundation_model/upsampling_without_hamming_30.0Hz' 
# mouse_id_folder = 'dynamic17797-4-7-Video-021a75e56847d574b9acbcc06c675055_30hz'
# experiment = Experiment(f'{experanto_base_folder}/{mouse_id_folder}', DEFAULT_MODALITY_CONFIG, cache_data=True)

In [9]:
def prepare_experanto_session(
    experanto_base_folder, mouse_id_folder,  output_dir, dataset_name="sensorium_data"
):
    """
    Converts an Experanto experiment into a Brainsets HDF5 file for POYO.
    """
    logging.basicConfig(level=logging.INFO)

    # 1. Load Experanto Data
    # ---------------------------------------------------------
    logging.info(f"Loading Experanto: {experanto_base_folder}/{mouse_id_folder}")

    # [Correction 2]: Set cache_data=True.
    # If False (default), accessing ._data on devices often returns None or incomplete data.
    DEFAULT_MODALITY_CONFIG['responses']['sampling_rate'] = 30
    DEFAULT_MODALITY_CONFIG['responses']['chunk_size'] = 60

    DEFAULT_MODALITY_CONFIG['eye_tracker']['sampling_rate'] = 20
    DEFAULT_MODALITY_CONFIG['eye_tracker']['chunk_size'] = 40

    DEFAULT_MODALITY_CONFIG['treadmill']['sampling_rate'] = 20
    DEFAULT_MODALITY_CONFIG['treadmill']['chunk_size'] = 40
    DEFAULT_CONFIG['dataset']['modality_config'] = DEFAULT_MODALITY_CONFIG

    experiment = Experiment(f'{experanto_base_folder}/{mouse_id_folder}', DEFAULT_MODALITY_CONFIG, cache_data=True)

    # 2. Define Metadata
    # ---------------------------------------------------------
    brainset_desc = BrainsetDescription(
        id=dataset_name,
        origin_version="1.0.0",
        derived_version="1.0.0",
        source="local_experanto",
        description="Converted from Experanto dataset",
    )

    subject_desc = SubjectDescription(
        id=mouse_id_folder, # needed for poyo validation metrics to work
        species=Species.MUS_MUSCULUS,
        sex=Sex.UNKNOWN,  # Good practice to include if known, or handle defaults
    )

    session_desc = SessionDescription(
        id=mouse_id_folder, # needed for poyo validation metrics to work
        recording_date=datetime.now(),
    )

    device_desc = DeviceDescription(
        # same as in https://github.com/neuro-galaxy/brainsets/blob/main/brainsets_pipelines/allen_visual_coding_ophys_2016/prepare_data.py#L420
        id=mouse_id_folder.split('-')[0].split('dynamic')[-1], 
        recording_tech=RecordingTech.TWO_PHOTON_IMAGING,
    )

    # extract calcium traces
    calcium_traces = extract_calcium_traces(experiment)
    units = extract_units(experanto_base_folder, mouse_id_folder)

    # extract stimulus and behavior data
    stimuli_and_behavior_dict = {}
    if "treadmill" in experiment.devices:
        stimuli_and_behavior_dict["running"] =  extract_running_speed(experiment)

    if "eye_tracker" in experiment.devices:
        pupil = extract_pupil_info(experiment)
        if pupil is not None:
            stimuli_and_behavior_dict["pupil"] = pupil
    
    data = Data(
        brainset=brainset_desc,
        subject=subject_desc,
        session=session_desc,
        device=device_desc,
        # neural activity
        calcium_traces=calcium_traces,
        units=units,
        # stimuli and behavior
        **stimuli_and_behavior_dict,
        # domain
        domain=calcium_traces.domain,
    )
    logging.info("Creating Splits...")

    with open(f"{experanto_base_folder}/{mouse_id_folder}/screen/combined_meta.json", "r") as file:
        data_json = json.load(file)

    all_tiers = list(set([k['tier'] for k in data_json.values() if 'tier' in k]))

    train_intervals = get_intervals_for_tier(experanto_base_folder, mouse_id_folder, calcium_traces, 'train', DEFAULT_CONFIG)
    data.set_train_domain(train_intervals)
    # -------
    valid_intervals = get_intervals_for_tier(experanto_base_folder, mouse_id_folder, calcium_traces, 'validation', DEFAULT_CONFIG)
    data.set_valid_domain(valid_intervals)
    # -------
    if 'final_test_main' in all_tiers or 'final_test_1' in all_tiers or "final_test" in all_tiers:
        if 'final_test_main' in all_tiers:
            tier = 'final_test_main'
        elif 'final_test_1' in all_tiers:
            tier = 'final_test_1'
        else:
            tier = 'final_test'
        test_intervals = get_intervals_for_tier(experanto_base_folder, mouse_id_folder, calcium_traces, tier, DEFAULT_CONFIG)
        data.set_test_domain(test_intervals)

    # save data to disk
    path = os.path.join(output_dir, f"{mouse_id_folder}.h5")
    with h5py.File(path, "w") as file:
        data.to_hdf5(file, serialize_fn_map=serialize_fn_map)
    


In [10]:
output_dir = '/mnt/vast-react/projects/neural_foundation_model/torch_brain_polly_export/'
experanto_base_folder = '/mnt/vast-react/projects/neural_foundation_model/upsampling_without_hamming_30.0Hz' 
# mouse_id_folder = 'dynamic17797-4-7-Video-021a75e56847d574b9acbcc06c675055_30hz'

In [12]:
target_mice = ['upsampling_without_hamming_30.0Hz/dynamic29234-6-9-Video-021a75e56847d574b9acbcc06c675055_30hz',
'upsampling_without_hamming_30.0Hz/dynamic29514-2-9-Video-021a75e56847d574b9acbcc06c675055_30hz',
'upsampling_without_hamming_30.0Hz/dynamic29513-3-5-Video-021a75e56847d574b9acbcc06c675055_30hz',
'upsampling_without_hamming_30.0Hz/dynamic29156-11-10-Video-021a75e56847d574b9acbcc06c675055_30hz',
'upsampling_without_hamming_30.0Hz/dynamic17797-8-5-Video-021a75e56847d574b9acbcc06c675055_30hz',
'upsampling_without_hamming_30.0Hz/dynamic29228-2-10-Video-021a75e56847d574b9acbcc06c675055_30hz',
'test_upsampling_without_hamming_30.0Hz/dynamic26872-17-20-Video-021a75e56847d574b9acbcc06c675055_30hz', 
'test_upsampling_without_hamming_30.0Hz/dynamic27204-5-13-Video-021a75e56847d574b9acbcc06c675055_30hz']

In [None]:
output_dir = '/mnt/vast-react/projects/neural_foundation_model/torch_brain_polly_export/normalized/'
for t in target_mice:
    print(f'started with {t}')
    experanto_base_folder = f'/mnt/vast-react/projects/neural_foundation_model/{t.split("/")[0]}'
    mouse_id_folder = t.split('/')[1]
    prepare_experanto_session(
        experanto_base_folder, mouse_id_folder,  output_dir, dataset_name="normalized" # dataset name should match the folder name!
    )
    print('\n\n')