# cebra example

This notebook is for development / experimenting.
It is based on https://stes.io/NeuroDataReHack2023/

See submit_cebra.py

In [None]:
import pynwb
import dendro.client as den
import remfile
import h5py


# Load project D-000129
project = den.load_project("9638e926")
dandiset_id = "000129"

# Select an NWB file
asset_path = "sub-Indy/sub-Indy_desc-train_behavior+ecephys.nwb"

# Lazy load NWB file
file = remfile.File(project.get_file(f"imported/{dandiset_id}/{asset_path}"))
io = pynwb.NWBHDF5IO(file=h5py.File(file, "r"), mode="r")
nwbfile = io.read()

In [None]:
import numpy as np
from nlb_tools.nwb_interface import NWBDataset

class Dataset(NWBDataset):

    def __init__(self, nwbfile):

        super().__init__(nwbfile, "*train", split_heldout=False)
        # To make computations faster, we will bin the whole dataset into 20ms bins
        self.resample(target_bin = 20)

        for signal_type in set(self.data.columns.get_level_values(level = 0)):
            print(signal_type, self.data[signal_type].shape)
            setattr(self, signal_type, self.data[signal_type].values)

        values = [tuple(v) for v in self.target_pos]
        unique_values = list(sorted(set([v for v in values if not np.isnan(v).any()])))
        self.target_pos_idx = np.array([-1 if np.isnan(v).any() else unique_values.index(v) for v in values], dtype = int)

dataset = Dataset(nwbfile)

print("Loaded dataset:")
display(dataset.data.head())

In [None]:
import cebra

MAX_ITERATIONS = 500

def init_model():
    return cebra.CEBRA(
        # Our selected model will use 10 time bins (200ms) as its input
        model_architecture = "offset10-model",

        # We will use mini-batches of size 1000 for optimization. You should
        # generally pick a number greater than 512, and larger values (if they
        # fit into memory) are generally better.
        batch_size = 1000,

        # This is the number of steps to train. I ran an example with 10_000
        # which resulted in a usable embedding, but training longer might further
        # improve the results
        max_iterations = MAX_ITERATIONS,

        # This will be the number of output features. The optimal number depends
        # on the complexity of the dataset.
        output_dimension = 8,

        # If you want to see a progress bar during training, specify this
        verbose = True

        # There are many more parameters to explore. Head to
        # https://cebra.ai/docs/api/sklearn/cebra.html to explore them.
    )

model = init_model()

In [None]:
is_nan = np.isnan(dataset.spikes).any(axis = 1)
model.fit(
    dataset.spikes[~is_nan],
    dataset.cursor_pos[~is_nan]
)

In [None]:
embedding = model.transform(dataset.spikes[~is_nan])

In [None]:
cebra.plot_loss(model, label = "Loss curve")