In [None]:
import numpy as np
import yaml
import shutil
from pathlib import Path
from experanto.dataloaders import get_multisession_dataloader
from omegaconf import OmegaConf

# 1. Setup paths
base_path = Path("/mnt/vast-nhr/projects/nix00014/goirik/data/dummy2")
spikes_path = base_path / "responses"
if spikes_path.exists():
    shutil.rmtree(spikes_path)
spikes_path.mkdir()

# 2. Create Dummy Data (Spikes + Meta)
n_neurons = 500
duration = 100.0
all_spikes = []
indices = [0]

for _ in range(n_neurons):
    n_spikes = int(duration * 20)  # 20 Hz
    spikes = np.sort(np.random.uniform(0, duration, n_spikes))
    all_spikes.append(spikes)
    indices.append(indices[-1] + len(spikes))

flat_spikes = np.concatenate(all_spikes)
flat_spikes.tofile(spikes_path / "spikes.npy")

meta = {
    "modality": "spikes",
    "n_signals": n_neurons,
    "spike_indices": indices,
    "start_time": 0.0,
    "end_time": duration,
    "sampling_rate": 1000.0 
}

with open(spikes_path / "meta.yml", "w") as f:
    yaml.dump(meta, f)

# 3. Define Configs
# Note: ChunkDataset/get_multisession_dataloader usually requires a configuration 
# specifying 'dataset' parameters (passed to ChunkDataset) and 'dataloader' parameters.
config = {
    "dataset": {
        "global_chunk_size": 8, 
        "global_sampling_rate": 30,
        "add_behavior_as_channels": False,
        "out_keys": ["responses"],
        
        "modality_config": {
            "spike": {
                "chunk_size": 40,
                "sampling_rate": 20,
                "interpolation": {
                    "interpolation_window": 0.5,
                    "interpolation_align": "center",
                }, 
            },
            "screen": {
                "chunk_size": 60,
                "sampling_rate": 30,
                "sample_stride": 1,
                "interpolation": {
                    "interpolation_mode": "nearest_neighbor",
                },   
            },
        },
    },
    "dataloader": {
        "batch_size": 16,
        "shuffle": True,
        "num_workers": 2,
        "drop_last": True,
    }
}


config_obj = OmegaConf.create(config)

# 4. Create Dataloader
# Returns a LongCycler wrapping the individual dataloaders
print("Creating DataLoader...")
loader = get_multisession_dataloader(
    paths=[str(base_path)],
    configs=[config_obj]
)

# 5. Iterate
print("Iterating...")
for batch in loader:
    # 'batch' is typically a namedtuple or dict depending on ChunkDataset implementation
    # It usually contains 'inputs', 'targets', etc.
    print(f"Loaded batch keys: {batch.keys()}") 
    # Assuming the interpolator returns the data as the main signal:
    print(f"Data shape: {batch['inputs'].shape if 'inputs' in batch else batch}")
    break

# Cleanup
# shutil.rmtree(base_path)
# This will not work until I make a dummy dataset with screen data as well

Creating DataLoader...
No metadata file found at /mnt/vast-nhr/projects/nix00014/goirik/data/dummy2/meta.json


KeyError: 'screen'

# Initializing experiments works as expected

In [36]:
from experanto.experiment import Experiment

In [41]:
config = {
    "dataset": {
        "global_chunk_size": 8, 
        "global_sampling_rate": 30,
        "add_behavior_as_channels": False,
        "out_keys": ["responses"],
        
        "modality_config": {
            "spike": {
                "chunk_size": 40,
                "sampling_rate": 20,
                "interpolation": {
                    "interpolation_window": 0.5,
                    "interpolation_align": "center",
                }, 
            },
            "screen": {
                "chunk_size": 60,
                "sampling_rate": 30,
                "sample_stride": 1,
                "interpolation": {
                    "interpolation_mode": "nearest_neighbor",
                },   
            },
        },
    },
    "dataloader": {
        "batch_size": 16,
        "shuffle": True,
        "num_workers": 2,
        "drop_last": True,
    }
}


config_obj = OmegaConf.create(config)


experiment = Experiment(root_folder=base_path,
        modality_config= config["dataset"]["modality_config"],
        cache_data=True,
        interpolate_precision=5)

In [42]:
experiment.device_names

('spike',)