# Paper Figures: Figure 2 - Photometry (Dopamine) Analysis

This notebook generates publication-ready figures for photometry data (dopamine responses) using data assembled by `src/assemble_all_data.py`.

**Figure 2: Photometry Analysis** — Neural responses showing dopamine heatmaps and summary plots for replete and deplete conditions across sodium concentrations.

In [None]:
%load_ext autoreload
%autoreload 2

import pathlib
from pathlib import Path
import sys
sys.path.insert(0, str(Path("../src").resolve()))
from pickle_compat import enable_dill_pathlib_compat
enable_dill_pathlib_compat()

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import dill


from figure_config import (
    configure_matplotlib, COLORS, HEATMAP_CMAP_DIV, 
    DATAFOLDER, RESULTSFOLDER, FIGSFOLDER,
    HEATMAP_VLIM_PHOTO, YLIMS_PHOTO,
    PHOTO_SMOOTH_WINDOW, SAVE_FIGS
)
from figure_plotting import (
    smooth_array, get_heatmap_data, get_mean_snips, get_auc,
    init_heatmap_figure, init_snips_figure, make_heatmap,
    plot_snips, plot_auc_summary, save_figure_atomic, print_auc_stats,
    scale_vlim_to_data, calculate_ylims, make_correlation_plot_da
)

# Configure matplotlib
configure_matplotlib()
colors = COLORS  # Use shared color palette
custom_cmap = HEATMAP_CMAP_DIV  # Use shared colormap

## Load Assembled Data

Load the complete dataset from the pickle file generated by the assembly script.

In [None]:
assembled_data_path = DATAFOLDER / "assembled_data.pickle"

with open(assembled_data_path, "rb") as f:
    data = dill.load(f)

# Extract main components
x_array = data["x_array"]
snips_photo = data["snips_photo"]
params = data.get("params", {})
metadata = data.get("metadata", {})

print(f"Loaded assembled data from {assembled_data_path}")
print(f"\nData structure:")
print(f"  - x_array shape: {x_array.shape}")
print(f"  - snips_photo shape: {snips_photo.shape}")
print(f"  - x_array columns: {x_array.columns.tolist()}")
print(f"  - Number of trials: {len(x_array)}") 

# Check data processing metadata
print(f"\nData processing metadata:")
if metadata:
    print(f"  Behaviour metric: {metadata.get('behav_metric', 'unknown')}")
    print(f"  Behaviour smoothed: {metadata.get('behav_smoothed', False)}")
    print(f"  Behaviour z-scored: {metadata.get('behav_zscored', False)}")
    print(f"  Photometry smoothed: {metadata.get('photo_smoothed', False)}")
    print(f"  Photometry z-scored: {metadata.get('photo_zscored', False)}")
else:
    print(f"  No metadata found; assuming standard processing")

# Store metadata flags for use during plotting
photo_already_smoothed = metadata.get('photo_smoothed', False)

## Figure 2: Photometry Analysis — Dopamine Responses

Analysis of neural dopamine responses showing heatmaps, snip time series, and summary AUC metrics across replete and deplete sodium conditions, separated by infusion type (10NaCl vs 45NaCl).

In [None]:
# Photometry (dopamine) data is NOT smoothed during assembly
# Do not apply any smoothing to maintain signal fidelity
snips_photo_smooth = snips_photo  # Use as-is, unsmoothed
print(f"Using unsmoothed photometry data (photometry is NOT smoothed during assembly)")

# Parameters for visualization
# Use dynamic scaling based on actual data instead of static limits
vlim = scale_vlim_to_data(snips_photo_smooth, percentile=80)
print("Based on data, vlim are:", vlim)
vlim = (-1, 1)  # Uncomment to override auto scaling with hardcoded values


In [None]:
### 2A. Heatmaps — Replete Condition

f, ax1, ax2, cbar_ax = init_heatmap_figure()

# Replete + 10NaCl
heatmap_data_rep_10 = get_heatmap_data(snips_photo_smooth, x_array, "replete", "10NaCl")
replete_10_auc = get_auc(heatmap_data_rep_10)
make_heatmap(heatmap_data_rep_10, ax1, vlim, inf_bar=True, cmap=custom_cmap)

# Replete + 45NaCl
heatmap_data_rep_45 = get_heatmap_data(snips_photo_smooth, x_array, "replete", "45NaCl")
replete_45_auc = get_auc(heatmap_data_rep_45)
make_heatmap(heatmap_data_rep_45, ax2, vlim, cmap=custom_cmap, cbar_ax=cbar_ax)

if SAVE_FIGS:
    save_figure_atomic(f, "fig2_heatmap_dopamine_replete", FIGSFOLDER)


In [None]:
### 2B. Heatmaps — Deplete Condition

f, ax1, ax2, cbar_ax = init_heatmap_figure()

# Deplete + 10NaCl
heatmap_data_dep_10 = get_heatmap_data(snips_photo_smooth, x_array, "deplete", "10NaCl")
deplete_10_auc = get_auc(heatmap_data_dep_10)
make_heatmap(heatmap_data_dep_10, ax1, vlim, cmap=custom_cmap, inf_bar=True)

# Deplete + 45NaCl
heatmap_data_dep_45 = get_heatmap_data(snips_photo_smooth, x_array, "deplete", "45NaCl")
deplete_45_auc = get_auc(heatmap_data_dep_45)
make_heatmap(heatmap_data_dep_45, ax2, vlim, cmap=custom_cmap, cbar_ax=cbar_ax)

if SAVE_FIGS:
    save_figure_atomic(f, "fig2_heatmap_dopamine_deplete", FIGSFOLDER)


In [None]:

ylims = calculate_ylims(snips_photo_smooth)  # Calculate ylims based on data
print("Based on data, ylims are:", ylims)
ylims = (-0.6, 3)  # snips plot limits

In [None]:
### 2C. Time Series Snips — Replete Condition

# Get animal-averaged snips for replete
snips_rep_10, snips_rep_45 = get_mean_snips(snips_photo_smooth, x_array, "replete")

f, ax = init_snips_figure()
plot_snips(snips_rep_10, snips_rep_45, ax, colors[0], colors[1], ylims)

if SAVE_FIGS:
    save_figure_atomic(f, "fig2_snips_dopamine_replete", FIGSFOLDER)


In [None]:
### 2D. Time Series Snips — Deplete Condition

# Get animal-averaged snips for deplete
snips_dep_10, snips_dep_45 = get_mean_snips(snips_photo_smooth, x_array, "deplete")

f, ax = init_snips_figure()
plot_snips(snips_dep_10, snips_dep_45, ax, colors[2], colors[3], ylims, scalebar=True)

if SAVE_FIGS:
    save_figure_atomic(f, "fig2_snips_dopamine_deplete", FIGSFOLDER)


In [None]:
### 2E. AUC Summary — Bar Plot with Individual Data Points

# Organize AUCs by condition
replete_aucs = [get_auc(snips_rep_10), get_auc(snips_rep_45)]
deplete_aucs = [get_auc(snips_dep_10), get_auc(snips_dep_45)]
aucs = [replete_aucs, deplete_aucs]

f, ax = plot_auc_summary(aucs, colors, ylabel="Dopamine (AUC)",
                         figsize=(2.2, 2.8))


if SAVE_FIGS:
    save_figure_atomic(f, "fig2_auc_dopamine_summary", FIGSFOLDER)

# Print summary statistics
auc_labels = [
    f"Replete + 10NaCl (n={len(snips_rep_10)})",
    f"Replete + 45NaCl (n={len(snips_rep_45)})",
    f"Deplete + 10NaCl (n={len(snips_dep_10)})",
    f"Deplete + 45NaCl (n={len(snips_dep_45)})"
]
auc_arrays = [replete_aucs[0], replete_aucs[1], deplete_aucs[0], deplete_aucs[1]]
print_auc_stats(auc_arrays, auc_labels, title="Figure 2 — Dopamine Summary Statistics")

In [None]:
def make_auc_data_per_trial(x_array, condition, infusiontype):
    
    return (
        x_array
        .query("condition == @condition & infusiontype == @infusiontype")
        .groupby("trial")
        .auc_snips
        .mean()
        .values
    )
    
replete_10 = make_auc_data_per_trial(x_array, "replete", "10NaCl")
replete_45 = make_auc_data_per_trial(x_array, "replete", "45NaCl")
deplete_10 = make_auc_data_per_trial(x_array, "deplete", "10NaCl")
deplete_45 = make_auc_data_per_trial(x_array, "deplete", "45NaCl")

f = make_correlation_plot_da(replete_10, replete_45, colors[0], colors[1], yaxis=True)
if SAVE_FIGS:
    save_figure_atomic(f, "fig2_correlation_da_replete", FIGSFOLDER)
    
f = make_correlation_plot_da(deplete_10, deplete_45, colors[2], colors[3])
if SAVE_FIGS:
    save_figure_atomic(f, "fig2_correlation_da_deplete", FIGSFOLDER)

In [None]:
# 2F. Model fits on AUC-by-trial data with fit-quality checks
import numpy as np
import pandas as pd
from scipy.optimize import curve_fit
from scipy.stats import pearsonr

def _linear_model(x, m, b):
    return m * x + b

def _exp_model(x, a, b, c):
    return a * np.exp(b * x) + c

def _sigmoid_model(x, L, k, x0, b):
    z = np.clip(-k * (x - x0), -60, 60)
    return L / (1 + np.exp(z)) + b

def _safe_pearson(y_true, y_pred):
    if np.allclose(np.std(y_true), 0) or np.allclose(np.std(y_pred), 0):
        return np.nan, np.nan
    return pearsonr(y_true, y_pred)

def _model_metrics(y_true, y_pred, n_params):
    y_true = np.asarray(y_true, dtype=float)
    y_pred = np.asarray(y_pred, dtype=float)
    n = len(y_true)
    residuals = y_true - y_pred
    sse = np.nansum(residuals ** 2)
    sse = max(float(sse), 1e-12)
    rmse = np.sqrt(sse / n) if n > 0 else np.nan
    aic = n * np.log(sse / n) + 2 * n_params if n > 0 else np.nan
    bic = n * np.log(sse / n) + n_params * np.log(n) if n > 1 else np.nan
    aicc = aic + (2 * n_params * (n_params + 1)) / (n - n_params - 1) if n > (n_params + 1) else np.nan
    return rmse, aic, aicc, bic

def _fit_linear(x, y):
    try:
        m, b = np.polyfit(x, y, 1)
        yhat = _linear_model(x, m, b)
        r, p = _safe_pearson(y, yhat)
        rmse, aic, aicc, bic = _model_metrics(y, yhat, n_params=2)
        return {"r": r, "p": p, "rmse": rmse, "aic": aic, "aicc": aicc, "bic": bic, "params": np.array([m, b]), "pcov": None}
    except Exception:
        return {"r": np.nan, "p": np.nan, "rmse": np.nan, "aic": np.nan, "aicc": np.nan, "bic": np.nan, "params": np.array([np.nan, np.nan]), "pcov": None}

def _fit_curve(model_name, model_func, x, y):
    try:
        if model_name == "exponential":
            p0 = [np.ptp(y) if np.ptp(y) != 0 else 1.0, 0.05, np.min(y)]
            bounds = ([-np.inf, -np.inf, -np.inf], [np.inf, np.inf, np.inf])
            n_params = 3
        elif model_name == "sigmoidal":
            p0 = [np.ptp(y) if np.ptp(y) != 0 else 1.0, 0.2, np.median(x), np.min(y)]
            bounds = ([-np.inf, -np.inf, np.min(x), -np.inf], [np.inf, np.inf, np.max(x), np.inf])
            n_params = 4
        else:
            raise ValueError(f"Unknown model: {model_name}")

        params, pcov = curve_fit(model_func, x, y, p0=p0, bounds=bounds, maxfev=30000)
        yhat = model_func(x, *params)
        r, p = _safe_pearson(y, yhat)
        rmse, aic, aicc, bic = _model_metrics(y, yhat, n_params=n_params)
        return {"r": r, "p": p, "rmse": rmse, "aic": aic, "aicc": aicc, "bic": bic, "params": params, "pcov": pcov}
    except Exception:
        n_params = 3 if model_name == "exponential" else 4
        return {"r": np.nan, "p": np.nan, "rmse": np.nan, "aic": np.nan, "aicc": np.nan, "bic": np.nan, "params": np.full(n_params, np.nan), "pcov": None}

def _sigmoid_quality_checks(x, params, pcov):
    if params is None or len(params) != 4 or np.any(~np.isfinite(params)):
        return False, "fit_failed", {"x0_interior": False, "k_plausible": False, "ci_finite": False, "asymptotes_covered": False}

    L, k, x0, b = params
    x_min, x_max = float(np.min(x)), float(np.max(x))
    x_range = max(x_max - x_min, 1.0)
    edge_margin = 0.15 * x_range
    x0_interior = (x_min + edge_margin) <= x0 <= (x_max - edge_margin)
    k_plausible = np.isfinite(k) and (0.02 <= abs(k) <= 2.5)

    ci_finite = False
    if pcov is not None:
        diag = np.diag(pcov)
        if np.all(np.isfinite(diag)) and np.all(diag >= 0):
            se = np.sqrt(diag)
            ci_finite = np.all(np.isfinite(se)) and np.all(se > 0)

    amplitude = abs(L)
    if amplitude < 1e-8:
        asymptotes_covered = False
    else:
        lower_asym = min(b, b + L)
        upper_asym = max(b, b + L)
        y_start = _sigmoid_model(np.array([x_min]), L, k, x0, b)[0]
        y_end = _sigmoid_model(np.array([x_max]), L, k, x0, b)[0]

        tol = 0.2
        start_low = abs(y_start - lower_asym) / amplitude <= tol
        start_high = abs(y_start - upper_asym) / amplitude <= tol
        end_low = abs(y_end - lower_asym) / amplitude <= tol
        end_high = abs(y_end - upper_asym) / amplitude <= tol

        start_side = "low" if start_low else ("high" if start_high else None)
        end_side = "low" if end_low else ("high" if end_high else None)
        asymptotes_covered = (start_side is not None) and (end_side is not None) and (start_side != end_side)

    checks = {
        "x0_interior": bool(x0_interior),
        "k_plausible": bool(k_plausible),
        "ci_finite": bool(ci_finite),
        "asymptotes_covered": bool(asymptotes_covered),
    }

    failed = [name for name, ok in checks.items() if not ok]
    is_valid = len(failed) == 0
    reasons = "ok" if is_valid else ";".join(failed)
    return is_valid, reasons, checks

fit_inputs = {
    "Replete + 10NaCl": replete_10,
    "Replete + 45NaCl": replete_45,
    "Deplete + 10NaCl": deplete_10,
    "Deplete + 45NaCl": deplete_45,
}

rows = []
sigmoid_diagnostics = []

for condition_label, y in fit_inputs.items():
    y = np.asarray(y, dtype=float)
    x = np.arange(1, len(y) + 1, dtype=float)

    lin = _fit_linear(x, y)
    rows.append({
        "Condition": condition_label,
        "Model": "Linear",
        "r": lin["r"],
        "p": lin["p"],
        "RMSE": lin["rmse"],
        "AICc": lin["aicc"],
        "BIC": lin["bic"],
        "Sigmoid_valid": np.nan,
        "Sigmoid_flags": "",
    })

    exp_fit = _fit_curve("exponential", _exp_model, x, y)
    rows.append({
        "Condition": condition_label,
        "Model": "Exponential",
        "r": exp_fit["r"],
        "p": exp_fit["p"],
        "RMSE": exp_fit["rmse"],
        "AICc": exp_fit["aicc"],
        "BIC": exp_fit["bic"],
        "Sigmoid_valid": np.nan,
        "Sigmoid_flags": "",
    })

    sig_fit = _fit_curve("sigmoidal", _sigmoid_model, x, y)
    is_valid, reasons, checks = _sigmoid_quality_checks(x, sig_fit["params"], sig_fit["pcov"])
    rows.append({
        "Condition": condition_label,
        "Model": "Sigmoidal",
        "r": sig_fit["r"],
        "p": sig_fit["p"],
        "RMSE": sig_fit["rmse"],
        "AICc": sig_fit["aicc"],
        "BIC": sig_fit["bic"],
        "Sigmoid_valid": is_valid,
        "Sigmoid_flags": reasons,
    })

    L, k, x0, b = sig_fit["params"] if len(sig_fit["params"]) == 4 else [np.nan, np.nan, np.nan, np.nan]
    sigmoid_diagnostics.append({
        "Condition": condition_label,
        "L": L,
        "k": k,
        "x0": x0,
        "b": b,
        "Sigmoid_valid": is_valid,
        "Sigmoid_flags": reasons,
        "x0_interior": checks["x0_interior"],
        "k_plausible": checks["k_plausible"],
        "ci_finite": checks["ci_finite"],
        "asymptotes_covered": checks["asymptotes_covered"],
    })

fit_results = pd.DataFrame(rows)
fit_results["AICc_rank"] = fit_results.groupby("Condition")["AICc"].rank(method="dense")
fit_results["Best_by_AICc"] = fit_results["AICc_rank"] == 1

for col in ["r", "p", "RMSE", "AICc", "BIC", "AICc_rank"]:
    fit_results[col] = fit_results[col].round(4)

sigmoid_diagnostics_table = pd.DataFrame(sigmoid_diagnostics)
sigmoid_diagnostics_table[["L", "k", "x0", "b"]] = sigmoid_diagnostics_table[["L", "k", "x0", "b"]].round(4)

print("Model fit results (AUC by trial) with quality checks:")
display(fit_results)

In [None]:
# 2G. Sigmoidal fit parameters and validity checks by condition
sigmoid_params_table = sigmoid_diagnostics_table[[
    "Condition", "L", "k", "x0", "b",
    "Sigmoid_valid", "Sigmoid_flags",
    "x0_interior", "k_plausible", "ci_finite", "asymptotes_covered",
]].copy()

print("Sigmoidal fit parameters (AUC by trial) with validity checks:")
display(sigmoid_params_table)

In [None]:
# 2H. Best valid model per condition (compact summary)
summary_rows = []

for condition_label, group in fit_results.groupby("Condition"):
    group = group.copy()
    group["Is_valid_for_selection"] = np.where(
        group["Model"] == "Sigmoidal",
        group["Sigmoid_valid"].fillna(False),
        True,
    )

    valid_group = group[group["Is_valid_for_selection"] == True]
    if len(valid_group) > 0 and valid_group["AICc"].notna().any():
        best = valid_group.sort_values("AICc", ascending=True).iloc[0]
        selection_reason = "best_valid_by_AICc"
    else:
        best = group.sort_values("AICc", ascending=True).iloc[0]
        selection_reason = "fallback_no_valid_model"

    summary_rows.append({
        "Condition": condition_label,
        "Selected_Model": best["Model"],
        "Selection_Reason": selection_reason,
        "r": best["r"],
        "p": best["p"],
        "RMSE": best["RMSE"],
        "AICc": best["AICc"],
        "BIC": best["BIC"],
        "Sigmoid_valid": best["Sigmoid_valid"],
        "Sigmoid_flags": best["Sigmoid_flags"],
    })

best_model_summary = pd.DataFrame(summary_rows)
for col in ["r", "p", "RMSE", "AICc", "BIC"]:
    best_model_summary[col] = best_model_summary[col].round(4)

print("Best valid model per condition (AICc-based):")
display(best_model_summary)

In [None]:
x_array.columns

## Organization

This notebook generates **Figure 2 (Photometry Analysis)** only. Each figure has its own dedicated notebook:

- **figure_1_paper.ipynb**: Movement analysis
- **figure_2_paper.ipynb**: Photometry (dopamine) analysis (current)
- **figure_3_paper.ipynb**: Neural-behavioral correlation
- **figure_4_paper.ipynb**: Transition analysis
- **figure_5_paper.ipynb**: Cluster analysis

All notebooks share common settings and functions from:
- `src/figure_config.py` — Colors, paths, parameters
- `src/figure_plotting.py` — Data extraction and plotting functions

This keeps each figure focused and manageable, while reducing code duplication.

In [None]:
# Configuration for Figure Saving
# ───────────────────────────────────────────────────────────────────────
# The SAVE_FIGS setting is loaded from figure_config.py
# Figures are saved in two formats:
#   - PDF for publication (vector format, smaller file size)
#   - PNG for presentations (raster format, high DPI for screen)
#
# All figures follow naming convention:
#   fig{number}_{description}.{pdf|png}
#
# Example:
#   fig2_heatmap_dopamine_replete.pdf
#   fig2_snips_dopamine_replete.png
# ───────────────────────────────────────────────────────────────────────

print(f"\nFigure saving is currently: {'ENABLED' if SAVE_FIGS else 'DISABLED'}")
print(f"Figure output folder: {FIGSFOLDER}")
if SAVE_FIGS:
    print("All generated figures will be saved in both PDF and PNG formats.")
else:
    print("To save figures, set SAVE_FIGS = True in src/figure_config.py")