In [None]:
# import packages
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl

In [None]:
# load meta data
meta_fp = "data/metadata.csv" # meta data filepath
df = pd.read_csv(meta_fp, index_col=0)

# get data path
idx = 0
data_path = df[df["mouse_id"] == idx]["data_path"].item()

# load actual data
data = np.load(data_path, allow_pickle=True)
data = data.item()


In [None]:
data.keys()

In [None]:
# save timing information
stim_on, stim_off = data["stim"]
times = [stim_on, stim_off, data["go_cue"]]

In [None]:
data["L"].shape

# 
For this example, I am interested in expressing averaged trial activity in a new set of axes
> I want to reduce the neurons axes

In [None]:
def stratify_by_outcome(data, results, outcomes=[0, 1]):
    new_data = []
    for outcome in outcomes:
        data_by_outcome = data[:,:,np.where(results == outcome)]
        new_data.append(np.squeeze(data_by_outcome))
    return new_data

In [None]:
# get correct trial indices
activity = []
outcomes = [0, 1] # two possible outcomes: correct or incorrect

activity = stratify_by_outcome(data["L"], data["correct_L_trials"])
activity += stratify_by_outcome(data["R"], data["correct_R_trials"])


In [None]:
axis = 2
avg_activity = []
for i in range(len(activity)):
    avg_activity.append(np.mean(activity[i], axis=axis))

avg_activity_concat = np.concatenate(avg_activity)

In [None]:
avg_activity_concat.shape

In [None]:
from sklearn.decomposition import PCA
pca = PCA()
pca.fit(avg_activity_concat)

In [None]:
from helper import *

In [None]:
## show variance explained (by the first 10 PCs)
n_pc = 20
# plot_cum_var_explained(pca.explained_variance_ratio_, n=n_pc)
plot_var_explained(pca.explained_variance_ratio_, n=n_pc)

Now is a good time to pause and see if you understand what the PCs are. What is a good geometric understanding of the PCs we are finding here?

## Visualize neural activity in the first 2 PCs

In [None]:
## show time-evolution of average neural dynamics along 2 PCs
pc_id1, pc_id2 = 0, 1 # PC's of interest

colors = ["Reds", "Blues"] # colormaps for right and left
trial_length = data["L"].shape[0] # length of a trial

fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharex=True, sharey=True)

labels = ["go", "s_start", "s_end"] # timing labels
label_colors = ["b", "r"]
counter = 0

for i in range(2): # R, L
    for c, title in enumerate(["incorrect trials", "correct trials"]): # correct, incorrect
        # get PC projection for specific dir x cond 
        offset = trial_length * counter
        counter += 1
        # trial average data
        avg_act = avg_activity_concat[offset:trial_length + offset,:]
        # pca transformed data
        pcaD = pca.transform(avg_act)
        
        ax = axes[c]
        # set colormap
        cmap = mpl.colormaps.get_cmap(colors[i])
        for t in range(trial_length): # plot each time point
            s = ax.scatter(pcaD[t, pc_id1], pcaD[t, pc_id2], 
                           color=cmap(t/trial_length), marker="o")
        ax.set_title(title)
            
fig.supxlabel(f"PC{pc_id1}")
fig.supylabel(f"PC{pc_id2}")

# get a colorbar
from mpl_toolkits.axes_grid1 import make_axes_locatable
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)

cmap = mpl.colormaps.get_cmap("Greys")
norm = mpl.colors.Normalize(vmin=0, vmax=trial_length)

plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
             cax=cax, label='Time Steps', location="right")

### A 3D plot just for fun

In [None]:
# Dependencies for 3D plot
import plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots

PCS = [0, 1, 2] # PCs of interest

is_correct = True # decide to plot correct or error trials
correct, wrong = [0, 1], [2, 3] # for indexing
blocks = correct
if (not is_correct):
    blocks = wrong
    
ann = []

colorscales = ["Blues", "Reds"]
colorbarx = [0.95, 1]

fig = make_subplots(rows=1, cols=1, specs=[[{"type": "scene"}]])

for i in range(2):
    offset = blocks[i] * trial_length
    pcaD = pca.transform(avg_activity_concat[offset:trial_length+offset,:])
    d = {"x":pcaD[:,PCS[0]], "y":pcaD[:,PCS[1]], "z":pcaD[:,PCS[2]],
         "Time":list(range(trial_length))}

    # configure the trace
    trace = go.Scatter3d(x=d["x"], y=d["y"], z=d["z"],
                         mode='markers',
                         marker={'size': 5,
                                 'opacity': 1,
                                 'color': d["Time"], # color by time step
                                 'colorscale': colorscales[i],
                                 'colorbar': dict(thickness=20, x=colorbarx[i], title="Time")
                                }
    )


    fig.add_trace(trace, row=1, col=1)

fig.update_layout(
    margin=dict(l=30, r=30, t=30, b=30)
)

### What else can we do with our PCA outcomes?

In [None]:
def moving_avg(d, win=2):
    filt = np.ones(win)/win
    return np.convolve(d, filt)

def get_smoothed_velocity(D, win=2):
    d = np.diff(D, axis=0)
    return d, moving_avg(d, win)

In [None]:
f, axes = plt.subplots(1, 3, figsize=(15, 3), sharex=True, sharey=True)
labels = ["Left", "Right"]
colors = [ "tab:red", "tab:blue"]

# check one condition at a time
is_correct = True

correct, wrong = [0, 1], [2, 3] # for indexing
blocks = correct # default
cond = "Correct" # default
if (not is_correct):
    blocks = wrong
    cond = "Wrong"

# look at the first n pcs
n_pcs = 3

for i in range(len(blocks)):
    offset = blocks[i] * trial_length
    pcaD = pca.transform(avg_activity_concat[offset:trial_length+offset,:])

    for j in range(n_pcs): # plot for the first 3 PCs
        ax = axes[j]
        d, v = get_smoothed_velocity(pcaD[:, j], 5)
        ax.plot(v, label=labels[i], c=colors[i])
        ax.set_title("PC%d"%j)
        
        for k in range(3): # time markers
            t = times[k]
            ax.axvline(t, ls="--", c="black")
        ax.axhline(0, ls=":", c="gray")
        
plt.legend(bbox_to_anchor=(1.05, 0.5)) # put legend outside

## set overal title, axes labels
f.add_subplot(111, frameon=False)
# hide tick and tick label of the big axis
plt.tick_params(labelcolor='none', which='both', top=False, bottom=False, left=False, right=False)
plt.xlabel("Time", fontsize=14)
plt.ylabel("Velocity (disp. in unit time)", fontsize=14, labelpad=20)
plt.title("Velocity (smoothed) in the first 3 PCs through time [%s Trials]"%cond, fontsize=15, pad=35)