# CEBRA Embeddings

generates CEBRA embeddings from cleaned spike data 

- input: spike data stored in a pickle file, organized by (datas, recordings)
- output: CEBRA embedding pickle file (saved by name (mouse ID))

Author: @emilyekstrum
<br> 11/17/25

In [None]:
import cebra
import itertools
import os
import torch
import matplotlib
import random

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle as pkl

from cebra import CEBRA
from pathlib import Path

plt.style.use(['default', 'seaborn-v0_8-paper'])

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42



## Load in data

In [None]:
def load_embedding_data(filepath):
    """Load embedding data from a pickle file.
    Args:
        filepath : str path to the pickle file.
        
    Returns:
        single_session_dict : dict loaded data.
        session_names : list of session names. """
    
    with open(filepath, "rb") as f:
        single_session_dict = pkl.load(f)
        session_names = list(single_session_dict.keys())
        
    return single_session_dict, session_names

In [None]:
# load in data from a pickle file
with open('C:/Users/denmanlab/Desktop/Emily_rotation/data_to_run_at_home/LGNcolor_exchange.pkl', 'rb') as f:
    loaded = pkl.load(f)
    datas = loaded['datas'] # tensors (time*trials, neurons)
    recordings = loaded['recordings'] # list of names / mouse IDs

In [None]:
# option to load in npy data
shared_relu = np.load(r"C:\Users\denmanlab\Desktop\Emily_rotation\nick_CNN_data\shared_relu.npy")

## Train CEBRA model

In [None]:
#Train model - Juan's code
train_steps = 30000

single_session = dict()
device = 'cuda'

cebra_time_model = CEBRA(model_architecture = 'offset10-model',
                        batch_size         = 512,
                        learning_rate      = 3e-4,
                        #  temperature_mode   = "auto",
                        temperature        = 1,
                        #  min_temperature    = 1e-1,
                        output_dimension   = 3,
                        max_iterations     = train_steps,
                        num_hidden_units   = 128,
                        distance           = 'cosine',
                        conditional        = 'time', # time delta uses behavior/stimulus data to fit model
                        device             = device,
                        verbose            = True,
                        time_offsets       = 10,
                        optimizer          = 'adam',
                        )

for name, X in zip(recordings, datas):
    print(name)
    out_path = Path(os.path.join(r'G:\cebra',name))
    out_path.mkdir(exist_ok=True)
    single_session[name] = {}

    cebra_time_model.fit(X.type(torch.FloatTensor))
    cebra_time = cebra_time_model.transform(X.type(torch.FloatTensor))

    cebra_time_model.save(os.path.join(out_path,f'{name}_lgn_time.pt'),backend='torch')

    single_session[name]['model']      = cebra_time_model
    single_session[name]['embedding']  = cebra_time

with open(f'C:/Users/denmanlab/Desktop/Emily_rotation/CEBRA/test.pkl','wb') as f:
    pkl.dump(single_session,f)

d4


pos: -0.9582 neg:  6.3999 total:  5.4418 temperature:  1.0000: 100%|██████████| 30000/30000 [02:43<00:00, 183.66it/s]


In [None]:
# downsample embeddings - FPS or random
def downsample_embedding(embed, n_target=1000, method="fps", seed=42):
    """ Downsample an embedding to n_target points using specified method.
    
    Args:
        embed : np.ndarray input embedding of shape (N, D).
        n_target : int target number of points after downsampling.
        method : str downsampling method: "random" or "fps" (farthest point sampling).

    Returns:
        np.ndarray: Downsampled embedding of shape (n_target, D).
    """
    N = embed.shape[0]
    if n_target >= N:
        return embed  # nothing to downsample
    
    rng = np.random.default_rng(seed)

    if method == "random":
        idx = rng.choice(N, size=n_target, replace=False)
        return embed[idx]

    elif method == "fps":
        # Farthest Point Sampling 
        idxs = [rng.integers(N)]
        dists = np.linalg.norm(embed - embed[idxs[0]], axis=1)

        for _ in range(1, n_target):
            next_idx = np.argmax(dists)
            idxs.append(next_idx)
            new_dists = np.linalg.norm(embed - embed[next_idx], axis=1)
            dists = np.minimum(dists, new_dists)

        return embed[idxs]

    else:
        raise ValueError("method must be 'random' or 'fps'")

In [20]:
# train CEBRA for nicks CNN data

train_steps = 30000

single_session = dict()
device = 'cuda'


cebra_time_model = CEBRA(model_architecture = 'offset10-model',
                        batch_size         = 512,
                        learning_rate      = 3e-4,
                        #  temperature_mode   = "auto",
                        temperature        = 1,
                        #  min_temperature    = 1e-1,
                        output_dimension   = 32,
                        max_iterations     = train_steps,
                        num_hidden_units   = 128,
                        distance           = 'cosine',
                        conditional        = 'time', # time delta uses behavior/stimulus data to fit model
                        device             = device,
                        verbose            = True,
                        time_offsets       = 10,
                        optimizer          = 'adam',
                        )

cebra_time_model.fit(shared_relu)
cebra_time = cebra_time_model.transform(shared_relu)



pos: -0.9973 neg:  6.2887 total:  5.2914 temperature:  1.0000: 100%|██████████| 30000/30000 [02:52<00:00, 173.44it/s]


In [21]:
name = 'shared_relu'
single_session[name] = {}

single_session[name]['model']      = cebra_time_model
single_session[name]['embedding']  = cebra_time

with open(f'C:/Users/denmanlab/Desktop/Emily_rotation/nick_CNN_data/32d_shared_relu_CEBRA_embed.pkl','wb') as f:
    pkl.dump(single_session,f)

In [None]:
# plot CEBRA embedding - DON'T USE THIS FUNCTION, USE plot_3d_CEBRA BELOW
def plot_3d_CEBRA(ncols=4):

    n=len(names)
    n_rows=int(np.ceil(n/ncols))

    fig = plt.figure(figsize=(20,4))

    for n, name in enumerate(names):
        embedding = single_session[name]['embedding']

        # Use subplot2grid for better control over positioning
        ax = plt.subplot2grid((2, 8), (0, n), projection='3d')
        ax.scatter(embedding[:,0], embedding[:,1], embedding[:,2], s=5, alpha=0.3,
                   cmap='hsv', c=np.linspace(0,1,len(embedding)), clim=(0,1), rasterized=True)
        ax.set_title(name, y=0.8)
        
        ax = plt.subplot2grid((2, 8), (1, n), projection='3d')
        ax.scatter(embedding[:,0],
                embedding[:,1],
                embedding[:,2],
                s=5, alpha=0.3,
                cmap='hsv', c=np.linspace(0,1,len(embedding)), clim=(0,1), rasterized=True
                )

    for axs in fig.get_axes():
        axs.axis('off')
        axs.grid(False)
        axs.xaxis.pane.fill = False
        axs.yaxis.pane.fill = False
        axs.zaxis.pane.fill = False
        axs.xaxis.pane.set_edgecolor('w')
        axs.yaxis.pane.set_edgecolor('w')
        axs.zaxis.pane.set_edgecolor('w')

    # create a ScalarMappable with a proper Normalize instance
    from matplotlib import colors as mcolors
    sm = plt.cm.ScalarMappable(norm=mcolors.Normalize(vmin=0, vmax=1), cmap='hsv')
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=fig.get_axes(), orientation='vertical', fraction=0.02, pad=0.1)
    cbar.set_label('Trial', rotation=270, labelpad=15)


    plt.subplots_adjust(wspace=0,
                        hspace=0)
    plt.suptitle('LGN Time')
    plt.tight_layout()
    # plt.savefig(r'G:\cebra_v1_time.pdf',transparent=True)
    plt.show()

In [1]:
def plot_3d_CEBRA(filepath, ncols=4):
    """ Plot 3D CEBRA embeddings from a pickle file.
    Args:
        filepath : str path to the pickle file.
        ncols : int number of columns in the plot grid. """
    
    single_session, session_names = load_embedding_data(filepath)

    n=len(session_names)
    n_rows=int(np.ceil(n/ncols))

    fig = plt.figure(figsize=(20,4))

    for n, name in enumerate(session_names):
        embedding = single_session[name]['embedding']
        
        ax = plt.subplot2grid((2, 8), (0, n), projection='3d')
        ax.scatter(embedding[:,0], embedding[:,1], embedding[:,2], s=5, alpha=0.3,
                   cmap='hsv', c=np.linspace(0,1,len(embedding)), clim=(0,1), rasterized=True)
        ax.set_title(name, y=0.8)
        
        ax = plt.subplot2grid((2, 8), (1, n), projection='3d')
        ax.scatter(embedding[:,0],
                embedding[:,1],
                embedding[:,2],
                s=5, alpha=0.3,
                cmap='hsv', c=np.linspace(0,1,len(embedding)), clim=(0,1), rasterized=True
                )

    for axs in fig.get_axes():
        axs.axis('off')
        axs.grid(False)
        axs.xaxis.pane.fill = False
        axs.yaxis.pane.fill = False
        axs.zaxis.pane.fill = False
        axs.xaxis.pane.set_edgecolor('w')
        axs.yaxis.pane.set_edgecolor('w')
        axs.zaxis.pane.set_edgecolor('w')

    from matplotlib import colors as mcolors
    sm = plt.cm.ScalarMappable(cmap='hsv') #norm=mcolors.Normalize(vmin=0, vmax=1), 
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=fig.get_axes(), orientation='vertical', fraction=0.02, pad=0.1)
    cbar.set_label('Normalized Time\nAcross Trials', rotation=270, labelpad=10)

    plt.subplots_adjust(wspace=0,
                        hspace=0)
    plt.suptitle('LGN Time')
    plt.tight_layout()
    # plt.savefig(r'G:\cebra_v1_time.pdf',transparent=True)
    plt.show()

In [2]:
# Example usage - plotting 3 dims of CEBRA embeddings
file_path = r"C:\\Users\\denmanlab\\Desktop\\Emily_rotation\\CEBRA\\8d\\raw\\LGN_8d_CEBRA_unsup_time_chromatic_gratings.pkl"
plot_3d_CEBRA(file_path)

NameError: name 'load_embedding_data' is not defined