# Simulating TASEP Models
___

In [None]:
"""
MicroLive Notebook
==================
This notebook requires MicroLive to be installed:
    pip install tasep_models

For development mode:
    pip install -e /path/to/tasep_models
"""
# MicroLive imports
# from microlive import microscopy as mi
# from microlive.utils.device import check_gpu_status

# Verify GPU support
check_gpu_status()

# Standard scientific imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path


## Loading and plotting the file sequences.
___

In [None]:
# Modeling

import tasep_models as tm
from tasep_models import *
#tag_dict = {'GFP': GFP_TAG, 'HA': HA_TAG, 'U': U_TAG, 'SUN': SUN_TAG, 'ALFA': ALFA_TAG}
tag_file = 'UTAG'
if tag_file == 'SUN':
    tag_sequence = tag_dict['SUN']
    dna_file_path = pathlib.Path( '/Users/nzlab-la/Desktop/micro/gene_sequences/utag_project/pNZ266(pUB-24xGCN4-KDM5B-MS2).dna' )             # SunTag
elif tag_file == 'UTAG':
    tag_sequence = tag_dict['U']
    dna_file_path = pathlib.Path( '/Users/nzlab-la/Desktop/micro/gene_sequences/utag_project/pNZ208(pUB-24xUTagFullLength-KDM5B-MS2).dna' )   # UTAG
elif tag_file == 'ALFA':
    tag_sequence = tag_dict['ALFA']
    dna_file_path = pathlib.Path( '/Users/nzlab-la/Desktop/micro/gene_sequences/utag_project/pNZ267 (pUB-24xALFAtag-KDM5B-MS2).dna' )         # AlfaTag

dna_file_path.name.split('.')[0]
plasmid_name = dna_file_path.name.split('.')[0].replace('(','_').replace(')','_')
plasmid_name

# Creating a results folder for outputs
results_folder = current_dir.joinpath('results_simulations') #/ plasmid_name
results_folder.mkdir(exist_ok=True)


#dna_file_path = pathlib.Path( '/Users/nzlab-la/Desktop/micro/gene_sequences/utag_project/pNZ208(pUB-24xUTagFullLength-KDM5B-MS2).dna' )   # UTAG
#dna_file_path = pathlib.Path( '/Users/nzlab-la/Desktop/micro/gene_sequences/utag_project/pNZ267 (pUB-24xALFAtag-KDM5B-MS2).dna' )         # AlfaTag


In [None]:
# reading the sequence and extracting the elongation rates
protein, rna, dna, indexes_tags, indexes_pauses, seq_record, graphic_features  = read_sequence(seq=dna_file_path, min_protein_length=50,TAG=[tag_sequence])
plasmid_figure = plot_plasmid(seq_record, graphic_features,figure_width=25, figure_height=3)

gene_length = len(protein)+1 # adding 1 to account for the stop codon
tag_positions_first_probe_vector = indexes_tags[0]
tag_positions_second_probe_vector = indexes_tags[1] if len(indexes_tags) > 1 else None

first_probe_position_vector = create_probe_vector(tag_positions_first_probe_vector, gene_length)
second_probe_position_vector = create_probe_vector(tag_positions_second_probe_vector, gene_length) if tag_positions_second_probe_vector is not None else None


## Simulation parameters 

___

In [None]:
# # Initial conditions
ki = 0.03  # Initiation rate
global_elongation_rate = 5 #3.7  # Elongation rates for positions 1 to N-1
number_repetitions = 200
burnin_time = 2500
t_max = 360*10 #timePerturbationApplication + 25*60  # Maximum time
step_size_in_sec = 5 # 5
time_array = np.arange(0, t_max, step_size_in_sec)
number_tested_parameters = 5
downsample = False
downsample_factor = 3
MAD_THRESHOLD_FACTOR = 4
efficiency_list = [1]  # binding efficiency for the two probes

In [None]:
ke = calculate_codon_elongation_rates (rna, global_elongation_rate=global_elongation_rate)


In [None]:
intensity_vector_first_signal_ode,intensity_vector_second_signal_ode = simulate_TASEP_ODE(ki, ke, gene_length, t_max,first_probe_position_vector,second_probe_position_vector,burnin_time, time_interval_in_seconds= step_size_in_sec)


In [None]:
first_probe_position_vector

In [None]:
constant_elongation_rate = None # this is a signal to the SSA to use the elongation rates considering sequence variability.
list_ribosome_trajectories,list_occupancy_output, matrix_intensity_first_signal_RT, matrix_intensity_second_signal_RT = simulate_TASEP_SSA(ki, ke, gene_length, t_max,
                                time_interval_in_seconds=step_size_in_sec,
                                number_repetitions=number_repetitions, 
                                first_probe_position_vector=first_probe_position_vector, 
                                second_probe_position_vector=second_probe_position_vector,
                                burnin_time=burnin_time,
                                constant_elongation_rate=constant_elongation_rate,
                                efficiency_list= efficiency_list,
                                fast_output=False)
                         

In [None]:
# calculating the mean occupancy across all frames
ribosomal_density = np.round( (gene_length/global_elongation_rate) *ki , 1)
print(f'Ribosomal density: {ribosomal_density} ribosomes per gene length ({gene_length} codons)')
ribosomal_footprint = 10
# ribosomal density covering the RNA length
ribosomal_density_coverage = np.round( (ribosomal_density * ribosomal_footprint) / gene_length, 2)
percentage_coverage = (ribosomal_density_coverage * 100)  # percentage coverage of the gene length by ribosomes
print(f'Ribosomal density coverage: {np.round(ribosomal_density_coverage,2)} ribosomes per gene length ({gene_length} codons, {percentage_coverage:.2f}% coverage)')
#print(f'Ribosomal density coverage: {ribosomal_density_coverage} ribosomes per gene length ({gene_length} codons)')
len(list_occupancy_output)
# each element in the list in list_occupancy_output is an array with shape [gene_location, frame]. Calculate the number of non-zero elements in each column, and then take the mean across the frames.
occupancy_array = np.array([np.mean(np.count_nonzero(occupancy, axis=0)) for occupancy in list_occupancy_output])

# print validated ribosomal density with simulated data
print(f'Ribosomal density with simulated data: { np.round( np.mean(occupancy_array), 2)} ribosomes per gene length ({gene_length} codons)')


In [None]:
(gene_length * ki) /global_elongation_rate

In [None]:
((20*10) / 1970 )*100

In [None]:
list_occupancy_output[0].shape

In [None]:
# calculate the mean and std of the matrix_intensity_first_signal_RT and matrix_intensity_second_signal_RT
mean_first_signal_RT = np.mean(matrix_intensity_first_signal_RT, axis=0)
sem_first_signal_RT = np.std(matrix_intensity_first_signal_RT, axis=0)/np.sqrt(number_repetitions)
if second_probe_position_vector is not None:
    mean_second_signal_RT = np.mean(matrix_intensity_second_signal_RT, axis=0)
    sem_second_signal_RT = np.std(matrix_intensity_second_signal_RT, axis=0)/np.sqrt(number_repetitions)

In [None]:
def plot_trajectories(matrix_intensity_first_signal_RT, intensity_vector_first_signal_ode, time_array, number_repetitions, plot_color = 'orangered'):
    # --- Set fonts and background as before ---
    plt.rcParams["font.family"] = "Arial"
    plt.rcParams["figure.facecolor"] = "white"
    plt.rcParams["axes.facecolor"] = "white"
    plt.rcParams["axes.edgecolor"] = "black"
    plt.rcParams["axes.labelcolor"] = "black"
    plt.rcParams["xtick.color"] = "black"
    plt.rcParams["ytick.color"] = "black"

    # --- Determine the global intensity range from both datasets ---
    global_min = min(matrix_intensity_first_signal_RT.min(), intensity_vector_first_signal_ode.min())
    global_max = max(matrix_intensity_first_signal_RT.max(), intensity_vector_first_signal_ode.max())

    # --- Create subplots: left for trajectories, right for histogram ---
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 3), gridspec_kw={'width_ratios': [4, 1]})

    # --- Left Plot: Trajectories ---
    for i in range(number_repetitions):
        if i == 0:
            ax1.plot(time_array, matrix_intensity_first_signal_RT[i, :],
                    label='SSA', color=plot_color, alpha=1, linewidth=2)
        else:
            ax1.plot(time_array, matrix_intensity_first_signal_RT[i, :],
                    color=plot_color, alpha=0.1, linewidth=0.4)
    ax1.plot(time_array, intensity_vector_first_signal_ode, label='ODE', color='k', linewidth=3)

    ax1.set_xlabel('Time (s)', fontsize=20)
    ax1.set_ylabel('Intensity (a.u.)', fontsize=20)
    ax1.set_ylim(global_min, global_max)

    # Set the axes frame with a distinct black border for ax1:
    for spine in ax1.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(2)
        spine.set_color('black')

    # Place the legend in the upper right corner with a black border
    legend1 = ax1.legend(loc='upper right', fontsize=14)
    legend1.get_frame().set_edgecolor('black')
    legend1.get_frame().set_linewidth(1.5)

    ax1.grid(False)  # Remove grid lines
    ax1.tick_params(axis='both', which='major', labelsize=16)


    # --- Right Plot: Horizontal Histogram of SSA Trajectories ---
    # Flatten all SSA trajectory values into a single array
    ssa_values = matrix_intensity_first_signal_RT.flatten()

    ax2.hist(ssa_values, bins=100, orientation='horizontal',
            color=plot_color, alpha=0.7)
    ax2.set_xlabel('Counts', fontsize=20)
    ax2.set_ylabel('Intensity (a.u.)', fontsize=20)
    ax2.set_ylim(global_min, global_max)
    # set axis font size
    ax2.tick_params(axis='both', which='major', labelsize=16)

    # Set the axes frame with a distinct black border for ax2:
    for spine in ax2.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(2)
        spine.set_color('black')

    ax2.grid(False)  # Remove grid lines

    plt.tight_layout()
    plt.show()



In [None]:
plot_trajectories(matrix_intensity_first_signal_RT, intensity_vector_first_signal_ode, time_array, number_repetitions,plot_color='grey')

In [None]:
selected_trajectory = 0
#list_ribosome_trajectories, list_occupancy_output, matrix_intensity_first_signa_RT, matrix_intensity_second_signa_RT 
ribosome_trajectories = list_ribosome_trajectories[selected_trajectory]    
ribosome_trajectories = ribosome_trajectories[:,:]
intensity_vector_first_signal = matrix_intensity_first_signal_RT[selected_trajectory,:]
if second_probe_position_vector is not None:
    intensity_vector_second_signal = matrix_intensity_second_signal_RT[selected_trajectory,:]
else:
    intensity_vector_second_signal = None


In [None]:
str_ki = str(ki).replace('.','_')
str_k = str(global_elongation_rate).replace('.','_')
fileNameGif = 'simulation_'+plasmid_name+'_ke_'+str_k+'_ki_'+str_ki 
color = 'lightgreen'
plot_RibosomeMovement_and_Microscope(ribosome_trajectories, intensity_vector_first_signal, tag_positions_first_probe_vector, SecondIntensityVector=intensity_vector_second_signal, second_probePositions=tag_positions_second_probe_vector,FrameVelocity=10,fileNameGif=fileNameGif,color=color)

In [None]:
# simulating missing data 
simulate_missing_data = False
if simulate_missing_data:
    array_simulated,_ = mi.Utilities().simulate_missing_data(matrix_intensity_first_signal_RT,matrix2= None, percentage_to_remove_data=5,replace_with='nan')
    array_simulated  = mi.Utilities().shift_trajectories(array_simulated, )
else:
    array_simulated = matrix_intensity_first_signal_RT

In [None]:

def print_parameters (gene_length, tag_positions_first_probe_vector, g_0, dwell_time):
    print ('------------------------------------')
    #ke_calculated_ch0 =  np.round( (gene_length-np.max(tag_positions_first_probe_vector)/2) /dwell_time_ch0  , 2)
    ke_calculated_ch0 =  np.round( gene_length /dwell_time , 2)
    ke_calculated_ch0_corrected =  np.round( (gene_length-(np.max(tag_positions_first_probe_vector))) /dwell_time  , 2)
    print ('------------------------------------')
    print('Elongation rates: ')
    print('Calculated ch0', ke_calculated_ch0 , ' Corrected: ',ke_calculated_ch0_corrected, )
    print ('------------------------------------')
    print('Initiation rates: ')
    # initiation rate
    ki_calculated = np.round( 1/ (g_0* dwell_time), 3)
    print('Calculated', ki_calculated, ' 1/sec')
    print ('------------------------------------')
    print('Ribosomal density: ')
    # initiation rate
    #ribosomal_density = ki_calculated* ke_calculated_ch0_corrected
    ribosomal_density = np.round( (gene_length/ke_calculated_ch0_corrected) *ki_calculated , 1)
    print('Calculated', np.round(ribosomal_density,1) , ' average number of ribosomes per RNA')
    print ('------------------------------------')
    print('Ribosomal occurrence: ')
    # initiation rate
    ribsomal_space = 1/ki_calculated 
    print('Calculated', np.round(ribsomal_space,2) , ' seconds between ribosome initiation')
    print ('------------------------------------')
    return None

print_parameters(gene_length, tag_positions_first_probe_vector, g_0=0.08, dwell_time=279)

In [None]:
mean_correlation, std_correlation, lags, correlations_array, dwell_time = mi.Correlation(
    primary_data= matrix_intensity_first_signal_RT, #intensity_vector_first_signal.reshape(1, -1),  # SHAPE 2D
    #secondary_data=amplitude_vector_delayed.reshape(-1, 1),  # SHAPE 2D
    #max_lag=180, 
    nan_handling='forward_fill',
    shift_data=True,
    return_full=False,
    time_interval_between_frames_in_seconds=1,
    use_bootstrap=True,
    show_plot=True,
    start_lag=0,
    fit_type='linear',
    de_correlation_threshold=0.005,
    correct_baseline=True,
    use_linear_projection_for_lag_0=False,
    save_plots=False,
    use_global_mean=False,
    remove_outliers=False,
    MAD_THRESHOLD_FACTOR=6,
    plot_individual_trajectories=False,
    #x_axes_min_max_list_values=[-120,120],
    #y_axes_min_max_list_values=[-1,2],
    multi_tau=False,
    plot_title=None
).run()

In [None]:
raise Exception('Stop here')

In [None]:
class Correlation:
    """
    A class for calculating the autocorrelation or cross-correlation of datasets.

    Attributes:
        primary_data (np.ndarray): Primary dataset for autocorrelation, shape [sample, time].
        secondary_data (np.ndarray, optional): Secondary dataset for cross-correlation, same shape as primary.
        max_lag (int, optional): Maximum lag to compute correlation, defaults to half time series length.
        nan_handling (str, optional): Strategy to handle NaN values. Options: 'zeros', 'mean', 'forward_fill', 'ignore'.
        return_full (bool, optional): Whether to return the full correlation array or only positive lags.
        use_bootstrap (bool, optional): Whether to use bootstrap for error estimation.
        shift_data (bool, optional): Whether to shift data based on leading NaNs.
        show_plot (bool, optional): Whether to display the plot.
        save_plots (bool, optional): Whether to save the plots.
        plot_name (str, optional): Name of the plot file.
        time_interval_between_frames_in_seconds (int, optional): Time interval between frames.
        index_max_lag_for_fit (int, optional): Index for maximum lag for fitting.
        color_channel (int, optional): Color channel for plotting.
        start_lag (int, optional): Starting lag for plateau finding.
        line_color (str, optional): Color of the plot line.
        plot_title (str, optional): Title of the plot.
        fit_type (str, optional): Type of fit for the plot.
        de_correlation_threshold (float, optional): Threshold for decorrelation.
        use_linear_projection_for_lag_0 (bool, optional): Whether to use linear projection for lag 0.
        correct_baseline (bool, optional): If True, subtract baseline from mean correlation.
        use_global_mean (bool, optional): If True, use a global mean for correlation normalization.
        use_normalization_factor (bool, optional): If True, multiply correlation by 1/(global_means * #time_points).
        remove_outliers (bool, optional): If True, remove “extreme” outlier trajectories before computing mean.
        MAD_THRESHOLD_FACTOR (float, optional): Threshold factor for outlier removal.
        multi_tau (bool, optional): If True, use multi-tau algorithm for correlation (non-uniform, positive lags only).

    Methods:
        run():
            Executes the correlation computation based on initialized settings.

            Returns:
                (mean_correlation, error_correlation, lags, correlations_array, dwell_time)
    """
    def __init__(
        self,
        primary_data,
        secondary_data=None,
        max_lag=None,
        nan_handling='zeros',
        return_full=True,
        use_bootstrap=True,
        shift_data=False,
        show_plot=False,
        save_plots=False,
        plot_name='temp_AC.png',
        time_interval_between_frames_in_seconds=1,
        index_max_lag_for_fit=None,
        color_channel=0,
        start_lag=0,
        line_color='blue',
        correct_baseline=False,
        baseline_offset=None,
        use_global_mean=False,
        plot_title=None,
        fit_type='linear',
        de_correlation_threshold=0.01,
        use_linear_projection_for_lag_0=True,
        normalize_plot_with_g0=False,
        remove_outliers=True,
        MAD_THRESHOLD_FACTOR=6.0,
        plot_individual_trajectories=False,
        y_axes_min_max_list_values=None,
        x_axes_min_max_list_values=None,
        multi_tau=False
    ):
        def shift_and_fill(data1, data2=None, min_nan_threshold=3, fill_with_nans=True):
            """
            Remove leading NaNs beyond a threshold and shift arrays left, filling the end with NaNs or zeros.
            """
            if data1.ndim != 1:
                raise ValueError("Both data1 and data2 must be 1D arrays.")
            nan_count = 0
            for value in data1:
                if np.isnan(value):
                    nan_count += 1
                else:
                    break
            if nan_count >= min_nan_threshold:
                fill_value = np.nan if fill_with_nans else 0
                new_data1 = np.full_like(data1, fill_value)
                new_data1[: len(data1) - nan_count] = data1[nan_count:]
                if data2 is not None:
                    new_data2 = np.full_like(data2, fill_value)
                    new_data2[: len(data2) - nan_count] = data2[nan_count:]
                else:
                    new_data2 = None
                return new_data1, new_data2
            return data1, data2

        if shift_data:
            primary_data_shifted = np.zeros_like(primary_data)
            secondary_data_shifted = np.zeros_like(secondary_data) if secondary_data is not None else None
            for i in range(primary_data.shape[0]):
                if secondary_data is None:
                    primary_data_shifted[i, :], _ = shift_and_fill(primary_data[i, :], None, min_nan_threshold=2)
                else:
                    primary_data_shifted[i, :], secondary_data_shifted[i, :] = shift_and_fill(primary_data[i, :], secondary_data[i, :], min_nan_threshold=2)
            primary_data = primary_data_shifted
            if secondary_data is not None:
                secondary_data = secondary_data_shifted

        # Store attributes
        self.primary_data = primary_data
        self.secondary_data = secondary_data
        self.max_lag = max_lag
        self.nan_handling = nan_handling
        self.return_full = return_full
        self.use_bootstrap = use_bootstrap
        self.BOOTSTRAP_ITERATIONS = 1000
        self.time_interval_between_frames_in_seconds = float(time_interval_between_frames_in_seconds)
        self.index_max_lag_for_fit = index_max_lag_for_fit
        self.plot_name = plot_name
        self.save_plots = save_plots
        self.show_plot = show_plot
        self.color_channel = color_channel
        self.start_lag = start_lag
        self.line_color = line_color
        self.plot_title = plot_title
        self.fit_type = fit_type
        self.de_correlation_threshold = de_correlation_threshold
        self.use_linear_projection_for_lag_0 = use_linear_projection_for_lag_0
        self.normalize_plot_with_g0 = normalize_plot_with_g0
        self.correct_baseline = correct_baseline
        if baseline_offset is None:
            # Default: use half the time series length as baseline offset for fitting baseline
            self.baseline_offset = int(primary_data.shape[1] // 2)
        else:
            self.baseline_offset = baseline_offset
        if correct_baseline:
            plot_individual_trajectories = False
            print('Baseline correction is enabled. Plotting individual trajectories is disabled due to baseline correction.')
        self.use_global_mean = use_global_mean
        self.remove_outliers = remove_outliers
        self.MAD_THRESHOLD_FACTOR = MAD_THRESHOLD_FACTOR
        self.plot_individual_trajectories = plot_individual_trajectories
        self.y_axes_min_max_list_values = y_axes_min_max_list_values
        self.x_axes_min_max_list_values = x_axes_min_max_list_values
        self.multi_tau = multi_tau

    def run(self):
        """
        Execute the correlation calculations with optional bootstrap error estimation.
        """
        if self.max_lag is None:
            self.max_lag = self.primary_data.shape[1] - 1
        else:
            if self.max_lag >= self.primary_data.shape[1]:
                raise ValueError("Max lag cannot be greater than the length of the time series.")

        # Helper functions
        def trim_nans_from_edges(data):
            mask = ~np.isnan(data)
            if not np.any(mask):
                return np.array([])
            start_idx = np.argmax(mask)
            end_idx = len(mask) - np.argmax(mask[::-1])
            return data[start_idx:end_idx]

        # Prepare forward fill function if needed
        if self.nan_handling == "forward_fill":
            def forward_fill_func(data):
                not_nan = ~np.isnan(data)
                if not np.any(not_nan):
                    return np.array([])
                first_valid_index = np.argmax(not_nan)
                last_valid_index = len(data) - np.argmax(not_nan[::-1]) - 1
                trimmed = data[first_valid_index:last_valid_index + 1]
                mask_ff = np.isnan(trimmed)
                idx = np.where(~mask_ff, np.arange(len(trimmed)), 0)
                np.maximum.accumulate(idx, out=idx)
                filled = trimmed[idx]
                result = np.full_like(data, np.nan)
                result[first_valid_index:last_valid_index + 1] = filled
                return result
            local_forward_fill = forward_fill_func
        else:
            local_forward_fill = lambda arr: arr

        # Global means for normalization if needed
        global_mean_data1 = np.nanmean(self.primary_data)
        global_mean_data2 = np.nanmean(self.secondary_data) if self.secondary_data is not None else global_mean_data1

        if not self.multi_tau:
            # Linear correlation for each sample (symmetric output)
            def process_sample_linear(i):
                try:
                    data1 = trim_nans_from_edges(self.primary_data[i, :])
                    data2 = trim_nans_from_edges(self.secondary_data[i, :]) if self.secondary_data is not None else None
                    if data2 is None:
                        data2 = data1
                    # Handle NaNs according to strategy
                    if self.nan_handling == "mean":
                        mean_val1 = np.nanmean(data1) if len(data1) > 0 else 0.0
                        mean_val2 = np.nanmean(data2) if len(data2) > 0 else 0.0
                        data1 = np.nan_to_num(data1, nan=mean_val1)
                        data2 = np.nan_to_num(data2, nan=mean_val2)
                    elif self.nan_handling == "forward_fill":
                        data1 = local_forward_fill(data1)
                        data2 = local_forward_fill(data2)
                    elif self.nan_handling == "ignore":
                        valid_mask = ~np.isnan(data1) & ~np.isnan(data2)
                        data1 = data1[valid_mask]
                        data2 = data2[valid_mask]
                    elif self.nan_handling == "zeros":
                        data1 = np.nan_to_num(data1)
                        data2 = np.nan_to_num(data2)
                    effective_points = np.sum(~np.isnan(data1))
                    if effective_points < 1:
                        return np.full(2 * self.max_lag + 1, np.nan)
                    # Center data by mean
                    if self.use_global_mean:
                        local_mean1 = global_mean_data1
                        local_mean2 = global_mean_data2
                    else:
                        local_mean1 = np.nanmean(data1)
                        local_mean2 = np.nanmean(data2)
                    cdata1 = data1 - local_mean1
                    cdata2 = data2 - local_mean2
                    # Remove any residual NaNs
                    cdata1 = cdata1[~np.isnan(cdata1)]
                    cdata2 = cdata2[~np.isnan(cdata2)]
                    if len(cdata1) == 0 or len(cdata2) == 0:
                        return np.full(2 * self.max_lag + 1, np.nan)
                    N = len(cdata1)
                    # Compute raw full cross-correlation
                    raw_corr = np.correlate(cdata1, cdata2, mode="full")
                    mid = N - 1
                    min_overlap = max(5, int(0.2 * N))
                    final_corr = np.empty_like(raw_corr, dtype=np.float64)
                    for j in range(len(raw_corr)):
                        lag = j - mid
                        overlap = N - abs(lag)
                        if overlap < min_overlap:
                            final_corr[j] = np.nan
                            continue
                        if lag >= 0:
                            seg1 = data1[:N - lag]
                            seg2 = data2[lag:]
                        else:
                            seg1 = data1[-lag:]
                            seg2 = data2[:N + lag]
                        local_norm = np.nanmean(seg1) * np.nanmean(seg2)
                        if local_norm == 0:
                            final_corr[j] = np.nan
                        else:
                            final_corr[j] = (raw_corr[j] / overlap) / local_norm
                    mid_point = len(final_corr) // 2
                    desired_length = 2 * self.max_lag + 1
                    current_length = len(final_corr)
                    if current_length < desired_length:
                        # Pad with NaNs if needed
                        out = np.full(desired_length, np.nan)
                        start_idx = (desired_length - current_length) // 2
                        out[start_idx:start_idx + current_length] = final_corr
                        return out
                    else:
                        start_idx = mid_point - self.max_lag
                        end_idx = mid_point + self.max_lag + 1
                        return final_corr[start_idx:end_idx]
                except Exception as e:
                    print(f"Error in process_sample_linear for sample {i}: {e}")
                    return np.full(2 * self.max_lag + 1, np.nan)

            correlations_array = np.array(
                Parallel(n_jobs=-1)(delayed(process_sample_linear)(i) for i in range(self.primary_data.shape[0])),
                dtype=np.float64
            )
        else:
            # Multi-tau correlation for each sample (positive lags, non-uniform spacing)
            # Set parameter m (channels per stage after initial). m=8 gives 16 initial points, then groups of 8.
            m = 8
            N0 = self.primary_data.shape[1]
            global_lags_idx = []
            current_length = N0
            dt_factor = 1
            stage = 0
            while True:
                if stage == 0:
                    start_i = 0
                    end_i = min(2 * m - 1, self.max_lag, current_length - 1)
                else:
                    start_i = m
                    end_i = min(2 * m - 1, int(self.max_lag // dt_factor), current_length - 1)
                if start_i > end_i:
                    break
                for i_val in range(start_i, end_i + 1):
                    global_lags_idx.append(i_val * dt_factor)
                if end_i < 2 * m - 1 or (end_i * dt_factor) >= self.max_lag or current_length < 2:
                    break
                new_length = current_length // 2
                if new_length < 1:
                    break
                current_length = new_length
                dt_factor *= 2
                stage += 1
            global_lags_idx = sorted(set(global_lags_idx))
            global_lags_idx = np.array(global_lags_idx, dtype=int)
            idx_map = {lag: idx for idx, lag in enumerate(global_lags_idx)}

            def process_sample_multi_tau(i):
                try:
                    data1 = trim_nans_from_edges(self.primary_data[i, :])
                    data2 = trim_nans_from_edges(self.secondary_data[i, :]) if self.secondary_data is not None else None
                    if data2 is None:
                        data2 = data1
                    # Handle NaNs according to strategy
                    if self.nan_handling == "mean":
                        mean_val1 = np.nanmean(data1) if len(data1) > 0 else 0.0
                        mean_val2 = np.nanmean(data2) if len(data2) > 0 else 0.0
                        data1 = np.nan_to_num(data1, nan=mean_val1)
                        data2 = np.nan_to_num(data2, nan=mean_val2)
                    elif self.nan_handling == "forward_fill":
                        data1 = local_forward_fill(data1)
                        data2 = local_forward_fill(data2)
                    elif self.nan_handling == "ignore":
                        valid_mask = ~np.isnan(data1) & ~np.isnan(data2)
                        data1 = data1[valid_mask]
                        data2 = data2[valid_mask]
                    elif self.nan_handling == "zeros":
                        data1 = np.nan_to_num(data1)
                        data2 = np.nan_to_num(data2)
                    effective_points = np.sum(~np.isnan(data1))
                    if effective_points < 1:
                        # Return all-NaN array of global length
                        return np.full(len(global_lags_idx), np.nan)
                    # Center data by mean
                    if self.use_global_mean:
                        local_mean1 = global_mean_data1
                        local_mean2 = global_mean_data2
                    else:
                        local_mean1 = np.nanmean(data1)
                        local_mean2 = np.nanmean(data2)
                    cdata1 = data1 - local_mean1
                    cdata2 = data2 - local_mean2
                    # Remove any residual NaNs
                    mask_valid = ~np.isnan(cdata1) & ~np.isnan(cdata2)
                    cdata1 = cdata1[mask_valid]
                    cdata2 = cdata2[mask_valid]
                    data1_valid = data1[mask_valid]
                    data2_valid = data2[mask_valid]
                    current_N = len(cdata1)
                    if current_N == 0:
                        return np.full(len(global_lags_idx), np.nan)
                    # Multi-tau loop
                    output_corr = np.full(len(global_lags_idx), np.nan, dtype=np.float64)
                    current_data_raw1 = data1_valid.copy()
                    current_data_raw2 = data2_valid.copy()
                    current_cdata1 = cdata1.copy()
                    current_cdata2 = cdata2.copy()
                    dt_factor = 1
                    stage = 0
                    while True:
                        if stage == 0:
                            start_i = 0
                            end_i = min(2 * m - 1, self.max_lag, current_N - 1)
                        else:
                            start_i = m
                            end_i = min(2 * m - 1, int(self.max_lag // dt_factor), current_N - 1)
                        if start_i > end_i:
                            break
                        min_overlap = max(5, int(0.2 * current_N))
                        for j in range(start_i, end_i + 1):
                            overlap = current_N - j
                            if overlap < min_overlap:
                                continue
                            raw_sum = np.nansum(current_cdata1[:current_N - j] * current_cdata2[j:current_N])
                            seg1 = current_data_raw1[:current_N - j]
                            seg2 = current_data_raw2[j:current_N]
                            local_norm = np.nanmean(seg1) * np.nanmean(seg2)
                            if local_norm == 0:
                                corr_val = np.nan
                            else:
                                corr_val = (raw_sum / overlap) / local_norm
                            lag_index = j * dt_factor
                            if lag_index in idx_map:
                                output_corr[idx_map[lag_index]] = corr_val
                        if end_i < 2 * m - 1 or (end_i * dt_factor) >= self.max_lag or current_N < 2:
                            break
                        new_length = current_N // 2
                        if new_length < 1:
                            break
                        # Downsample data
                        if self.secondary_data is None:
                            new_data_raw1 = 0.5 * (current_data_raw1[:2 * new_length:2] + current_data_raw1[1:2 * new_length:2])
                            new_data_raw2 = new_data_raw1  # autocorrelation
                        else:
                            new_data_raw1 = 0.5 * (current_data_raw1[:2 * new_length:2] + current_data_raw1[1:2 * new_length:2])
                            new_data_raw2 = 0.5 * (current_data_raw2[:2 * new_length:2] + current_data_raw2[1:2 * new_length:2])
                        if self.use_global_mean:
                            new_mean1 = global_mean_data1
                            new_mean2 = global_mean_data2
                        else:
                            new_mean1 = np.nanmean(new_data_raw1)
                            new_mean2 = np.nanmean(new_data_raw2) if self.secondary_data is not None else new_mean1
                        new_cdata1 = new_data_raw1 - new_mean1
                        new_cdata2 = new_data_raw2 - new_mean2
                        # Update for next stage
                        current_data_raw1 = new_data_raw1
                        current_data_raw2 = new_data_raw2
                        current_cdata1 = new_cdata1 = new_cdata1  # (just to use consistent naming; not needed separately)
                        current_cdata2 = new_cdata2 = new_cdata2
                        current_N = new_length
                        dt_factor *= 2
                        stage += 1
                    return output_corr
                except Exception as e:
                    print(f"Error in process_sample_multi_tau for sample {i}: {e}")
                    return np.full(len(global_lags_idx), np.nan)

            correlations_array = np.array(
                Parallel(n_jobs=-1)(delayed(process_sample_multi_tau)(i) for i in range(self.primary_data.shape[0])),
                dtype=np.float64
            )

        # Remove outlier trajectories if required
        if self.remove_outliers and correlations_array.size > 0:
            traj_means = np.nanmean(correlations_array, axis=1)
            median_mean = np.nanmedian(traj_means)
            mad = np.nanmedian(np.abs(traj_means - median_mean))
            if mad == 0 or np.isnan(mad):
                keep_mask = np.ones_like(traj_means, dtype=bool)
            else:
                keep_mask = np.abs(traj_means - median_mean) < self.MAD_THRESHOLD_FACTOR * mad
            num_removed = np.sum(~keep_mask)
            num_total = len(traj_means)
            if num_removed > 0:
                print(f"Warning: Removed {num_removed} outlier trajectories (out of {num_total}) based on a threshold of {self.MAD_THRESHOLD_FACTOR} MAD from the median mean correlation.")
            correlations_array = correlations_array[keep_mask, :]

        # If all data removed or no valid points
        if correlations_array.shape[0] == 0:
            length = correlations_array.shape[1] if correlations_array.ndim > 1 else (len(global_lags_idx) if self.multi_tau else (2 * self.max_lag + 1))
            mean_correlation = np.full(length, np.nan)
            error_correlation = np.full_like(mean_correlation, np.nan)
            if not self.multi_tau:
                lags = np.arange(-self.max_lag, self.max_lag + 1) * self.time_interval_between_frames_in_seconds
            else:
                lags = (global_lags_idx if 'global_lags_idx' in locals() else np.arange(0, self.max_lag + 1)) * self.time_interval_between_frames_in_seconds
            return mean_correlation, error_correlation, lags, correlations_array, None

        # Compute mean correlation and error bars
        mean_correlation = np.nanmean(correlations_array, axis=0)
        # Correct baseline via exponential fit if requested
        if self.correct_baseline:
            L = len(mean_correlation) - 1
            start_idx_fit = 2
            time_range = int(L * 0.99)
            if time_range <= start_idx_fit:
                time_range = start_idx_fit + 1
            if not self.multi_tau:
                # Construct positive lags array (assuming mean_correlation corresponds to lag 0..max_lag if return_full=False, or symmetrical otherwise)
                if self.return_full:
                    # If still symmetric, convert to positive lags portion for fitting
                    lags_array = np.arange(0, L + 1) * self.time_interval_between_frames_in_seconds
                else:
                    lags_array = np.arange(0, len(mean_correlation)) * self.time_interval_between_frames_in_seconds
            else:
                lags_array = global_lags_idx * self.time_interval_between_frames_in_seconds
            y_fit = mean_correlation[start_idx_fit:time_range]
            t_fit = lags_array[start_idx_fit:time_range]
            B_guess = y_fit[-1] if len(y_fit) > 0 else 0.0
            A_guess = mean_correlation[0] - B_guess
            tau_guess = (t_fit[-1] - t_fit[0]) / 2.0 if len(t_fit) > 1 else 1.0
            initial_guess = [A_guess, tau_guess, B_guess]
            lower_bounds = [0, 1e-6, np.min(y_fit) if len(y_fit) > 0 else 0.0]
            upper_bounds = [np.inf, np.inf, mean_correlation[0]]
            mask_fit = np.isfinite(y_fit) & np.isfinite(t_fit)
            t_clean = t_fit[mask_fit]
            y_clean = y_fit[mask_fit]
            if len(t_clean) < 3:
                warnings.warn(f"Too few valid points ({len(t_clean)}) for exponential fit—using fallback.")
                fitted_B = np.nanpercentile(y_fit, 10) if len(y_fit) > 0 else 0.0
            else:
                try:
                    popt, _ = curve_fit(lambda t, A, tau, B: A * np.exp(-t / tau) + B,
                                        t_clean, y_clean, p0=initial_guess,
                                        bounds=(lower_bounds, upper_bounds),
                                        maxfev=10000)
                    fitted_B = popt[2]
                    print(f"Exponential fit parameters: A={popt[0]:.3g}, τ={popt[1]:.3g}, B={popt[2]:.3g}")
                except Exception as e:
                    warnings.warn(f"Exp fit failed ({type(e).__name__}: {e}) → using 10th percentile fallback.")
                    fitted_B = np.nanpercentile(y_fit, 10) if len(y_fit) > 0 else 0.0
            mean_correlation = mean_correlation - fitted_B

        num_kept = correlations_array.shape[0]
        if self.use_bootstrap and num_kept > 1:
            def single_bootstrap_iteration(_):
                rng = np.random.default_rng()
                indices = rng.choice(num_kept, size=num_kept, replace=True)
                sample = correlations_array[indices, :]
                m = np.nanmean(sample, axis=0)
                if self.correct_baseline:
                    if not self.multi_tau:
                        center_idx = self.max_lag
                        offset = min(self.baseline_offset, center_idx)
                        neg_region = m[center_idx - offset : center_idx]
                        pos_region = m[center_idx + 1 : center_idx + 1 + offset]
                        baseline_value = np.nanpercentile(np.concatenate([neg_region, pos_region]), 10)
                        m = m - baseline_value
                    else:
                        offset = min(self.baseline_offset, len(m))
                        tail_region = m[-offset:] if offset > 0 else m
                        baseline_value = np.nanpercentile(tail_region, 10)
                        m = m - baseline_value
                return m
            all_means = np.array(
                Parallel(n_jobs=-1)(delayed(single_bootstrap_iteration)(_) for _ in range(self.BOOTSTRAP_ITERATIONS)),
                dtype=np.float64
            )
            error_correlation = np.nanstd(all_means, axis=0)
        else:
            error_correlation = np.nanstd(correlations_array, axis=0) / np.sqrt(num_kept)

        # Construct lags array (in seconds)
        if not self.multi_tau:
            lags = np.arange(-self.max_lag, self.max_lag + 1) * self.time_interval_between_frames_in_seconds
        else:
            lags = global_lags_idx * self.time_interval_between_frames_in_seconds

        # Linear projection adjustment for lag=0
        if self.use_linear_projection_for_lag_0:
            if not self.multi_tau:
                center_idx = self.max_lag
                if self.secondary_data is None:
                    # Autocorrelation: use negative side to project to 0
                    if center_idx - 6 >= 0 and center_idx - 1 >= 0:
                        x = lags[center_idx - 6 : center_idx - 1]
                        y = mean_correlation[center_idx - 5 : center_idx]
                        if len(x) > 1 and np.all(np.isfinite(y)):
                            _, intercept, _, _, _ = linregress(x, y)
                            mean_correlation[center_idx] = intercept
                    if center_idx < len(error_correlation):
                        error_correlation[center_idx] = 0
                else:
                    # Cross-correlation: project both sides and take max
                    if center_idx - 6 >= 0 and center_idx - 1 >= 0:
                        x_bef = lags[center_idx - 6 : center_idx - 1]
                        y_bef = mean_correlation[center_idx - 5 : center_idx]
                        corr_before = linregress(x_bef, y_bef).intercept if len(x_bef) > 1 and np.all(np.isfinite(y_bef)) else mean_correlation[center_idx]
                    else:
                        corr_before = mean_correlation[center_idx]
                    if center_idx + 6 < len(mean_correlation):
                        x_aft = lags[center_idx + 1 : center_idx + 6]
                        y_aft = mean_correlation[center_idx + 1 : center_idx + 6]
                        corr_after = linregress(x_aft, y_aft).intercept if len(x_aft) > 1 and np.all(np.isfinite(y_aft)) else mean_correlation[center_idx]
                    else:
                        corr_after = mean_correlation[center_idx]
                    mean_correlation[center_idx] = np.nanmax([corr_before, corr_after])
                    if center_idx < len(error_correlation):
                        error_correlation[center_idx] = 0
            else:
                if self.secondary_data is None:
                    # Autocorrelation multi-tau: use first few lags to project to 0
                    if len(lags) > 5:
                        x = lags[1:6]
                        y = mean_correlation[1:6]
                        if len(x) > 1 and np.all(np.isfinite(y)):
                            _, intercept, _, _, _ = linregress(x, y)
                            mean_correlation[0] = intercept
                    if len(error_correlation) > 0:
                        error_correlation[0] = 0
                else:
                    # Cross-correlation multi-tau: no adjustment (lack negative lags)
                    pass

        # For linear correlation, handle return_full flag (for positive lags only)
        if not self.multi_tau and not self.return_full:
            mean_correlation = mean_correlation[self.max_lag:]
            error_correlation = error_correlation[self.max_lag:]
            correlations_array = correlations_array[:, self.max_lag:]
            lags = lags[self.max_lag:]

        dwell_time = None
        if self.show_plot:
            if self.secondary_data is None:
                dwell_time = mi.Plots().plot_autocorrelation(
                    mean_correlation=mean_correlation,
                    error_correlation=error_correlation,
                    lags=lags,
                    correlations_array=correlations_array,
                    time_interval_between_frames_in_seconds=self.time_interval_between_frames_in_seconds,
                    index_max_lag_for_fit=self.index_max_lag_for_fit,
                    start_lag=self.start_lag,
                    plot_name=self.plot_name,
                    save_plots=self.save_plots,
                    line_color=self.line_color,
                    plot_title=self.plot_title,
                    fit_type=self.fit_type,
                    de_correlation_threshold=self.de_correlation_threshold,
                    normalize_plot_with_g0=self.normalize_plot_with_g0,
                    plot_individual_trajectories=self.plot_individual_trajectories,
                    y_axes_min_max_list_values=self.y_axes_min_max_list_values,
                    x_axes_min_max_list_values=self.x_axes_min_max_list_values,
                )
            else:
                dwell_time = mi.Plots().plot_crosscorrelation(
                    intensity_array_ch0=self.primary_data,
                    intensity_array_ch1=self.secondary_data,
                    mean_correlation=mean_correlation,
                    error_correlation=error_correlation,
                    lags=lags,
                    time_interval_between_frames_in_seconds=self.time_interval_between_frames_in_seconds,
                    plot_name=self.plot_name,
                    save_plots=self.save_plots,
                    line_color=self.line_color,
                    plot_title=self.plot_title,
                    normalize_plot_with_g0=self.normalize_plot_with_g0,
                    y_axes_min_max_list_values=self.y_axes_min_max_list_values,
                    x_axes_min_max_list_values=self.x_axes_min_max_list_values,
                )
        return mean_correlation, error_correlation, lags, correlations_array, dwell_time

## Testing the temporal resolution
___

In [None]:
list_ke = np.linspace(2,10,number_tested_parameters).astype(int)


In [None]:
list_mean_correlation = []
list_std_correlation = []
list_lags = []
for i, ke_constant in enumerate (list_ke):
    ke = calculate_codon_elongation_rates (rna, global_elongation_rate=ke_constant)
    ssa_array = simulate_TASEP_SSA(ki, ke, gene_length, t_max,
                                time_interval_in_seconds=step_size_in_sec,
                                number_repetitions=number_repetitions, 
                                first_probe_position_vector=first_probe_position_vector, 
                                second_probe_position_vector=second_probe_position_vector,
                                burnin_time=burnin_time,
                                constant_elongation_rate=ke_constant,
                                fast_output=True)[2]
    # Calculating the autocorrelation of the intensity signal
    mean_correlation, std_correlation, lags, correlations_array, dwell_time = Correlation(primary_data=ssa_array,
                                                                                            max_lag=None, 
                                                                                            nan_handling='forward_fill',  #forward_fill, 'ignore'
                                                                                            shift_data=True,
                                                                                            return_full=False,
                                                                                            time_interval_between_frames_in_seconds=step_size_in_sec,
                                                                                            use_bootstrap=True,
                                                                                            show_plot=False,
                                                                                            start_lag=0,
                                                                                            fit_type='linear',
                                                                                            de_correlation_threshold=0.01,
                                                                                            correct_baseline=True,
                                                                                            use_linear_projection_for_lag_0=True,
                                                                                            save_plots=False,
                                                                                            use_global_mean= False,
                                                                                            remove_outliers = True,
                                                                                            MAD_THRESHOLD_FACTOR = MAD_THRESHOLD_FACTOR,
                                                                                            #high_outlier_percentile = high_outlier_percentile,
                                                                                            #low_outlier_percentile = low_outlier_percentile,
                                                                                            plot_individual_trajectories = False,
                                                                                            y_axes_min_max_list_values = None, #y_axes_min_max_list_values,
                                                                                            x_axes_min_max_list_values=None,
                                                                                            multi_tau=True,
                                                                                            plot_title=None).run()
    
    if downsample:
        mean_correlation = mean_correlation[::downsample_factor]
        std_correlation = std_correlation[::downsample_factor]
        lags = lags[::downsample_factor]

    list_mean_correlation.append(mean_correlation)
    list_std_correlation.append(std_correlation)
    list_lags.append(lags)


In [None]:
# calculate the min_max normalization to list_mean_correlation
list_mean_correlation_normalized = list_mean_correlation.copy()
for i, correlation in enumerate(list_mean_correlation):
    list_mean_correlation_normalized[i] = (correlation - np.nanmin(correlation))/(np.nanmax(correlation) - np.nanmin(correlation))



In [None]:
# calculate the theoretical deocorrelation values for each elongation rate.
list_theoretical_decorrelation = []
for i, ke_constant in enumerate (list_ke):
    list_theoretical_decorrelation.append( gene_length/ke_constant )
print(list_theoretical_decorrelation)


In [None]:
fig, ax = plt.subplots(figsize=(5, 3))
plt.rcParams.update({
    'figure.facecolor': 'white',    # figure background is white
    'axes.facecolor': 'white',      # axes (plot area) background is white
    'axes.edgecolor': 'black',      # axis spines (box) will be black
    'axes.linewidth': 1.5,          # thicker border lines for a clear box
    'font.family': 'sans-serif',
    'font.sans-serif': 'Arial',
    'axes.labelsize': 14,           # axis labels: minimum 12
    'axes.titlesize': 14,           # title font size: 16
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    # make ticks adn axels labels black fonts
    'axes.labelcolor': 'black',
    'text.color': 'black',
    'xtick.color': 'black',
    'ytick.color': 'black',
    'axes.edgecolor': 'black',
})
for i, mean_correlation in enumerate(list_mean_correlation_normalized):
    ax.plot(lags, mean_correlation, label='$k_e$='+str(list_ke[i]), linewidth=2)
    ax.fill_between(lags, mean_correlation - list_std_correlation[i], mean_correlation + list_std_correlation[i], alpha=0.2)
ax.set_xlabel(r'$\tau (s)$', fontsize=14)
ax.set_ylabel(r'$G(\tau)/G(0)$', fontdict={'fontsize': 14})
ax.legend()

# plot the line of the theoretical decorrelation as a vertical dashed line of the same color as the plot line. 
for i, theoretical_decorrelation in enumerate(list_theoretical_decorrelation):
    ax.axvline(x=theoretical_decorrelation, color=ax.get_lines()[i].get_color(), linestyle='--',lw=1)
# plot a circle at the theoretical decorrelation value and y =0
for i, theoretical_decorrelation in enumerate(list_theoretical_decorrelation):
    ax.plot(theoretical_decorrelation, 0, markersize=10, marker='o', color=ax.get_lines()[i].get_color())

ax.set_xlim([-10, 1500])
for spine in ax.spines.values():
    spine.set_edgecolor('black')
    spine.set_linewidth(1.5)
plt.show()


### Testing initiation
____

In [None]:
list_ki = np.round(np.linspace(0.01,0.1,number_tested_parameters),3)
print(list_ki)

In [None]:
list_mean_correlation_ki = []
list_std_correlation_ki = []
list_lags_ki = []
fixed_ke = 5
for i, ki_tested in enumerate (list_ki):
    ke = calculate_codon_elongation_rates (rna, global_elongation_rate=fixed_ke)
    ssa_array = simulate_TASEP_SSA(ki_tested, ke, gene_length, t_max,
                                time_interval_in_seconds=step_size_in_sec,
                                number_repetitions=number_repetitions, 
                                first_probe_position_vector=first_probe_position_vector, 
                                second_probe_position_vector=second_probe_position_vector,
                                burnin_time=burnin_time,
                                constant_elongation_rate=fixed_ke,
                                fast_output=True)[2]
    # Calculating the autocorrelation of the intensity signal
    mean_correlation, std_correlation, lags, correlations_array, dwell_time = Correlation(primary_data=ssa_array,
                                                                                            max_lag=None, 
                                                                                            nan_handling='forward_fill',  #forward_fill, 'ignore'
                                                                                            shift_data=True,
                                                                                            return_full=False,
                                                                                            time_interval_between_frames_in_seconds=step_size_in_sec,
                                                                                            use_bootstrap=True,
                                                                                            show_plot=False,
                                                                                            start_lag=0,
                                                                                            fit_type='linear',
                                                                                            de_correlation_threshold=0.01,
                                                                                            correct_baseline=True,
                                                                                            use_linear_projection_for_lag_0=False,
                                                                                            save_plots=False,
                                                                                            use_global_mean= False,
                                                                                            remove_outliers = True,
                                                                                            #high_outlier_percentile = high_outlier_percentile,
                                                                                            #low_outlier_percentile = low_outlier_percentile,
                                                                                            plot_individual_trajectories = False,
                                                                                            y_axes_min_max_list_values = None, #y_axes_min_max_list_values,
                                                                                            x_axes_min_max_list_values=None,
                                                                                            multi_tau=True,
                                                                                            plot_title=None).run()
    
    if downsample:
        mean_correlation = mean_correlation[::downsample_factor]
        std_correlation = std_correlation[::downsample_factor]
        lags = lags[::downsample_factor]
    

    list_mean_correlation_ki.append(mean_correlation)
    list_std_correlation_ki.append(std_correlation)
    list_lags_ki.append(lags)


In [None]:
# calculate the theoretical G(0) values for each initiation rate.

theoretical_decorrelation_fixed_ke = gene_length/fixed_ke

list_theoretical_G0 = []
for i, ki_tested in enumerate (list_ki):
    list_theoretical_G0.append( 1/(ki_tested*theoretical_decorrelation_fixed_ke))

In [None]:
list_theoretical_G0

In [None]:
fig, ax = plt.subplots(figsize=(5, 3))
# Update rcParams to set a white background and Arial fonts.
plt.rcParams.update({
    'figure.facecolor': 'white',    # figure background is white
    'axes.facecolor': 'white',      # axes (plot area) background is white
    'axes.edgecolor': 'black',      # axis spines (box) will be black
    'axes.linewidth': 1.5,          # thicker border lines for a clear box
    'font.family': 'sans-serif',
    'font.sans-serif': 'Arial',
    'axes.labelsize': 14,           # axis labels: minimum 12
    'axes.titlesize': 14,           # title font size: 16
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    # make ticks adn axels labels black fonts
    'axes.labelcolor': 'black',
    'text.color': 'black',
    'xtick.color': 'black',
    'ytick.color': 'black',
    'axes.edgecolor': 'black',
})
for i, mean_correlation in enumerate(list_mean_correlation_ki):
    ax.plot(lags, mean_correlation, label='$k_i$='+str(np.round(list_ki[i],2)), linewidth=2)
    ax.fill_between(lags, mean_correlation - list_std_correlation_ki[i], mean_correlation + list_std_correlation_ki[i], alpha=0.2)
ax.set_xlabel(r'$\tau (s)$', fontsize=14)

ax.set_ylabel(r'$G(\tau)$', fontsize=14)
# legend location top right
ax.legend(loc='upper right', fontsize=12)

# plot the line of the theoretical decorrelation as a horizontal dashed line of the same color as the plot line.
for i, theoretical_G0 in enumerate(list_theoretical_G0):
    ax.axhline(y=theoretical_G0, color=ax.get_lines()[i].get_color(), linestyle='--',lw=1)
# plot a circle at the theoretical decorrelation value and x =0
for i, theoretical_G0 in enumerate(list_theoretical_G0):
    ax.plot(0, theoretical_G0, markersize=10, marker='o', color=ax.get_lines()[i].get_color())

for spine in ax.spines.values():
    spine.set_edgecolor('black')
    spine.set_linewidth(1.5)
ax.set_xlim([-20, 1500])

plt.show()


In [None]:
raise

## Multi-tau Algorithm 
___

In [None]:
class Correlation:
    """
    ... [existing docstring content] ...
        multi_tau (bool, optional): If True, use multiple-τ downsampling algorithm for autocorrelation (logarithmic lag spacing). 
                                    (Note: multi_tau applies **only** to autocorrelation; for cross-correlation this option is ignored.)
    """
    def __init__(
        self,
        primary_data,
        secondary_data=None,
        max_lag=None,
        nan_handling='zeros',
        return_full=True,
        use_bootstrap=True,
        shift_data=False,
        show_plot=False,
        save_plots=False,
        plot_name='temp_AC.png',
        time_interval_between_frames_in_seconds=1,
        index_max_lag_for_fit=None,
        color_channel=0,
        start_lag=0,
        line_color='blue',
        correct_baseline=False,
        baseline_offset=None,
        use_global_mean=False,
        plot_title=None,
        fit_type='linear',
        de_correlation_threshold=0.01,
        use_linear_projection_for_lag_0=True,
        normalize_plot_with_g0=False,
        remove_outliers=True,
        MAD_THRESHOLD_FACTOR=6.0,
        plot_individual_trajectories=False,
        y_axes_min_max_list_values=None,
        x_axes_min_max_list_values=None,
        multi_tau=False    # <-- Added multi_tau parameter
    ):
        

        def shift_and_fill(data1, data2=None, min_nan_threshold=3, fill_with_nans=True):
            """
            Processes two 1D NumPy arrays by removing leading NaNs that exceed a given threshold,
            then shifts both arrays left using the shift determined from the first array, and fills
            the rightmost part with NaNs or zeros to maintain the original shape.
            """
            if data1.ndim != 1:
                raise ValueError("Both data1 and data2 must be 1D arrays.")

            nan_count = 0
            for value in data1:
                if np.isnan(value):
                    nan_count += 1
                else:
                    break
            if nan_count >= min_nan_threshold:
                fill_value = np.nan if fill_with_nans else 0

                new_data1 = np.full_like(data1, fill_value)
                new_data1[: len(data1) - nan_count] = data1[nan_count:]

                if data2 is not None:
                    new_data2 = np.full_like(data2, fill_value)
                    new_data2[: len(data2) - nan_count] = data2[nan_count:]
                else:
                    new_data2 = None
                return new_data1, new_data2
            return data1, data2
        if shift_data:
            primary_data_shifted = np.zeros_like(primary_data)
            if secondary_data is not None:
                secondary_data_shifted = np.zeros_like(secondary_data)
            else:
                secondary_data_shifted = None

            for i in range(primary_data.shape[0]):
                if secondary_data is None:
                    primary_data_shifted[i, :], _ = shift_and_fill(
                        primary_data[i, :], None, min_nan_threshold=2
                    )
                else:
                    primary_data_shifted[i, :], secondary_data_shifted[i, :] = shift_and_fill(
                        primary_data[i, :], secondary_data[i, :], min_nan_threshold=2
                    )
            primary_data = primary_data_shifted
            if secondary_data is not None:
                secondary_data = secondary_data_shifted

        # Store attributes
        self.primary_data = primary_data
        self.secondary_data = secondary_data
        self.max_lag = max_lag
        self.nan_handling = nan_handling
        self.return_full = return_full
        self.use_bootstrap = use_bootstrap
        self.BOOTSTRAP_ITERATIONS = 1000
        self.time_interval_between_frames_in_seconds = float(time_interval_between_frames_in_seconds)
        self.index_max_lag_for_fit = index_max_lag_for_fit
        self.plot_name = plot_name
        self.save_plots = save_plots
        self.show_plot = show_plot
        self.color_channel = color_channel
        self.start_lag = start_lag
        self.line_color = line_color
        self.plot_title = plot_title
        self.fit_type = fit_type
        self.de_correlation_threshold = de_correlation_threshold
        self.use_linear_projection_for_lag_0 = use_linear_projection_for_lag_0
        self.normalize_plot_with_g0 = normalize_plot_with_g0
        self.correct_baseline = correct_baseline
        if baseline_offset is None:
            # Use half the time series length as default baseline offset
            self.baseline_offset = int(primary_data.shape[1] // 2)
        else:
            self.baseline_offset = baseline_offset
        self.use_global_mean = use_global_mean
        if correct_baseline:
            plot_individual_trajectories = False
            print('Baseline correction is enabled. Plotting individual trajectories is disabled due to baseline correction.')
        self.remove_outliers = remove_outliers
        self.MAD_THRESHOLD_FACTOR = MAD_THRESHOLD_FACTOR
        self.plot_individual_trajectories = plot_individual_trajectories
        self.y_axes_min_max_list_values = y_axes_min_max_list_values
        self.x_axes_min_max_list_values = x_axes_min_max_list_values
        self.multi_tau = multi_tau  # <-- Store multi_tau flag

    def run(self):
        """
        Execute the correlation calculations (auto or cross) with optional multi-tau downsampling.

        Returns:
            mean_correlation, error_correlation, lags, correlations_array, dwell_time
        """
        # Determine maximum lag if not set
        if self.max_lag is None:
            self.max_lag = self.primary_data.shape[1] - 1
        else:
            if self.max_lag >= self.primary_data.shape[1]:
                raise ValueError("Max lag cannot be greater than the length of the time series.")
        
        # Multi-tau autocorrelation branch
        if self.multi_tau and self.secondary_data is None:
            # Use multi-tau algorithm for autocorrelation
            num_traj = self.primary_data.shape[0]
            correlations_list = []
            lags_list = []
            # Pre-compute global mean if needed
            global_mean_val = None
            if self.use_global_mean:
                global_mean_val = np.nanmean(self.primary_data)
            for i in range(num_traj):
                # Extract single trajectory and apply NaN trimming and handling
                data = self.primary_data[i, :].astype(float)
                # Trim leading/trailing NaNs
                mask = ~np.isnan(data)
                if not np.any(mask):
                    # All NaNs, skip this trajectory (no correlation)
                    correlations_list.append([]); lags_list.append([])
                    continue
                start_idx = np.argmax(mask)
                end_idx = len(mask) - np.argmax(mask[::-1])
                data = data[start_idx:end_idx]
                # Handle NaNs according to strategy
                if self.nan_handling == "mean":
                    mean_val = np.nanmean(data) if np.isnan(data).any() else np.mean(data)
                    data = np.nan_to_num(data, nan=mean_val)
                elif self.nan_handling == "forward_fill":
                    # Forward fill NaNs within the trimmed segment
                    not_nan = ~np.isnan(data)
                    if np.any(not_nan):
                        first_valid = np.argmax(not_nan)
                        last_valid = len(data) - np.argmax(not_nan[::-1]) - 1
                        segment = data[first_valid:last_valid + 1]
                        mask_seg = np.isnan(segment)
                        idx = np.where(~mask_seg, np.arange(len(segment)), 0)
                        np.maximum.accumulate(idx, out=idx)
                        segment_filled = segment[idx]
                        data_filled = np.full_like(data, np.nan)
                        data_filled[first_valid:last_valid + 1] = segment_filled
                        data = data_filled
                elif self.nan_handling == "ignore":
                    valid_mask = ~np.isnan(data)
                    data = data[valid_mask]
                elif self.nan_handling == "zeros":
                    data = np.nan_to_num(data)
                if data.size == 0:
                    # No valid data after NaN handling
                    correlations_list.append([]); lags_list.append([])
                    continue

                # Subtract mean (global or local)
                if self.use_global_mean and global_mean_val is not None:
                    local_mean = global_mean_val
                else:
                    local_mean = np.nanmean(data)
                series_orig = data.copy()  # original processed series (with NaNs handled)
                cdata = np.nan_to_num(series_orig - local_mean)  # mean-centered series

                # Multi-tau correlation computation:
                L_current = len(cdata)
                factor = 1
                level = 0
                traj_corr = []
                traj_lags = []
                # Determine block size m (even integer for multi-tau)
                m = 16 if not hasattr(self, 'm') else self.m  # default to 16 if not provided separately
                # (Alternatively, m could be made an attribute or parameter if desired)
                while True:
                    L = L_current
                    if L == 0:
                        break
                    # Define lag range for this level
                    if level == 0:
                        start_lag = 0
                        max_new_lag = min(L - 1, m, self.max_lag)
                    else:
                        start_lag = 2  # skip 0 and 1 at subsequent levels (avoid overlapping earlier outputs)
                        # Max new lag allowed by data length and overall max_lag
                        max_new_lag = min(L - 1, m, self.max_lag // factor if factor > 0 else 0)
                    if max_new_lag < start_lag:
                        break  # no output at this level
                    # Compute autocorrelation on the current data (positive lags up to max_new_lag)
                    raw_corr = np.correlate(cdata[:L], cdata[:L], mode="full")
                    mid_idx = L - 1  # index corresponding to lag 0
                    # Loop through requested lags at this level
                    for new_k in range(start_lag, max_new_lag + 1):
                        orig_lag = new_k * factor
                        if orig_lag > self.max_lag:
                            break
                        overlap = L - new_k  # number of overlapping points for this lag
                        # Require a minimum overlap (to avoid very high-noise points)
                        min_overlap = max(5, int(0.2 * L))
                        if overlap < min_overlap:
                            continue  # skip if not enough overlap 
                        # Compute local normalization using the mean of segments (on original processed series)
                        seg1 = series_orig[: L - new_k]
                        seg2 = series_orig[new_k: L] 
                        # (Using current series segments for normalization; they approximate original segments)
                        mean1 = np.nanmean(seg1) if seg1.size > 0 else 0.0
                        mean2 = np.nanmean(seg2) if seg2.size > 0 else 0.0
                        norm_factor = mean1 * mean2
                        # Raw correlation value (sum of products) at this lag:
                        raw_val = raw_corr[mid_idx + new_k]  
                        if norm_factor == 0 or overlap == 0:
                            corr_val = np.nan
                        else:
                            # Normalize by number of pairs (overlap) and mean product
                            corr_val = (raw_val / overlap) / norm_factor
                        traj_corr.append(corr_val)
                        traj_lags.append(orig_lag)
                    # Decide whether to continue to next level
                    new_length = L // 2  # next downsampled length (drop one if L is odd)
                    if new_length < 1 or new_length <= m:
                        # Stop if next level would have too few points (<= m) [oai_citation:2‡file-padaqanwh6amz255s9e9yh](file://file-PadaQAnWH6aMZ255s9e9yh#:~:text=2%20for%20q%202%20f,it%20will%20assume%20the%20values)
                        break
                    # Down-sample the series by averaging pairs
                    if L % 2 == 1:
                        # If odd length, ignore the last point in averaging
                        cdata = (cdata[:L-1:2] + cdata[1:L:2]) / 2.0
                        series_orig = (series_orig[:L-1:2] + series_orig[1:L:2]) / 2.0
                        L_current = new_length  # L//2
                    else:
                        cdata = (cdata[0:L:2] + cdata[1:L:2]) / 2.0
                        series_orig = (series_orig[0:L:2] + series_orig[1:L:2]) / 2.0
                        L_current = new_length
                    factor *= 2
                    level += 1
                correlations_list.append(traj_corr)
                lags_list.append(traj_lags)
            
            # Align all trajectories' results to the same lag grid
            # Gather unique lag values from all trajectories
            all_lags = set()
            for lag_arr in lags_list:
                all_lags.update(lag_arr)
            if len(all_lags) == 0:
                # No valid data in any trajectory
                mean_correlation = np.full( (1 if self.return_full else 1), np.nan )
                error_correlation = np.full_like(mean_correlation, np.nan)
                lags = np.array([])  # no lags
                correlations_array = np.empty((0, 0))
                return mean_correlation, error_correlation, lags, correlations_array, None
            all_lags = sorted(all_lags)
            # Create correlation_array (N_trajectories x N_lags) and fill with NaN initially
            N_traj = len(correlations_list)
            N_lags = len(all_lags)
            correlations_array = np.full((N_traj, N_lags), np.nan, dtype=float)
            for idx, (corr_vals, lag_vals) in enumerate(zip(correlations_list, lags_list)):
                if len(corr_vals) == 0:
                    continue  # trajectory had no valid correlation
                # Map each trajectory's results to the unified lag positions
                lag_to_idx = {lag: j for j, lag in enumerate(all_lags)}
                for val, lag in zip(corr_vals, lag_vals):
                    j = lag_to_idx.get(lag)
                    if j is not None:
                        correlations_array[idx, j] = val
            lags = np.array(all_lags, dtype=float) * self.time_interval_between_frames_in_seconds

            # Remove outlier trajectories if enabled
            if self.remove_outliers and correlations_array.shape[0] > 0:
                traj_means = np.nanmean(correlations_array, axis=1)
                median_mean = np.nanmedian(traj_means)
                mad = np.nanmedian(np.abs(traj_means - median_mean))
                if mad == 0:
                    keep_mask = np.ones_like(traj_means, dtype=bool)
                else:
                    keep_mask = np.abs(traj_means - median_mean) < self.MAD_THRESHOLD_FACTOR * mad
                num_removed = np.sum(~keep_mask)
                num_total = len(traj_means)
                if num_removed > 0:
                    print(f"Warning: Removed {num_removed} outlier trajectories (out of {num_total}) based on a threshold of {self.MAD_THRESHOLD_FACTOR} MAD from the median mean correlation.")
                correlations_array = correlations_array[keep_mask, :]
                # Update trajectory count after removal
                N_traj = correlations_array.shape[0]

            # If no valid trajectories remain:
            if correlations_array.size == 0 or N_traj == 0:
                mean_correlation = np.full((len(all_lags),), np.nan)
                error_correlation = np.full_like(mean_correlation, np.nan)
                return mean_correlation, error_correlation, lags, correlations_array, None

            # Compute mean correlation across trajectories (ignoring NaNs)
            mean_correlation = np.nanmean(correlations_array, axis=0)
            # Baseline correction for mean correlation (plateau subtraction) if enabled
            if self.correct_baseline:
                L_total = len(mean_correlation) - 1  # index of last lag
                # Define region for exponential fit: from lag index 2 up to ~99% of max lag
                start_idx_fit = 2
                end_idx_fit = int(L_total * 0.99) if int(L_total * 0.99) > start_idx_fit else L_total
                y_fit = mean_correlation[start_idx_fit: end_idx_fit]
                t_fit = np.arange(start_idx_fit, end_idx_fit) * self.time_interval_between_frames_in_seconds
                # Initial guesses for exponential fit
                if y_fit.size > 0:
                    B_guess = y_fit[-1] if not np.isnan(y_fit[-1]) else np.nanmedian(y_fit)
                else:
                    B_guess = 0.0
                A_guess = (mean_correlation[0] - B_guess) if not np.isnan(mean_correlation[0]) else 0.0
                tau_guess = (t_fit[-1] - t_fit[0]) / 2.0 if t_fit.size > 1 else 1.0
                initial_guess = [A_guess, tau_guess, B_guess]
                lower_bounds = [0, 1e-6, np.min(y_fit) if y_fit.size > 0 else 0.0]
                upper_bounds = [np.inf, np.inf, mean_correlation[0] if not np.isnan(mean_correlation[0]) else np.inf]
                try:
                    # Perform exponential decay fit: f(t) = A * exp(-t/tau) + B
                    popt, _ = curve_fit(lambda t, A, tau, B: A * np.exp(-t / tau) + B,
                                        t_fit[np.isfinite(t_fit) & np.isfinite(y_fit)],
                                        y_fit[np.isfinite(t_fit) & np.isfinite(y_fit)],
                                        p0=initial_guess, bounds=(lower_bounds, upper_bounds), maxfev=10000)
                    _, _, fitted_B = popt
                except Exception as e:
                    # Fallback: use 10th percentile of tail as baseline estimate
                    fitted_B = np.nanpercentile(y_fit, 10) if y_fit.size > 0 else 0.0
                # Subtract estimated baseline (plateau) from mean correlation
                mean_correlation = mean_correlation - fitted_B

            # Compute standard error of the mean via bootstrap or analytic formula
            num_kept = correlations_array.shape[0]
            if self.use_bootstrap and num_kept > 1:
                bootstrap_means = []
                rng = np.random.default_rng()
                for _ in range(self.BOOTSTRAP_ITERATIONS):
                    # Resample trajectories with replacement
                    indices = rng.choice(num_kept, size=num_kept, replace=True)
                    sample = correlations_array[indices, :]
                    m = np.nanmean(sample, axis=0)
                    if self.correct_baseline:
                        # Subtract baseline from this sample's mean (use tail percentile as baseline)
                        offset = max(1, len(m) // 4)  # use last quarter of points for baseline
                        baseline_val = np.nanpercentile(m[-offset:], 10)
                        m = m - baseline_val
                    bootstrap_means.append(m)
                bootstrap_means = np.array(bootstrap_means, dtype=float)
                error_correlation = np.nanstd(bootstrap_means, axis=0)
            else:
                # Standard error: std / sqrt(N) (ignoring NaNs)
                error_correlation = np.nanstd(correlations_array, axis=0) / np.sqrt(num_kept)
            # Lags array is already computed (positive lags only for multi_tau)
            # If return_full is False or multi_tau (auto) by design returns only positive lags:
            if not self.return_full:
                # Already only non-negative lags
                pass  # (mean_correlation, error_correlation, etc., are already for lag >= 0)
            else:
                # For multi_tau autocorrelation, we will **not** mirror negative lags.
                # (Auto-correlation is symmetric, but multi-tau returns only positive lags in this implementation.)
                pass

            # Plotting if requested (autocorrelation plot)
            dwell_time = None
            if self.show_plot:
                dwell_time = mi.Plots().plot_autocorrelation(
                    mean_correlation=mean_correlation,
                    error_correlation=error_correlation,
                    lags=lags,
                    correlations_array=correlations_array,
                    time_interval_between_frames_in_seconds=self.time_interval_between_frames_in_seconds,
                    index_max_lag_for_fit=self.index_max_lag_for_fit,
                    start_lag=self.start_lag,
                    plot_name=self.plot_name,
                    save_plots=self.save_plots,
                    line_color=self.line_color,
                    plot_title=self.plot_title,
                    fit_type=self.fit_type,
                    de_correlation_threshold=self.de_correlation_threshold,
                    normalize_plot_with_g0=self.normalize_plot_with_g0,
                    plot_individual_trajectories=self.plot_individual_trajectories,
                    y_axes_min_max_list_values=self.y_axes_min_max_list_values,
                    x_axes_min_max_list_values=self.x_axes_min_max_list_values,
                )
            return mean_correlation, error_correlation, lags, correlations_array, dwell_time

        # ... [existing code for standard correlation (no multi_tau or cross-correlation)] ...
        # (No changes to the original non-multi_tau branch of run())

In [None]:
list_mean_correlation_ki_multi = []
list_std_correlation_ki_multi = []
list_lags_ki_multi = []
fixed_ke = 5
for i, ki_tested in enumerate (list_ki):
    ke = calculate_codon_elongation_rates (rna, global_elongation_rate=fixed_ke)
    ssa_array = simulate_TASEP_SSA(ki_tested, ke, gene_length, t_max,
                                time_interval_in_seconds=step_size_in_sec,
                                number_repetitions=number_repetitions, 
                                first_probe_position_vector=first_probe_position_vector, 
                                second_probe_position_vector=second_probe_position_vector,
                                burnin_time=burnin_time,
                                constant_elongation_rate=fixed_ke,
                                fast_output=True)[2]
    # Calculating the autocorrelation of the intensity signal
    mean_correlation_multi, std_correlation_multi, lags_multi, correlations_array_multi, dwell_time_multi = Correlation(primary_data=ssa_array,
                                                                                            max_lag=None, 
                                                                                            nan_handling='forward_fill',  #forward_fill, 'ignore'
                                                                                            shift_data=True,
                                                                                            return_full=False,
                                                                                            time_interval_between_frames_in_seconds=step_size_in_sec,
                                                                                            use_bootstrap=True,
                                                                                            show_plot=False,
                                                                                            start_lag=0,
                                                                                            fit_type='linear',
                                                                                            de_correlation_threshold=0.01,
                                                                                            correct_baseline=True,
                                                                                            use_linear_projection_for_lag_0=False,
                                                                                            save_plots=False,
                                                                                            use_global_mean= False,
                                                                                            remove_outliers = True,
                                                                                            #high_outlier_percentile = high_outlier_percentile,
                                                                                            #low_outlier_percentile = low_outlier_percentile,
                                                                                            plot_individual_trajectories = False,
                                                                                            y_axes_min_max_list_values = None, #y_axes_min_max_list_values,
                                                                                            x_axes_min_max_list_values=None,
                                                                                            multi_tau=True,  # Use multi-tau for autocorrelation
                                                                                            plot_title=None).run()
        
    if downsample:
        mean_correlation_multi = mean_correlation_multi[::downsample_factor]
        std_correlation_multi = std_correlation_multi[::downsample_factor]
        lags_multi = lags_multi[::downsample_factor]


    list_mean_correlation_ki_multi.append(mean_correlation_multi)
    list_std_correlation_ki_multi.append(std_correlation_multi)
    list_lags_ki_multi.append(lags_multi)



In [None]:
fig, ax = plt.subplots(figsize=(5, 3))

for i, mean_correlation in enumerate(list_mean_correlation_ki_multi):
    ax.plot(lags_multi, mean_correlation, label='$k_i$='+str(np.round(list_ki[i],2)), linewidth=2)
    ax.fill_between(lags_multi, mean_correlation - list_std_correlation_ki_multi[i], mean_correlation + list_std_correlation_ki_multi[i], alpha=0.2)
ax.set_xlabel(r'$\tau (s)$', fontsize=14)

ax.set_ylabel(r'$G(\tau)$', fontsize=14)
# legend location top right
ax.legend(loc='upper right', fontsize=12)

# plot the line of the theoretical decorrelation as a horizontal dashed line of the same color as the plot line.
for i, theoretical_G0 in enumerate(list_theoretical_G0):
    ax.axhline(y=theoretical_G0, color=ax.get_lines()[i].get_color(), linestyle='--',lw=1)
# plot a circle at the theoretical decorrelation value and x =0
for i, theoretical_G0 in enumerate(list_theoretical_G0):
    ax.plot(0, theoretical_G0, markersize=10, marker='o', color=ax.get_lines()[i].get_color())

for spine in ax.spines.values():
    spine.set_edgecolor('black')
    spine.set_linewidth(1.5)
ax.set_xlim([-20, 1500])

plt.show()


In [None]:
list_mean_correlation = []
list_std_correlation = []
list_lags = []
for i, ke_constant in enumerate (list_ke):
    ke = calculate_codon_elongation_rates (rna, global_elongation_rate=ke_constant)
    ssa_array = simulate_TASEP_SSA(ki, ke, gene_length, t_max,
                                time_interval_in_seconds=step_size_in_sec,
                                number_repetitions=number_repetitions, 
                                first_probe_position_vector=first_probe_position_vector, 
                                second_probe_position_vector=second_probe_position_vector,
                                burnin_time=burnin_time,
                                constant_elongation_rate=ke_constant,
                                fast_output=True)[2]
    # Calculating the autocorrelation of the intensity signal
    mean_correlation, std_correlation, lags, correlations_array, dwell_time = Correlation(primary_data=ssa_array,
                                                                                            max_lag=None,
                                                                                            nan_handling='forward_fill',  # forward_fill, 'ignore'
                                                                                            shift_data=True,
                                                                                            return_full=False,
                                                                                            time_interval_between_frames_in_seconds=step_size_in_sec,
                                                                                            use_bootstrap=True,
                                                                                            show_plot=False,
                                                                                            start_lag=0,
                                                                                            fit_type='linear',
                                                                                            de_correlation_threshold=0.01,
                                                                                            correct_baseline=True,
                                                                                            use_linear_projection_for_lag_0=True,
                                                                                            save_plots=False,
                                                                                            use_global_mean= False,
                                                                                            remove_outliers = True,
                                                                                            MAD_THRESHOLD_FACTOR = MAD_THRESHOLD_FACTOR,
                                                                                            #high_outlier_percentile = high_outlier_percentile,
                                                                                            #low_outlier_percentile = low_outlier_percentile,
                                                                                            plot_individual_trajectories = False,
                                                                                            y_axes_min_max_list_values = None, #y_axes_min_max_list_values,
                                                                                            x_axes_min_max_list_values=None,
                                                                                            multi_tau=True,  # Use standard autocorrelation
                                                                                            plot_title=None).run()
    
    if downsample:
        mean_correlation = mean_correlation[::downsample_factor]
        std_correlation = std_correlation[::downsample_factor]
        lags = lags[::downsample_factor]

    list_mean_correlation.append(mean_correlation)
    list_std_correlation.append(std_correlation)
    list_lags.append(lags)


In [None]:
# calculate the min_max normalization to list_mean_correlation
list_mean_correlation_normalized = list_mean_correlation.copy()
for i, correlation in enumerate(list_mean_correlation):
    list_mean_correlation_normalized[i] = (correlation - np.nanmin(correlation))/(np.nanmax(correlation) - np.nanmin(correlation))


In [None]:
fig, ax = plt.subplots(figsize=(5, 3))
for i, mean_correlation in enumerate(list_mean_correlation_normalized):
    ax.plot(lags, mean_correlation, label='$k_e$='+str(list_ke[i]), linewidth=2)
    ax.fill_between(lags, mean_correlation - list_std_correlation[i], mean_correlation + list_std_correlation[i], alpha=0.2)
ax.set_xlabel(r'$\tau (s)$', fontsize=14)
ax.set_ylabel(r'$G(\tau)/G(0)$', fontdict={'fontsize': 14})
ax.legend()

# plot the line of the theoretical decorrelation as a vertical dashed line of the same color as the plot line. 
for i, theoretical_decorrelation in enumerate(list_theoretical_decorrelation):
    ax.axvline(x=theoretical_decorrelation, color=ax.get_lines()[i].get_color(), linestyle='--',lw=1)
# plot a circle at the theoretical decorrelation value and y =0
for i, theoretical_decorrelation in enumerate(list_theoretical_decorrelation):
    ax.plot(theoretical_decorrelation, 0, markersize=10, marker='o', color=ax.get_lines()[i].get_color())

ax.set_xlim([-10, 1500])
for spine in ax.spines.values():
    spine.set_edgecolor('black')
    spine.set_linewidth(1.5)
plt.show()


In [None]:
raise

In [None]:
from scipy.stats import linregress

def estimate_decorrelation_time_normalized(lags, acf, fit_threshold=0.01):
    """
    Estimate the decorrelation time by normalizing the ACF between 0 and 1,
    then performing a linear fit to the early decay portion.
    
    Parameters:
        lags (np.ndarray): Array of lag times.
        acf (np.ndarray): Raw autocorrelation function values.
        fit_threshold (float): Threshold for the normalized ACF below which
                               we stop the fit. (Default: 0.2)
                               
    Returns:
        tau_d (float): Estimated decorrelation time (lag where the fitted line reaches zero).
        norm_acf (np.ndarray): Normalized ACF used for fitting.
        fit_idx (int): The index up to which the linear fit was performed.
    """
    # Estimate the baseline from the long-lag region (assume last 30% of data)
    n = len(acf)
    baseline = np.nanmean(acf[int(0.7 * n):])
    
    # Normalize the ACF so that at lag 0 it is 1 and the plateau becomes 0.
    # This assumes acf[0] > baseline.
    norm_acf = (acf - baseline) / (acf[0] - baseline)
    
    # Determine the fitting region: from lag 0 until norm_acf first drops below fit_threshold.
    indices = np.where(norm_acf < fit_threshold)[0]
    if len(indices) > 0:
        fit_idx = indices[0]
    else:
        # if the ACF never drops below the threshold, use the entire range
        fit_idx = n
    
    # Use the data from lag 0 to fit_idx for the linear regression.
    x_fit = lags[:fit_idx]
    y_fit = norm_acf[:fit_idx]
    
    # Perform a linear regression on the fitting region.
    slope, intercept, r_value, p_value, std_err = linregress(x_fit, y_fit)
    
    if slope >= 0:
        print("Warning: The fitted slope is non-negative. Check your ACF or fitting region.")
        return None, norm_acf, fit_idx
    
    # The fitted line is: y = intercept + slope * x.
    # Since we expect the intercept to be close to 1 (at lag 0), 
    # the decorrelation time is defined as the lag where y=0:
    # 0 = intercept + slope * tau_d  ==> tau_d = -intercept / slope.
    tau_d = -intercept / slope
    return tau_d, norm_acf, fit_idx

In [None]:

# calculate initiation rate using G(0) = 1/ (ki * tau)
calculated_ki = []
max_x_range_lags =  150
for i, mean_correlation in enumerate(list_mean_correlation_ki):
    # Calculate the decorrelation time with threshold method (default threshold=1/e)
    tau_estimated = estimate_decorrelation_time_normalized(lags[:max_x_range_lags], mean_correlation[:max_x_range_lags])[0]
    print("Estimated decorrelation time:", tau_estimated)
    G0 = mean_correlation[0]
    ki = 1/(G0*tau_estimated)
    calculated_ki.append(ki)
print(np.round(calculated_ki,3))
print(list_ki)

In [None]:
raise

In [None]:
# # Initial conditions
ki = 0.057  # Initiation rate
global_elongation_rate = 5.33  # Elongation rates for positions 1 to N-1
number_repetitions = 100
folding_delay = 0 #240
added_folding_delay = 0 
burnin_time = 800
timePerturbationApplication = None #5*60
t_max = 360*5 #timePerturbationApplication + 25*60  # Maximum time
inhibitor_effectiveness=0.0
evaluatingInhibitor = 0


time_interval_between_frames_in_seconds = 1 # seconds
burnin = 500
downsample_time = 5
downsample_replicates = 1
percentage_to_remove_data = 0  # Remove 80% of the data
shift_data = True
simulate_photobleacing = True
correct_photobleaching = True
decay_rate = 0.9  # 1% decrease per minute


In [None]:
time_delay_to_codons = global_elongation_rate* folding_delay
time_delay_to_codons

In [None]:

file_path = pathlib.Path('/Users/nzlab-la/Desktop/Advanced_Microscopy/modeling/TASEP/pNZ208_pUB-24xUTagFullLength-KDM5B-MS2.dna')

# reading the sequence and extracting the elongation rates
protein, rna, dna, indexes_tags, _, seq_record, graphic_features  = read_sequence(seq=file_path, min_protein_length=50,TAG=TAGS)
plasmid_figure = plot_plasmid(seq_record, graphic_features,figure_width=25, figure_height=3)

gene_length = len(protein)+1 # adding 1 to account for the stop codon
tag_positions_first_probe_vector = indexes_tags[0]
tag_positions_second_probe_vector = indexes_tags[1] if len(indexes_tags) > 1 else None


if folding_delay > 0:
    # adding the folding delay as the time when the intensity is generated
    time_delay_to_codons = global_elongation_rate* folding_delay
    tag_positions_second_probe_vector = [tag_position + time_delay_to_codons for tag_position in tag_positions_second_probe_vector] if tag_positions_second_probe_vector is not None else None
    # remove the tags that are not in the gene
    tag_positions_second_probe_vector = [tag_position for tag_position in tag_positions_second_probe_vector if tag_position <= gene_length]
    if len(tag_positions_second_probe_vector) == 0:
        tag_positions_second_probe_vector = None
    print(tag_positions_second_probe_vector)

first_probe_position_vector = create_probe_vector(tag_positions_first_probe_vector, gene_length)
second_probe_position_vector = create_probe_vector(tag_positions_second_probe_vector, gene_length) if tag_positions_second_probe_vector is not None else None


In [None]:
file_path.name.split('.')[0]
plasmid_name = file_path.name.split('.')[0].replace('(','_').replace(')','_')
plasmid_name

In [None]:
ke = calculate_codon_elongation_rates (rna, global_elongation_rate=global_elongation_rate)

In [None]:
use_pause = False
if use_pause:
    # adding a pause site by setting the elongation rate to 0.001
    ke[-5] = 1/(global_elongation_rate+5) # 1/60 is the elongation rate of the pause site. Meaning this codon takes 60 seconds to be translated. 

## Deterministic modeling
____

In [None]:
# intensity_vector_first_signal_ode,intensity_vector_second_signal_ode = simulate_TASEP_ODE(ki, ke, gene_length, t_max, first_probe_position_vector,second_probe_position_vector,burnin_time)
# plt.plot(intensity_vector_first_signal_ode/np.max(intensity_vector_first_signal_ode))
# plt.plot(intensity_vector_second_signal_ode/np.max(intensity_vector_second_signal_ode))
# plt.show()


# Modeling TASEP SSA
____

In [None]:
list_ribosome_trajectories, list_occupancy_output, matrix_intensity_first_signal_RT, matrix_intensity_second_signal_RT = simulate_TASEP_SSA(ki=ki, 
                                                                                                                                            ke=ke, 
                                                                                                                                            gene_length=gene_length, 
                                                                                                                                            t_max=t_max,
                                                                                                                                            time_interval_in_seconds=1,
                                                                                                                                            number_repetitions=number_repetitions, 
                                                                                                                                            first_probe_position_vector=first_probe_position_vector,
                                                                                                                                            second_probe_position_vector=second_probe_position_vector, 
                                                                                                                                            constant_elongation_rate=global_elongation_rate,#None,
                                                                                                                                            fast_output = True,
                                                                                                                                            burnin_time=burnin_time)




# improved same version = 30s
# improved version 2 = 30 s
# rethinkied version = 30 s


# rib position 37
# old 44 - 38

In [None]:
# calculate the mean and std of the matrix_intensity_first_signal_RT and matrix_intensity_second_signal_RT
mean_first_signal_RT = np.mean(matrix_intensity_first_signal_RT, axis=0)
sem_first_signal_RT = np.std(matrix_intensity_first_signal_RT, axis=0)/np.sqrt(number_repetitions)
if second_probe_position_vector is not None:
    mean_second_signal_RT = np.mean(matrix_intensity_second_signal_RT, axis=0)
    sem_second_signal_RT = np.std(matrix_intensity_second_signal_RT, axis=0)/np.sqrt(number_repetitions)

In [None]:
# plot a single trajectory for the the two signals
plt.figure()
selected_trajectory = 0
plt.plot(matrix_intensity_first_signal_RT[selected_trajectory,:]/np.max(matrix_intensity_first_signal_RT[selected_trajectory,:]), label='first signal')
if second_probe_position_vector is not None:
    plt.plot(matrix_intensity_second_signal_RT[selected_trajectory,:]/np.max(matrix_intensity_second_signal_RT[selected_trajectory,:]), label='second signal')
plt.legend()
plt.show()

In [None]:
# plot all the trajectories
plt.figure()
for i in range(number_repetitions):
    plt.plot(matrix_intensity_first_signal_RT[i,:], label='first signal', color='blue', alpha=0.1)
plt.show()

In [None]:
mean_correlation_ssa, std_correlation_ssa, lags_ssa, correlations_array_ssa, dwell_time_ssa = mi.Correlation(primary_data=matrix_intensity_first_signal_RT,
                                                                                        max_lag=None, 
                                                                                        nan_handling='forward_fill',  #forward_fill, 'ignore'
                                                                                        shift_data=True,
                                                                                        return_full=False,
                                                                                        time_interval_between_frames_in_seconds=1,
                                                                                        use_bootstrap=True,
                                                                                        show_plot=True,
                                                                                        start_lag=0,
                                                                                        fit_type='linear',
                                                                                        index_max_lag_for_fit = 300,
                                                                                        de_correlation_threshold=0.01,
                                                                                        correct_baseline=True,
                                                                                        use_linear_projection_for_lag_0=True,
                                                                                        save_plots=False,
                                                                                        remove_outliers = False,
                                                                                        plot_title=None).run()

In [None]:

def print_parameters (gene_length, tag_positions_first_probe_vector, g_0, dwell_time):
    print ('------------------------------------')
    #ke_calculated_ch0 =  np.round( (gene_length-np.max(tag_positions_first_probe_vector)/2) /dwell_time_ch0  , 2)
    ke_calculated_ch0 =  np.round( gene_length /dwell_time , 2)
    ke_calculated_ch0_corrected =  np.round( (gene_length-(np.max(tag_positions_first_probe_vector))) /dwell_time  , 2)
    print ('------------------------------------')
    print('Elongation rates: ')
    print('Calculated ch0', ke_calculated_ch0 , ' Corrected: ',ke_calculated_ch0_corrected, )
    print ('------------------------------------')
    print('Initiation rates: ')
    # initiation rate
    ki_calculated = np.round( 1/ (g_0* dwell_time), 3)
    print('Calculated', ki_calculated, ' 1/sec')
    print ('------------------------------------')
    print('Ribosomal density: ')
    # initiation rate
    #ribosomal_density = ki_calculated* ke_calculated_ch0_corrected
    ribosomal_density = np.round( (gene_length/ke_calculated_ch0_corrected) *ki_calculated , 1)
    print('Calculated', np.round(ribosomal_density,1) , ' average number of ribosomes per RNA')
    print ('------------------------------------')
    print('Ribosomal occurrence: ')
    # initiation rate
    ribsomal_space = 1/ki_calculated 
    print('Calculated', np.round(ribsomal_space,2) , ' seconds between ribosome initiation')
    print ('------------------------------------')
    return None

print_parameters(gene_length, tag_positions_first_probe_vector, g_0=0.08, dwell_time=279)

In [None]:
# plot the mean and std as error shade
downsample = 50
downsampled_time = np.arange(0,t_max,downsample)

plt.figure(figsize=(5,4))
plt.plot(mean_first_signal_RT,color = 'k',linewidth=2, label='SSA')
plt.fill_between(np.arange(len(mean_first_signal_RT)), mean_first_signal_RT-sem_first_signal_RT, mean_first_signal_RT+sem_first_signal_RT, color='k', alpha=0.2)
plt.plot(downsampled_time, intensity_vector_first_signal_ode[::downsample], color = 'blue', linestyle='dashed', marker='o', label='ODE')
if second_probe_position_vector is not None:
    # plot for the second signal
    plt.plot(mean_second_signal_RT,color = 'k',linewidth=2, label='SSA')
    plt.fill_between(np.arange(len(mean_second_signal_RT)), mean_second_signal_RT-sem_second_signal_RT, mean_second_signal_RT+sem_second_signal_RT, color='k', alpha=0.2)
    plt.plot(downsampled_time, intensity_vector_second_signal_ode[::downsample], color = 'red', linestyle='dashed', marker='o', label='ODE')

plt.xlabel('Time')
plt.ylabel('Intensity')
plt.legend()
plt.show()

In [None]:
raise

# Plot ribosome movement
___

In [None]:
selected_trajectory = 0

#list_ribosome_trajectories, list_occupancy_output, matrix_intensity_first_signa_RT, matrix_intensity_second_signa_RT 
ribosome_trajectories = list_ribosome_trajectories[selected_trajectory]    
ribosome_trajectories = ribosome_trajectories[:,:]
intensity_vector_first_signal = matrix_intensity_first_signal_RT[selected_trajectory,:]
if second_probe_position_vector is not None:
    intensity_vector_second_signal = matrix_intensity_second_signal_RT[selected_trajectory,:]
else:
    intensity_vector_second_signal = None
#plot_RibosomeMovement(ribosome_trajectories, intensity_vector_first_signal ,tag_positions_first_probe_vector,SecondIntensityVector=intensity_vector_second_signal,second_probePositions=tag_positions_second_probe_vector,timePerturbationApplication=timePerturbationApplication) # intensity_vector_second_signal

In [None]:
str_ki = str(ki).replace('.','_')
str_k = str(global_elongation_rate).replace('.','_')
fileNameGif = 'simulation_'+plasmid_name+'_ke_'+str_k+'_ki_'+str_ki + '_inhibitor_effectiveness_'+str(inhibitor_effectiveness)
plot_RibosomeMovement_and_Microscope(ribosome_trajectories, intensity_vector_first_signal, tag_positions_first_probe_vector, SecondIntensityVector=intensity_vector_second_signal, second_probePositions=tag_positions_second_probe_vector,FrameVelocity=20,timePerturbationApplication=timePerturbationApplication,fileNameGif=fileNameGif)

In [None]:
tag_positions = tag_positions_first_probe_vector
probe_vector = np.zeros(gene_length)
for tag in tag_positions:
    if tag < gene_length:  # Ensure the tag position is within the gene length
        probe_vector[tag:] += 1

In [None]:
raise

## Calculating Correlations
____

In [None]:
# importlib.reload(mi)


matrix_intensity_first_signal_RT_downsampled = matrix_intensity_first_signal_RT[:,burnin:][::downsample_replicates,::downsample_time]
if second_probe_position_vector is not None:
    matrix_intensity_second_signal_RT_downsampled = matrix_intensity_second_signal_RT[:,burnin:][::downsample_replicates,::downsample_time]
    print('number of replicates : ', matrix_intensity_first_signal_RT_downsampled.shape[0], '\nnumber of time points : ', matrix_intensity_first_signal_RT_downsampled.shape[1])
    mi.Plots().plot_matrix_sample_time(matrix_intensity_first_signal_RT_downsampled, matrix_intensity_second_signal_RT_downsampled)
else:
    mi.Plots().plot_matrix_sample_time(matrix_intensity_first_signal_RT_downsampled)



In [None]:
if simulate_photobleacing:
    decay_rate_first_signal = -np.log(decay_rate) / (100/downsample_time)  # 20% decrease after 100 minutes
    decay_rate_second_signal = -np.log(decay_rate) / (100/downsample_time) # 10% decrease after 100 minutes
    if second_probe_position_vector is not None:
        matrix_intensity_second_signal_RT_downsampled = simulate_photobleaching_in_trajectories(matrix_intensity_second_signal_RT_downsampled, decay_rate_second_signal)
        mi.Plots().plot_matrix_sample_time(matrix_intensity_first_signal_RT_downsampled, matrix_intensity_second_signal_RT_downsampled)
    else:
        matrix_intensity_first_signal_RT_downsampled = simulate_photobleaching_in_trajectories(matrix_intensity_first_signal_RT_downsampled, decay_rate_first_signal)
        mi.Plots().plot_matrix_sample_time(matrix_intensity_first_signal_RT_downsampled)


In [None]:
if correct_photobleaching:
    decay_rate_first_signal = -np.log(decay_rate) / (100/downsample_time)  # 20% decrease after 100 minutes
    decay_rate_second_signal = -np.log(decay_rate) / (100/downsample_time) # 10% decrease after 100 minutes
    matrix_intensity_first_signal_RT_downsampled = correct_photobleaching_in_trajectories(matrix_intensity_first_signal_RT_downsampled, decay_rate_first_signal)
    if second_probe_position_vector is not None:
        matrix_intensity_second_signal_RT_downsampled = correct_photobleaching_in_trajectories(matrix_intensity_second_signal_RT_downsampled, decay_rate_second_signal)
        mi.Plots().plot_matrix_sample_time(matrix_intensity_first_signal_RT_downsampled, matrix_intensity_second_signal_RT_downsampled)
    else:
        mi.Plots().plot_matrix_sample_time(matrix_intensity_first_signal_RT_downsampled)


In [None]:
# simulating misssing data.
if second_probe_position_vector is None:
    matrix_intensity_first_signal_RT_downsampled,matrix_intensity_second_signal_RT_downsampled = simulate_missing_data(matrix_intensity_first_signal_RT_downsampled, None, percentage_to_remove_data,replace_with='nan')
    mi.Plots().plot_matrix_sample_time(matrix_intensity_first_signal_RT_downsampled)
else:
    matrix_intensity_first_signal_RT_downsampled,matrix_intensity_second_signal_RT_downsampled = simulate_missing_data(matrix_intensity_first_signal_RT_downsampled,matrix_intensity_second_signal_RT_downsampled, percentage_to_remove_data,replace_with='nan')
    mi.Plots().plot_matrix_sample_time(matrix_intensity_first_signal_RT_downsampled, matrix_intensity_second_signal_RT_downsampled)


In [None]:
# shift the data to the left
# importlib.reload(mi)

if shift_data == True:
    matrix_intensity_first_signal_RT_downsampled  = mi.Utilities().shift_trajectories(matrix_intensity_first_signal_RT_downsampled, )
    mi.Plots().plot_matrix_sample_time(matrix_intensity_first_signal_RT_downsampled)
    if second_probe_position_vector is not None:
        matrix_intensity_first_signal_RT_downsampled, matrix_intensity_second_signal_RT_downsampled = mi.Utilities().shift_trajectories(matrix_intensity_first_signal_RT_downsampled, matrix_intensity_second_signal_RT_downsampled)
        mi.Plots().plot_matrix_sample_time(matrix_intensity_first_signal_RT_downsampled, matrix_intensity_second_signal_RT_downsampled)


In [None]:
# importlib.reload(mi)
mean_correlation_ch0, std_correlation_ch0, lags_ch0, correlations_array_ch0, dwell_time_ch0 = mi.Correlation(primary_data=matrix_intensity_first_signal_RT_downsampled, max_lag=None, nan_handling='ignore',shift_data=True,return_full=False,time_interval_between_frames_in_seconds=time_interval_between_frames_in_seconds*downsample_time,show_plot=True,start_lag=0,fit_type='linear',de_correlation_threshold=0.05).run()

In [None]:
mi.Plots().plot_autocorrelation( correlations_array_ch0, mean_correlation_ch0,std_correlation_ch0,lags_ch0,5, plot_name='temp_AC.png', save_plots=True,)

In [None]:
if second_probe_position_vector is not None:
    mean_correlation_ch1, std_correlation_ch1, lags_ch1, correlations_array_ch1, dwell_time_ch1 = mi.Correlation(primary_data=matrix_intensity_second_signal_RT_downsampled, max_lag=None, nan_handling='ignore',shift_data=True,return_full=False,time_interval_between_frames_in_seconds=time_interval_between_frames_in_seconds*downsample_time,show_plot=True,start_lag=0,fit_type='linear',de_correlation_threshold=0.001).run()

In [None]:
# importlib.reload(mi)
if second_probe_position_vector is not None:
    mean_cross_correlation, std_cross_correlation, lags_cross_correlation, cross_correlations_array, delay_cross_correlation = mi.Correlation(primary_data=matrix_intensity_first_signal_RT_downsampled, secondary_data=matrix_intensity_second_signal_RT_downsampled, max_lag=None, nan_handling='ignore', shift_data=True, return_full=True,time_interval_between_frames_in_seconds=time_interval_between_frames_in_seconds*downsample_time,show_plot=True).run()

In [None]:
# plt.plot(lags_cross_correlation,mean_cross_correlation)
# # set the xlim between 1000 and 1500
# plt.xlim(-500,500)

# # plot a vertical line at zero
# plt.axvline(x=0, color='k', linestyle='--')
# # plot horizontal lines at 0.01
# plt.axhline(y=0.01, color='r', linestyle='--')
# plt.axhline(y=0.1, color='k', linestyle='--')
# plt.axvline(x=-folding_delay, color='k', linestyle='--')

# plt.show()

if second_probe_position_vector is not None:
    # calculate the first derivative of the cross correlation
    first_derivative = np.diff(mean_cross_correlation)
    # smooth the first derivative
    #first_derivative = np.convolve(first_derivative, np.ones(20)/20, mode='same')
    # plot the first derivative
    plt.figure(figsize=(5,3))
    plt.plot(lags_cross_correlation[:-1],first_derivative, label='first derivative', color='r', linewidth=4)
    plt.xlim(-500,500)

    plt.axhline(y=0, color='k', linestyle='--', linewidth=0.5)
    plt.axvline(x=0, color='k', linestyle='--', linewidth=0.5)
    # add y label as the derivative of the cross correlation with symbols
    plt.ylabel('dG/dt')
    # x label is tau
    plt.xlabel(r'$\tau$'+ ' (au)')

    #plt.axvline(x=folding_delay, color='k', linestyle='--', linewidth=0.5)
    plt.axvline(x=-folding_delay, color='k', linestyle='--', linewidth=1)

    plt.show()



In [None]:
print ('------------------------------------')
print('Decorrelation times: ')
estimated_decorrelation_time = np.round(  (gene_length-(np.max(tag_positions_first_probe_vector)//2) ) /global_elongation_rate   , 2)
print('Estimated decorrelaiton time', estimated_decorrelation_time )


#ke_calculated_ch0 =  np.round( (gene_length-np.max(tag_positions_first_probe_vector)/2) /dwell_time_ch0  , 2)
ke_calculated_ch0 =  np.round( gene_length /dwell_time_ch0  , 2)
ke_calculated_ch0_corrected =  np.round( (gene_length-(np.max(tag_positions_first_probe_vector))) /dwell_time_ch0  , 2)

print ('------------------------------------')

print('Elongation rates: ')
print('Calculated ch0', ke_calculated_ch0 , ' Corrected: ',ke_calculated_ch0_corrected, ', True: ',global_elongation_rate)

if second_probe_position_vector is not None:
    ke_calculated_ch1 = np.round( ( gene_length /dwell_time_ch1)  , 2)
    if second_probe_position_vector is not None:
        ke_calculated_ch1_corrected =  np.round( (gene_length-(np.max(tag_positions_second_probe_vector))) /dwell_time_ch1  , 2)
        print('Calculated ch1', ke_calculated_ch1, ', True: ',global_elongation_rate)

print ('------------------------------------')

print('Initiation rates: ')

# initiation rate
ki_calculated = np.round( 1/ (mean_correlation_ch0[0] * dwell_time_ch0), 3)
print('Calculated', ki_calculated, ', True: ',ki)

print ('------------------------------------')

print('Ribosomal occupancy: ')

# make a binary matrix of the ribosome trajectories that are more than 0
binary_matrix_ribosomal_occupancy = ribosome_trajectories > 0
binary_matrix_ribosomal_occupancy.shape
ribosomal_occupancy = np.sum(binary_matrix_ribosomal_occupancy, axis=0)
# calculate the ribosomal occupancy
theoretical_occupancy = np.round( (gene_length/global_elongation_rate) *ki , 2)
print( 'Calculated: ',np.round( np.mean(ribosomal_occupancy) ,2) , ', Theoretical: ',theoretical_occupancy)
print ('------------------------------------')

if second_probe_position_vector is not None:
    print('Folding delay: ')
    print('Calculated:' , delay_cross_correlation, ', True: ', folding_delay)
print ('------------------------------------')


In [None]:
def detect_minima(array, threshold_percentage):
    minima_indices = []
    for sample_idx, sample in enumerate(array):
        # smooth the signal with a moving average of 5 points
        sample = np.convolve(sample, np.ones(10)/10, mode='same')
        # Normalize the sample between 0 and 100
        sample_normalized = (sample - np.min(sample)) / (np.max(sample) - np.min(sample)) * 100
        # Define the threshold value
        threshold = (threshold_percentage / 100.0) * np.max(sample_normalized)
        # Invert the signal to detect minima as peaks
        inverted_sample = -sample_normalized
        # Use find_peaks to detect peaks in the inverted signal
        peaks, properties = find_peaks(inverted_sample,height=-threshold )
        # Filter peaks based on the threshold
        valid_minima = peaks
        minima_indices.append(valid_minima)
    return minima_indices

def extract_windows(array, indices_list, window_size):
    windows = []
    for sample_idx, indices in enumerate(indices_list):
        sample_windows = []
        for idx in indices:
            start = idx - window_size
            end = idx + window_size + 1
            if start >= 0 and end <= array.shape[1]:
                window = array[sample_idx, start:end]
                sample_windows.append(window)
        if sample_windows:
            windows.append(np.array(sample_windows))
        else:
            windows.append(np.array([]))
    return windows

def calculate_average_profiles(windows):
    avg_profiles = []
    for sample_windows in windows:
        if sample_windows.size == 0:
            avg_profiles.append(None)
            continue
        # Average across all windows at each time point
        avg_profile = np.mean(sample_windows, axis=0)
        avg_profiles.append(avg_profile)
    return avg_profiles

def generate_control_indices(array_shape, n_control_points, window_size):
    control_indices = []
    num_samples = array_shape[0]
    num_timepoints = array_shape[1]
    min_index = window_size
    max_index = num_timepoints - window_size - 1
    for sample_idx in range(num_samples):
        if max_index < min_index:
            control_indices.append(np.array([], dtype=int))
            continue
        possible_indices = np.arange(min_index, max_index + 1)
        n_sample = min(n_control_points, len(possible_indices))
        random_indices = np.random.choice(possible_indices, size=n_sample, replace=False)
        control_indices.append(random_indices)
    return control_indices

def plot_results(avg_profiles_1, avg_profiles_2, control_profiles_1, control_profiles_2, window_size):
    time_points = np.arange(-window_size, window_size + 1)
    # Collect all non-None profiles
    valid_avg_profiles_1 = []
    valid_avg_profiles_2 = []
    valid_control_profiles_1 = []
    valid_control_profiles_2 = []
    for i in range(len(avg_profiles_1)):
        if avg_profiles_1[i] is not None:
            valid_avg_profiles_1.append(avg_profiles_1[i])
        if avg_profiles_2[i] is not None:
            valid_avg_profiles_2.append(avg_profiles_2[i])
        if control_profiles_1[i] is not None:
            valid_control_profiles_1.append(control_profiles_1[i])
        if control_profiles_2[i] is not None:
            valid_control_profiles_2.append(control_profiles_2[i])
    # Check if there are any valid profiles
    if not valid_avg_profiles_1 or not valid_control_profiles_1:
        print("No valid profiles to plot.")
        return
    # Stack the profiles and compute the overall average
    overall_avg_profile_1 = np.mean(np.vstack(valid_avg_profiles_1), axis=0)
    overall_avg_profile_2 = np.mean(np.vstack(valid_avg_profiles_2), axis=0)
    overall_control_profile_1 = np.mean(np.vstack(valid_control_profiles_1), axis=0)
    overall_control_profile_2 = np.mean(np.vstack(valid_control_profiles_2), axis=0)
    
    # Normalize the profiles between 0 and 1
    def normalize_profile(profile):
        return (profile - np.min(profile)) / (np.max(profile) - np.min(profile))
    
    #overall_avg_profile_1 = normalize_profile(overall_avg_profile_1)
    #overall_avg_profile_2 = normalize_profile(overall_avg_profile_2)
    #overall_control_profile_1 = normalize_profile(overall_control_profile_1)
    #overall_control_profile_2 = normalize_profile(overall_control_profile_2)
    # Plot the overall average profiles
    plt.figure(figsize=(7, 5))
    #plt.plot(time_points, overall_avg_profile_1, label='Signal 1', color='blue', linewidth=4)
    plt.plot(time_points, overall_avg_profile_2, label='Signal 2', color='red', linewidth=4)
    #plt.plot(time_points, overall_control_profile_1, label='Control Signal 1', linestyle=':', color='blue', alpha=0.5)
    plt.plot(time_points, overall_control_profile_2, label='Control Signal 2', linestyle=':', color='red', alpha=0.5)
    # plot a vertical line at time 0
    plt.axvline(x=0, color='black', linestyle='--', linewidth=1)
    plt.axvline(x=folding_delay, color='k', linestyle='--')
    # plot a verticla line when the overall_avg_profile_2 is minimum
    min_index = np.argmin(overall_avg_profile_2)
    plt.axvline(x=time_points[min_index], color='red', linestyle='--', linewidth=1)
    print('Intensity surronding minima, Delay: ', time_points[min_index])
    plt.title('Intensity Profiles')
    plt.xlabel('Time Relative to Minima')
    plt.ylabel('Intensity')
    # add legend outside the plot to the top right
    plt.legend(loc='upper right')
    #plt.legend(loc='upper right')
    plt.grid(True)
    plt.show()

def analyze_delay(array_1, array_2, threshold_percentage, window_size, n_control_points):
    # Detect minima in array_1
    minima_indices_1 = detect_minima(array_1, threshold_percentage)
    # Extract windows around minima in both arrays
    windows_1 = extract_windows(array_1, minima_indices_1, window_size)
    windows_2 = extract_windows(array_2, minima_indices_1, window_size)
    # Calculate average intensity profiles
    avg_profiles_1 = calculate_average_profiles(windows_1)
    avg_profiles_2 = calculate_average_profiles(windows_2)
    # Generate control data
    control_indices = generate_control_indices(array_1.shape, n_control_points, window_size)
    # Extract windows around control indices in both arrays
    control_windows_1 = extract_windows(array_1, control_indices, window_size)
    control_windows_2 = extract_windows(array_2, control_indices, window_size)
    # Calculate control average intensity profiles
    control_profiles_1 = calculate_average_profiles(control_windows_1)
    control_profiles_2 = calculate_average_profiles(control_windows_2)
    # Plot the results
    plot_results(avg_profiles_1, avg_profiles_2, control_profiles_1, control_profiles_2, window_size)


In [None]:
# Parameters
threshold_percentage = 20  # User-defined threshold percentage
window_size = 300            # Number of values before and after the minima
n_control_points = 10      # Number of random positions per sample for control data

# Run the analysis
analyze_delay(matrix_intensity_first_signal_RT, matrix_intensity_second_signal_RT, threshold_percentage, window_size, n_control_points)
analyze_delay(matrix_intensity_first_signal_RT_downsampled, matrix_intensity_second_signal_RT_downsampled, threshold_percentage, window_size, n_control_points)
