# Imports and Data Loading

In [None]:
%reload_ext autoreload
%matplotlib inline 
import h5py
import numpy as np
import scipy 
import matplotlib.pyplot as plt
import pandas as pd
import os
from mpl_toolkits.mplot3d import Axes3D
from datetime import datetime

In [None]:
# Load data and display the sessions 
animal="FN" # "FN" or "WI" 
data_file = f"/Users/withercp/Documents/dev/brainhack/data/TH_task_{animal}_singleprobe_500ms.h5"
assert os.path.exists(data_file), "File not found: {data_file}".format(data_file=data_file)

with h5py.File(data_file, 'r') as f:
    sessions = list(f.keys())

In [None]:
i = 0

# Load data for a particular session and view the format of the data
session = sessions[i]
# convert mm-dd-yyyy to mm/dd/yyyy
date_str = session.replace('-', '/')

with h5py.File(data_file, 'r') as f:
    print(f[sessions[i]].keys())
    epoc_np=np.array(f[sessions[i]]['epoch'])
    rem_rec_np=np.array(f[sessions[i]]['rem_rec'])
    trial_np=np.array(f[sessions[i]]['trial'])
    binned_spike_np = np.array(f[sessions[i]]['binned_spike'])
    meta = np.array(f[sessions[i]]['meta'])

rem_rec_bin = np.array([1 if x == b'Remote' else 0 for x in rem_rec_np])

block_num = 0
block_nums = np.zeros(len(rem_rec_bin), dtype=int)  # Initialize array with zeros

for i in range(1, len(rem_rec_bin)):
    if rem_rec_bin[i] != rem_rec_bin[i-1]:  # If trial type changes
        block_num += 1  # Increment block number
    block_nums[i] = block_num  # Assign block number to current trial

unique_blocks = np.unique(block_nums)

In [None]:
fig, ax = plt.subplots(figsize=(10,8))

x = np.arange(binned_spike_np.shape[1])
y = np.mean(binned_spike_np, axis=0)
ax.scatter(
    x,
    y,
)
plt.show()

In [None]:
# # Get correct and incorrect trials per screen
# df = pd.read_csv(f"/Users/withercp/Documents/dev/brainhack/data/{animal}_TT_ALL_refined.csv")
# df = df[df['Date'] == date_str]
# print(df.head())

# screen_correct = np.zeros((4, df['TrialNum'].max()))
# screen_correct[0, :] = df["Correct"][::4]
# screen_correct[1, :] = df["Correct"][1::4]
# screen_correct[2, :] = df["Correct"][2::4]
# screen_correct[3, :] = df["Correct"][3::4]
# screen_correct_drift = screen_correct[:, :-1]

In [None]:
scipy.stats.pearsonr(x,y)

In [None]:
# Only keep rows where in at least one bin, the neuron fired at least 500 times 
# (i.e. the neuron is not active in that bin)

rows_to_keep = [] 
for i in unique_blocks:
    for j in range(binned_spike_np.shape[0]):
        if np.sum(binned_spike_np[j, block_nums == i]) >= 500:
            rows_to_keep.append(j)

rows_to_keep = np.unique(rows_to_keep)

binned_spike_np = binned_spike_np[rows_to_keep, :]

# Normalize binned spike data 
binned_spike_np_z = scipy.stats.zscore(binned_spike_np, axis=1)

# Classifier

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression

# Simple logistic regression classifier to predict behavior (classification) from low dimensional data (data)
def logistic_regression_classifier(data, classification):
    # Split the data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(data, classification, test_size=0.2, random_state=42)

    # Fit the model
    model = LogisticRegression()
    model.fit(X_train, y_train)

    # Make predictions
    y_pred = model.predict(X_test)

    # Calculate the accuracy
    accuracy = accuracy_score(y_test, y_pred)
    return accuracy

# Tracking representations

In [None]:
epoc_np[:100]