In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

In [None]:
# -------------------------
# 1. Load the Data
# -------------------------
# Assumes a CSV file without a header, with shape (1110, 35) (rows: time points, columns: cells)
file_name = "sgRosa26_30uM_DHPG_puff_processed.csv"  # or use the actual file name
group_name = file_name.split("_")[0]  # This extracts "sgRosa26"
data = pd.read_csv(file_name, header=None)
data = data.replace([np.inf, -np.inf], np.nan).dropna()
raw = data.values  # shape (1110, 35)
n_points, n_cells = raw.shape

fs = 5.0  # Sampling rate in Hz
time = np.arange(n_points) / fs  # time axis in seconds

In [None]:
# -------------------------
# 2. Plot the Raw Signals: Mean ± SEM
# -------------------------
mean_raw = raw.mean(axis=1)
sem_raw = raw.std(axis=1, ddof=1) / np.sqrt(n_cells)

plt.figure(figsize=(10, 6))
plt.fill_between(time, mean_raw - sem_raw, mean_raw + sem_raw, alpha=0.3, label='SEM')
plt.plot(time, mean_raw, 'k-', label='Mean Raw Signal')
plt.xlabel('Time (s)')
plt.ylabel('Calcium Fluorescence')
plt.title('Raw Calcium Fluorescence Signal (Mean ± SEM)')
plt.legend()
plt.show()

In [None]:
# -------------------------
# 3. Fit a One‐Phase Decay to the Entire Trace for Each Cell
# -------------------------
# Define the one-phase exponential decay function.
def one_phase_decay(t, A, k, C):
    return A * np.exp(-k * t) + C

# We will store the fitted parameters for each cell and the fitted curves.
params = np.zeros((n_cells, 3))      # Each row: [A, k, C]
fitted_curves = np.zeros((n_points, n_cells))

# Fit the decay using the entire trace (time from 0 to end)
for i in range(n_cells):
    y_data = raw[:, i]
    # Use an initial guess:
    # A_guess: difference between the first and last value,
    # k_guess: a small decay rate (adjust if needed),
    # C_guess: final value of the trace.
    A_guess = y_data[0] - y_data[-1]
    k_guess = 0.01
    C_guess = y_data[-1]
    p0 = [A_guess, k_guess, C_guess]
    
    try:
        popt, _ = curve_fit(one_phase_decay, time, y_data, p0=p0, maxfev=10000)
    except RuntimeError:
        print(f"Fit did not converge for cell {i}")
        popt = [np.nan, np.nan, np.nan]
    params[i, :] = popt
    # Compute the fitted curve for all time points.
    fitted_curves[:, i] = one_phase_decay(time, *popt)

In [None]:
# -------------------------
# 4. Subtract the Fitted Decay from the Raw Signal
# -------------------------
# Correct the raw signal by subtracting the fitted decay.
corrected_signal = raw - fitted_curves

# # Also, plot the average ± SEM of the fitted decay curves.
mean_fit = fitted_curves.mean(axis=1)
sem_fit = fitted_curves.std(axis=1, ddof=1) / np.sqrt(n_cells)

plt.figure(figsize=(10, 6))
plt.fill_between(time, mean_fit - sem_fit, mean_fit + sem_fit, alpha=0.3, label='SEM')
plt.plot(time, mean_fit, 'r-', label='Mean Fitted Decay')
plt.xlabel('Time (s)')
plt.ylabel('Fitted Decay')
plt.title('Fitted One‐Phase Decay (Mean ± SEM)')
plt.legend()
plt.show()

In [None]:
# -------------------------
# 5. Calculate Z‐Scores (Baseline: 10 s Before Puff)
# -------------------------
# The puff is assumed to occur at index 500.
puff_idx = 505
baseline_start = puff_idx - int(10 * fs)  # 10 seconds before puff (10*5 = 50 points)
baseline_end = puff_idx

# For each cell, compute the z-score using the 10-second baseline of the corrected signal.
zscores = np.zeros_like(raw)
for i in range(n_cells):
    baseline = corrected_signal[baseline_start:baseline_end, i]
    base_mean = baseline.mean()
    base_std = baseline.std(ddof=1)
    zscores[:, i] = (corrected_signal[:, i] - base_mean) / base_std

# Plot the average and SEM of the z-scored signals.
mean_z = zscores.mean(axis=1)
sem_z = zscores.std(axis=1, ddof=1) / np.sqrt(n_cells)

plt.figure(figsize=(10, 6))
plt.fill_between(time, mean_z - sem_z, mean_z + sem_z, alpha=0.3, label='SEM')
plt.plot(time, mean_z, 'g-', label='Mean Z-score')
plt.xlabel('Time (s)')
plt.ylabel('Z-score')
#plt.title('30uM DHPG sgRosa26 (Mean ± SEM)')
plt.legend()
plt.show()