# Get modified activity (binned spike counts) by subtracting possible speed (and acceleration) modulation

We estimate the modulation with linear regression of the measured activity

In [None]:
import os
import glob
import re
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt

# from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression

# My module for the decoding analysis
#import decoding
# from decoding import *
from decoding import *

## Set paths and load data

In [None]:
base_p = Path(os.getcwd())
start_p = base_p / 'data'
print(start_p)

# Load linear modelling data:
fname = 'linear_modelling_of_SOMI_activity_during_reward.npy'

data_d = np.load(start_p / fname, allow_pickle=True).item() 
print(data_d.keys())

## Set parameters from data

In [None]:
# Define time windows for regression and simulated data;
# before and after the reward delievery start:
time_before = 3.0 # in s
time_after = time_before # in s

decoding_offset_time = 1.5 # in s
# decoding_offset = round(decoding_offset_time * sample_rate_Hz)

# maximum time window for binning before and after decoding start time:
window_max = 1.5 # in s

# Dict keys for expert and non-expert sessions:
expert_key = 'high'
non_expert_key = 'low'

params_d = data_d['params']
print(params_d)

num_bins = params_d['# bins']
binwidth_steps = params_d['binwidth']
binwidth = binwidth_steps/sample_rate_Hz # binwidth in s
bw_ms = round(binwidth_steps/sample_rate)

print('Bin width:', binwidth)
ibw = 1/binwidth

num_bins_before = round(time_before/binwidth) # should be 1/2 of num_bins
print('Number of bins before:', num_bins_before)

In [None]:
# Aligned time points of the bins (centers):
ts, t_bin = np.linspace(-time_before,time_after,num_bins, endpoint=False, retstep=True)
print(ts)
ts_center = ts + 0.5*t_bin
# print(ts_center)
num_bins_decoding = np.searchsorted(ts, -decoding_offset_time)
print('Number of bins for decoding:', num_bins_decoding, ts[num_bins_decoding], ts[2*num_bins_decoding])

## Select expert or non-expert sessions/animals

In [None]:
session_s = expert_s

# Get data from Expert or Non-expert sessions:
if session_s.startswith(expert_s):
    session_status = expert_s
    lm_d = data_d[expert_key]
elif session_s.startswith(non_expert_s):
    session_status = non_expert_s
    lm_d = data_d[non_expert_key]
    
lm_d.keys()

### Relevant lists in the dictionary

In [None]:
cell_ids = lm_d['cell ids'].copy()

# Data as lists of the same length as cell_ids
X_list = lm_d['X'] # just references to the list in the dict lm_d
y_list = lm_d['y']

# print(len(cell_ids), len(X_list), len(y_list))
cell_ids

## Iterate over the lists and fit the linear model

We need to store and save the results.

In [None]:
# Regression variables in the arrays X
labels = lm_d['coef legend']
print(labels)
acc_ind = labels.index('acceleration') # should be == 3
print(acc_ind)
speed_ind = labels.index('speed')

In [None]:
def stochastic_round(x):
    """Stochastic rounding of numpy array x to nearest integer."""
    return np.floor(x + rng.random(size=x.shape)).astype(int)

In [None]:
rng = np.random.default_rng()
rng_state = rng.bit_generator.state

In [None]:
cells_data_d = {} # container for loaded and modified data
# from all cells under consideration

num_cells = 0
for i, cell in enumerate(cell_ids):
    unit_data_d = {} # dictionary for data of a unit
    
    #print(i, cell)
    # Match date pattern in session name:
    res = re.search(r"20\d{2}-\d{2}-\d{2}", cell) 
    session_id_s = session_s + '__' + res.string[:res.end()] #.replace('/','__')
    #print('Session ID:', session_id_s)
    unit_id_s = cell[cell.find('shank'):].removesuffix('.txt')
    #print('Session ID:', session_id_s, 'Unit ID:', unit_id_s)
    cell_id_s = '__'.join([session_id_s, unit_id_s])
    print('Cell:', cell_id_s)

    unit_data_d['cell_id_s'] = cell_id_s
    unit_data_d['unit_id_s'] = unit_id_s
    unit_data_d['session_id_s'] = session_id_s

    # Regression data
    Xo = X_list[i].copy()
    yo = y_list[i].copy()
    # in original shape
    #print(Xo.shape, yo.shape)
    # yts = np.reshape(yo,(-1,num_bins))
    
    num_trials = yo.shape[0]//2
    #print('Number of trials:', num_trials)
    unit_data_d['num_trials'] = num_trials
    unit_data_d['num_reward_start'] = num_trials

    # Clean up nans in X:
    num_nans = np.zeros(Xo.shape[0], dtype=int)
    X_mean = np.zeros(Xo.shape[0])
    for i, x in enumerate(Xo):
        num_nans[i] = np.isnan(x).sum()
        if num_nans[i] > 0:
            X_mean[i] = np.nanmean(x)
            #x = np.nan_to_num(x, nan=X_mean[i])
            np.nan_to_num(x, nan=X_mean[i], copy=False)
            # Changes the nan entries in x in place, and thus in Xo, 
            # with the mean of the non-nan entries.
            # Then, we can continue as before.
            # However, we should check if there are more than
            # a few nans in the array! Let's say more than 6, or so?
        else:
            X_mean[i] = np.mean(x)
    #print(num_nans)
    #print('Cleaned up average of X:', X_mean)
    
    # Further preparation for regression:
    Xo[acc_ind] *= sample_rate_Hz # Change units of acceleration, in-place
    X_mean[acc_ind] *= sample_rate_Hz
    X_center = Xo.T - X_mean # center predictor variables 
    ys = yo.flatten('C') # makes a copy!
    #print(ys.shape) # should be equal: Xo.shape[1] == ys.shape[0]

    # average firing rate of binned spike counts
    pooled_rate = ibw*np.mean(ys) # in Hz
    if pooled_rate < min_pooled_rate:
        print(f'Pooled average rate of unit {unit_id_s} is only {pooled_rate:5.2f} Hz.')
        print('Exclude from decoding analysis.')
        continue # to next cell
    unit_data_d['pooled_rate'] = pooled_rate

    # Finally, fit the linear model:
    lm = LinearRegression()
    # Use centered X for the regression:
    lm.fit(X_center, ys)
    lm_coeffs = lm.coef_
    lm_intercept = lm.intercept_ # equal to np.mean(ys) due to centering
    unit_data_d['lm_coeffs'] = lm_coeffs
    unit_data_d['lm_intercept'] = lm_intercept
    
    #y_pred = lm.predict(X_center)

    # fig = plt.figure(dpi=150)
    # plt.plot(ys)
    # plt.plot(y_pred)
    # plt.plot(X_center[:,acc_ind]*coeffs[acc_ind], color='tab:red') # speed contribution
    # #plt.xlim([0, 200])
    # plt.show(fig)
    # plt.close(fig)

    # What do we want to subtract?
    # #- X_center[:,speed_ind]*lm_coeffs[speed_ind] - X_center[:,acc_ind]*lm_coeffs[acc_ind]
    # Subtract modulation of ys due to speed and accelaration variation
    # as estimated by the fitted linear (regression) model:
    ys_red = ys - (X_center[:,speed_ind]*lm_coeffs[speed_ind] + X_center[:,acc_ind]*lm_coeffs[acc_ind])
    
    # Add preprocessed data to the dicts in cells_data_d
    unit_data_d['bw_ms'] = bw_ms
    unit_data_d['bin_width'] = binwidth
    unit_data_d['yts'] = np.reshape(ys, (-1, num_bins))
    unit_data_d['yts_red'] = np.reshape(ys_red, (-1, num_bins))
    # the reduced spike counts include some randomness:
    #unit_data_d['yts_red_counts'] = np.reshape(ys_red_counts, (-1, num_bins))
     
    # Put the loaded data into the container
    cells_data_d[cell_id_s] = unit_data_d
    num_cells += 1

print(f'Data from {num_cells} cells have been loaded.', len(cells_data_d))
print('Now, (almost) all relevant date should be in the dict cells_data_d.')

## Prepare spike count data for population decoding

First we generate the spike count data to be used for the decoding analysis.

In [None]:
num_repeat = 1 # number of repetitions for stochastic rounding

num_cells = 0
for cell_id_s, cell_d in cells_data_d.items():
    print('Cell:', cell_id_s)
    print('Pooled rate:', cell_d['pooled_rate'])
    num_cells += 1

    num_trials = cell_d['num_trials']
    # Get true labels: 0 for "before" and 1 for "after"
    true_labels = get_true_labels(num_trials)
    #print(num_trials, cell_d['yts'].shape)
    
    cell_d['num_bins'] = num_bins_decoding
    # This should be correct!
    # Before it was set here as
    #num_bins_decoding = num_bins_before//2
    # gets generated by 
    #create_bins(bin_width, window_max)
    
    cell_d['true_labels'] = true_labels

    # So far missing for decoding from multiple units
    #data_d['spike_counts_aligned'] = spike_counts_aligned
    #data_d['rate_counts_aligned'] = rate_counts_aligned

    spike_counts_list = [] # collect the spike counts
    # rng_state = rng.bit_generator.state
    for i in range(num_repeat+1):
        if i == 0:
            spike_counts_aligned = cell_d['yts'][:,:2*num_bins_decoding].astype(int)           
        else:
            # rng.bit_generator.state = rng_state
            yts_red = cell_d['yts_red'][:,:2*num_bins_decoding]
            spike_counts_aligned = stochastic_round(np.maximum(yts_red, 0.0))
        #print(i, spike_counts_aligned.shape)
        pooled_rate_rep = ibw*np.mean(spike_counts_aligned) # in Hz
        print(pooled_rate_rep)
        spike_counts_list.append(spike_counts_aligned)

    cell_d['spike_counts_list'] = spike_counts_list
    cell_d['num_repeat'] = num_repeat

print(f'Data from {num_cells} cells has been modified.', len(cells_data_d))

In [None]:
cells_data_fname = 'activity_subtracted_data__' + session_s + '.npy'
fname = start_p / cells_data_fname

np.save(fname, cells_data_d)

## Decoding analysis for all cells that were not excluded.

In [None]:
# Window lengths for analysis:
win_start = 0.25 #0.25 # in seconds
win_stop = 1.5 #3.75 should be less than (or equal to) window_max
win_step = binwidth #max(0.125, bin_width) #0.125 in seconds, 2**-3
win_arr = create_windows(win_start, win_stop, win_step)
print(win_arr)

split_variant = 0
split_variant_name = split_variant_names[split_variant]
print('Using split variant:', split_variant_name)
decoding_s = f'_bw{bw_ms}' + f'_split_{split_variant_name}'
# print(decoding_s)

In [None]:
# rng.bit_generator.state = rng_state

print_output = False

cells_results_d = {} # container for results of decoding analysis
num_cells = 0
for cell_id_s, cell_d in cells_data_d.items():
    # if num_cells < 0:
    #     num_cells += 1
    #     continue
    # elif num_cells > 15:
    #     break
    print('Cell:', cell_id_s)
    print('Pooled rate:', cell_d['pooled_rate'])
    num_cells += 1
    #continue

    num_trials = cell_d['num_trials']
    # Get true labels: 0 for "before" and 1 for "after"
    true_labels = cell_d['true_labels'] #get_true_labels(num_trials)
    
    # Split into even and odd trials:
    # if split_variant == 0:
    inds_split = split_even_odd(num_trials)

    results_list = [] # collect the decoding results
    # rng_state = rng.bit_generator.state
    for i in range(num_repeat+1):
        # if i == 0:
        #     spike_counts_aligned = cell_d['yts'][:,:2*num_bins_decoding].astype(int)           
        # else:
        #     # rng.bit_generator.state = rng_state
        #     yts_red = cell_d['yts_red'][:,:2*num_bins_decoding]
        #     spike_counts_aligned = stochastic_round(np.maximum(yts_red, 0.0))
        # print(i, ':', np.abs(cell_d['spike_counts_list'][i] - spike_counts_aligned).sum())
        spike_counts_aligned = cell_d['spike_counts_list'][i]

        pooled_rate = ibw*np.mean(spike_counts_aligned) # in Hz
        #print(pooled_rate)
        
        # Perform decoding analysis for a collection of time windows:
        # 3 variants in this order: frac_correct, frac_correct_rate_before, frac_correct_train
        fc_arr, pv_arr = decoding_windows(win_arr, binwidth, num_bins_decoding, 
                                          spike_counts_aligned, spike_counts_aligned, 
                                          true_labels, inds_split, output=print_output)

        results_d = dict(repeat = i, win_arr=win_arr, fc_arr=fc_arr, pv_arr=pv_arr, 
                          bw_ms=bw_ms, split_variant=split_variant, 
                          bin_width=binwidth, 
                         session_rate=cell_d['pooled_rate'], pooled_rate=pooled_rate)
        
        results_list.append(results_d)

    # Store the results in a container
    cells_results_d[cell_id_s] = results_list
    
    # Plot results in decoding_results_list
    fig, ax = plt.subplots(1,1, figsize=(3,2.5), dpi=150)
    for res in results_list:
        ax.plot(res['win_arr'], res['fc_arr'][:,0], marker='.', label=res['repeat'])
        
    ax.axhline(0.65, linestyle=':', linewidth=0.5, color='k')
    ax.set_xlabel('window length (s)')
    ax.set_title(cell_id_s)
    ax.set_ylabel('fraction correct')
    ax.legend()
    plt.show(fig)

print(f'Data from {num_cells} cells has been analyzed.', len(cells_results_d))

In [None]:
list(cells_results_d.keys())

## Store the dictionary, indicating the session status.

In [None]:
cells_results_fname = 'activity_subtracted_results__' + session_s + '.npy'
fname = start_p / cells_results_fname

np.save(fname, cells_results_d)