# Results for a single cell in a bath of delta

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from types import SimpleNamespace
from matplotlib.animation import FuncAnimation
from matplotlib.colors import LogNorm
from scipy.signal import find_peaks, correlate
from IPython.display import HTML
from tqdm import tqdm
import pickle
import matplotlib.cm as cm
import visualisation as vis
from delta_hes_model import *

In [None]:
def evaluate_fixed_delta(num_tsteps, dt, external_delta, lattice, params):
    h_init, m_h_init, d_init, m_d_init = get_initial(lattice, params, initial_type = 'checkerboard', initial_val2=300)
    h_init, m_h_init, d_init, m_d_init = get_initial(lattice, params, initial_type = 'uniform')

    #intrinsic oscillator components
    h = np.zeros([num_tsteps, lattice.P, lattice.Q])
    m_h = np.zeros([num_tsteps, lattice.P, lattice.Q])
    d = np.zeros([num_tsteps, lattice.P, lattice.Q])
    m_d = np.zeros([num_tsteps, lattice.P, lattice.Q])

    h[0] = h_init
    m_h[0] = m_h_init
    d[0] = d_init
    m_d[0] = m_d_init

    # iterate through the time steps and calculate the values of the next time step
    for i in tqdm(range(int(num_tsteps-1))):

        # calculate delayed values of Hes
        params.T_h_steps = np.round(params.T_h / dt).astype(int)
        h_delay = get_delayed_value(h, i, params.T_h_steps, lattice, params)

        d[i+1,:,:] = Euler(d[i,:,:], dd_dt(d[i,:,:], m_d[i,:,:], params), dt)
        m_d[i+1,:,:] = Euler(m_d[i,:,:], dmd_dt(m_d[i,:,:], h_delay[i,:,:], params), dt)

        couple_component = external_delta/w_coupling
        
        # calculate the values of the next time step for Hes and Hes mRNA
        h[i+1,:,:] = Euler(h[i,:,:], dh_dt(h[i,:,:], m_h[i,:,:], params), dt)
        m_h[i+1,:,:] = Euler(m_h[i,:,:], dmh_dt(m_h[i,:,:], h_delay[i,:,:], couple_component, params, lattice), dt)

    return h, m_h, d, m_d

# def estimate_amplitude_from_peaks(signal, height=None, distance=None, prominence=None):
#     """
#     Estimate amplitude by detecting peaks and computing average peak height.

#     Parameters:
#     - signal: 1D numpy array
#     - height, distance, prominence: Optional arguments for peak detection

#     Returns:
#     - amplitude: Estimated amplitude
#     - peak_values: Values of detected peaks
#     - trough_values: Values of detected troughs
#     """
#     peaks, _ = find_peaks(signal, height=height, distance=distance, prominence=prominence)
#     troughs, _ = find_peaks(-signal, height=height, distance=distance, prominence=prominence)

#     peak_values = signal[peaks]
#     trough_values = signal[troughs]

#     if len(peak_values) == 0 or len(trough_values) == 0:
#         return None, peak_values, trough_values  # Not enough data

#     avg_peak = np.mean(peak_values)
#     avg_trough = np.mean(trough_values)
 
#     amplitude = 0.5 * (avg_peak - avg_trough)
#     amplitude = np.nan_to_num(amplitude, nan=0.0)  # Handle NaN values
#     return amplitude, peak_values, trough_values

def estimate_amplitude_from_peaks(signal, height=None, distance=None, prominence=None, ignore_initial_outlier=True):
    """
    Estimate the amplitude of an oscillating signal using peak and trough detection.

    Parameters:
    - signal: 1D numpy array of the signal values
    - height, distance, prominence: Optional arguments passed to find_peaks
    - ignore_initial_outlier: If True, automatically ignore first peak/trough if it's a large outlier

    Returns:
    - amplitude: Estimated average amplitude (float)
    - peak_indices: Indices of the used peaks
    - trough_indices: Indices of the used troughs
    """
    peak_indices, _ = find_peaks(signal, height=height, distance=distance, prominence=prominence)
    trough_indices, _ = find_peaks(-signal, height=height, distance=distance, prominence=prominence)

    if len(peak_indices) < 1 or len(trough_indices) < 1:
        return None, peak_indices, trough_indices

    peak_vals = signal[peak_indices]
    trough_vals = signal[trough_indices]

    # Optionally remove large initial outlier peak
    if ignore_initial_outlier and len(peak_vals) > 2:
        med = np.median(peak_vals[1:])  # Skip first
        if peak_vals[0] > 1.5 * med:
            peak_indices = peak_indices[1:]
            peak_vals = peak_vals[1:]

    if ignore_initial_outlier and len(trough_vals) > 2:
        med = np.median(trough_vals[1:])
        if trough_vals[0] < 3 * med:
            trough_indices = trough_indices[1:]
            trough_vals = trough_vals[1:]

    min_len = min(len(peak_vals), len(trough_vals))
    if min_len < 1:
        return None, peak_indices, trough_indices

    amplitude = 0.5 * np.mean(peak_vals[:min_len] - trough_vals[:min_len])

    return amplitude, peak_indices[:min_len], trough_indices[:min_len]



# def estimate_period_from_peaks(signal, time=None, height=None, distance=None, prominence=None):
#     """
#     Estimate the period of an oscillating signal using peak detection.

#     Parameters:
#     - signal: 1D numpy array of the signal values
#     - time: Optional 1D array of time values (same length as signal). If None, assume uniform time steps.
#     - height, distance, prominence: Optional arguments passed to find_peaks for filtering.

#     Returns:
#     - period: Estimated average period (float)
#     - peak_times: Time values of the detected peaks
#     """
#     # Find peaks
#     peaks, _ = find_peaks(signal, height=height, distance=distance, prominence=prominence)
#     if time is None:
#         # Assume uniform spacing
#         time = np.arange(len(signal))

#     peak_times = time[peaks]

#     if len(peak_times) < 4:
#         period = -1000  # Not enough peaks to estimate period

#     # Calculate differences between consecutive peaks
#     peak_diffs = np.diff(peak_times)
#     period = np.mean(peak_diffs)

#     return period, peak_times

def estimate_period_from_peaks(signal, time=None, height=None, distance=None, prominence=None, ignore_initial_outlier=True):
    """
    Estimate the period of an oscillating signal using peak detection.

    Parameters:
    - signal: 1D numpy array of the signal values
    - time: Optional 1D array of time values (same length as signal). If None, assume uniform time steps.
    - height, distance, prominence: Optional arguments passed to find_peaks
    - ignore_initial_outlier: If True, ignores the first peak if it's a large outlier

    Returns:
    - period: Estimated average period (float)
    - peak_times: Time values of the detected peaks
    """
    # Find peaks
    peaks, _ = find_peaks(signal, height=height, distance=distance, prominence=prominence)

    if len(peaks) < 2:
        return None, np.array([])  # Not enough peaks to estimate period

    # Optionally remove initial outlier peak
    peak_vals = signal[peaks]
    if ignore_initial_outlier and len(peak_vals) > 2:
        med = np.median(peak_vals[1:])
        if peak_vals[0] > 3 * med:
            peaks = peaks[1:]
            peak_vals = peak_vals[1:]

    # Assume uniform spacing if time is not given
    if time is None:
        time = np.arange(len(signal))

    peak_times = time[peaks]

    if len(peak_times) < 2:
        return None, peak_times  # Still not enough peaks

    # Calculate differences between consecutive peaks
    peak_diffs = np.diff(peak_times)
    period = np.mean(peak_diffs)

    return period, peak_times

In [None]:
#set the lattice of the cells 
P = 3
Q = 1

lattice = get_lattice(P, Q)

# set initial parameters of the reactions 
gamma_h = 0.03
gamma_d = 0.03
gamma_m = 0.03
p_h = 100
p_d = 50
T_h = 20
T_coupling = 0
w_h = 1
w_coupling = 0.5
l = 5
n = 3

params = get_params(gamma_h, gamma_d, gamma_m, p_h, p_d, T_h, T_coupling, w_h, w_coupling, l, n, lattice, grad_hes = False, grad_coup = False, grad_hes_strength = 0.2, grad_coup_strength = 0)
print(params)


num_tsteps = 20000
dt = 0.2

In [None]:
delta_fixed = np.linspace(0, 0.1, 26)
results_fixed = np.zeros((len(delta_fixed), num_tsteps, 2))

for i, delta in enumerate(delta_fixed):
    h, m_h, d, m_d = evaluate_fixed_delta(num_tsteps, dt, delta, lattice, params)

    print(i, delta)
    
    results_fixed[i,:,0] = h[:,1].flatten()
    results_fixed[i,:,1] = d[:,1].flatten()


In [None]:
start_time = 10000

h_mean = np.mean(results_fixed[:,start_time:,0], axis=1)
d_mean = np.mean(results_fixed[:,start_time:,1], axis=1)

amplitude_h = np.zeros(len(delta_fixed))
amplitude_d = np.zeros(len(delta_fixed))
period_h = np.zeros(len(delta_fixed))
period_d = np.zeros(len(delta_fixed))

for i in range(len(delta_fixed)):
    amplitude_h[i] = estimate_amplitude_from_peaks(results_fixed[i,start_time:,0])[0]
    amplitude_d[i] = estimate_amplitude_from_peaks(results_fixed[i,start_time:,1])[0]

    amplitude_d = np.nan_to_num(amplitude_d, nan=0.0)  # Handle NaN values
    amplitude_h = np.nan_to_num(amplitude_h, nan=0.0)  # Handle NaN values

    if amplitude_h[i] > 0.5:
        period_h[i] = estimate_period_from_peaks(results_fixed[i,start_time:,0], time = np.arange(num_tsteps-start_time)*dt)[0]
    else:
        period_h[i] = None

    if amplitude_d[i] > 0.5:
        period_d[i] = estimate_period_from_peaks(results_fixed[i,start_time:,1], time = np.arange(num_tsteps-start_time)*dt)[0]
    else:
        period_d[i] = None

print(amplitude_h.shape)
print(amplitude_d)

plt.figure(figsize=(10, 10))
plt.subplot(3, 2, 1)
plt.plot(delta_fixed, h_mean, label='h mean')
plt.xlabel('delta fixed')
plt.ylabel('h mean')   
plt.legend()

plt.subplot(3, 2, 2)
plt.plot(delta_fixed, d_mean, label='d mean')
plt.xlabel('delta fixed')
plt.ylabel('d mean')
plt.legend()

plt.subplot(3, 2, 3)
plt.plot(delta_fixed, period_h, label='h period')
plt.xlabel('delta fixed')
plt.xlim(0, max(delta_fixed))
plt.ylabel('h period')
plt.legend()

plt.subplot(3, 2, 4)
plt.plot(delta_fixed, period_d, label='d period')
plt.xlabel('delta fixed')
plt.xlim(0, max(delta_fixed))
plt.ylabel('d period')
plt.legend()

plt.subplot(3, 2, 5)
plt.plot(delta_fixed, amplitude_h, label='h amplitude')
plt.xlabel('delta fixed')
plt.ylabel('h amplitude')
plt.legend()

plt.subplot(3, 2, 6)
plt.plot(delta_fixed, amplitude_d, label='d amplitude')
plt.xlabel('delta fixed')
plt.ylabel('d amplitude')
plt.legend()
plt.tight_layout()