In [None]:
from pathlib import Path
import sys

# Register dill/pathlib compatibility shim BEFORE importing dill
sys.path.insert(0, str(Path("../src").resolve()))
from pickle_compat import enable_dill_pathlib_compat
enable_dill_pathlib_compat()

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import rcParams
import seaborn as sns
import trompy as tp

from scipy import stats
from scipy.spatial.distance import cdist
from scipy.optimize import curve_fit

import dill

rcParams['font.family'] = 'Arial'
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['savefig.transparent'] = True
colors = ["#67AFD2", "#016895", "#F4795B", "#C74632"]

savefigs = False

DATAFOLDER = Path("..//data")
RESULTSFOLDER = Path("..//results")
FIGSFOLDER = Path("C:/Users/jmc010/Dropbox/Publications in Progress/Bazzino Roitman_sodium/figs")

## Load Assembled Data

Load the complete dataset from assembled_data.pickle which includes trial metadata, photometry snips, and cluster assignments.

In [None]:
assembled_data_path = DATAFOLDER / "assembled_data.pickle"

with open(assembled_data_path, "rb") as f:
    data = dill.load(f)

# Extract main components
x_array = data["x_array"]
snips_photo = data["snips_photo"]
snips_movement = data["snips_movement"]

snips_angvel = data.get("snips_angvel", None)
fits_df = data["fits_df"]
params = data.get("params", {})

print(f"Loaded assembled data from {assembled_data_path}")
print(f"  - x_array shape: {x_array.shape}")
print(f"  - snips_photo shape: {snips_photo.shape}")
print(f"  - Number of unique animals: {x_array['id'].nunique()}")

In [None]:
# Load PCA-transformed photometry data (optional - if needed for PC-based analysis)
pcafile = RESULTSFOLDER / "transformed_data_photo.pickle"

if pcafile.exists():
    with open(pcafile, 'rb') as f:
        pca = dill.load(f)
    pc1 = pca[:,0]
    pca_data = pca[:, :3]
    print(f"Loaded PCA data: {pca_data.shape}")
else:
    print(f"Warning: PCA file not found at {pcafile}")
    pca_data = None

## Calculate trial distances based on clusters and PCs

In [None]:
# Note: assembled_data.pickle should already contain cluster_photo assignments
# This cell calculates continuous distance metrics based on PCA space

if pca_data is None:
    print("Skipping clusterness calculation - PCA data not available")
elif 'cluster_photo' not in x_array.columns:
    print("Warning: cluster_photo column not found in x_array")
    print("Cannot calculate clusterness without cluster assignments")
else:
    # calculate centroids
    cluster_0_centroid = pca_data[x_array.cluster_photo == 0].mean(axis=0)
    cluster_1_centroid = pca_data[x_array.cluster_photo == 1].mean(axis=0)

    ## First, to work out projections
    # Step 2: Define the cluster separation vector
    cluster_vector = cluster_0_centroid - cluster_1_centroid

    # Step 3: Project each observation onto the cluster vector
    # Normalize the cluster vector
    cluster_vector_norm = cluster_vector / np.linalg.norm(cluster_vector)
    # Compute projections
    projections = np.dot(pca_data - cluster_1_centroid, cluster_vector_norm)

    # Step 4: Normalize the projections to range between 0 and 1
    min_projection = projections.min()
    max_projection = projections.max()
    normalized_projections = (projections - min_projection) / (max_projection - min_projection)

    x_array = x_array.assign(clusterness_photo=normalized_projections)


    ## Second to work out Euclidian distances
    # Stack centroids into a matrix
    centroids = np.vstack([cluster_0_centroid, cluster_1_centroid])

    # Calculate all distances at once using cdist
    # This creates a matrix where each row is an observation and each column is a centroid
    distances = cdist(pca_data, centroids, metric='euclidean')

    distances_diff = distances[:, 1] - distances[:, 0]

    x_array = x_array.assign(euclidean_diff=distances_diff)
    
    print("Added clusterness_photo and euclidean_diff columns to x_array")

In [None]:
## to save data with clutserness and euclidian diff

with open(DATAFOLDER / "bazzino_data_with_clusters_and_dists.pickle", "wb") as f:
    dill.dump({
        "x_array": x_array,
        "snips_photo": snips_photo,
        "snips_vel": snips_vel,
        "pca": pca
    }, f)

In [None]:
# to show eligible rats for this fitting
x_array.query("condition == 'deplete' & infusiontype == '45NaCl'").id.unique()

In [None]:
# Robust logistic fits for binary/near-binary data (handles optional offset)

# 3-parameter logistic (no baseline offset)
def logistic3(x, L, x0, k):
    return L / (1 + np.exp(-k * (x - x0)))

# 4-parameter logistic (baseline offset A)
def logistic4(x, A, L, x0, k):
    return A + (L - A) / (1 + np.exp(-k * (x - x0)))


def _normalize_x(x):
    x = np.asarray(x, dtype=float)
    m, s = float(np.mean(x)), float(np.std(x))
    if not np.isfinite(s) or s == 0:
        s = 1.0
    return (x - m) / s, m, s


def _clip_y(y):
    y = np.asarray(y, dtype=float)
    return np.clip(y, 1e-4, 1 - 1e-4)


def fit_logistic_per_series(y, x=None, prefer_4p=True, direction=None, maxfev=60000):
    """
    Fit a logistic curve to binary/near-binary data with robust inits and bounds.
    - prefer_4p: try 4-parameter (with baseline) first, then fallback to 3-parameter
    - direction: None to infer from corr(x,y); 'increasing' or 'decreasing' to enforce k sign
    Returns dict with keys: {'model','params','y_hat','x0_orig','success','note'}
    """
    if x is None:
        x = np.arange(len(y), dtype=float)
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)

    # Normalize x for stabler k/x0 estimation
    x_norm, x_mean, x_std = _normalize_x(x)
    y_clip = _clip_y(y)

    # Initial guesses from data
    y_min, y_max = float(np.min(y_clip)), float(np.max(y_clip))
    A_init = y_min
    L_init = y_max
    x0_init = 0.0

    # k sign from direction or correlation
    if direction is None:
        try:
            c = float(np.corrcoef(x, y_clip)[0, 1])
        except Exception:
            c = 0.0
        if not np.isfinite(c):
            c = 0.0
        sign = 1.0 if c >= 0 else -1.0
    else:
        sign = 1.0 if direction == 'increasing' else -1.0

    k_mags = [0.5, 1.0, 2.0]

    def try_fit(func, p0_list, bounds, n_params):
        best = None
        best_rss = np.inf
        for p0 in p0_list:
            try:
                popt, _ = curve_fit(func, x_norm, y_clip, p0=p0, bounds=bounds, maxfev=maxfev)
                y_hat = func(x_norm, *popt)
                rss = float(np.sum((y_clip - y_hat) ** 2))
                if rss < best_rss:
                    best_rss, best = rss, (popt, y_hat)
            except Exception:
                continue
        return best

    # 4-parameter attempt
    res4 = None
    if prefer_4p:
        p0s_4 = [[A_init, L_init, x0_init, sign * km] for km in k_mags]
        # Keep values sensible for binary-ish data
        bnds_4 = ([ -0.1,  0.4, -3.0, -10.0],
                  [  0.6,  1.6,  3.0,  10.0])
        res4 = try_fit(logistic4, p0s_4, bnds_4, 4)

    # 3-parameter fallback or primary
    res3 = None
    p0s_3 = [[L_init, x0_init, sign * km] for km in k_mags]
    bnds_3 = ([0.4, -3.0, -10.0], [1.6, 3.0, 10.0])
    res3 = try_fit(logistic3, p0s_3, bnds_3, 3)

    # Choose result: prefer successful 4p; otherwise 3p
    if res4 is not None:
        popt, y_hat = res4
        # x0 already in normalized units; convert to original units for reporting
        A, L, x0n, k = map(float, popt)
        x0_orig = x0n * x_std + x_mean
        return {
            'model': 'logistic4',
            'params': {'A': A, 'L': L, 'x0_norm': x0n, 'x0_orig': x0_orig, 'k': k},
            'y_hat': y_hat,
            'x0_orig': x0_orig,
            'success': True,
            'note': ''
        }
    elif res3 is not None:
        popt, y_hat = res3
        L, x0n, k = map(float, popt)
        x0_orig = x0n * x_std + x_mean
        return {
            'model': 'logistic3',
            'params': {'L': L, 'x0_norm': x0n, 'x0_orig': x0_orig, 'k': k},
            'y_hat': y_hat,
            'x0_orig': x0_orig,
            'success': True,
            'note': '4p failed; used 3p'
        }
    else:
        return {'model': None, 'params': {}, 'y_hat': None, 'x0_orig': np.nan, 'success': False, 'note': 'fit failed'}
    

In [None]:
# Logistic fits for raw cluster assignments (binary/inverted cluster_photo)
df_dep_45 = x_array.query("condition == 'deplete' & infusiontype == '45NaCl'").copy()
all_fits = []

f, ax = plt.subplots(figsize=(6,4))

for rat in df_dep_45.id.unique():
    sig = df_dep_45.loc[df_dep_45.id == rat, 'cluster_photo'].to_numpy()
    # Make binary: invert if needed (original code inverted)
    y = np.logical_not(sig).astype(int)
    x = np.arange(len(y), dtype=float)

    fit = fit_logistic_per_series(y, x=x, prefer_4p=True, direction='decreasing')
    all_fits.append({ 'id': rat, **fit['params'], 'model': fit['model'], 'x0_orig': fit['x0_orig'], 'success': fit['success'], 'note': fit['note'] })

    if fit["success"] and fit['x0_orig'] > 0 and fit['x0_orig'] < len(y):
        ax.plot(x, fit["y_hat"], color=colors[2], alpha=0.5, linestyle="--")

fits_df = pd.DataFrame(all_fits)
fits_df = fits_df.query("success == True and x0_orig > 0").copy()

x0 = fits_df['x0_orig'].to_list()
ax.plot(x0, [1.1]*len(x0), marker="o", linestyle="None", color=colors[2], alpha=0.5, clip_on=False)
ax.text(np.max(x0)+2, 1.1, "Transition points", ha="left", va="center", fontsize=10, color=colors[2])

ax.plot([np.mean(x0), np.mean(x0)], [1.05, 1.15], color=colors[2], linestyle="--", alpha=0.5, clip_on=False)
ax.text(np.mean(x0), 1.16, f"Mean=trial {int(np.mean(x0))}", ha="center", va="bottom", fontsize=10, color=colors[2])

sns.despine(ax=ax, offset=5)
ax.set_xlabel("Trial Number")
ax.set_ylabel("Probability of Cluster 1")

ax.set_yticks([0, 0.5, 1])
ax.set_ylim([-0.02, 1.1])

if savefigs:
    f.savefig(FIGSFOLDER / "logistic_fits_45NaCl.pdf", dpi=600, transparent=True)
    
fits_df_cluster_raw = fits_df
        
# print(fits_df)

In [None]:
# Logistic fits for clusterness (continuous between 0 and 1, based on projection onto cluster vector)
df_dep_45 = x_array.query("condition == 'deplete' & infusiontype == '45NaCl'").copy()
all_fits = []

f, ax = plt.subplots(figsize=(6,4))

for rat in df_dep_45.id.unique():
    sig = df_dep_45.loc[df_dep_45.id == rat, 'clusterness_photo'].to_numpy()
    # Make binary: invert if needed (original code inverted)
    y = sig
    x = np.arange(len(y), dtype=float)

    fit = fit_logistic_per_series(y, x=x, prefer_4p=True, direction='decreasing')
    all_fits.append({ 'id': rat, **fit['params'], 'model': fit['model'], 'x0_orig': fit['x0_orig'], 'success': fit['success'], 'note': fit['note'] })

    if fit["success"] and fit['x0_orig'] > 0 and fit['x0_orig'] < len(y):
        ax.plot(x, fit["y_hat"], color=colors[2], alpha=0.5, linestyle="--")

fits_df = pd.DataFrame(all_fits)
fits_df = fits_df.query("success == True and x0_orig > 0 and x0_orig < 50").copy()

x0 = fits_df['x0_orig'].to_list()
ax.plot(x0, [0.75]*len(x0), marker="o", linestyle="None", color=colors[2], alpha=0.5, clip_on=False)
ax.text(np.max(x0)+2, 0.75, "Transition points", ha="left", va="center", fontsize=10, color=colors[2])

ax.plot([np.mean(x0), np.mean(x0)], [0.73, 0.77], color=colors[2], linestyle="--", alpha=0.5, clip_on=False)
ax.text(np.mean(x0), 0.78, f"Mean=trial {int(np.mean(x0))}", ha="center", va="bottom", fontsize=10, color=colors[2])

sns.despine(ax=ax, offset=5)
ax.set_xlabel("Trial Number")
ax.set_ylabel("Cluster 1-ness")

# ax.set_yticks([0, 0.5, 1])
ax.set_ylim([0.31, 0.7])

if savefigs:
    f.savefig(FIGSFOLDER / "logistic_fits_45NaCl.pdf", dpi=600, transparent=True)
    
fits_df_clusterness = fits_df
        
# print(fits_df)


In [None]:
df_dep_45.columns

In [None]:
# Logistic fits for euclidean distance difference (continuous between -inf and +inf)
df_dep_45 = x_array.query("condition == 'deplete' & infusiontype == '45NaCl'").copy()
all_fits = []

f, ax = plt.subplots(figsize=(6,4))

for rat in df_dep_45.id.unique():
    sig = df_dep_45.loc[df_dep_45.id == rat, 'euclidean_diff'].to_numpy()
    # Make binary: invert if needed (original code inverted)
    y = sig
    x = np.arange(len(y), dtype=float)

    fit = fit_logistic_per_series(y, x=x, prefer_4p=True, direction='decreasing')
    all_fits.append({ 'id': rat, **fit['params'], 'model': fit['model'], 'x0_orig': fit['x0_orig'], 'success': fit['success'], 'note': fit['note'] })

    if fit["success"] and fit['x0_orig'] > 0 and fit['x0_orig'] < len(y):
        ax.plot(x, fit["y_hat"], color=colors[2], linestyle="--", alpha=0.5)

fits_df = pd.DataFrame(all_fits)
fits_df = fits_df.query("success == True and x0_orig > 0 and x0_orig < 50").copy()

x0 = fits_df['x0_orig'].to_list()
ax.plot(x0, [1.1]*len(x0), marker="o", linestyle="None", color=colors[2], alpha=0.5, clip_on=False)
ax.text(np.max(x0)+2, 1.1, "Transition points", ha="left", va="center", fontsize=10, color=colors[2])

ax.plot([np.mean(x0), np.mean(x0)], [1.05, 1.15], color=colors[2], linestyle="--", alpha=0.5, clip_on=False)
ax.text(np.mean(x0), 1.16, f"Mean=trial {int(np.mean(x0))}", ha="center", va="bottom", fontsize=10, color=colors[2])

sns.despine(ax=ax, offset=5)
ax.set_xlabel("Trial Number")
ax.set_ylabel("Cluster 1-ness")

# ax.set_yticks([0, 0.5, 1])
# ax.set_ylim([0.31, 0.7])

if savefigs:
    f.savefig(FIGSFOLDER / "logistic_fits_45NaCl.pdf", dpi=600, transparent=True)

fits_df_euclidean = fits_df
# print(fits_df)


In [None]:
with open(DATAFOLDER / "sigmoidal_fits.pickle", "wb") as f:
    dill.dump({
        "fits_df_cluster_raw": fits_df_cluster_raw,
        "fits_df_clusterness": fits_df_clusterness,
        "fits_df_euclidean": fits_df_euclidean
    }, f)

In [None]:
df_dep_45.id.unique()

In [None]:
## Quick visualization for one rat
example_rat = df_dep_45.id.unique()[2]
sig = df_dep_45.loc[df_dep_45.id == example_rat, 'cluster_photo'].to_numpy()
y = np.logical_not(sig).astype(int)
x = np.arange(len(y), dtype=float)
res = fit_logistic_per_series(y, x=x, prefer_4p=True, direction='decreasing')

plt.figure(figsize=(6, 3.5))
plt.scatter(x, y, s=24, color='#016895', alpha=0.8, label='Binary data')
if res['y_hat'] is not None:
    plt.plot(x, res['y_hat'], color='#F4795B', lw=2, label=f"{res['model']} fit")
    if np.isfinite(res['x0_orig']):
        plt.axvline(res['x0_orig'], color='#999', ls='--', lw=1)
        plt.text(res['x0_orig'], 1.02, f"x0â‰ˆ{res['x0_orig']:.1f}", ha='center', va='bottom', fontsize=9)
plt.ylim(-0.1, 1.1)
plt.xlabel('Trial')
plt.ylabel('Binary outcome')
sns.despine()
plt.legend(frameon=False)
plt.tight_layout()
plt.show()

In [None]:
merged_fits = (
    fits_df_cluster_raw[['id', 'x0_orig']]
    .rename(columns={'x0_orig': 'x0_cluster_raw'})
    .merge(
        fits_df_clusterness[['id', 'x0_orig']].rename(columns={'x0_orig': 'x0_clusterness'}),
        on='id',
        how='outer'
    )
    .merge(
        fits_df_euclidean[['id', 'x0_orig']].rename(columns={'x0_orig': 'x0_euclidean'}),
        on='id',
        how='outer'
    )
    .dropna()
    .assign(
        x0_cluster_raw=lambda df: df['x0_cluster_raw'].round().astype(int),
        x0_clusterness=lambda df: df['x0_clusterness'].round().astype(int),
        x0_euclidean=lambda df: df['x0_euclidean'].round().astype(int)
    )
)

merged_fits

In [None]:
## Sigmoidal Fits for time_moving (Deplete + 45NaCl)
# Fit sigmoids to normalized time_moving data for each animal with quality checks from figure_1

from scipy.stats import pearsonr
import matplotlib.pyplot as plt

def _sigmoid_model(x, L, k, x0, b):
    """4-parameter sigmoid function."""
    z = np.clip(-k * (x - x0), -60, 60)
    return L / (1 + np.exp(z)) + b

def _safe_pearson(y_true, y_pred):
    """Safe Pearson correlation handling edge cases."""
    if np.allclose(np.std(y_true), 0) or np.allclose(np.std(y_pred), 0):
        return np.nan, np.nan
    return pearsonr(y_true, y_pred)

def _model_metrics(y_true, y_pred, n_params):
    """Calculate model fit metrics (RMSE, AIC, AICc, BIC)."""
    y_true = np.asarray(y_true, dtype=float)
    y_pred = np.asarray(y_pred, dtype=float)
    n = len(y_true)
    residuals = y_true - y_pred
    sse = np.nansum(residuals ** 2)
    sse = max(float(sse), 1e-12)
    rmse = np.sqrt(sse / n) if n > 0 else np.nan
    aic = n * np.log(sse / n) + 2 * n_params if n > 0 else np.nan
    bic = n * np.log(sse / n) + n_params * np.log(n) if n > 1 else np.nan
    aicc = aic + (2 * n_params * (n_params + 1)) / (n - n_params - 1) if n > (n_params + 1) else np.nan
    return rmse, aic, aicc, bic

def _sigmoid_quality_checks(x, params, pcov):
    """
    Check sigmoid fit quality using criteria from figure_1_behaviour notebook:
    1. x0 must be interior (not near edges)
    2. k must be plausible (0.02 <= |k| <= 2.5)
    3. Confidence intervals must be finite
    4. Both asymptotes must be visible in data range
    """
    if params is None or len(params) != 4 or np.any(~np.isfinite(params)):
        return False, "fit_failed", {"x0_interior": False, "k_plausible": False, "ci_finite": False, "asymptotes_covered": False}

    L, k, x0, b = params
    x_min, x_max = float(np.min(x)), float(np.max(x))
    x_range = max(x_max - x_min, 1.0)
    edge_margin = 0.15 * x_range
    x0_interior = (x_min + edge_margin) <= x0 <= (x_max - edge_margin)
    k_plausible = np.isfinite(k) and (0.02 <= abs(k) <= 2.5)

    ci_finite = False
    if pcov is not None:
        diag = np.diag(pcov)
        if np.all(np.isfinite(diag)) and np.all(diag >= 0):
            se = np.sqrt(diag)
            ci_finite = np.all(np.isfinite(se)) and np.all(se > 0)

    amplitude = abs(L)
    if amplitude < 1e-8:
        asymptotes_covered = False
    else:
        lower_asym = min(b, b + L)
        upper_asym = max(b, b + L)
        y_start = _sigmoid_model(np.array([x_min]), L, k, x0, b)[0]
        y_end = _sigmoid_model(np.array([x_max]), L, k, x0, b)[0]

        tol = 0.2
        start_low = abs(y_start - lower_asym) / amplitude <= tol
        start_high = abs(y_start - upper_asym) / amplitude <= tol
        end_low = abs(y_end - lower_asym) / amplitude <= tol
        end_high = abs(y_end - upper_asym) / amplitude <= tol

        start_side = "low" if start_low else ("high" if start_high else None)
        end_side = "low" if end_low else ("high" if end_high else None)
        asymptotes_covered = (start_side is not None) and (end_side is not None) and (start_side != end_side)

    checks = {
        "x0_interior": bool(x0_interior),
        "k_plausible": bool(k_plausible),
        "ci_finite": bool(ci_finite),
        "asymptotes_covered": bool(asymptotes_covered),
    }

    failed = [name for name, ok in checks.items() if not ok]
    is_valid = len(failed) == 0
    reasons = "ok" if is_valid else ";".join(failed)
    return is_valid, reasons, checks

# Get deplete + 45NaCl subset
deplete_45_subset = (
    x_array
    .query("condition == 'deplete' & infusiontype == '45NaCl'")
    .copy()
)

# Get unique animals
animals = sorted(deplete_45_subset['id'].unique())
print(f"Fitting sigmoids for {len(animals)} animals in deplete + 45NaCl condition")
print(f"Animals: {animals}\n")

# Fit sigmoids for each animal
fit_results = []
fit_traces = []

for animal in animals:
    # Get this animal's data
    animal_data = (
        deplete_45_subset
        .query("id == @animal")
        .sort_values('trial')
        .copy()
    )
    
    # Extract time_moving values
    y_raw = animal_data['time_moving'].values
    x = np.arange(1, len(y_raw) + 1, dtype=float)
    
    # Normalize to [0, 1]
    y_min, y_max = y_raw.min(), y_raw.max()
    y_range = y_max - y_min
    if y_range < 1e-8:
        print(f"  {animal}: SKIPPED (no variation in time_moving)")
        continue
    
    y = (y_raw - y_min) / y_range
    
    # Determine initial k value based on expected direction
    # For time_moving, we expect increase (aversion develops), so positive k
    initial_k = 1.0
    
    # Set up sigmoid parameters for time_moving (expected to increase)
    p0 = [
        1.0,                    # L: amplitude (normalized to 1)
        initial_k,              # k: steepness (positive for increase)
        np.median(x),           # x0: midpoint
        0.0                     # b: baseline (normalized to 0)
    ]
    bounds = (
        [-np.inf, -np.inf, np.min(x), -np.inf],
        [np.inf, np.inf, np.max(x), np.inf]
    )
    
    # Fit sigmoid
    try:
        params, pcov = curve_fit(_sigmoid_model, x, y, p0=p0, bounds=bounds, maxfev=30000)
        yhat = _sigmoid_model(x, *params)
        yhat_raw = yhat * y_range + y_min
        r, p_val = _safe_pearson(y, yhat)
        rmse, aic, aicc, bic = _model_metrics(y, yhat, n_params=4)
        
        # Quality checks
        is_valid, reasons, checks = _sigmoid_quality_checks(x, params, pcov)
        
        L, k, x0, b = params
        
        fit_results.append({
            'animal': animal,
            'n_trials': len(y),
            'L': L,
            'k': k,
            'x0': x0,
            'b': b,
            'r': r,
            'p': p_val,
            'rmse': rmse,
            'aicc': aicc,
            'bic': bic,
            'is_valid': is_valid,
            'reasons': reasons,
            'x0_interior': checks['x0_interior'],
            'k_plausible': checks['k_plausible'],
            'ci_finite': checks['ci_finite'],
            'asymptotes_covered': checks['asymptotes_covered'],
        })

        fit_traces.append({
            'animal': animal,
            'x': x,
            'y_raw': y_raw,
            'y_fit': yhat_raw,
            'is_valid': is_valid,
        })
        
    except Exception as e:
        print(f"  {animal}: FIT FAILED ({e})")
        fit_results.append({
            'animal': animal,
            'n_trials': len(y),
            'L': np.nan,
            'k': np.nan,
            'x0': np.nan,
            'b': np.nan,
            'r': np.nan,
            'p': np.nan,
            'rmse': np.nan,
            'aicc': np.nan,
            'bic': np.nan,
            'is_valid': False,
            'reasons': 'fit_failed',
            'x0_interior': False,
            'k_plausible': False,
            'ci_finite': False,
            'asymptotes_covered': False,
        })

# Create results DataFrame
sigmoid_fit_df = pd.DataFrame(fit_results)

# Round for display
display_df = sigmoid_fit_df.copy()
for col in ['L', 'k', 'x0', 'b', 'r', 'p', 'rmse', 'aicc', 'bic']:
    if col in display_df.columns:
        display_df[col] = display_df[col].round(4)

print("\n" + "="*80)
print("SIGMOID FIT RESULTS FOR time_moving (Deplete + 45NaCl)")
print("="*80)
print("\nFit Parameters:")
print(display_df[['animal', 'n_trials', 'L', 'k', 'x0', 'b', 'r', 'p', 'rmse', 'aicc']].to_string(index=False))

print("\n" + "-"*80)
print("Quality Checks:")
print(display_df[['animal', 'is_valid', 'reasons', 'x0_interior', 'k_plausible', 'ci_finite', 'asymptotes_covered']].to_string(index=False))

# Summary statistics
valid_fits = sigmoid_fit_df[sigmoid_fit_df['is_valid'] == True]
print("\n" + "="*80)
print(f"SUMMARY: {len(valid_fits)}/{len(sigmoid_fit_df)} fits passed all quality checks")
print("="*80)

if len(valid_fits) > 0:
    print(f"\nValid fits statistics:")
    print(f"  k (steepness):  mean = {valid_fits['k'].mean():.4f}, std = {valid_fits['k'].std():.4f}, range = [{valid_fits['k'].min():.4f}, {valid_fits['k'].max():.4f}]")
    print(f"  x0 (transition): mean = {valid_fits['x0'].mean():.2f}, std = {valid_fits['x0'].std():.2f}, range = [{valid_fits['x0'].min():.2f}, {valid_fits['x0'].max():.2f}]")
    print(f"  r (correlation): mean = {valid_fits['r'].mean():.4f}, std = {valid_fits['r'].std():.4f}")
    print(f"  RMSE:            mean = {valid_fits['rmse'].mean():.4f}, std = {valid_fits['rmse'].std():.4f}")

# Store results for potential use
time_moving_sigmoid_fits = sigmoid_fit_df


In [None]:
fits_df

In [None]:
## Plot per-animal sigmoid fits (time_moving)
if len(fit_traces) == 0:
    print("No fits to plot. Run the fitting cell first.")
else:
    for trace in fit_traces:
        fig, ax = plt.subplots(figsize=(6, 4))
        ax.scatter(trace['x'], trace['y_raw'], s=30, color='black', alpha=0.7, label='data')
        ax.plot(trace['x'], trace['y_fit'], color='tab:red', linewidth=2, label='sigmoid fit')
        ax.set_title(f"{trace['animal']} time_moving sigmoid fit")
        ax.set_xlabel('trial')
        ax.set_ylabel('time_moving')
        ax.legend(frameon=False)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        fig.tight_layout()
        if savefigs:
            outpath = FIGSFOLDER / f"sigmoid_fit_time_moving_{trace['animal']}.png"
            fig.savefig(outpath, dpi=300)


In [None]:
## Leave one out analysis
# Plan is to remove one trial at a time, fit a sigmoid and calculate the transition point without this trial, 
# and then to ??? 
# 1) realign based on these transition points, calculate k (steepness), and iterate over all trials, averaging k values?
# 2) re 
