# Demostration of the C-SPIKES algorithm for calcium imaging spike inference
## The algorithm demonstration here reflects the exposition as in "Precise calcium-to-spike inference using biophysical generative models" from Broussard et al.
### The first two cells in this notebook must be run prior to running the PGAS algorithm cells (labeled below)
### Other blocks can be run in any order desired as long as the flags in the following code cell are properly set. See below for details

First we will import libraries, including the python-bound c++ library where the PGAS algorithm is implemented. We will also introduce a few helper methods that handle data, etc.

In [None]:
#Importing project packages and required libraries
import numpy as np
import build.pgas_bound as pgas
from src.c_spikes.syn_gen import synth_gen
import matplotlib.pyplot as plt
import scipy.io as sio
import os

#Setting flags for what to calculate on this run
recalc_pgas = True # this runs the particle Gibbs sampler to extract spike times and cell parameters
recalc_Cparams = True # this runs the particle Gibbs sampler with known spike times to extract cell parameters
recalc_synth = True # this runs the synthetic data generation code to create new synthetic data
retrain_and_infer = True # this runs the cascade training and inference code to train a new model and infer spikes on the original data

# Utility-type methods
## Code to convert spike times to a binary vector
def spike_times_2_binary(spike_times,time_stamps):
    # initialize the binary vector
    binary_vector = np.zeros(len(time_stamps), dtype=int)

    # get event times within the time_stamps ends
    good_spike_times = spike_times[(spike_times >= time_stamps[0]) & (spike_times <= time_stamps[-1])]
    
    # Find the nearest element in 'a' that is less than the elements in 'b'
    for event_time in good_spike_times:
        # Find indices where 'a' is less than 'event_time'
        valid_indices = np.where(time_stamps < event_time)[0]
        if valid_indices.size > 0:
            nearest_index = valid_indices[-1]  # Taking the last valid index
            binary_vector[nearest_index] += 1

    return binary_vector

# For opening the janelia datasets
def open_Janelia_1(j_path):
    all_data = sio.loadmat(j_path)
    dff = all_data['dff']
    time_stamps = all_data['time_stamps']
    spike_times = all_data['ap_times'] 

    return time_stamps, dff, spike_times

# For working with the PGAS output state trajectories
def unroll_dat_files(dat_file):
    '''
    PGAS contain the following output variables:
    -B = basline drift, brownian
    -S = discretized spike number per time bin
    -C = "calcium" value - really more akin to a DFF-like metric
    -Y = original data (not included in PGBAR output)
    '''
    

    data = np.genfromtxt(dat_file, delimiter=',', skip_header=1)
    #Dealing out data
    index = data[:,0]
    B = data[:,2]
    S = data[:,3]
    C = data[:,4]
    
    #Note that files produced by PGBAR (rather than the more general PGAS) lack Y - for now not including it
    try:
        Y = data[:,5]
    except:
        Y = np.nan
    
    return index,B,S,C,Y

# For calculating noise levels in the data
# This method is based on one from Rupprecht et al. 2021 CASCADE paper
def calculate_standardized_noise(dff,frame_rate):
    noise_levels = np.nanmedian(np.abs(np.diff(dff, axis=-1)), axis=-1) / np.sqrt(frame_rate)
    return noise_levels * 100     # scale noise levels to percent

Next, we'll take some sample jGCaMP8f data reported in Zhang et al 2023 with data retrieved from Márton Rózsa's dandi dataset 000168

In [None]:
# First we'll load in the original data as a numpy array
#janelia_file = "jGCaMP8f_ANM471993_cell03" (high SNR excitatory)#"jGCaMP8f_ANM478349_cell06" (low SNR inhibitory)
janelia_file = "jGCaMP8f_ANM471993_cell03"
filename = os.path.join("gt_data",janelia_file+".mat")

time,data,spike_times = open_Janelia_1(filename)
time1 = np.float64(time[0,1000:2000])
time1 = time1.copy()
data1 = np.float64(data[0,1000:2000])
data1 = data1.copy()
binary_spikes = np.float64(spike_times_2_binary(spike_times,time1))

## Setting up parameters for the particle gibbs sampler
tag="test"
Gparam_file="src/spike_find/pgas/20230525_gold.dat"

# 1. PGAS for spike time estimates
## Now we are ready to run the PGAS algorithm with the biophysical GCaMP model as its generative kernel
## This cell runs the Particle Gibbs sampler for 300 iterations
## The output file is a viewable plot of the average spike state trajectory (trace) compared to ground truth spike times (depicted as vertical dashed lines)

In [None]:
# Set up the parameters for the particle gibbs sampler
analyzer = pgas.Analyzer(
    time=time1,
    data=data1,
    constants_file="parameter_files/constants_GCaMP8_soma.json",
    output_folder="results",
    column=1,
    tag=tag,
    niter=300,
    append=False,
    verbose=1,
    gtSpikes=binary_spikes,
    has_gtspikes=False,
    maxlen=1000, 
    Gparam_file=Gparam_file,
    seed=2
)

## Run the sampler
analyzer.run()

## Load files containing the results to plot
pgas_out_file = os.path.join('results','traj_samples_'+tag+'.dat')
index, B, S, C, Y = unroll_dat_files(pgas_out_file)

## Get the mean trajectory with specified burnin
trajs = [];spks = [];avg_spks = []
spks.append(S.reshape((-1,np.sum(index==0))).T)
avg_spks.append(np.mean(spks[i][:,-100:-1],axis=1))

## Plotting the results
fig, axes = plt.subplots(1,1, figsize=(3, 3))
axes.plot(time1, data1, label='Data')
axes.plot(time1, avg_spks, label='PGAS')
axes.set_xlabel("Time (s)")
# adding gt spikes
for spike in spike_times:
    axes.axvline(spike,ls='--',alpha=0.5,linewidth=0.8)

## Save figure to results
plt.show()
fig.savefig(os.path.join('results',tag+'_trajs.pdf'), bbox_inches='tight')

# 2. PGAS for Cparam extraction
## Our PGAS implementation can be run either with no knowledge of spike times, or (as in this cell) with spike times given to improve estimates of the underlying cell parameters
## If you wish to run later cells without performing cell parameter estimation, that can be done by setting the flag "recalc_Cparams" to false in the first code cell of this notebook

In [None]:
if recalc_Cparams:
    analyzer = pgas.Analyzer(
        time=time1,
        data=data1,
        constants_file="parameter_files/constants_GCaMP3_soma.json",
        output_folder="pgas_output",
        column=1,
        tag=tag,
        niter=2,
        append=False,
        verbose=1,
        gtSpikes=binary_spikes,
        has_gtspikes=True,
        maxlen=1000, 
        Gparam_file=Gparam_file,
        seed=2
    )

    ## Run the sampler
    analyzer.run()

    ## Return cell parameter estimate distributions
    parameter_samples = analyzer.get_parameter_estimates()   

# 3. Synthetic data generation
## Cparams extracted using PGAS can be used to run the Biphysical cell model forward.
## That process is demonstrated here. 
### To run this seciton, "recalc_synth" should be set to True
### Note that if you set "recalc_Cparams" as True, you must run cell 2 prior to this one to allow PGAS to extract cell parameters

In [None]:
# Open Cparameter file if not recalculated
if not recalc_Cparams:
    ## Param file location
    param_sample_file = os.path.join("sample_pgas_output","param_samples_"+tag+".dat")
    ## Opening the saved parameter samples for use as Cparams
    parameter_samples = np.loadtxt(param_sample_file,delimiter=',',skiprows=1)

# Prepare Cparams calculation from parameter estimates - less than 100 samples for testing
burnin = 100 if np.size(parameter_samples,0) > 100 else 0
parameter_samples = parameter_samples[burnin:,0:6]
print("mean of samples")
print(np.mean(parameter_samples,axis=0))

# Construct synthetic dataset
if recalc_synth:
    # Create synthetic data
    ## Load parameters into the GCaMP model to use for synthetic data creation
    Cparams = np.mean(parameter_samples,axis=0)
    Gparams = np.loadtxt(Gparam_file)
    gcamp = pgas.GCaMP(Gparams,Cparams)

    ## Generate synthetic data from the PGAS-derived cell paramters 
    # Now making broader spike pulls
    nominal_rates = np.array([1,1.1,1.5,2,2.5,3,3.5,4,4.5,5])#
    for rate in nominal_rates:
        synth = synth_gen.synth_gen(plot_on=False,GCaMP_model=gcamp,\
            spike_rate=rate,cell_params=Cparams,tag=tag,use_noise=True)
        synth.generate()


# 4. Using synthetic data to train a CASCADE model for spike inference
## Either the output from 3 or the sample synthetic data (included here) can be used as the training dataset
## Inference is run on sample data from Rózsa's jGCaMP8f dataset.
### To run this cell's contents, the "retrain_and_infer" flag should be set to true

In [None]:
if retrain_and_infer:
    ## Package checks for cascade
    print("Current directory: {}".format(os.getcwd()))
    from src.spike_find.cascade2p import checks, utils, cascade
    print("\nChecks for packages:")
    checks.check_packages()

    ## Train a cascade model using the synthetic dataset
    # First get the data and noise level we're training against
    #data = np.loadtxt(data_file).transpose()
    time_stamps = time[0,:]
    fluo_data = data[0,:]
    frame_rate = 1/np.mean(np.diff(data[0,:]))
    noise_level = utils.calculate_noise_levels(fluo_data,frame_rate)

    # Set configurations file for sample training
    synthetic_test = f"synth_{tag}"
    synthetic_training_dataset = os.path.join(f"synth_{tag}")
    cfg = dict( 
        model_name = tag,    # Model name (and name of the save folder)
        sampling_rate = 30,    # Sampling rate in Hz (round to next integer)
        
        training_datasets = [
                synthetic_training_dataset
                            ],
        
        noise_levels = [noise for noise in range(2,8)],#
        
        smoothing = 0.05,     # std of Gaussian smoothing in time (sec)
        causal_kernel = 0,   # causal ground truth smoothing kernel
        verbose = 1,
            )
    
    # ## save parameter as config.yaml file - TODO: make cascade overwrite configs on this call
    print(cfg['noise_levels'])

    cascade.create_model_folder( cfg )

    ## Train a model based on config contents
    #from spike_find.cascade2p import models
    model_name = cfg['model_name']
    cascade.train_model( model_name )

    # ## Use trained model to perform inference on the original dataset
    from spike_find.cascade2p.utils_discrete_spikes import infer_discrete_spikes
    spike_prob = cascade.predict(model_name, np.reshape(fluo_data, (1, len(fluo_data))))

    # Separate Python file that organizes model names 
    ## Saving routine
    save_dir = "results"
    os.makedirs(save_dir,exist_ok=True)
    save_path = os.path.join(save_dir,f"{tag}_syn_trained_CASCADE_output.mat")
    sio.savemat(save_path,{'spike_prob':spike_prob,'time_stamps':time_stamps,'dff':fluo_data,'cfg':cfg})#
