In [1]:
import os
import sys
from collections import defaultdict
from pathlib import Path
import numpy as np
from scipy.io import loadmat
from quantities import ms

from elephant.statistics import instantaneous_rate
import neo
from elephant.kernels import GaussianKernel
from MARBLE import utils
import mat73
import pickle
import MARBLE

target_folder = os.path.abspath('/media/ubuntu/sda/MARBLE/MARBLE-main/examples/macaque_reaching')

if target_folder not in sys.path:
    sys.path.append(target_folder)
from macaque_reaching_helpers import *


  from .autonotebook import tqdm as notebook_tqdm


In [38]:
def spikes_to_rates(data, d, sampling_period=20):
    """
    Converts matlab spiking data into instantaneous rates in a suitable format for further analysis
    """

    # defining conditions by their ordering (this was how it was ordered in matlab script)
    conditions = ["DownLeft", "Left", "UpLeft", "Up", "UpRight", "Right", "DownRight"]

    data_day = data[d]  # daily session

    gk = GaussianKernel(100 * ms)  # increase this for smoother signals (previously used auto)

    # define empty dictionary for each day
    rates = {}

    # loop over the 7 conditions
    for c, cond in enumerate(conditions):

        # define empty list for each condition during each session
        trial_data = []

        # extracting data for a single condition on a single day (composed of t trials)
        data_day_cond = data_day[c]

        # loop over trials
        for t, trial in enumerate(data_day_cond):

            # if the trial exists (sometimes there is None)
            if trial:
                trial = trial[0]  # it was a single element list

                # loop over neurons
                inst_rates = []
                for ch in range(trial.shape[0]):

                    # extract spikes for a given channel (neuron)
                    spikes = np.where(trial[ch, :])[0]

                    # get spike train (1200 milliseconds)
                    st = neo.SpikeTrain(spikes, units="ms", t_stop=1200)

                    # get rates
                    inst_rate = instantaneous_rate(st, kernel=gk, sampling_period=sampling_period*ms).magnitude

                    # append into list
                    inst_rates.append(inst_rate.flatten())

                # stack rates back together and transpose = (channels by time)
                inst_rates = np.stack(inst_rates, axis=1)

                # append rates from one trial into trial data
                trial_data.append(inst_rates)

        # stack into an array of trial x channels x time
        rates[cond] = np.dstack(trial_data).transpose(2, 1, 0)

    return rates

In [None]:
def get_spikes(data, d):

    conditions = ["DownLeft", "Left", "UpLeft", "Up", "UpRight", "Right", "DownRight"]

    data_day = data[d]  # daily session

    # loop over the 7 conditions
    for c, cond in enumerate(conditions):

        trial_data = []

        data_day_cond = data_day[c]

        for t, trial in enumerate(data_day_cond):

            if trial:
                trial = trial[0]  # it was a single element list

                
                trial_data.append(trial)

        # stack into an array of trial x channels x time
        rates[cond] = np.dstack(trial_data).transpose(2, 1, 0)

    return rates

In [3]:
data_file = "data/conditions_spiking_data.mat"
Path("data").mkdir(exist_ok=True)
os.system(f"wget -nc https://dataverse.harvard.edu/api/access/datafile/6963157 -O {data_file}")

# load data compiled into matlab cell array
data = mat73.loadmat(data_file)["all_results"]

File ‘data/conditions_spiking_data.mat’ already there; not retrieving.


In [4]:
data_day = data[0]

In [7]:
len(data_day[0])

34

In [25]:
conditions = ["DownLeft", "Left", "UpLeft", "Up", "UpRight", "Right", "DownRight"]

spike_train = {}
for idx, data_day in enumerate(data):
    idx = idx + 1
    if data_day:
        spike_train[idx] = {}

    for c, cond in enumerate(conditions):

        trial_data = []

        data_day_cond = data_day[c]

        for t, trial in enumerate(data_day_cond):

            if trial:
                trial = trial[0]  # it was a single element list

                
                trial_data.append(trial)

        trial_data_array = np.stack(trial_data)

        spike_train[idx][cond] = trial_data_array

In [None]:

conditions = ["DownLeft", "Left", "UpLeft", "Up", "UpRight", "Right", "DownRight"]

for c, cond in enumerate(conditions):

    trial_data = []

    data_day_cond = data_day[c]

    for t, trial in enumerate(data_day_cond):

        if trial:
            trial = trial[0]  # it was a single element list

            
            trial_data.append(trial)

    trial_data_array = np.stack(trial_data)

    

In [21]:
a = np.stack(trial_data)

In [22]:
a.shape

(33, 24, 1200)

In [None]:
data_file = "data/conditions_spiking_data.mat"
Path("data").mkdir(exist_ok=True)
os.system(f"wget -nc https://dataverse.harvard.edu/api/access/datafile/6963157 -O {data_file}")

# load data compiled into matlab cell array
data = mat73.loadmat(data_file)["all_results"]

rates = utils.parallel_proc(
    spikes_to_rates, range(len(data)), data, processes=-1, desc="Converting spikes to rates..."
)

all_rates = {}
for i, rates_day in enumerate(rates):
    all_rates[i] = rates_day

with open(f"data/rate_data_20ms.pkl", "wb") as handle:
    pickle.dump(all_rates, handle, protocol=pickle.HIGHEST_PROTOCOL)

File ‘data/conditions_spiking_data.mat’ already there; not retrieving.
Converting spikes to rates...:  52%|█████▏    | 23/44 [01:35<01:11,  3.39s/it]

In [32]:
data_file = "data/trial_ids.pkl"
os.system(f"wget -nc  https://dataverse.harvard.edu/api/access/datafile/6963200  -O {data_file}")    
trial_ids = pickle.load(open("./data/trial_ids.pkl", "rb"))

--2025-08-26 20:51:04--  https://dataverse.harvard.edu/api/access/datafile/6963200
Resolving dataverse.harvard.edu (dataverse.harvard.edu)... 3.237.56.51, 52.205.237.141, 52.6.5.183
Connecting to dataverse.harvard.edu (dataverse.harvard.edu)|3.237.56.51|:443... connected.
HTTP request sent, awaiting response... 303 See Other
Location: https://dvn-cloud.s3.amazonaws.com/10.7910/DVN/KTE4PC/186a3225852-fd9e8941cbf5?response-content-disposition=attachment%3B%20filename%2A%3DUTF-8%27%27trial_ids.pkl&response-content-type=application%2Foctet-stream&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20250826T125106Z&X-Amz-SignedHeaders=host&X-Amz-Credential=AKIAIEJ3NV7UYCSRJC7A%2F20250826%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Expires=3600&X-Amz-Signature=2f83d58eb3c52cc42888d975c67f5729a0919899c122ccadacb4c364b80c1b34 [following]
--2025-08-26 20:51:06--  https://dvn-cloud.s3.amazonaws.com/10.7910/DVN/KTE4PC/186a3225852-fd9e8941cbf5?response-content-disposition=attachment%3B%20filename%2A%3DUTF-8%27%

In [10]:
data_file = "data/rate_data_20ms.pkl"
metadata_file = "data/trial_ids.pkl"

rates = pickle.load(open(data_file, "rb"))
trial_ids = pickle.load(open(metadata_file, "rb"))

# defining the set of conditions
conditions = ["DownLeft", "Left", "UpLeft", "Up", "UpRight", "Right", "DownRight"]

# list of days
days = rates.keys()

# define some parameters
pca_n = 5
filter_data = True

# storing all distance matrices
embeddings = []
distance_matrices = []
times = [] # to store the time point of each node in the trajectory
all_condition_labels = [] # to store the condition label for each node
all_trial_ids = [] # trial ids for each node
all_sampled_ids = [] # to store all the nodes sampled by marble

# loop over each day
for day in days:

    # first stack all trials from that day together and fit pca
    print(day)
    pca = fit_pca(rates, day, conditions, filter_data=filter_data, pca_n=pca_n)
    pos, vel, timepoints, condition_labels, trial_indexes = format_data(rates, 
                                                                        trial_ids,
                                                                        day, 
                                                                        conditions, 
                                                                        pca=pca,
                                                                        filter_data=filter_data)
    # construct data for marble
    data = MARBLE.construct_dataset(
        anchor=pos,
        vector=vel,
        k=30,
        spacing=0.0,
        delta=1.5,
    )

    params = {
        "epochs": 50,  # optimisation epochs
        "order": 2,  # order of derivatives
        "hidden_channels": 100,  # number of internal dimensions in MLP
        "out_channels": 3, # or 3 for Fig3
        "inner_product_features": False,
        "diffusion": True,
        "batch_size": 1024
    }

    model = MARBLE.net(data, params=params)

    model.fit(data, outdir="data/session_{}_20ms".format(day))
    data = model.transform(data)

    n_clusters = 50
    data = MARBLE.distribution_distances(data, n_clusters=n_clusters)

    embeddings.append(data.emb)
    distance_matrices.append(data.dist)
    times.append(np.hstack(timepoints))
    all_condition_labels.append(data.y)
    all_trial_ids.append(np.hstack(trial_indexes))
    all_sampled_ids.append(data.sample_ind)

    # save over after each session (incase computations crash)
    with open("data/marble_embeddings_20ms_out3.pkl", "wb") as handle:
        pickle.dump(
            [
                distance_matrices,
                embeddings,
                times,
                all_condition_labels,
                all_trial_ids,
                all_sampled_ids,
            ],
            handle,
            protocol=pickle.HIGHEST_PROTOCOL,
        )

# final save
with open("data/marble_embeddings_20ms_out3.pkl", "wb") as handle:
    pickle.dump(
        [
            distance_matrices,
            embeddings,
            times,
            all_condition_labels,
            all_trial_ids,
            all_sampled_ids,
        ],
        handle,
        protocol=pickle.HIGHEST_PROTOCOL,
    )

0

---- Embedding dimension: 5
---- Signal dimension: 5
---- Computing kernels ... 
---- Computing full spectrum ...
              (if this takes too long, then run construct_dataset()
              with number_of_eigenvectors specified) 
---- Settings: 

epochs : 50
order : 2
hidden_channels : [100]
out_channels : 3
inner_product_features : False
diffusion : True
batch_size : 1024
lr : 0.01
momentum : 0.9
dropout : 0.0
batch_norm : batch_norm
bias : True
frac_sampled_nb : -1
include_positions : False
include_self : True
vec_norm : False
emb_norm : False
seed : 0
dim_signal : 5
dim_emb : 5
n_sampled_nb : -1

---- Number of features to pass to the MLP:  155
---- Total number of parameters:  16104

Using device cuda:0

---- Training network ...

---- Timestamp: 20250829-205334

Epoch: 0, Training loss: 1.381683, Validation loss: 1.3653, lr: 0.0100 *
Epoch: 1, Training loss: 1.373892, Validation loss: 1.3768, lr: 0.0100
Epoch: 2, Training loss: 1.368482, Validation loss: 1.3725, lr: 0.010

In [3]:
import seaborn as sns
import pickle

In [4]:
with open("/media/ubuntu/sda/MARBLE/data/marble_embeddings_20ms_out3.pkl", "rb") as handle:
    a = pickle.load(handle)

  a = pickle.load(handle)


In [5]:
distance_matrices, embeddings, times, all_condition_labels, all_trial_ids, all_sampled_ids = a[0], a[1], a[2], a[3], a[4], [5]

In [18]:
import plotly.graph_objects as go
import numpy as np
import matplotlib.pyplot as plt
fig = go.Figure()

id = 9
for i in range(max(all_trial_ids[id])):
    index = np.where(all_trial_ids[id] == i)[0]
    color_index = all_condition_labels[id][index[0]].numpy()
    
    # 获取颜色 (从tab10颜色映射中)
    color = plt.get_cmap('tab10')(color_index)
    color_rgba = f'rgba({int(color[0]*255)}, {int(color[1]*255)}, {int(color[2]*255)}, {color[3]})'
    
    # 添加三维散点轨迹
    fig.add_trace(go.Scatter3d(
        x=embeddings[id][index, 0],
        y=embeddings[id][index, 1],  # 使用第二个维度作为y轴
        z=embeddings[id][index, 2],  # 使用第三个维度作为z轴
        mode='markers',  # 同时显示点和线
        marker=dict(
            size=4,
            color=color_rgba,
        ),
        line=dict(
            color=color_rgba,
            width=2
        ),
        name=f'Trial {i}'
    ))

# 更新布局
fig.update_layout(
    title='3D Embedding Visualization',
    scene=dict(
        xaxis_title='Dimension 1',
        yaxis_title='Dimension 2',
        zaxis_title='Dimension 3'
    ),
    width=800,
    height=600
)

# 显示图形
fig.show()