## Loads in cleaned spike data
- stimuli: color_exchange, luminance_flash, drifting_gratings, and chromatic_gratings

@emilyekstrum
<br> 11/17/25

In [1]:
import cebra
import itertools
import os
import torch
import matplotlib
import random

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle as pkl

from cebra import CEBRA
from cebra.data.helper import OrthogonalProcrustesAlignment
from glob import glob
from pathlib import Path
from dlab.psth_and_raster import trial_by_trial
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from umap import UMAP
from sklearn.manifold import Isomap
from sklearn.manifold import LocallyLinearEmbedding

plt.style.use(['default', 'seaborn-v0_8-paper'])

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
units_dir = r'Z:\color_representation\units'
stim_dir  = r'Z:\color_representation\stim'

## LGN

In [None]:
# color exchange data
recordings = ['d4','d5','d6','C153','C155','C159','C160','C161']
# recordings = ['C153','C155','C159','C160','C161']

stim       = 'color_exchange'
#stim      = 'luminance_flash'
#stim      = 'drifting_gratings'

bin_     = 0.010
pre      = 0.100
post     = 0.400
window   = pre+post
n_bins   = int(window/bin_)

datas      = []
labels0    = []
labels1    = []
all_probes = []
all_ids    = []

for m,mouse in enumerate(recordings):
    print(f'Mouse {mouse}')

    #Load in stimulus dataframe
    stim_df  = pd.read_json(glob(os.path.join(stim_dir,f'{mouse}*updated*'))[0])
    
    #Check for stimulus before loading in more data
    if stim not in stim_df.stimulus.unique():
        print(f'No {stim}')
        continue

    stim_df = stim_df.loc[stim_df.stimulus == stim]
    
    print('Loading Data....')
    #Load in unit data
    units_df    = pd.read_json(glob(os.path.join(units_dir,f'{mouse}*'))[0])
    units_good  = units_df.loc[(units_df.qmLabel == 'GOOD')|(units_df.qmLabel == 'NON-SOMA GOOD')]
    # units_good  = units_good.loc[(units_good.region.str.contains('Primary visual'))|(units_good.region.str.contains('lateral geniculate'))]
    units_good  = units_good.loc[(units_good.region.str.contains('lateral geniculate'))]

    units_good.reset_index(inplace=True,drop=True)
    units_df    = None

    if len(units_good) < 1:
        continue

    #Wrangle stimulus-relevant data
    stim_df     = stim_df.loc[(stim_df.stimulus == stim)]
    stim_df.reset_index(drop=True,inplace=True)
    stim_times  = stim_df.loc[:,'start_time'].values
    stim_length = stim_times[-1]-stim_times[0]
    
    if len(stim_df) > 680:
        trial_idx  = random.sample(range(len(stim_df)-1),680)
        trial_idx.sort()
        stim_times = stim_times[1:][trial_idx]

        M = np.repeat(stim_df.green.values.astype(float)[1:][trial_idx],n_bins) # M opsin
        S = np.repeat(stim_df.uv.values.astype(float)[1:][trial_idx],n_bins) # S opsin
        L = (M+S)/2 # luminance - additive opsin signal
        C = (M-S)/2 # contrast - subtractive opsin signal

        dM = np.repeat(np.diff(stim_df.green.values.astype(float))[trial_idx],n_bins)/2
        dS = np.repeat(np.diff(stim_df.uv.values.astype(float))[trial_idx],n_bins)/2
        dL = (M+S)/2
        dC = (M-S)/2
    else:
        M = np.repeat(stim_df.green.values.astype(float),n_bins)
        S = np.repeat(stim_df.uv.values.astype(float),n_bins)
        L = (M+S)/2
        C = (M-S)/2

        dM = np.repeat(np.diff(np.insert(stim_df.green.values.astype(float),0,0)),n_bins)/2
        dS = np.repeat(np.diff(np.insert(stim_df.uv.values.astype(float),0,0)),n_bins)/2
        dL = (dM+dS)/2
        dC = (dM-dS)/2

    labels0.append(torch.from_numpy(np.array([M,S,L,C]).T))
    labels1.append(torch.from_numpy(np.array([dM,dS,dL,dC]).T))

    #Organize spike data (y)
    n_trials = len(stim_times)
    n_units  = len(units_good)

    probes   = []
    cIDs     = []
    rec_data = []

    k = 0
    for i,row in units_good.iterrows():
        spike_times = np.array(row.times)
        stim_spikes = spike_times[(spike_times > stim_times[0]) & (spike_times < stim_times[-1]+post)]
        
        if len(stim_spikes)/stim_length < 0.5:
            continue
        
        psth,var,edges,bytrial = trial_by_trial(stim_spikes, stim_times, pre, post, bin_)
        
        rec_data.append(bytrial.ravel())
        
        probes.append(row.probe)
        cIDs.append(row.cluster_id)
        k +=1
    datas.append(torch.from_numpy(np.array(rec_data).T))
    all_probes.append(np.array(probes))
    all_ids.append(np.array(cIDs))

    # dM = np.repeat(np.diff(stim_df.green.values.astype(float)),n_bins)
    # dS = np.repeat(np.diff(stim_df.uv.values.astype(float)),n_bins)
    # dL = (dM+dS)/2
    # dC = (dM-dS)/2


    print(torch.from_numpy(np.array(rec_data).T).shape)

Mouse d4
Loading Data....
torch.Size([34000, 44])
Mouse d5
Loading Data....
torch.Size([34000, 90])
Mouse d6
Loading Data....
torch.Size([34000, 90])
Mouse C153
Loading Data....
torch.Size([34000, 150])
Mouse C155
Loading Data....
torch.Size([34000, 15])
Mouse C159
Loading Data....
torch.Size([34000, 150])
Mouse C160
Loading Data....
torch.Size([34000, 55])
Mouse C161
Loading Data....
torch.Size([34000, 91])


## Load in other stimulus data

In [None]:

stim       = 'chromatic_gratings'
recordings = ['d4','d5','d6','C153','C155','C159','C160','C161']

units_dir = r'Z:\color_representation\units'
stim_dir  = r'Z:\color_representation\stim'
res_dir   = 'G:\\'  # not used

bin_     = 0.010
pre      = 0.500
post     = 1.500
window   = pre + post
n_bins   = int(window / bin_)

datas      = []  # list[torch.Tensor], each (n_trials*n_bins, n_units)
all_probes = []  # list[np.ndarray] per recording
all_ids    = []  # list[np.ndarray] per recording

for m, mouse in enumerate(recordings):
    print(f'Mouse {mouse}')

    # Load stimulus dataframe (keep original file pattern)
    stim_matches = glob(os.path.join(stim_dir, f'{mouse}*'))
    if not stim_matches:
        print('No stim file found')
        continue
    stim_df = pd.read_json(stim_matches[0])

    # Ensure target stimulus present
    if stim not in stim_df.stimulus.unique():
        print(f'No {stim}')
        continue

    # Keep only this stimulus & reset index
    stim_df = stim_df.loc[stim_df.stimulus == stim].reset_index(drop=True)

    print('Loading Data....')
    # Load units
    unit_matches = glob(os.path.join(units_dir, f'{mouse}*'))
    if not unit_matches:
        print('No units file found')
        continue
    units_df = pd.read_json(unit_matches[0])

    # Good units filter (Primary visual OR LGN, same as your original)
    units_good = units_df.loc[
        (units_df.qmLabel.isin(['GOOD', 'NON-SOMA GOOD'])) &
        (
            units_df.region.str.contains('Primary visual', case=False, na=False) |
            units_df.region.str.contains('lateral geniculate', case=False, na=False)
        )
    ].reset_index(drop=True)
    units_df = None

    if len(units_good) < 1:
        print('No good units after filtering')
        continue

    # Stim timing
    stim_times_full = stim_df.loc[:, 'start_time'].values


    if len(stim_df) > 680:
        # sample indices from range(len(stim_df)-1), then shift times by 1 like script #2
        trial_idx = random.sample(range(len(stim_df) - 1), 680)
        trial_idx.sort()
        stim_times = stim_times_full[1:][trial_idx]
    else:
        stim_times = stim_times_full

    # Compute stimulus span for firing-rate heuristic
    stim_length = stim_times[-1] - stim_times[0]

    # Build datas (flattened bytrial per unit -> columns)
    probes   = []
    cIDs     = []
    rec_data = []

    kept_units = 0
    for _, row in units_good.iterrows():
        spike_times = np.array(row.times)
        stim_spikes = spike_times[
            (spike_times > stim_times[0]) & (spike_times < stim_times[-1] + post)
        ]
        
        if len(stim_spikes) / stim_length < 0.5:
            continue

        psth, var, edges, bytrial = trial_by_trial(
            stim_spikes, stim_times, pre, post, bin_
        )

        # Flatten (n_trials * n_bins,) for this unit
        rec_data.append(bytrial.ravel())
        probes.append(row.probe)
        cIDs.append(row.cluster_id)
        kept_units += 1

    if kept_units == 0:
        print('No units passed firing-rate threshold')
        continue

    # Shape: (n_trials*n_bins, n_units)
    data_tensor = torch.from_numpy(np.array(rec_data).T)
    datas.append(data_tensor)
    all_probes.append(np.array(probes))
    all_ids.append(np.array(cIDs))

    print(f'{mouse}: {data_tensor.shape} (timeÃ—units)')

# datas      -> list[torch.Tensor], each (n_trials*n_bins, n_units)
# all_probes -> list[np.ndarray]
# all_ids    -> list[np.ndarray]


In [11]:
#color exchange data
units_dir = r'Z:\color_representation\units'
stim_dir  = r'Z:\color_representation\stim'

recordings = ['d4','d5','d6','C153','C155','C159','C160','C161']
# recordings = ['C153','C155','C159','C160','C161']

stim       = 'color_exchange'
#stim      = 'chromatic_gratings'
#stim      = 'luminance_flash'
#stim      = 'drifting_gratings'

bin_     = 0.010
pre      = 0.100
post     = 0.400
window   = pre+post
n_bins   = int(window/bin_)

datas      = []
labels0    = []
labels1    = []
all_probes = []
all_ids    = []

for m,mouse in enumerate(recordings):
    print(f'Mouse {mouse}')

    #Load in stimulus dataframe
    stim_df  = pd.read_json(glob(os.path.join(stim_dir,f'{mouse}*updated*'))[0])
    
    #Check for stimulus before loading in more data
    if stim not in stim_df.stimulus.unique():
        print(f'No {stim}')
        continue

    stim_df = stim_df.loc[stim_df.stimulus == stim]
    
    print('Loading Data....')
    #Load in unit data
    units_df    = pd.read_json(glob(os.path.join(units_dir,f'{mouse}*'))[0])
    units_good  = units_df.loc[(units_df.qmLabel == 'GOOD')|(units_df.qmLabel == 'NON-SOMA GOOD')]
    # units_good  = units_good.loc[(units_good.region.str.contains('Primary visual'))|(units_good.region.str.contains('lateral geniculate'))]
    units_good  = units_good.loc[(units_good.region.str.contains('lateral geniculate'))]

    units_good.reset_index(inplace=True,drop=True)
    units_df    = None

    if len(units_good) < 1:
        continue

    #Wrangle stimulus-relevant data
    stim_df     = stim_df.loc[(stim_df.stimulus == stim)]
    stim_df.reset_index(drop=True,inplace=True)
    stim_times  = stim_df.loc[:,'start_time'].values
    stim_length = stim_times[-1]-stim_times[0]
    
    if len(stim_df) > 680:
        trial_idx  = random.sample(range(len(stim_df)-1),680)
        trial_idx.sort()
        stim_times = stim_times[1:][trial_idx]

        M = np.repeat(stim_df.green.values.astype(float)[1:][trial_idx],n_bins)
        S = np.repeat(stim_df.uv.values.astype(float)[1:][trial_idx],n_bins)
        L = (M+S)/2
        C = (M-S)/2

        dM = np.repeat(np.diff(stim_df.green.values.astype(float))[trial_idx],n_bins)/2
        dS = np.repeat(np.diff(stim_df.uv.values.astype(float))[trial_idx],n_bins)/2
        dL = (M+S)/2
        dC = (M-S)/2
    else:
        M = np.repeat(stim_df.green.values.astype(float),n_bins)
        S = np.repeat(stim_df.uv.values.astype(float),n_bins)
        L = (M+S)/2
        C = (M-S)/2

        dM = np.repeat(np.diff(np.insert(stim_df.green.values.astype(float),0,0)),n_bins)/2
        dS = np.repeat(np.diff(np.insert(stim_df.uv.values.astype(float),0,0)),n_bins)/2
        dL = (dM+dS)/2
        dC = (dM-dS)/2

    labels0.append(torch.from_numpy(np.array([M,S,L,C]).T))
    labels1.append(torch.from_numpy(np.array([dM,dS,dL,dC]).T))

    #Organize spike data (y)
    n_trials = len(stim_times)
    n_units  = len(units_good)

    probes   = []
    cIDs     = []
    rec_data = []

    k = 0
    for i,row in units_good.iterrows():
        spike_times = np.array(row.times)
        stim_spikes = spike_times[(spike_times > stim_times[0]) & (spike_times < stim_times[-1]+post)]
        
        if len(stim_spikes)/stim_length < 0.5:
            continue
        
        psth,var,edges,bytrial = trial_by_trial(stim_spikes, stim_times, pre, post, bin_)
        
        rec_data.append(bytrial.ravel())
        
        probes.append(row.probe)
        cIDs.append(row.cluster_id)
        k +=1
    datas.append(torch.from_numpy(np.array(rec_data).T))
    all_probes.append(np.array(probes))
    all_ids.append(np.array(cIDs))

    # dM = np.repeat(np.diff(stim_df.green.values.astype(float)),n_bins)
    # dS = np.repeat(np.diff(stim_df.uv.values.astype(float)),n_bins)
    # dL = (dM+dS)/2
    # dC = (dM-dS)/2


    print(torch.from_numpy(np.array(rec_data).T).shape)

Mouse d4
Loading Data....
torch.Size([34000, 44])
Mouse d5
Loading Data....
torch.Size([34000, 90])
Mouse d6
Loading Data....
torch.Size([34000, 90])
Mouse C153
Loading Data....
torch.Size([34000, 150])
Mouse C155
Loading Data....
torch.Size([34000, 15])
Mouse C159
Loading Data....
torch.Size([34000, 150])
Mouse C160
Loading Data....
torch.Size([34000, 55])
Mouse C161
Loading Data....
torch.Size([34000, 91])


In [None]:
#save datas and recordings 
to_save = {
    'datas': datas,
    'recordings': recordings
}

with open('C:/Users/denmanlab/Desktop/Emily_rotation/data_to_run_at_home/LGNcolor_exchange.pkl', 'wb') as f:
    pkl.dump(to_save, f)

In [13]:
# chromatic gratings data
stim       = 'chromatic_gratings'
recordings = ['C155','C159','C161']

units_dir = r'Z:\color_representation\units'
stim_dir  = r'Z:\color_representation\stim'
res_dir   = 'G:\\'  # not used 

bin_     = 0.010
pre      = 0.500
post     = 1.500
window   = pre + post
n_bins   = int(window / bin_)

datas      = []  # list[torch.Tensor], each (n_trials*n_bins, n_units)
all_probes = []  # list[np.ndarray] per recording
all_ids    = []  # list[np.ndarray] per recording

for m, mouse in enumerate(recordings):
    print(f'Mouse {mouse}')

    # Load stimulus dataframe
    stim_df = pd.read_json(glob(os.path.join(stim_dir, f'{mouse}*updated*'))[0])

    # Ensure target stimulus present
    if stim not in stim_df.stimulus.unique():
        print(f'No {stim}')
        continue

    # Keep only this stimulus & reset index
    stim_df = stim_df.loc[stim_df.stimulus == stim]

    print('Loading Data....')
    #Load in unit data
    units_df    = pd.read_json(glob(os.path.join(units_dir,f'{mouse}*'))[0])
    units_good  = units_df.loc[(units_df.qmLabel == 'GOOD')|(units_df.qmLabel == 'NON-SOMA GOOD')]
    # units_good  = units_good.loc[(units_good.region.str.contains('Primary visual'))|(units_good.region.str.contains('lateral geniculate'))]
    units_good  = units_good.loc[(units_good.region.str.contains('lateral geniculate'))]

    units_good.reset_index(inplace=True,drop=True)
    units_df    = None

    if len(units_good) < 1:
        continue

    # Stim timing
    stim_times_full = stim_df.loc[:, 'start_time'].values

    #Wrangle stimulus-relevant data
    stim_df     = stim_df.loc[(stim_df.stimulus == stim)]
    stim_df.reset_index(drop=True,inplace=True)
    stim_times  = stim_df.loc[:,'start_time'].values
    stim_length = stim_times[-1]-stim_times[0]

    if len(stim_df) > 680:
        trial_idx = random.sample(range(len(stim_df) - 1), 680)
        trial_idx.sort()
        stim_times = stim_times_full[1:][trial_idx]
    else:
        stim_times = stim_times_full

    # Compute stimulus span for firing-rate heuristic
    stim_length = stim_times[-1] - stim_times[0]

    # Build datas (flattened bytrial per unit -> columns)
    probes   = []
    cIDs     = []
    rec_data = []

    kept_units = 0
    for _, row in units_good.iterrows():
        spike_times = np.array(row.times)
        stim_spikes = spike_times[
            (spike_times > stim_times[0]) & (spike_times < stim_times[-1] + post)
        ]
        
        if len(stim_spikes) / stim_length < 0.5:
            continue

        psth, var, edges, bytrial = trial_by_trial(
            stim_spikes, stim_times, pre, post, bin_
        )

        # Flatten (n_trials * n_bins,) for this unit
        rec_data.append(bytrial.ravel())
        probes.append(row.probe)
        cIDs.append(row.cluster_id)
        kept_units += 1

    if kept_units == 0:
        print('No units passed firing-rate threshold')
        continue

    # Shape: (n_trials*n_bins, n_units)
    data_tensor = torch.from_numpy(np.array(rec_data).T)
    datas.append(data_tensor)
    all_probes.append(np.array(probes))
    all_ids.append(np.array(cIDs))

    print(f'{mouse}: {data_tensor.shape}')




Mouse C155
Loading Data....
C155: torch.Size([136000, 15])
Mouse C159
Loading Data....
C159: torch.Size([136000, 177])
Mouse C161
Loading Data....
C161: torch.Size([136000, 95])


In [None]:
#save datas and recordings 
to_save = {
    'datas': datas,
    'recordings': recordings
}

with open('C:/Users/denmanlab/Desktop/Emily_rotation/data_to_run_at_home/LGNchromatic_gratings.pkl', 'wb') as f:
    pkl.dump(to_save, f)

In [3]:
# check stimulus recordings for a mouse
mouse = 'C161'
stim_dir  = r'Z:\color_representation\stim'

stim_matches = glob(os.path.join(stim_dir, f'{mouse}*'))
print(f"Stim files for {mouse}: {stim_matches}")

stim_file = stim_matches[1]
stim_df = pd.read_json(stim_file)
print(f"Stimulus types in {stim_file}: {stim_df.stimulus.unique()}")
print(stim_df.head())

if 'stimulus' in stim_df.columns:
    print(f"Stimulus column found with {stim_df['stimulus'].nunique()} unique values.")

Stim files for C161: ['Z:\\color_representation\\stim\\C161_stim.json', 'Z:\\color_representation\\stim\\C161_stim_updated.json']
Stimulus types in Z:\color_representation\stim\C161_stim_updated.json: ['luminance_flash' 'spatioluminance_noise' 'spatiochromatic_noise'
 'color_exchange' 'drifting_gratings' 'chromatic_gratings' 'sweeping_bar']
          stimulus   start_time green   uv  contrast  temporal_frequency  \
0  luminance_flash  3686.609433   256  256        -1                  -1   
1  luminance_flash  3689.603167     0    0        -1                  -1   
2  luminance_flash  3692.605900   256  256        -1                  -1   
3  luminance_flash  3695.608400     0    0        -1                  -1   
4  luminance_flash  3698.612733   256  256        -1                  -1   

   spatial_frequency  orientation  stimulus_index  
0               -1.0           -1              -1  
1               -1.0           -1              -1  
2               -1.0           -1           

In [18]:
# drifting gratings data
stim       = 'drifting_gratings'
# recordings = ['d4','d5','d6','C153','C155','C159','C160','C161']
recordings = ['d5','C155','C159','C160','C161']

units_dir = r'Z:\color_representation\units' 
stim_dir  = r'Z:\color_representation\stim'

bin_     = 0.010
pre      = 0.500
post     = 1.500
window   = pre+post
n_bins   = int(window/bin_)

datas      = []
all_probes = []
all_ids    = []

for m,mouse in enumerate(recordings):
    print(f'Mouse {mouse}')

    #Load in stimulus dataframe
    stim_df  = pd.read_json(glob(os.path.join(stim_dir,f'{mouse}*updated*'))[0])
    
    #Check for stimulus before loading in more data
    if stim not in stim_df.stimulus.unique():
        print(f'No {stim}')
        continue

    stim_df = stim_df.loc[stim_df.stimulus == stim]
    
    print('Loading Data....')
    #Load in unit data
    units_df    = pd.read_json(glob(os.path.join(units_dir,f'{mouse}*'))[0])
    units_good  = units_df.loc[(units_df.qmLabel == 'GOOD')|(units_df.qmLabel == 'NON-SOMA GOOD')]
    #units_good  = units_good.loc[(units_good.region.str.contains('Primary visual'))|(units_good.region.str.contains('lateral geniculate'))]
    units_good  = units_good.loc[(units_good.region.str.contains('lateral geniculate'))]
    units_good.reset_index(inplace=True,drop=True)
    units_df    = None

    if len(units_good) < 1:
        continue        

    # Stim timing
    stim_times_full = stim_df.loc[:, 'start_time'].values

    #Wrangle stimulus-relevant data
    stim_df     = stim_df.loc[(stim_df.stimulus == stim)]
    stim_times  = stim_df.loc[:,'start_time'].values
    stim_length = stim_times[-1]-stim_times[0]
    
    #Organize spike data (y)
    n_trials = len(stim_times)
    n_units  = len(units_good)
    
    yall = np.zeros((n_trials,n_bins,n_units))

    probes = []
    cIDs   = []
    rec_data = []
    k = 0
    for i,row in units_good.iterrows():
        spike_times = np.array(row.times)
        stim_spikes = spike_times[(spike_times > stim_times[0]) & (spike_times < stim_times[-1]+post)]
        
        if len(stim_spikes)/stim_length < 0.5:
            continue
        
        psth,var,edges,bytrial = trial_by_trial(stim_spikes, stim_times, pre, post, bin_)
        
        rec_data.append(bytrial.ravel())
        probes.append(row.probe)
        cIDs.append(row.cluster_id)
        k +=1
    
    #Organize stimulus data
    orientation       = np.array(stim_df.orientation.values).astype(float)
    spatial_frequency = np.array(stim_df.spatial_frequency.values).astype(float)
    
    conditions = [orientation,spatial_frequency]

    Xall = np.ones((n_trials,n_bins,len(conditions)+1))

    for s in range(len(stim_times)):
        for c,cond in enumerate(conditions):
            if c == 1:
                Xall[s,:int(0.100/bin_)+1,c] = 0
                Xall[s,int(0.100/bin_)+1:,c] = cond[s]
            else:
                Xall[s,:int(0.100/bin_)+1,c] = -1
                Xall[s,int(0.100/bin_)+1:,c] = cond[s]

    if k == 0:
        print('No units passed firing-rate threshold')
        continue

    datas.append(torch.from_numpy(np.array(rec_data).T))
    all_probes.append(np.array(probes))
    all_ids.append(np.array(cIDs))

    print(torch.from_numpy(np.array(rec_data).T).shape)


Mouse d5
Loading Data....
torch.Size([128000, 92])
Mouse C155
Loading Data....
torch.Size([64000, 16])
Mouse C159
Loading Data....
torch.Size([64000, 160])
Mouse C160
Loading Data....
torch.Size([64000, 51])
Mouse C161
Loading Data....
torch.Size([64000, 94])


In [None]:
#save datas and recordings 
to_save = {
    'datas': datas,
    'recordings': recordings
}

with open('C:/Users/denmanlab/Desktop/Emily_rotation/data_to_run_at_home/LGNdrifting_gratings.pkl', 'wb') as f:
    pkl.dump(to_save, f)

In [5]:
# luminance flash data
stim       = 'luminance_flash'
# recordings = ['d6','C153','C155','C159','C160','C161']
recordings = ['d4','d5','d6','C155','C159','C160','C161']
# recordings = ['d4','d5',]

units_dir = r'Z:\color_representation\units'
stim_dir  = r'Z:\color_representation\stim'
res_dir   = 'G:\\'

bin_     = 0.010
window   = 0.500
n_bins   = int(window/bin_)

datas      = []
all_probes = []
all_ids    = []

for m,mouse in enumerate(recordings):
    print(f'Mouse {mouse}')

    #Load in stimulus dataframe
    stim_df  = pd.read_json(glob(os.path.join(stim_dir,f'{mouse}*updated*'))[0])
    
    #Check for stimulus before loading in more data
    if stim not in stim_df.stimulus.unique():
        print(f'No {stim}')
        continue
    
    print('Loading Data....')
    #Load in unit data
    units_df    = pd.read_json(glob(os.path.join(units_dir,f'{mouse}*'))[0])
    units_good  = units_df.loc[(units_df.qmLabel == 'GOOD')|(units_df.qmLabel == 'NON-SOMA GOOD')]
    #units_good  = units_good.loc[(units_good.region.str.contains('Primary visual'))|(units_good.region.str.contains('lateral geniculate'))]
    units_good  = units_good.loc[(units_good.region.str.contains('lateral geniculate'))]
    units_good.reset_index(inplace=True,drop=True)
    units_df    = None

    if len(units_good) < 1:
        continue

    #Wrangle stimulus-relevant data
    stim_df     = stim_df.loc[(stim_df.stimulus == stim)]
    stim_times  = stim_df.loc[:,'start_time'].values
    stim_length = stim_times[-1]-stim_times[0]
    
    #Organize spike data (y)
    n_trials = len(stim_times)
    n_units  = len(units_good)
    
    yall = np.zeros((n_trials,n_bins,n_units))

    probes = []
    cIDs   = []
    rec_data = []
    k = 0
    for i,row in units_good.iterrows():
        spike_times = np.array(row.times)
        stim_spikes = spike_times[(spike_times > stim_times[0]) & (spike_times < stim_times[-1]+0.5)]

        if len(stim_spikes)/stim_length < 0.5:
            continue
        
        psth,var,edges,bytrial = trial_by_trial(stim_spikes, stim_times, 0.100, 0.400, bin_)
        
        rec_data.append(bytrial.ravel())
        probes.append(row.probe)
        cIDs.append(row.cluster_id)
        k +=1

    if k == 0:
        print('No units passed firing-rate threshold')
        continue
    
    #Organize stimulus data
    sequence      = np.zeros(len(stim_times))
    sequence[::2] = 1

    conditions = [sequence]

    Xall = np.ones((n_trials,n_bins,len(conditions)+1))

    for s in range(len(stim_times)):
        for c,cond in enumerate(conditions):
            Xall[s,:int(0.100/bin_)+1,c]                  = 0.5
            Xall[s,int(0.100/bin_)+1:int(0.050/bin_)+1,c] = cond[s]
            Xall[s,int(0.050/bin_)+1:,c]                  = 0.5

    # Shape: (n_trials*n_bins, n_units)
    data_tensor = torch.from_numpy(np.array(rec_data).T)
    datas.append(data_tensor)
    all_probes.append(np.array(probes))
    all_ids.append(np.array(cIDs))

    print(f'{mouse}: {data_tensor.shape}')        
 

Mouse d4
Loading Data....
d4: torch.Size([5000, 37])
Mouse d5
Loading Data....
d5: torch.Size([5000, 65])
Mouse d6
Loading Data....
d6: torch.Size([5000, 79])
Mouse C155
Loading Data....
C155: torch.Size([5000, 12])
Mouse C159
Loading Data....
C159: torch.Size([5000, 145])
Mouse C160
Loading Data....
C160: torch.Size([5000, 62])
Mouse C161
Loading Data....
C161: torch.Size([5000, 72])


In [None]:
#save datas and recordings 
to_save = {
    'datas': datas,
    'recordings': recordings
}

with open('C:/Users/denmanlab/Desktop/Emily_rotation/data_to_run_at_home/LGNluminance_flash.pkl', 'wb') as f:
    pkl.dump(to_save, f)