Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logic for design matrix creation #6

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# useful constants during analysis

FEATURES = [
'CIRCLE', 'SQUARE', 'STAR', 'TRIANGLE',
'CYAN', 'GREEN', 'MAGENTA', 'YELLOW',
'ESCHER', 'POLKADOT', 'RIPPLE', 'SWIRL'
]
]

NUM_UNITS = 59
pqz317 marked this conversation as resolved.
Show resolved Hide resolved

COLUMN_NAMES_W_UNITS = FEATURES + ["CORRECT", "INCORRECT"] + [f"unit_{i}" for i in range(0, NUM_UNITS)]
COLUMN_NAMES = FEATURES + ["CORRECT", "INCORRECT"]
99 changes: 99 additions & 0 deletions create_design_matrix.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Notebook to create and store a design matrix of behavior and spikes "
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"from spike_tools import (\n",
" general as spike_general,\n",
" analysis as spike_analysis,\n",
")\n",
"import data_utils\n",
"from constants import FEATURES, COLUMN_NAMES_W_UNITS\n",
"\n",
"species = 'nhp'\n",
"subject = 'SA'\n",
"exp = 'WCST'\n",
"session = 20180802 # this is the session for which there are spikes at the moment. \n",
"\n",
"tau_pre = 20\n",
"tau_post = 0"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"spikes_by_bins = pd.read_pickle('/data/processed/sub-SA_sess-20180802_spike_counts_binsize_50.pickle')\n",
"beh_by_bins = pd.read_pickle('/data/processed/sub-SA_sess-20180802_behavior_binsize_50.pickle')\n",
"intervals = pd.read_pickle(\"/data/processed/sub-SA_sess-20180802_interval_1500_fb_1500_binsize_50.pickle\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"design_mat = data_utils.get_design_matrix(spikes_by_bins, beh_by_bins, COLUMN_NAMES_W_UNITS, tau_pre, tau_post)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"design_mat.to_pickle(\"/data/processed/sub-SA_sess-20180802_design_mat_taupre_20_taupost_0_binsize_50.pickle\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
128 changes: 128 additions & 0 deletions data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from constants import FEATURES
pqz317 marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
import pandas as pd

def get_behavior_by_bins(bin_size, beh):
"""
bin_size: in miliseconds, bin size
data: dataframe for behavioral data from object features csv
Returns: new dataframe with one-hot encoding of features, feedback
"""
max_time = np.max(beh["TrialEnd"].values)
max_bin_idx = int(max_time / bin_size) + 1
columns = FEATURES + ["CORRECT", "INCORRECT"]
types = ["f4" for _ in columns]
zipped = list(zip(columns, types))
dtype = np.dtype(zipped)
arr = np.zeros((max_bin_idx), dtype=dtype)

for _, row in beh.iterrows():
# grab features of item chosen
item_chosen = int(row["ItemChosen"])
color = row[f"Item{item_chosen}Color"]
shape = row[f"Item{item_chosen}Shape"]
pattern = row[f"Item{item_chosen}Pattern"]

chosen_time = row["FeedbackOnset"] - 800
chosen_bin = int(chosen_time / bin_size)
arr[chosen_bin][color] = 1
arr[chosen_bin][shape] = 1
arr[chosen_bin][pattern] = 1

feedback_bin = int(row["FeedbackOnset"] / bin_size)
# print(feedback_bin)
if row["Response"] == "Correct":
arr[feedback_bin]["CORRECT"] = 1
else:
arr[feedback_bin]["INCORRECT"] = 1
df = pd.DataFrame(arr)
df["bin_idx"] = np.arange(len(df))
return df


def get_spikes_by_bins(bin_size, spike_times):
"""Given a bin_size and a series of spike times, return spike counts by bin.
Args:
bin_size: size of bins in miliseconds
spike_times: dataframe with unit_id, spike times.
Returns:
df with bin_idx, unit_* as columns, filled with spike counts
"""

units = np.unique(spike_times.UnitID.values)
time_stamp_max = int(spike_times.SpikeTime.max()) + 1

num_time_bins = int(time_stamp_max/bin_size) + 1
bins = np.arange(num_time_bins) * bin_size

df = pd.DataFrame(data={'bin_idx': np.arange(num_time_bins)[:-1]})
for unit in units:
unit_spike_times = spike_times[spike_times.UnitID==unit].SpikeTime.values
unit_spike_counts, bin_edges = np.histogram(unit_spike_times, bins=bins)
df[f'unit_{unit}'] = unit_spike_counts
return df

def get_trial_intervals(behavioral_data, event="FeedbackOnset", pre_interval=0, post_interval=0, bin_size=50):
"""Per trial, finds time interval surrounding some event in the behavioral data

Args:
behavioral_data: Dataframe describing each trial, must contain
columns: TrialNumber, whatever 'event' param describes
event: name of event to align around, must be present as a
column name in behavioral_data Dataframe
pre_interval: number of miliseconds before event
post_interval: number of miliseconds after event

Returns:
DataFrame with num_trials length, columns: TrialNumber,
IntervalStartTime, IntervalEndTime
"""
trial_event_times = behavioral_data[["TrialNumber", event]]

intervals = np.empty((len(trial_event_times), 3))
intervals[:, 0] = trial_event_times["TrialNumber"]
intervals[:, 1] = trial_event_times[event] - pre_interval
intervals[:, 2] = trial_event_times[event] + post_interval
intervals_df = pd.DataFrame(columns=["TrialNumber", "IntervalStartTime", "IntervalEndTime"])
intervals_df["TrialNumber"] = trial_event_times["TrialNumber"].astype(int)
intervals_df["IntervalStartTime"] = trial_event_times[event] - pre_interval
intervals_df["IntervalEndTime"] = trial_event_times[event] + post_interval
intervals_df["IntervalStartBin"] = (intervals_df["IntervalStartTime"] / bin_size).astype(int)
intervals_df["IntervalEndBin"] = (intervals_df["IntervalEndTime"] / bin_size).astype(int)
return intervals_df


def get_design_matrix(spikes_by_bins, beh_by_bins, columns, tau_pre, tau_post):
"""
Reformats data as a design matrix dataframe, where for each of the specified columns,
additional columns are added for each of the time points between tau_pre and tau_post
Args:
spike_by_bins: df with bin_idx, unit_* as columns
beh_by_bins: df with bin_idx, behavioral vars of interest as columns
columns: columns to include, must be present in either spike_by_bins or beh_by_bins
tau_pre: number of bins to look in the past
tau_post: number of bins to look in the future
Returns:
df with bin_idx, columns for each time points between tau_pre and tau_post
"""
joint = pd.merge(spikes_by_bins, beh_by_bins, on="bin_idx", how="inner")
res = pd.DataFrame()
taus = np.arange(-tau_pre, tau_post)
for tau in taus:
shift_idx = -1 * tau
column_names = [f"{x}_{tau}" for x in columns]
res[column_names] = joint.shift(shift_idx)[columns]
res["bin_idx"] = joint["bin_idx"]
return res


def get_interval_bins(intervals):
"""
Gets all the bins belonging to all the intervals
Args:
intervals: df with trialnumber, IntervalStartBin, IntervalEndBin
Returns:
np array of all bins for all trials falling between startbin and endbin
"""
interval_bins = intervals.apply(lambda x: np.arange(x.IntervalStartBin, x.IntervalEndBin).astype(int), axis=1)
return np.concatenate(interval_bins.to_numpy())
91 changes: 14 additions & 77 deletions format_beh.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"from spike_tools import (\n",
" general as spike_general,\n",
" analysis as spike_analysis,\n",
")\n",
"import data_utils\n",
"from constants import FEATURES\n",
"\n",
"species = 'nhp'\n",
Expand All @@ -22,7 +26,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -33,60 +37,20 @@
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
"def get_X_by_bins(bin_size, data):\n",
" max_time = np.max(valid_beh[\"TrialEnd\"].values)\n",
" max_bin_idx = int(max_time / bin_size) + 1\n",
" columns = FEATURES + [\"CORRECT\", \"INCORRECT\"]\n",
" types = [\"f4\" for _ in columns]\n",
" zipped = list(zip(columns, types))\n",
" dtype = np.dtype(zipped)\n",
" arr = np.zeros((max_bin_idx), dtype=dtype)\n",
"\n",
" for _, row in data.iterrows():\n",
" # grab features of item chosen\n",
" item_chosen = int(row[\"ItemChosen\"])\n",
" color = row[f\"Item{item_chosen}Color\"]\n",
" shape = row[f\"Item{item_chosen}Shape\"]\n",
" pattern = row[f\"Item{item_chosen}Pattern\"]\n",
"\n",
" chosen_time = row[\"FeedbackOnset\"] - 800\n",
" chosen_bin = int(chosen_time / bin_size)\n",
" arr[chosen_bin][color] = 1\n",
" arr[chosen_bin][shape] = 1\n",
" arr[chosen_bin][pattern] = 1\n",
"\n",
" feedback_bin = int(row[\"FeedbackOnset\"] / bin_size)\n",
" # print(feedback_bin)\n",
" if row[\"Response\"] == \"Correct\":\n",
" arr[feedback_bin][\"CORRECT\"] = 1\n",
" else:\n",
" arr[feedback_bin][\"INCORRECT\"] = 1\n",
" df = pd.DataFrame(arr)\n",
" df[\"bin_idx\"] = np.arange(len(df))\n",
" return df\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": 75,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"res = get_X_by_bins(50, valid_beh)"
"behavior_by_bins = data_utils.get_behavior_by_bins(50, valid_beh)"
]
},
{
"cell_type": "code",
"execution_count": 76,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"res.to_pickle('/data/processed/sub-SA_sess-20180802_behavior_binsize_50.pickle')"
"behavior_by_bins.to_pickle('/data/processed/sub-SA_sess-20180802_behavior_binsize_50.pickle')"
]
},
{
Expand All @@ -99,47 +63,20 @@
},
{
"cell_type": "code",
"execution_count": 79,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def get_trial_intervals(behavioral_data, event=\"FeedbackOnset\", pre_interval=0, post_interval=0, bin_size=50):\n",
" \"\"\"Per trial, finds time interval surrounding some event in the behavioral data\n",
"\n",
" Args:\n",
" behavioral_data: Dataframe describing each trial, must contain\n",
" columns: TrialNumber, whatever 'event' param describes\n",
" event: name of event to align around, must be present as a\n",
" column name in behavioral_data Dataframe\n",
" pre_interval: number of miliseconds before event\n",
" post_interval: number of miliseconds after event\n",
"\n",
" Returns:\n",
" DataFrame with num_trials length, columns: TrialNumber,\n",
" IntervalStartTime, IntervalEndTime\n",
" \"\"\"\n",
" trial_event_times = behavioral_data[[\"TrialNumber\", event]]\n",
"\n",
" intervals = np.empty((len(trial_event_times), 3))\n",
" intervals[:, 0] = trial_event_times[\"TrialNumber\"]\n",
" intervals[:, 1] = trial_event_times[event] - pre_interval\n",
" intervals[:, 2] = trial_event_times[event] + post_interval\n",
" intervals_df = pd.DataFrame(columns=[\"TrialNumber\", \"IntervalStartTime\", \"IntervalEndTime\"])\n",
" intervals_df[\"TrialNumber\"] = trial_event_times[\"TrialNumber\"].astype(int)\n",
" intervals_df[\"IntervalStartTime\"] = trial_event_times[event] - pre_interval\n",
" intervals_df[\"IntervalEndTime\"] = trial_event_times[event] + post_interval\n",
" intervals_df[\"IntervalStartBin\"] = (intervals_df[\"IntervalStartTime\"] / bin_size).astype(int)\n",
" intervals_df[\"IntervalEndBin\"] = (intervals_df[\"IntervalEndTime\"] / bin_size).astype(int)\n",
" return intervals_df\n"
"intervals = data_utils.get_trial_intervals(valid_beh, pre_interval=1500, post_interval=1500, bin_size=50)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 82,
"metadata": {},
"outputs": [],
"source": [
"get_trial_intervals(valid_beh, pre_interval=1500, post_interval=1500, bin_size=50)"
"intervals.to_pickle(\"/data/processed/sub-SA_sess-20180802_interval_1500_fb_1500_binsize_50.pickle\")"
]
}
],
Expand Down