<a href="https://colab.research.google.com/github/dudyu/neural_segmenation_tutorial/blob/main/Neural_Segmenation_of_Motor_Output.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Neural Segmenation of Motor Output

## 1. Overview
A prominent idea in motor neuroscience is that of motor primitives. Namely, that complex movements are composed from a set of simple atomic elements. The  kinematic (or dynamic) nature of these hypothesized elements, is, however, unknown. Nonetheless, if such elements originate in the CNS, we expect they will correspond to a temporal structure in the neural dynamics of the motor areas. Therefore, a possible approach to decompose movement to its primitives, is by segmenting the motor-neural activity in an unsupervised manner, and analysing how the resulting segments correspond to different movement features.
Here we demonstrate this approach, roughly following the methodology presented in [https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7302741/pdf/bhy060.pdf], using a dataset from Neural Latents Benchmark [https://neurallatents.github.io/datasets#mcrtt].

## 2. Data
We will be using the Random Target Task (RTT) dataset, provided by Joseph O'Doherty and Philip Sabes from UCSF. The data consists of Utah array recordings from M1 + end-effector position and velocity, of Macaque performing reaches between random targets on a plane (see https://zenodo.org/record/3854034). 

## 3. Method Guideline
The first step in our analysis is to obtain an unsupervised segmentation based on the neural dynamics. To this end, we fit a Hidden Markov Model (HMM) to the spike rate data. Due to limitations of the python package, we use a Gaussian emissions model, but generally a Poisson (or tabular) model is more suitable for this kind of data. After the HMM is fitted, we use it to decode the neural data, and recover the most probable sequence of hidden states that gave rise to it. A *segment* is then defined as starting/ending whenever a change of state occurs.
In the second step, the segments are projected on the kinematic data, to test their correspondence with different kinematic features.


In [None]:
## Packages and Data

!pip install git+https://github.com/neurallatents/nlb_tools.git
!pip install hmmlearn
!pip install dandi
!dandi download https://gui.dandiarchive.org/#/dandiset/000129


In [47]:
## Imports

from nlb_tools.nwb_interface import NWBDataset
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
from hmmlearn import hmm
from scipy.ndimage import gaussian_filter1d

plt.rcParams["figure.figsize"] = (14, 7)
plt.rcParams['xtick.labelsize'] = 6
plt.rcParams['ytick.labelsize'] = 6

In [40]:
## Params

BIN_SIZE_MS = 10            # data bin size
SPIKE_SMOOTH_MS = 50        # smooth size for spikes 
LAG_MS = 100                # lag between neural activity and motor output
NUM_HMM_STATES = 5          # number of states for HMM
MAX_DATA_DURATION_SEC = 60  # take subset of the data, for quicker analysis

## Helper functions

cmap = cm.get_cmap('tab20')

def project_states_on_kinematics(x, states):
    """ Project HMM states on kinematic feature 
      x - np array. either 2d data, or sequence of scalars
    """
    assert x.ndim in (1, 2)
    is_sequence = x.ndim == 1
    if is_sequence:
      x = np.stack([range(len(x)), x], axis=1)
    transitions = [0] + list(1 + np.nonzero(states[:-1] != states[1:])[0]) + [len(states)]
    plt.figure()
    for i in range(len(transitions) - 1):
        ix_from, ix_to = transitions[i], transitions[i + 1]
        color = cmap(states[ix_from] / NUM_HMM_STATES)
        plt.plot(x[ix_from: ix_to, 0], x[ix_from: ix_to, 1], color=color)
    if not is_sequence:
      plt.axis('equal')

def statewise_angular_hist(x, states, hist_bins=10):
    """ Histogram per state, for angular data """
    nrows = int(np.sqrt(NUM_HMM_STATES))
    ncols = int(np.ceil(NUM_HMM_STATES / nrows))
    theta = np.arange(0.0, 2 * np.pi, 2 * np.pi / hist_bins)
    width = (2*np.pi) / hist_bins
    _, axs = plt.subplots(nrows, ncols, subplot_kw=dict(projection='polar'))
    for state in range(NUM_HMM_STATES):
        radii, _ = np.histogram(x[states == state], bins=theta)
        plt.sca(axs[np.unravel_index(state, [nrows, ncols])])
        color = cmap(state / NUM_HMM_STATES)
        plt.bar(theta[:-1], radii, bottom=0.0, width=width, color=color)
        plt.gca().set_rticks([])
        plt.title(f'State {state}')


def statewise_hist(x, states, hist_bins=10):
    """ Histogram per state """
    nrows = int(np.sqrt(NUM_HMM_STATES))
    ncols = int(np.ceil(NUM_HMM_STATES / nrows))
    _, axs = plt.subplots(nrows, ncols)
    edges = np.linspace(x.min(), x.max(), hist_bins + 1)
    for state in range(NUM_HMM_STATES):
        plt.sca(axs[np.unravel_index(state, [nrows, ncols])])
        color = cmap(state / NUM_HMM_STATES)
        plt.hist(x[states == state], edges, color=color)
        plt.title(f'State {state}')


In [12]:
## Load and preprocess

# Load + resample to wanted bin size
dataset = NWBDataset("./000129/sub-Indy", "*train", split_heldout=False)
dataset.resample(BIN_SIZE_MS)

# Take subset (for performance)
data = dataset.data[:int(MAX_DATA_DURATION_SEC * 1000 / dataset.bin_width)]

# Spike rates + (lagged) kinematics:
lag_bins = int(LAG_MS / dataset.bin_width)
rates = data.spikes[:-lag_bins].to_numpy()
pos = data.finger_pos[lag_bins:].to_numpy()

valid_ixs = ~np.logical_or(np.any(np.isnan(rates), axis=1), np.any(np.isnan(pos), axis=1))
rates = rates[valid_ixs]
pos = pos[valid_ixs]
vel = data.finger_vel[lag_bins:].to_numpy()[valid_ixs]

In [13]:
# compute kinematic features
speed = np.linalg.norm(vel, axis=1)
direction = np.mod(np.arctan2(vel[:, 1], vel[:, 0]), 2 * np.pi)

pos -= pos.mean(axis=0)
pos_r = np.linalg.norm(pos, axis=1)
pos_theta = np.mod(np.arctan2(pos[:, 1], pos[:, 0]), 2 * np.pi)

signed_acc = np.concatenate([[0], np.diff(speed)]) / dataset.bin_width

In [14]:
# smooth spikes
smooth_sigma_bins = .5 * SPIKE_SMOOTH_MS / dataset.bin_width
rates = gaussian_filter1d(rates.astype('float32'), sigma=smooth_sigma_bins, axis=0)

In [15]:
## HMM

# Train HMM on neural data
model = hmm.GaussianHMM(NUM_HMM_STATES, "full")
model.fit(rates)

# Get HMM states
states = model.predict(rates)

In [None]:
## Display results

project_states_on_kinematics(pos, states)
plt.suptitle('Trajectory')

project_states_on_kinematics(speed, states)
plt.suptitle('Speed')

statewise_hist(signed_acc, states)
plt.suptitle('Acceleration')

statewise_hist(speed, states)
plt.suptitle('Speed')

statewise_hist(pos_r, states)
plt.suptitle('Radial Position')

statewise_angular_hist(direction, states)
plt.suptitle('Velocity Direction')

statewise_angular_hist(pos_theta, states)
plt.suptitle('Angular Position')
