# Lecture 7
Analyzing neural data from neuropixel experiment

Adapted from neuromatch academy material https://github.com/NeuromatchAcademy/course-content/tree/main/tutorials

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats

In [None]:
import io
import requests
r = requests.get('https://osf.io/sy5xt/download')
if r.status_code != 200:
    print('Could not download data')
else:
    spike_times = np.load(io.BytesIO(r.content), allow_pickle=True)['spike_times']

In [None]:
def plot_isis(single_neuron_isis):
  plt.hist(single_neuron_isis, bins=50, histtype="stepfilled")
  plt.axvline(single_neuron_isis.mean(), color="orange", label="Mean ISI")
  plt.xlabel("ISI duration (s)")
  plt.ylabel("Number of spikes")
  plt.legend()

In [None]:
type(spike_times)

In [None]:
spike_times

In [None]:
spike_times[0]

In [None]:
spike_times[0].shape

In [None]:
i_neurons = [0, 321]
i_print = slice(0, 5)

for i in i_neurons:
  print(
    "Neuron {}:".format(i),
    spike_times[i].dtype,
    spike_times[i][i_print],
    "\n",
    sep="\n"
  )

In [None]:
n_neurons = len(spike_times)
total_spikes_per_neuron = [len(spike_times_i) for spike_times_i in spike_times]

print(f"Number of neurons: {n_neurons}")
print(f"Number of spikes for first five neurons: {total_spikes_per_neuron[:5]}")

In [None]:
plt.hist(total_spikes_per_neuron, bins=50, histtype="stepfilled")
plt.xlabel("Total spikes per neuron")
plt.ylabel("Number of neurons");

In [None]:
mean_spike_count = np.mean(total_spikes_per_neuron)
plt.hist(total_spikes_per_neuron, bins=50, histtype="stepfilled")
plt.xlabel("Total spikes per neuron")
plt.ylabel("Number of neurons")
plt.axvline(mean_spike_count, color="orange", label="Mean neuron")
plt.legend();

In [None]:
# Exercise: Add in a line for the median spike count

## Plotting rasters

In [None]:
spike_times_flat = np.concatenate(spike_times)
print(np.min(spike_times_flat))
print(np.max(spike_times_flat))

In [None]:
def restrict_spike_times(spike_times, interval):
  """Given a spike_time dataset, restrict to spikes within given interval.

  Args:
    spike_times (sequence of np.ndarray): List or array of arrays,
      each inner array has spike times for a single neuron.
    interval (tuple): Min, max time values; keep min <= t < max.

  Returns:
    np.ndarray: like `spike_times`, but only within `interval`
  """
  interval_spike_times = []
  for spikes in spike_times:
    interval_mask = (spikes >= interval[0]) & (spikes < interval[1])
    interval_spike_times.append(spikes[interval_mask])
  return np.array(interval_spike_times, object)

In [None]:
interval_spike_times = restrict_spike_times(spike_times, (5, 15))
neuron_idx = 1
plt.eventplot(interval_spike_times[neuron_idx], color=".2")
plt.xlabel("Time (s)")
plt.yticks([]);

In [None]:
neuron_idx = np.arange(0, len(spike_times), 5)
plt.eventplot(interval_spike_times[neuron_idx], color=".2")
plt.xlabel("Time (s)")
plt.yticks([]);

In [None]:
n_bins = 200
neuron_idx = 0
plt.hist(spike_times[neuron_idx], bins=n_bins)

In [None]:
# What is the average firing rate?
# What is the average firing rate from time 100 to 150?

In [None]:
bins = np.arange(0, 2710, 10)
psths = np.asarray([np.histogram(spike_times[i], bins=bins)[0] for i in range(spike_times.shape[0])])

In [None]:
from sklearn.decomposition import PCA
pca = PCA(n_components = 3)
neurons_decomp = pca.fit_transform(psths)

In [None]:
plt.scatter(x=neurons_decomp[:, 0], y=neurons_decomp[:, 1])

## Getting trialed spike rates

In [None]:
#@title Data retrieval and loading
import os
import requests
import hashlib

url = "https://osf.io/r9gh8/download"
fname = "W1D4_steinmetz_data.npz"
expected_md5 = "d19716354fed0981267456b80db07ea8"

if not os.path.isfile(fname):
  try:
    r = requests.get(url)
  except requests.ConnectionError:
    print("!!! Failed to download data !!!")
  else:
    if r.status_code != requests.codes.ok:
      print("!!! Failed to download data !!!")
    elif hashlib.md5(r.content).hexdigest() != expected_md5:
      print("!!! Data download appears corrupted !!!")
    else:
      with open(fname, "wb") as fid:
        fid.write(r.content)

def load_steinmetz_data(data_fname=fname):

  with np.load(data_fname) as dobj:
    data = dict(**dobj)

  return data

In [None]:
data = load_steinmetz_data()
for key, val in data.items():
  print(key, val.shape)

In [None]:
y = data["choices"]
X = data["spikes"]

In [None]:
from sklearn.linear_model import LogisticRegression
# Define the model
log_reg = LogisticRegression(penalty="none")

# Fit it to data
log_reg.fit(X, y)

In [None]:
y_pred = log_reg.predict(X)

In [None]:
# TODO: Calculate the accuracy of your linear regression

In [None]:
from sklearn.model_selection import cross_val_score
accuracies = cross_val_score(LogisticRegression(penalty='none'), X, y, cv=8)  # k=8 cross validation