# Brain-Machine Interface Spike Sorting with PCA

### EE 16B: Designing Information Devices and Systems II, Adapted from Original SVD Lab 

## Table of Contents
* [Task 1: Two Neuron Spike Sorting](#Task-1:-Two-Neuron-Spike-Sorting)
* [Task 2: Three Neuron Spike Sorting](#Task-2:-Three-Neuron-Spike-Sorting)
* [Task 3: Determining Neurons](#Task-3:-Determining-Neurons)

## Task 1: Two Neuron Spike Sorting 

## Part a)

We are going to load in several waveforms for each neuron. We have three neurons total. We will refer to these as neurons 1, 2, and 3 respectively.

More specifically:  
Neuron 1 corresponds to the waveforms with name sig118a_wf. These are just a list of many waveforms collected by neuron 1. Each waveform has 32 samples of Neuron 1.   
Neuron 2 corresponds to the waveforms with name sig118b_wf. These are just a list of many waveforms collected by neuron 2. Each waveform has 32 samples of Neuron 2.   
Neuron 3 corresponds to the waveforms with name sig118c_wf. These are just a list of many waveforms collected by neuron 3. Each waveform has 32 samples of Neuron 3.   

In [None]:
%matplotlib notebook
import matplotlib
import numpy as np
import scipy.io
import scipy.cluster
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
matplotlib.rcParams["figure.max_open_warning"] = 40

# Load data
presorted = {k: v for k, v in scipy.io.loadmat('spike_waveforms').items() \
             if k in ('sig118a_wf', 'sig118b_wf', 'sig118c_wf')}

# We are defining this list for plotting convenience later. Safe to ignore for the most part.
presorted_list = [presorted['sig118a_wf'], presorted['sig118b_wf'], presorted['sig118c_wf']]

In [None]:
# We can append all these waveforms into a huge matrix with 32 columns (1 column for each sample)
presorted_two_neurons = np.concatenate([presorted['sig118b_wf'], presorted['sig118c_wf']])
presorted_three_neurons = np.concatenate([presorted['sig118a_wf'], presorted['sig118b_wf'], presorted['sig118c_wf']])

In [None]:
def _make_training_set(data):
    """ Separate data set into 2 sets. 
    1/6 of the dataset is training set and the rest is test set
    Parameter:
        data: waveform data (width = number of samples per spike)
    """
    n = data.shape[0]
    idx_training = np.random.choice(n, n // 6, replace=False)
    training_set = data[idx_training]
    idx_test = [i for i in range(n) if i not in idx_training]
    test_set = data[idx_test]
   
    return training_set, test_set

# Create training and testing dataset
two_neurons_training, two_neurons_test = _make_training_set(presorted_two_neurons)
three_neurons_training, three_neurons_test = _make_training_set(presorted_three_neurons)

print("Training data shape:", three_neurons_training.shape)

To get an idea of what the data looks like, let us plot 1 row of the `two_neurons_training` matrix and print/plot it out

In [None]:
first_row = two_neurons_training[0,:]
print(f"First Row of 2 Neuron Training Set: {first_row}")
plt.title("First Row of 2 Neuron Training Set")
plt.xlabel('Samples')
plt.ylabel('Electrode Reading')
plt.plot(first_row)

So, this is what one reading looks like. What we have is a mixture of two types of neurons and their waveforms stored in the rows of `two_neurons_training`. Similarly, we have three types of neurons and their waveforms in the rows of `three_neurons_training`. 

In [None]:
# Note: This may take several minutes to run

# Plot 100 random spikes
plt.figure()
plt.plot(three_neurons_training[:100].T)
plt.xlim((0, 31))
plt.title('100 random spikes')

# Plot the 3 spike shapes based on the presorted data
for i, waveforms in enumerate(presorted_list):
    plt.figure()
    plt.title(f"Waveforms of neuron #{i+1} spike shape")
    plt.plot(waveforms.T)
    plt.xlabel('Samples')
    plt.ylabel('Electrode Reading')
    plt.figure()
    plt.title(f"Average waveform of neuron #{i+1} spike shape")
    plt.plot(np.mean(waveforms, axis=0))
    plt.xlabel('Samples')
    plt.ylabel('Electrode Reading')
    
plt.figure()   
for i, waveforms in enumerate(presorted_list):
    plt.plot(np.mean(waveforms, axis=0), label=f'neuron #{i+1}')
    plt.xlabel('Samples')
    plt.ylabel('Electrode Reading')
plt.xlim((0, 31))
plt.title('Averaged presorted 3 neuron spikes')
plt.legend()

You will be using <a href="http://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.svd.html">np.linalg.svd</a> in your PCA function. Read the documentation for this function to figure out how to choose the principal components used as the basis for the lower dimensional space. (Note: in the docs, `a.H` means the congugate transpose of a).

## Part e)

In [None]:
def PCA_train(training_set, n_components):
    """ Use np.linalg.svd to perform PCA
    Parameters:
        training_set: the data set to perform PCA on (MxN)
        n_components: the dimensionality of the basis to return (i.e., number of neurons)
    Returns: 
        The n_components principal components with highest significants
    """    
    
    # Our data (each signal waveform) is stored in the rows of the training set.
    # For that reason, our principal components will be the columns of V (rows of Vt).
    # So, we can grab the first n_components columns of V as our principal components.
    U, s, Vt = np.linalg.svd(training_set)
    # We need to grab the first n_components columns of V. 
    # Same thing as grabing the first n_components rows of Vt and then transposing.
    basis_components = Vt[:n_components].T
    
    return basis_components

def PCA_project(data, principal_components):
    """ Project the data set into the new basis vectors. 
    In another word, here we aim to project the vectors to the new basis cordinates (dimension K in this case) 
    and return the projection coefficients.
    Parameters:
        data: data to project (MxN)
        principal_components: Our k principal components as column vectors (NxK)
    Returns: 
        Data projected onto new_basis (MxK), i.e., the projection coefficients
    """
    # YOUR CODE HERE
    return
    # END YOUR CODE

## Part f)

In [None]:
# Perform PCA on two neurons training data and plot the first 2 principal components.

# YOUR CODE HERE (replace the ? in the function call)
two_new_basis_2pcs = PCA_train(?, ?)
# END YOUR CODE

# Plot the basis components
plt.figure()
for i, comp in enumerate(two_new_basis_2pcs.T):
    plt.plot(comp, label=f"Principal Component #{i+1}")
plt.title("First 2 Principal Components")
plt.legend()

## Part g)

In [None]:
random_directions = np.random.randint(2, size=(32, 2))

two_projected = PCA_project(two_neurons_test, random_directions)

# Plot the projected neurons
plt.figure()
plt.scatter(*two_projected.T, s=1)
plt.title('two_neurons_test onto random directions')
plt.xlabel('Projection onto 1st Principal Component')
plt.ylabel('Projection onto 2nd Principal Component')

# Project the presorted data and plot it
plt.figure()
presorted_two_projected = [PCA_project(spikes, random_directions) for spikes in presorted_list[1:]]
colors = ['#0000ff', '#00ff00']
for dat, color in zip(presorted_two_projected, colors):
    plt.scatter(*dat.T, c=color, alpha=0.2,s=1)
plt.title('Presorted data onto random directions')
plt.xlabel('Projection onto 1st Principal Component')
plt.ylabel('Projection onto 2nd Principal Component')

## Part h)

In [None]:
# Project the test data two_neurons_test to the basis you found earlier

# YOUR CODE HERE
two_projected = PCA_project(?, ?)
# END YOUR CODE


# Plot the projected neurons
plt.figure()
plt.scatter(*two_projected.T,s=1)
plt.title('two_neurons_test onto PCA directions')
plt.xlabel('Projection onto 1st Principal Component')
plt.ylabel('Projection onto 2nd Principal Component')

# Project the presorted data and plot it
plt.figure()
presorted_two_projected = [PCA_project(spikes, two_new_basis_2pcs) for spikes in presorted_list[1:]]
colors = ['#0000ff', '#00ff00']
for dat, color in zip(presorted_two_projected, colors):
    plt.scatter(*dat.T, c=color, s=1, alpha=0.2)
plt.title('Presorted data onto PCA directions')
plt.xlabel('Projection onto 1st Principal Component')
plt.ylabel('Projection onto 2nd Principal Component')

## Part i)

We have projected the 2 neuron set onto 2 principal components. Lets see what happens if we project the 2 neuron set onto 3 principal components.

In [None]:
# Perform PCA and plot the first 3 principal components.
two_new_basis_3pcs = PCA_train(two_neurons_training, 3)

# Plot the basis components
plt.figure()
for i, comp in enumerate(two_new_basis_3pcs.T):
    plt.plot(comp, label=f"Principal Component #{i+1}")
plt.title("First 3 Principal Components")
plt.legend()

In [None]:
def plot_3D(data, view_from_top=False):
    """ Takes list of arrays (x, y, z) coordinate triples
    One array of triples per color
    """
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    colors = ['#0000ff', '#00ff00', '#ff0000']
    for dat, color in zip(data, colors):
        Axes3D.scatter(ax, *dat.T, s=1, c=color, alpha=0.2)
    if view_from_top:
        ax.view_init(elev=90., azim=0)                # Move perspective to view from top
    return ax

        
# Classify the three_neurons_test data based on the basis computed above
two_projected_3pcs = PCA_project(two_neurons_test, two_new_basis_3pcs)

# Plot the resulting projection
ax = plot_3D([two_projected_3pcs], False)
plt.title('two_neurons_test projected to 3 principal components')
ax.set_xlabel('Projection onto 1st Principal Component')
ax.set_ylabel('Projection onto 2nd Principal Component')
ax.set_zlabel('Projection onto 3rd Principal Component')

presorted_projected = [PCA_project(spikes, two_new_basis_3pcs) for spikes in presorted_list[1:]]
ax = plot_3D(presorted_projected, False)
plt.title('Presorted data projected to 3 principal components')
ax.set_xlabel('Projection onto 1st Principal Component')
ax.set_ylabel('Projection onto 2nd Principal Component')
ax.set_zlabel('Projection onto 3rd Principal Component')

Note that the first principal component separates the two neurons in the $x$-axis. Thus, technically we only need 1 principal component to separate the two neurons. This is because the algorithm maximizes the square of the dot product of each signal with the principal component, which results in a large positive dot product with 1 neuron and a large negative dot product with the other.

## Task 2: Three Neuron Spike Sorting

## Part j)

In [None]:
# Train with three neuron data, producing 2 principal components

# YOUR CODE HERE
three_new_basis_2pcs = PCA_train(?, ?)
# END YOUR CODE

# Plot the resulting basis
plt.figure()
for i, comp in enumerate(three_new_basis_2pcs.T):
    plt.plot(comp, label=f"Principal Component #{i+1}")
plt.title("First 2 Principal Components")
plt.legend()

In [None]:
# Project datapoints on the principal components
basis = three_new_basis_2pcs[:,:2]
three_projected_2pcs = PCA_project(three_neurons_test, basis)
presorted_projected_2pcs = [PCA_project(spikes, basis) for spikes in presorted_list]

# Plot the resulting projection
plt.figure()
plt.title("three_neurons_test projected to 2 principal components")
plt.xlabel('Projection onto 1st Principal Component')
plt.ylabel('Projection onto 2nd Principal Component')
plt.scatter(three_projected_2pcs.T[0], three_projected_2pcs.T[1], s=1)

plt.figure()
plt.title("Presorted data projected to 2 principal components")
plt.xlabel('Projection onto 1st Principal Component')
plt.ylabel('Projection onto 2nd Principal Component')
for p in presorted_projected_2pcs:
    plt.scatter(p.T[0], p.T[1], s=1, alpha=0.2)

## Part k)

In [None]:
# Repeat training with three neuron data, producing 3 principal components

# YOUR CODE HERE
three_new_basis_3pcs = PCA_train(?, ?)
# END YOUR CODE

# Plot the resulting basis
plt.figure()
for i, comp in enumerate(three_new_basis_3pcs.T):
    plt.plot(comp, label=f"Principal Component #{i+1}")
plt.title("First 3 Principal Components")
plt.legend()

In [None]:
def plot_3D(data, view_from_top=False):
    """ Takes list of arrays (x, y, z) coordinate triples
    One array of triples per color
    """
    fig=plt.figure(figsize=(10,7))
    ax = fig.add_subplot(111, projection='3d')
    colors = ['#0000ff', '#00ff00', '#ff0000']
    for dat, color in zip(data, colors):
        Axes3D.scatter(ax, *dat.T, s=1, c=color, alpha=0.2)
    if view_from_top:
        ax.view_init(elev=90.,azim=0)                # Move perspective to view from top
    return ax

        
# Classify the three_neurons_test data based on the basis computed above
# YOUR CODE HERE
three_projected_3pcs = PCA_project(?, ?)
# END YOUR CODE

# Plot the resulting projection
ax = plot_3D([three_projected_3pcs], False)
plt.title('three_neurons_test projected to 3 principal components')
ax.set_xlabel('Projection onto 1st Principal Component')
ax.set_ylabel('Projection onto 2nd Principal Component')
ax.set_zlabel('3rd Direction Projection')

presorted_projected_3pcs = [PCA_project(spikes, three_new_basis_3pcs) for spikes in presorted_list]
ax = plot_3D(presorted_projected_3pcs, False)
plt.title('Presorted data projected to 3 principal components')
ax.set_xlabel('Projection onto 1st Principal Component')
ax.set_ylabel('Projection onto 2nd Principal Component')
ax.set_zlabel('3rd Direction Projection')

Change the second argument to the `plot_3D` function calls above to True or rotate in the interactive viewer to view the plots "from the top" (i.e. looking down the positive z axis).

## Task 3: Determining Neurons

## Part m)

In [None]:
# Find the centroids of the presorted data
# HINT: Use presorted_two_projected for the presorted data.
# YOUR CODE HERE
centroid1 =
centroid2 =
# END YOUR CODE 

print("Centroid 1:", centroid1)
print("Centroid 2:", centroid2)

In [None]:
def which_neuron(data_point, centroid1, centroid2):
    """ Determine which centroid is closest to the data point
    Inputs:
        data_point: 1x2 array containing x/y coordinates of data point
        centroid1: 1x2 array containing x/y coordinates of centroid 1
        centroid2: 1x2 array containing x/y coordinates of centroid 2
    Returns: 
        The number of the centroid closest to the data point
    """
    dist1 = np.linalg.norm(data_point - centroid1)
    dist2 = np.linalg.norm(data_point - centroid2)
    
    if dist1 >= dist2:
    # YOUR CODE HERE
        return  # What should we return if we are closer to centroid 2?
    else:
        return  # What should we return if we are closer to centroid 1?
    # END YOUR CODE 

In [None]:
num_of_firings = np.zeros(2, dtype=np.int64)
label_arr = []
for datapoint in two_projected:
    neuron_number = which_neuron(datapoint, centroid1, centroid2)
    num_of_firings[neuron_number - 1] += 1
    label_arr.append(neuron_number)

print(f'Neuron 1 identified {num_of_firings[0]:d} times')
print(f'Neuron 2 identified {num_of_firings[1]:d} times')

plt.figure()
plt.title('Centroid Mean Classification')
plt.xlabel('Projection onto 1st Principal Component')
plt.ylabel('Projection onto 2nd Principal Component')
plt.scatter(*two_projected.T,  c=label_arr, marker='.')