In [1]:
import numpy as np
import scipy.io as sio
import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from scipy.stats import ttest_rel, f_oneway

In [3]:
subject = "zarya"
date = "20250411"
task = "tuning"

base_path = "/Users/j1155665/Desktop/Chris lab/Data Analysis"
file_name = f"@neural_{subject}/{subject}{date}dots3DMP.mat"
matfile_path = os.path.join(base_path, file_name)

try:
    mat_data = sio.loadmat(matfile_path)
    
    if 'dataStruct' in mat_data and 'data' in mat_data['dataStruct'].dtype.names:
        dataStruct = mat_data['dataStruct'][0,0]     
        data_field = dataStruct['data'][0,0]
        
        if task == "tuning":
            data = data_field['dots3DMPtuning'][0,0]
        else:
            data_spkrate = data['data_spkrate']
            stimOn_spkrate = data_spkrate['stimOn']
            saccOnset_spkrate = data_spkrate['saccOnset']
            postTargHold_spkrate = data_spkrate['postTargHold']
        
        events = data['events'][0, 0]
        unit = data['unit'][0, 0]
          
        if task != "tuning":
            choice = events['choice'].flatten().reshape(-1, 1)
            PDW = events['PDW'].flatten().reshape(-1, 1)
       
        data_spkrate = data['data_spkrate']
        depth = unit['depth'].flatten().reshape(-1, 1)
        cluster_id = unit['cluster_id'].flatten().reshape(-1, 1)
        modality = events['modality'].flatten().reshape(-1, 1)
        headingInd = events['headingInd'].flatten().reshape(-1, 1)
        coherence = events['coherenceInd'].flatten().reshape(-1, 1)
        
        print("Data loaded successfully.")
    else:
        print("Error: 'dataStruct.data' not found in the .mat file.")
        
except FileNotFoundError:
    print(f"Error: The file {matfile_path} was not found.")
except Exception as e:
    print(f"An error occurred while loading the file: {e}")

Data loaded successfully.


In [4]:
X = np.copy(data_spkrate)
n_trials = len(X)
n_units = len(X[0])
n_timepoints = X[1, 0].shape[1]

X_array = np.empty((n_trials, n_units, n_timepoints))

for i in range(n_trials):
    for j in range(n_units):
        X_array[i, j, :] = X[i, j]

######## Task Info  ########
if task == "tuning":
    event_info = {
        "offset": 0.05,  # ms
        "bin_size": 0.02,  # ms
        "num_time_points": n_timepoints, #bins
        "align_event": "stimOn",
        "center_start": -0.5,
        "center_stop": 0.5,
        "heading_angle": [-45, -21.5, -10, -3.9, 3.9, 10, 21.5, 45]
#         "heading_angle":[-90, -45, -22.5, -10, -3.9, -1.5, 0, 1.5, 3.9, 10, 22.5, 45, 90]
    }
    time_axis_heading = np.arange(
        event_info["center_start"],
        event_info["center_start"] + event_info["bin_size"] * event_info["num_time_points"],
        event_info["bin_size"]) * 1000  # Convert to ms
    print("loading tuning event info")
else: 
    event_info = {
        ### trial conditions
        "modality": ["vestibular", "visual", "combined"],
        "coherence": [0.2, 0.7],
        "heading_angle": [-10, -3.9, -1.5, 0, 1.5, 3.9, 10],
        "stimulus": ["left", "right"],
        "choice": ["left", "right"],
        "confidence": ["low bet", "high bet"],
        "name": ["modality", "coherenceInd", "headingInd", "choice", "PDW"],
        "class_1": [None, None, [1, 2, 3], 1, 0],
        "class_2": [None, None, [5, 6, 7], 2, 1],
        ### time info
        "offset": 0.05,  # ms
        "bin_size": 0.02,  # ms
        "align_event": ["stimOn", "saccOnset", "postTargHold"],
        "center_start": [-0.1, -0.8, -0.3],
        "center_stop": [0.8, 0.1, 0.1],
        ### for ploting
        "event_names": ['stimOn', 'Choice', 'PDW'],
        "spkrates": [stimOn_spkrate, saccOnset_spkrate, postTargHold_spkrate],
        "mods": [1, 2, 2, 3, 3],
        "cohs": [1, 1, 2, 1, 2],
        "mod_names": [
            "vestibular", 
            "visual, low coh",
            "visual, high coh",
            "combine, low coh",
            "combine, high coh"
        ]
    }
    print("loading dots3DMP event info")
    
######## Time Info ########
time_info = {

}

######## Recording Info ########

if date == "20250501":
    depth_ranges = [
        (0, 900),          # below MST
        (1000, 3080),      # MST
        (3900, 6460)     # VPS
    ]

    depth_labels = [
        "0–900 μm (below MST)",
        "1000–3080 μm (MST)",
        "3900–6460 μm (VPS)"
    ]
    
elif date == "20250306":
    depth_ranges = [
        (0, 1300),          # below MST
        (1300, 3180),      # MST
        (3900, 7460),     # VPS
        (7660, np.max(depth)) # above VPS
    ]


    depth_labels = [
        "0–1300 μm (below MST)",
        "1300–3180 μm (MST)",
        "3900–7460 μm (VPS)",
        "7660 - μm (above VPS)"
    ]

loading tuning event info


In [5]:
depth_jittered = depth.flatten().astype(float).copy()

unique_depths, counts = np.unique(depth_jittered, return_counts=True)
for d in unique_depths[counts > 1]:
    idxs = np.where(depth_jittered == d)[0]
    jitter = np.linspace(-10, 10, len(idxs))
    depth_jittered[idxs] += jitter

time_axis_ms = time_axis_heading
baseline_mask = time_axis_ms < 0
stim_window_mask = (time_axis_ms >= 0) & (time_axis_ms <= 2300)

# fig, axs = plt.subplots(1, 3, figsize=(18, 6), sharey=True)
responsive_units = []
selective_units = []

for mod in range(3):
#     ax = axs[mod]
    
    selected_trials = (modality.ravel() == mod + 1) & ~np.isnan(X_array[:, 1, 1])
    valid_trials = selected_trials
    X_plot = X_array[valid_trials.flatten(), :]
    modheadingInd = headingInd[valid_trials.flatten()]
    modheadingInd = modheadingInd.flatten()
    
    # Z-score baseline stats
    baseline_mean = np.nanmean(X_plot[:, :, baseline_mask], axis=2)
    baseline_std  = np.nanstd( X_plot[:, :, baseline_mask], axis=2) 
    baseline_mean[baseline_mean == 0] += 1e-10

    # Responsiveness test per unit, per stim time
    stim_fr = X_plot[:, :, stim_window_mask]
    p_values = np.full((stim_fr.shape[1], stim_fr.shape[2]), np.nan)  
    

    for u in range(stim_fr.shape[1]):
        for t in range(stim_fr.shape[2]):
#             print(u,t)    
            p_value = ttest_rel(stim_fr[:, u, t], baseline_mean[:, u]).pvalue
            if p_value > 0:
                p_values[u, t] = p_value
    

    responsive_unit = np.sum(p_values < 0.05, axis=1) 
    responsive_units.append(responsive_unit)
    
    # Selectivity
    anova_pvals = np.full((stim_fr.shape[1], stim_fr.shape[2]), np.nan)

    for u in range(stim_fr.shape[1]):
        for t in range(stim_fr.shape[2]):

            groups = [stim_fr[modheadingInd == h, u, t] for h in np.unique(modheadingInd)]
            f_stat, p_val = f_oneway(*groups)
            if p_val > 0:  
                anova_pvals[u, t] = p_val


    selective_unit = np.sum(anova_pvals < 0.05, axis=1)
    selective_units.append(selective_unit)
    
    
    
#     mean_fr_z = (mean_fr - baseline_mean) / baseline_std

    
#     sorted_idx = np.argsort(depth_jittered)
#     mean_fr_z_sorted = mean_fr_z[sorted_idx, :]
#     depth_sorted = depth_jittered[sorted_idx]
    
#     im = ax.imshow(
#         mean_fr_z_sorted,
#         aspect='auto',
#         extent=[time_axis_ms[0], time_axis_ms[-1], depth_sorted.min(), depth_sorted.max()],
#         origin='lower',
#         cmap='viridis',
#         vmin=-2, vmax=3
#     )
    
#     ax.axvline(0, color='black', linestyle='--', linewidth=1.5)
#     ax.axvline(2300, color='black', linestyle='--', linewidth=1.5)
    
#     mod_names = ["vestibular", "visual", "combine"]
#     ax.set_title(mod_names[mod])
#     ax.set_xlabel("Time (ms)")
#     if mod == 0:
#         ax.set_ylabel("Units (sorted by depth)")
# #     print(f"Modality {mod_names[mod]}: {np.nansum(unit_mask)} / {len(unit_mask)} units responsive")

# plt.tight_layout()
# plt.show()
# unit_mask_array = np.array(all_unit_masks)

np.save("20250411_responsive_unit.npy", responsive_units)
np.save("20250411_selective_unit.npy", selective_units)

  p_value = ttest_rel(stim_fr[:, u, t], baseline_mean[:, u]).pvalue


[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan]
