<a href="https://colab.research.google.com/github/obarnstedt/CEBRA-demos/blob/colab/Demo_hippocampus_multisession.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Technical: Training models across animals

* This notebook will demo how to use a multisession CEBRA implementation.
* We will compare embeddings obtained on 4 rat datasets when trained on four single-session models vs. one multisession model.
* We will use CEBRA-Behavior (``conditional='time_delta'``) for both single and multi session implementation.
* Each individual single session output embedding is aligned to the first session.


**How multi-session training works:**

* For flexibility, it is implemented so that it does not fit one model for all sessions. Consequently, it is possible to use sessions that do not have the same number of data features (e.g., not the same number of neurons from one session to the other). The number of samples can also vary (true for single session too).
* It fits one model per session but the pos/neg sampling is performed across all sessions, making the models invariant to the labels across all sessions.


**When to use multi-session training:**

Check out the following list to verify if our multisession implementation is the right tool for your needs.

* I have multiple sessions or animals that I want to consider as a *pseudo-subject* and use them jointly for training CEBRA.
* That is the case because of limited access to simultaneously recorded neurons or looking for animal-invariant features in the neural data.
* I want to get more consistent embeddings from one session or animal to the other.
* I want to be able to use CEBRA for a new session that is fully *unseen* during training (e.g., could be useful for a brain machine interface applications).
* **WARNING:** I am not interested in the influence of individual variations of the label features from one session/animal to the other on the resulting embedding. I do not need those session or animal specific information.


**Install note**

Be sure you have cebra, and the demo dependencies, installed to use this notebook:

In [1]:
!pip install 'cebra[datasets,demos]'

Collecting cebra[datasets,demos]
  Downloading cebra-0.2.0-py2.py3-none-any.whl (171 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m171.3/171.3 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
Collecting literate-dataclasses (from cebra[datasets,demos])
  Downloading literate_dataclasses-0.0.6-py3-none-any.whl (5.0 kB)
Collecting hdf5storage (from cebra[datasets,demos])
  Downloading hdf5storage-0.1.19-py2.py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.6/53.6 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
Collecting jupyter (from cebra[datasets,demos])
  Downloading jupyter-1.0.0-py2.py3-none-any.whl (2.7 kB)
Collecting qtconsole (from jupyter->cebra[datasets,demos])
  Downloading qtconsole-5.4.4-py3-none-any.whl (121 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.9/121.9 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
Collecting jedi>=0.16 (from ipython>=5.0.0->ipykernel->cebra[datasets,demos])
  Downlo

In [2]:
import sys
import os

import cebra.data
import cebra.datasets
import cebra.integrations
from cebra import CEBRA
import pickle

import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from plotly.subplots import make_subplots
import plotly.graph_objects as go

# check if notebook run in Colab:
try:
  import google.colab
  IN_COLAB = True
  %matplotlib inline
except:
  IN_COLAB = False
  %matplotlib notebook



## Load the data

- The data will be automatically downloaded into a `/data` folder.

In [3]:
if IN_COLAB and not os.path.exists('/content/data/rat_hippocampus'):
    import locale
    locale.getpreferredencoding = lambda: "UTF-8"  # to prevent locale errors
    #for google colab only, run this cell to download and extract data:
    !wget --content-disposition https://figshare.com/ndownloader/files/36869049?private_link=60adb075234c2cc51fa3
    !mkdir data
    !tar -xvf "/content/data.tgz" -C "/content/data"

hippocampus_a = cebra.datasets.init('rat-hippocampus-single-achilles')
hippocampus_b = cebra.datasets.init('rat-hippocampus-single-buddy')
hippocampus_c = cebra.datasets.init('rat-hippocampus-single-cicero')
hippocampus_g = cebra.datasets.init('rat-hippocampus-single-gatsby')

--2023-09-16 13:01:35--  https://figshare.com/ndownloader/files/36869049?private_link=60adb075234c2cc51fa3
Resolving figshare.com (figshare.com)... 52.19.197.201, 176.34.206.43, 2a05:d018:1f4:d003:ca7c:7587:e6e5:533c, ...
Connecting to figshare.com (figshare.com)|52.19.197.201|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/36869049/data.tgz?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIYCQYOYV5JSSROOA/20230916/eu-west-1/s3/aws4_request&X-Amz-Date=20230916T130136Z&X-Amz-Expires=10&X-Amz-SignedHeaders=host&X-Amz-Signature=7d92d0115c858d41d3933b0ccaecde9c5133299564b38ea61ad673b56a0c22ba [following]
--2023-09-16 13:01:36--  https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/36869049/data.tgz?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIYCQYOYV5JSSROOA/20230916/eu-west-1/s3/aws4_request&X-Amz-Date=20230916T130136Z&X-Amz-Expires=10&X-Amz-SignedHeaders=host&X-Amz-Signature=7d92d0115c858d4

In [4]:
names = ["achilles", "buddy", "cicero", "gatsby"]
datas = [hippocampus_a.neural.numpy(), hippocampus_b.neural.numpy(), hippocampus_c.neural.numpy(), hippocampus_g.neural.numpy()]
labels = [hippocampus_a.continuous_index.numpy(), hippocampus_b.continuous_index.numpy(), hippocampus_c.continuous_index.numpy(), hippocampus_g.continuous_index.numpy()]

## Single-session training

Create and fit one single session model per dataset

In [5]:
max_iterations = 15000 # by default, 5000 iterations

In [None]:
embeddings = dict()

# Single session training
for name, X, y in zip(names, datas, labels):
    # Fit one CEBRA model per session (i.e., per rat)
    print(f"Fitting CEBRA for {name}")
    cebra_model = CEBRA(model_architecture='offset10-model',
                        batch_size=512,
                        learning_rate=3e-4,
                        temperature=1,
                        output_dimension=3,
                        max_iterations=max_iterations,
                        distance='cosine',
                        conditional='time_delta',
                        device='cuda_if_available',
                        verbose=True,
                        time_offsets=10)

    cebra_model.fit(X, y)
    embeddings[name] = cebra_model.transform(X)


# Align the single session embeddings to the first rat
alignment = cebra.data.helper.OrthogonalProcrustesAlignment()
first_rat = list(embeddings.keys())[0]

for j, rat_name in enumerate(list(embeddings.keys())[1:]):
    embeddings[f"{rat_name}"] = alignment.fit_transform(
        embeddings[first_rat], embeddings[rat_name], labels[0], labels[j+1])


# Save embeddings in current folder
with open('embeddings.pkl', 'wb') as f:
    pickle.dump(embeddings, f)


Fitting CEBRA for achilles


pos:  0.1626 neg:  5.4230 total:  5.5856 temperature:  1.0000:   5%|▍         | 746/15000 [02:26<42:35,  5.58it/s]

## Multisession training

Create and fit one multi-session model on all datasets

In [None]:
multi_embeddings = dict()

# Multisession training
multi_cebra_model = CEBRA(model_architecture='offset10-model',
                    batch_size=512,
                    learning_rate=3e-4,
                    temperature=1,
                    output_dimension=3,
                    max_iterations=max_iterations,
                    distance='cosine',
                    conditional='time_delta',
                    device='cuda_if_available',
                    verbose=True,
                    time_offsets=10)

# Provide a list of data, i.e. datas = [data_a, data_b, ...]
multi_cebra_model.fit(datas, labels)

# Transform each session with the right model, by providing the corresponding session ID
for i, (name, X) in enumerate(zip(names, datas)):
    multi_embeddings[name] = multi_cebra_model.transform(X, session_id=i)

# Save embeddings in current folder
with open('multi_embeddings.pkl', 'wb') as f:
    pickle.dump(multi_embeddings, f)

## Compare embeddings

Also see [Extended Data Figure 7](https://cebra.ai/docs/cebra-figures/figures/ExtendedDataFigure7.html) of Schneider, Lee, Mathis.

In [None]:
with open('embeddings.pkl', 'rb') as f:
    embeddings = pickle.load(f)
with open('multi_embeddings.pkl', 'rb') as f:
    multi_embeddings = pickle.load(f)

In [None]:
def plot_hippocampus(ax, embedding, label, gray = False, idx_order = (0,1,2)):
    r_ind = label[:,1] == 1
    l_ind = label[:,2] == 1

    if not gray:
        r_cmap = 'cool'
        l_cmap = 'viridis'
        r_c = label[r_ind, 0]
        l_c = label[l_ind, 0]
    else:
        r_cmap = None
        l_cmap = None
        r_c = 'gray'
        l_c = 'gray'

    idx1, idx2, idx3 = idx_order
    r=ax.scatter(embedding [r_ind,idx1],
               embedding [r_ind,idx2],
               embedding [r_ind,idx3],
               c=r_c,
               cmap=r_cmap, s=0.05, alpha=0.75)
    l=ax.scatter(embedding [l_ind,idx1],
               embedding [l_ind,idx2],
               embedding [l_ind,idx3],
               c=l_c,
               cmap=l_cmap, s=0.05, alpha=0.75)

    ax.grid(False)
    ax.xaxis.pane.fill = False
    ax.yaxis.pane.fill = False
    ax.zaxis.pane.fill = False
    ax.xaxis.pane.set_edgecolor('w')
    ax.yaxis.pane.set_edgecolor('w')
    ax.zaxis.pane.set_edgecolor('w')

    return ax

def plot_hippocampus_plotly(fig, col, embedding, label, row=0, gray=False, idx_order=(0,1,2)):
    r_ind = label[:,1] == 1
    l_ind = label[:,2] == 1

    if not gray:
        # r_cmap = 'Plotly3'  # 'cool' does not exist in plotly
        r_cmap = ["#00fdff", "#ff40ff"]
        l_cmap = 'viridis'
        r_c = label[r_ind, 0]
        l_c = label[l_ind, 0]
    else:
        r_cmap = None
        l_cmap = None
        r_c = 'gray'
        l_c = 'gray'

    idx1, idx2, idx3 = idx_order

    fig.add_trace(
        go.Scatter3d(
            x=embedding[r_ind, idx1],
            y=embedding[r_ind, idx2],
            z=embedding[r_ind, idx3],
            mode='markers',
            marker=dict(
                size=2,
                opacity=0.6,
                color=r_c,   # set color to an array/list of desired values
                colorscale=r_cmap,   # choose a colorscale
                line=dict(width=0)
            ),
        ), col=col+1, row=row+1)

    fig.add_trace(
        go.Scatter3d(
            x=embedding[l_ind, idx1],
            y=embedding[l_ind, idx2],
            z=embedding[l_ind, idx3],
            mode='markers',
            marker=dict(
                size=2,
                opacity=0.6,
                color=l_c,   # set color to an array/list of desired values
                colorscale=l_cmap,   # choose a colorscale
                line=dict(width=0)
            ),
        ), col=col+1, row=row+1)

    return fig

In [None]:
plotly = True
if plotly:
    fig = make_subplots(rows=2, cols=4,
                        specs=[[{'is_3d': True}]*4, ],
                        print_grid=False,
                        subplot_titles=([f'Single-{n}' for n in names]+
                                        [f'Multi-{n}' for n in names]),
                        horizontal_spacing = 0.05,
    )
    for i, n in enumerate(names):
        plot_hippocampus_plotly(fig, i, embeddings[n], labels[n])
        plot_hippocampus_plotly(fig, i, embeddings[n], labels[n])
        plot_hippocampus_plotly(fig, i, embeddings[n], labels[n])
        plot_hippocampus_plotly(fig, i, embeddings[n], labels[n])
    for i, n in enumerate(names):
        plot_hippocampus_plotly(fig, i, multi_embeddings[n], labels[n], row=1)
        plot_hippocampus_plotly(fig, i, multi_embeddings[n], labels[n], row=1)
        plot_hippocampus_plotly(fig, i, multi_embeddings[n], labels[n], row=1)
        plot_hippocampus_plotly(fig, i, multi_embeddings[n], labels[n], row=1)
    fig.update_layout(template="plotly_white", showlegend=False, height=400)
    fig.show()

else:
    fig = plt.figure(figsize=(10,6))

    ax1 = plt.subplot(241, projection='3d')
    ax2 = plt.subplot(242, projection='3d')
    ax3 = plt.subplot(243, projection='3d')
    ax4 = plt.subplot(244, projection='3d')
    axs = [ax1, ax2, ax3, ax4]

    ax5 = plt.subplot(245, projection='3d')
    ax6 = plt.subplot(246, projection='3d')
    ax7 = plt.subplot(247, projection='3d')
    ax8 = plt.subplot(248, projection='3d')
    axs_multi = [ax5, ax6, ax7, ax8]

    for name, label, ax, ax_multi in zip(names, labels, axs, axs_multi):
        ax = plot_hippocampus(ax, embeddings[name], label)
        ax.set_title(f'Single-{name}', y=1, pad=-20)
        ax.axis('off')
        ax_multi = plot_hippocampus(ax_multi, multi_embeddings[name], label)
        ax_multi.set_title(f'Multi-{name}', y=1, pad=-20)
        ax_multi.axis('off')


    plt.subplots_adjust(wspace=0,
                        hspace=0)
    plt.show()