# data

In [3]:
import cebra
from cebra.datasets import init as dataset_init

# active
# 1. 基础数据集
dataset = dataset_init("area2-bump")

# 2. 带手部位置标签的数据集
pos_dataset = dataset_init("area2-bump-pos-active")  

# 3. 带目标方向标签的数据集 
target_dataset = dataset_init("area2-bump-target-active")

# 4. 同时带位置和方向标签的数据集
posdir_dataset = dataset_init("area2-bump-posdir-active")

# 检查数据特征
print("Dataset input dimension:", dataset.input_dimension)
print("Dataset length:", len(dataset))

# 获取连续标签(如果有)
continuous_labels = dataset.continuous_index
if continuous_labels is not None:
    print("Continuous labels shape:", continuous_labels.shape)

# 获取离散标签(如果有)
discrete_labels = dataset.discrete_index
if discrete_labels is not None:
    print("Discrete labels shape:", discrete_labels.shape)

# 分割数据集
train_dataset = dataset.split("train")
valid_dataset = dataset.split("valid")
test_dataset = dataset.split("test")

Dataset input dimension: 65
Dataset length: 115800
Continuous labels shape: torch.Size([115800, 2])
Discrete labels shape: torch.Size([115800])


In [None]:
# passive
# 1. 基础数据集
dataset = dataset_init("area2-bump")

# 2. 带手部位置标签的数据集
pos_dataset = dataset_init("area2-bump-pos-active")  

# 3. 带目标方向标签的数据集 
target_dataset = dataset_init("area2-bump-target-active")

# 4. 同时带位置和方向标签的数据集
posdir_dataset = dataset_init("area2-bump-posdir-active")

# 检查数据特征
print("Dataset input dimension:", dataset.input_dimension)
print("Dataset length:", len(dataset))

# 获取连续标签(如果有)
continuous_labels = dataset.continuous_index
if continuous_labels is not None:
    print("Continuous labels shape:", continuous_labels.shape)

# 获取离散标签(如果有)
discrete_labels = dataset.discrete_index
if discrete_labels is not None:
    print("Discrete labels shape:", discrete_labels.shape)

# 分割数据集
train_dataset = dataset.split("train")
valid_dataset = dataset.split("valid")
test_dataset = dataset.split("test")

# model definition

In [None]:
from cebra import CEBRA
import cebra.models

max_iterations = 15000

In [None]:
cebra_pos_model = CEBRA(model_architecture='offset10-model',
                        batch_size=512,
                        learning_rate=0.0001,
                        temperature=1,
                        output_dimension=3,
                        max_iterations=max_iterations,
                        distance='cosine',
                        conditional='time_delta',
                        device='cuda_if_available',
                        verbose=True,
                        time_offsets=10)

In [None]:
cebra_target_model = CEBRA(model_architecture='offset10-model',
                           batch_size=512,
                           learning_rate=0.0001,
                           temperature=1,
                           output_dimension=3,
                           max_iterations=max_iterations,
                           distance='cosine',
                           conditional='time_delta',
                           device='cuda_if_available',
                           verbose=True,
                           time_offsets=10)

In [None]:
cebra_time_model = CEBRA(model_architecture='offset10-model',
                         batch_size=512,
                         learning_rate=0.0001,
                         temperature=1,
                         output_dimension=3,
                         max_iterations=max_iterations,
                         distance='cosine',
                         conditional='time',
                         device='cuda_if_available',
                         verbose=True,
                         time_offsets=5)

# model training

## pos

In [None]:
cebra_pos_model.fit(pos_dataset.neural, pos_dataset.continuous_index.numpy())
cebra_pos = cebra_pos_model.transform(pos_dataset.neural)

In [None]:
cebra.plot_embedding(cebra_pos)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
fig = plt.figure(figsize=(12, 5))
plt.suptitle('CEBRA-behavior trained with position label',
             fontsize=20)
ax = plt.subplot(121, projection = '3d')
ax.set_title('x', fontsize=20, y=0)
x = ax.scatter(cebra_pos[:, 0],
               cebra_pos[:, 1],
               cebra_pos[:, 2],
               c=pos_dataset.continuous_index[:, 0],
               cmap='seismic',
               s=0.05,
               vmin=-15,
               vmax=15)
ax.axis('off')
ax = plt.subplot(122, projection = '3d')
y = ax.scatter(cebra_pos[:, 0],
               cebra_pos[:, 1],
               cebra_pos[:, 2],
               c=pos_dataset.continuous_index[:, 1],
               cmap='seismic',
               s=0.05,
               vmin=-15,
               vmax=15)
ax.axis('off')
ax.set_title('y', fontsize=20, y=0)
yc = plt.colorbar(y, fraction=0.03, pad=0.05, ticks=np.linspace(-15, 15, 7))
yc.ax.tick_params(labelsize=15)
yc.ax.set_title("(cm)", fontsize=10)
plt.show()

In [None]:
cebra.plot_loss(cebra_pos_model)

## target

In [None]:
cebra_target_model.fit(target_dataset.neural,
                       target_dataset.discrete_index.numpy())
cebra_target = cebra_target_model.transform(target_dataset.neural)

In [None]:
cebra.plot_embedding(cebra_target)

In [None]:

fig = plt.figure(figsize=(4, 2), dpi=300)
plt.suptitle('CEBRA-behavior trained with target label',
             fontsize=5)
ax = plt.subplot(121, projection = '3d')
ax.set_title('All trials embedding', fontsize=5, y=-0.1)
x = ax.scatter(cebra_target[:, 0],
               cebra_target[:, 1],
               cebra_target[:, 2],
               c=target_dataset.discrete_index,
               cmap=plt.cm.hsv,
               s=0.01)
ax.axis('off')

ax = plt.subplot(122,projection = '3d')
ax.set_title('direction-averaged embedding', fontsize=5, y=-0.1)
for i in range(8):
    direction_trial = (target_dataset.discrete_index == i)
    trial_avg = cebra_target[direction_trial, :].reshape(-1, 600,
                                                         3).mean(axis=0)
    trial_avg_normed = trial_avg/np.linalg.norm(trial_avg, axis=1)[:,None]
    ax.scatter(trial_avg_normed[:, 0],
               trial_avg_normed[:, 1],
               trial_avg_normed[:, 2],
               color=plt.cm.hsv(1 / 8 * i),
               s=0.01)
ax.axis('off')
plt.show()

In [None]:
cebra.plot_loss(cebra_target_model)

## time

In [None]:
cebra_time_model.fit(target_dataset.neural)
cebra_time = cebra_time_model.transform(target_dataset.neural)

In [None]:
cebra.plot_embedding(cebra_time)

In [None]:
fig = plt.figure(figsize=(4, 2), dpi=300)
plt.suptitle('CEBRA-time', fontsize=5)
ax = plt.subplot(121, projection='3d')
ax.set_title('x', fontsize=4, y=-0.1)
x = ax.scatter(cebra_time[:, 0],
               cebra_time[:, 1],
               cebra_time[:, 2],
               c=pos_dataset.continuous_index[:, 0],
               cmap='seismic',
               s=0.05,
               vmin=-15,
               vmax=15)
ax.axis('off')
ax = plt.subplot(122, projection='3d')
y = ax.scatter(cebra_time[:, 0],
               cebra_time[:, 1],
               cebra_time[:, 2],
               c=pos_dataset.continuous_index[:, 1],
               cmap='seismic',
               s=0.05,
               vmin=-15,
               vmax=15)
ax.axis('off')
ax.set_title('y', fontsize=5, y=-0.1)
yc = plt.colorbar(y, fraction=0.03, pad=0.05, ticks=np.linspace(-15, 15, 7))
yc.ax.tick_params(labelsize=3)
yc.ax.set_title("(cm)", fontsize=5)
plt.show()

In [None]:
cebra.plot_loss(cebra_time_model)

In [None]:
import cebra

# Labels to be used for the legend of the plot (optional)
labels = ["pos", "target", "time"]

cebra.compare_models([cebra_pos_model, cebra_target_model, cebra_time_model], labels)