In [16]:
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache

In [17]:
# Example cache directory path, it determines where downloaded data will be stored
output_dir = '/home/marcush/Data/AllenData'

In [None]:
# this path determines where downloaded data will be stored
manifest_path = os.path.join(output_dir, "manifest.json")

cache = EcephysProjectCache.from_warehouse(manifest=manifest_path)

print(cache.get_all_session_types())

In [None]:
sessions = cache.get_session_table()
brain_observatory_type_sessions = sessions[sessions["session_type"] == "brain_observatory_1.1"]
brain_observatory_type_sessions.tail()

In [None]:
file_path = "/home/marcush/Data/AllenData/PerSessionUnitYield.pkl"
loaded_all_unit_counts = pd.read_pickle(file_path)


print(f"Recorded Areas:")
print(loaded_all_unit_counts.keys())

print("Number of units per recording, per area:")
print(loaded_all_unit_counts)

In [None]:
session_id = 791319847
session = cache.get_session_data(session_id)

In [None]:
brain_observatory_type_sessions

In [None]:
row_format = session.structurewise_unit_counts.to_frame().T  # Transpose the DataFrame

row_format

In [None]:
session.metadata


In [None]:
session.structurewise_unit_counts


In [None]:
presentations = session.get_stimulus_table("flashes")
units = session.units[session.units["ecephys_structure_acronym"] == 'VISp']

time_step = 0.01
time_bins = np.arange(-0.1, 0.5 + time_step, time_step)

histograms = session.presentationwise_spike_counts(
    stimulus_presentation_ids=presentations.index.values,  
    bin_edges=time_bins,
    unit_ids=units.index.values
)

histograms.coords

In [None]:
histograms.shape # trial, time, unit. use 'histograms.coords' to confirm

In [None]:
mean_histograms = histograms.mean(dim="stimulus_presentation_id")

fig, ax = plt.subplots(figsize=(8, 8))
ax.pcolormesh(
    mean_histograms["time_relative_to_stimulus_onset"], 
    np.arange(mean_histograms["unit_id"].size),
    mean_histograms.T, 
    vmin=0,
    vmax=1
)

ax.set_ylabel("unit", fontsize=24)
ax.set_xlabel("time relative to stimulus onset (s)", fontsize=24)
ax.set_title("peristimulus time histograms for VISp units on flash presentations", fontsize=24)

plt.show()

In [24]:
type(histograms)

new_hist = np.array(histograms)

In [None]:
new_hist.shape

# Image Classification

In [None]:
scene_presentations = session.get_stimulus_table("natural_scenes")
visp_units = session.units[session.units["ecephys_structure_acronym"] == "VISp"]

spikes = session.presentationwise_spike_times(
    stimulus_presentation_ids=scene_presentations.index.values,
    unit_ids=visp_units.index.values[:]
)

spikes

In [None]:
# Create a dictionary where keys are tuples of (trial, unit) and values are lists of spike times
grouped = spikes.groupby(['stimulus_presentation_id', 'unit_id'])
spike_times_dict = grouped['time_since_stimulus_presentation_onset'].apply(list).to_dict()
# E.g.: trial_unit_key = (scene_presentations.index.values[0], visp_units.index.values[3]); spike_times = spike_times_dict[trial_unit_key]; print(spike_times)

In [None]:
trial_unit_key = (scene_presentations.index.values[0], visp_units.index.values[3])
spike_times = spike_times_dict[trial_unit_key]

print(spike_times)

In [None]:
unique_elements, counts = np.unique(scene_presentations['stimulus_condition_id'], return_counts=True)
print(unique_elements)  # List of unique elements
print(counts)  # Counts of each unique element

In [None]:
spikes["count"] = np.zeros(spikes.shape[0])
spikes = spikes.groupby(["stimulus_presentation_id", "unit_id"]).count()

design = pd.pivot_table(
    spikes, 
    values="count", 
    index="stimulus_presentation_id", 
    columns="unit_id", 
    fill_value=0.0,
    aggfunc=np.sum
)

design

In [None]:
targets = scene_presentations.loc[design.index.values, "frame"]
targets

In [15]:
from sklearn import svm
from sklearn.model_selection import KFold
from sklearn.metrics import confusion_matrix

In [16]:
design_arr = design.values.astype(float)
targets_arr = targets.values.astype(int)

labels = np.unique(targets_arr)

In [None]:
presentations.index.size

In [None]:
design_arr.shape

In [None]:
""""
Using kernalized SVC
"""""

accuracies = []
confusions = []

for train_indices, test_indices in KFold(n_splits=5).split(design_arr):
    
    clf = svm.SVC(gamma="scale", kernel="rbf")
    clf.fit(design_arr[train_indices], targets_arr[train_indices])
    
    test_targets = targets_arr[test_indices]
    test_predictions = clf.predict(design_arr[test_indices])
    
    accuracy = 1 - (np.count_nonzero(test_predictions - test_targets) / test_predictions.size)
    print(accuracy)
    
    accuracies.append(accuracy)
    confusions.append(confusion_matrix(y_true=test_targets, y_pred=test_predictions, labels=labels))

In [None]:
""""
Using logistic regression
"""""
from sklearn.linear_model import LogisticRegression

accuracies = []
confusions = []

for train_indices, test_indices in KFold(n_splits=5).split(design_arr):
    
    # Replace SVM classifier with Logistic Regression
    clf = LogisticRegression(max_iter=1000)
    clf.fit(design_arr[train_indices], targets_arr[train_indices])
    
    test_targets = targets_arr[test_indices]
    test_predictions = clf.predict(design_arr[test_indices])
    
    accuracy = 1 - (np.count_nonzero(test_predictions - test_targets) / test_predictions.size)
    print(accuracy)
    
    accuracies.append(accuracy)
    confusions.append(confusion_matrix(y_true=test_targets, y_pred=test_predictions, labels=labels))

In [None]:
print(f"mean accuracy: {np.mean(accuracies)}")
print(f"chance: {1/labels.size}")

In [None]:
mean_confusion = np.mean(confusions, axis=0)

fig, ax = plt.subplots(figsize=(8, 8))

img = ax.imshow(mean_confusion)
fig.colorbar(img)

ax.set_ylabel("actual")
ax.set_xlabel("predicted")

plt.show()

In [None]:
best = labels[np.argmax(np.diag(mean_confusion))]
worst = labels[np.argmin(np.diag(mean_confusion))]

fig, ax = plt.subplots(1, 2, figsize=(16, 8))

best_image = cache.get_natural_scene_template(best)
ax[0].imshow(best_image, cmap=plt.cm.gray)
ax[0].set_title("most decodable", fontsize=24)

worst_image = cache.get_natural_scene_template(worst)
ax[1].imshow(worst_image, cmap=plt.cm.gray)
ax[1].set_title("least decodable", fontsize=24)


plt.show()