# Paper Figures: Figure 1 - Movement Analysis

This notebook generates publication-ready figures for the Bazzino & Roitman sodium appetite manuscript using data assembled by `src/assemble_all_data.py`.

**Figure 1: Behavioral Analysis** — Movement analysis showing heatmaps and summary plots for replete and deplete conditions across sodium concentrations.

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import dill

# Add src to path for importing local modules
sys.path.insert(0, str(Path("../src").resolve()))

from figure_config import (
    configure_matplotlib, COLORS, HEATMAP_CMAP, 
    DATAFOLDER, RESULTSFOLDER, FIGSFOLDER,
    HEATMAP_VLIM_BEHAV, YLIMS_BEHAV,
    BEHAV_SMOOTH_WINDOW, SAVE_FIGS
)
from figure_plotting import (
    smooth_array, get_heatmap_data, get_mean_snips, get_auc,
    init_heatmap_figure, init_snips_figure, make_heatmap,
    plot_snips, plot_auc_summary, save_figure, print_auc_stats,
    scale_vlim_to_data, calculate_ylims
)

# Configure matplotlib
configure_matplotlib()
colors = COLORS  # Use shared color palette

# Create unidirectional colormap from white to dark red (for non-zscored behavioral data)
# colors[3] is the dark red: "#C74632"
custom_cmap = mcolors.LinearSegmentedColormap.from_list(
    "white_to_darkred", 
    ["white", colors[3]]
)

## Load Assembled Data

Load the complete dataset from the pickle file generated by the assembly script.

In [None]:
assembled_data_path = DATAFOLDER / "assembled_data.pickle"

with open(assembled_data_path, "rb") as f:
    data = dill.load(f)

# Extract main components
x_array = data["x_array"]
snips_photo = data["snips_photo"]

# Handle new data structure with separate movement and angular velocity
if "snips_movement" in data and "snips_angvel" in data:
    snips_movement = data["snips_movement"]
    snips_angvel = data["snips_angvel"]
    print("Loaded new structure with separate snips_movement and snips_angvel")
else:
    # Fallback for old structure
    snips_movement = data.get("snips_behav", None)
    snips_angvel = None
    print("WARNING: Old data structure detected (snips_behav). Consider regenerating assembled data.")

fits_df = data["fits_df"]
z_dep45 = data.get("z_dep45", None)
params = data.get("params", {})
metadata = data.get("metadata", {})

print(f"Loaded assembled data from {assembled_data_path}")
print(f"\nData structure:")
print(f"  - x_array shape: {x_array.shape}")
print(f"  - snips_photo shape: {snips_photo.shape}")
print(f"  - snips_movement shape: {snips_movement.shape}")
if snips_angvel is not None:
    print(f"  - snips_angvel shape: {snips_angvel.shape}")
print(f"  - x_array columns: {x_array.columns.tolist()}")
print(f"  - Number of trials: {len(x_array)}")

# Store metadata flags for use during plotting
photo_already_pzscored = metadata.get('photo_zscored', False)

## Figure 1: Behavioral Analysis — Movement Metrics

Analysis of behavioral responses (movement) showing heatmaps, snip time series, and summary AUC metrics across replete and deplete sodium conditions, separated by infusion type (10NaCl vs 45NaCl).

In [None]:
# Data is already smoothed during assembly, so use directly
snips_movement_smooth = snips_movement
if snips_angvel is not None:
    snips_angvel_smooth = snips_angvel

# Parameters for visualization - MOVEMENT
# Use dynamic scaling based on actual data ranges (asymmetric, not symmetric)
# This ensures heatmaps use the full range of your data without artificial symmetry
vmin = np.nanpercentile(snips_movement_smooth, 5)  # 5th percentile for lower bound
vmax = np.nanpercentile(snips_movement_smooth, 95)  # 95th percentile for upper bound
vlim = (vmin, vmax)
print(f"Movement - Calculated asymmetric vlims: vmin={vmin:.4f}, vmax={vmax:.4f}")

# Calculate dynamic y-limits based on all snips data - MOVEMENT
print("\n" + "="*60)
print("CALCULATING Y-LIMITS FOR TIME SERIES PLOTS")
print("="*60)

# Get all condition/infusion combinations to compute limits
rep_10, rep_45 = get_mean_snips(snips_movement_smooth, x_array, "replete")
dep_10, dep_45 = get_mean_snips(snips_movement_smooth, x_array, "deplete")

# Calculate y-limits based on all snips with 5% padding
calc_ylims = calculate_ylims([rep_10, rep_45, dep_10, dep_45], pad_percentage=5)

print(f"  Data range (min, max): ({np.nanmin(snips_movement_smooth):.4f}, {np.nanmax(snips_movement_smooth):.4f})")
print(f"  Calculated y-limits: {calc_ylims}")
print("="*60 + "\n")

# Use calculated limits or override with fixed values if preferred
ylims = calc_ylims

# Parameters for angular velocity (if available)
if snips_angvel is not None:
    vmin_angvel = np.nanpercentile(snips_angvel_smooth, 5)
    vmax_angvel = np.nanpercentile(snips_angvel_smooth, 95)
    vlim_angvel = (vmin_angvel, vmax_angvel)
    print(f"Angular Velocity - Calculated asymmetric vlims: vmin={vmin_angvel:.4f}, vmax={vmax_angvel:.4f}")
    
    rep_10_av, rep_45_av = get_mean_snips(snips_angvel_smooth, x_array, "replete")
    dep_10_av, dep_45_av = get_mean_snips(snips_angvel_smooth, x_array, "deplete")
    calc_ylims_angvel = calculate_ylims([rep_10_av, rep_45_av, dep_10_av, dep_45_av], pad_percentage=5)
    print(f"Angular Velocity y-limits: {calc_ylims_angvel}")
    ylims_angvel = calc_ylims_angvel

In [None]:
### 1A. Heatmaps — Replete Condition

f, ax1, ax2, cbar_ax = init_heatmap_figure()

# Replete + 10NaCl
heatmap_data_rep_10 = get_heatmap_data(snips_movement_smooth, x_array, "replete", "10NaCl")
replete_10_auc = get_auc(heatmap_data_rep_10)
make_heatmap(heatmap_data_rep_10, ax1, vlim, inf_bar=True, cmap=custom_cmap)

# Replete + 45NaCl
heatmap_data_rep_45 = get_heatmap_data(snips_movement_smooth, x_array, "replete", "45NaCl")
replete_45_auc = get_auc(heatmap_data_rep_45)
make_heatmap(heatmap_data_rep_45, ax2, vlim, cbar_ax=cbar_ax, cmap=custom_cmap)

ax1.set_title("Replete + 10NaCl", fontsize=10)
ax2.set_title("Replete + 45NaCl", fontsize=10)

if SAVE_FIGS:
    save_figure(f, "fig1_heatmap_movement_replete", FIGSFOLDER)

plt.tight_layout()
plt.show()

In [None]:
### 1B. Heatmaps — Deplete Condition

f, ax1, ax2, cbar_ax = init_heatmap_figure()

# Deplete + 10NaCl
heatmap_data_dep_10 = get_heatmap_data(snips_movement_smooth, x_array, "deplete", "10NaCl")
deplete_10_auc = get_auc(heatmap_data_dep_10)
make_heatmap(heatmap_data_dep_10, ax1, vlim, inf_bar=True, cmap=custom_cmap)

# Deplete + 45NaCl
heatmap_data_dep_45 = get_heatmap_data(snips_movement_smooth, x_array, "deplete", "45NaCl")
deplete_45_auc = get_auc(heatmap_data_dep_45)
make_heatmap(heatmap_data_dep_45, ax2, vlim, cbar_ax=cbar_ax, cmap=custom_cmap)

ax1.set_title("Deplete + 10NaCl", fontsize=10)
ax2.set_title("Deplete + 45NaCl", fontsize=10)

if SAVE_FIGS:
    save_figure(f, "fig1_heatmap_movement_deplete", FIGSFOLDER)

plt.tight_layout()
plt.show()

In [None]:
### 1C. Time Series Snips — Replete Condition

# Get animal-averaged snips for replete
snips_rep_10, snips_rep_45 = get_mean_snips(snips_movement_smooth, x_array, "replete")

f, ax = init_snips_figure()
plot_snips(snips_rep_10, snips_rep_45, ax, colors[0], colors[1], ylims)
ax.set_ylabel("Movement", fontsize=10)
ax.set_title("Replete Condition", fontsize=10)

if SAVE_FIGS:
    save_figure(f, "fig1_snips_movement_replete", FIGSFOLDER)

plt.tight_layout()
plt.show()

In [None]:
### 1D. Time Series Snips — Deplete Condition

# Get animal-averaged snips for deplete
snips_dep_10, snips_dep_45 = get_mean_snips(snips_movement_smooth, x_array, "deplete")

f, ax = init_snips_figure()
plot_snips(snips_dep_10, snips_dep_45, ax, colors[2], colors[3], ylims, scalebar=True)
ax.set_ylabel("Movement", fontsize=10)
ax.set_title("Deplete Condition", fontsize=10)

if SAVE_FIGS:
    save_figure(f, "fig1_snips_movement_deplete", FIGSFOLDER)

plt.tight_layout()
plt.show()

In [None]:
### 1E. AUC Summary — Bar Plot with Individual Data Points

# Organize AUCs by condition
replete_aucs = [get_auc(snips_rep_10), get_auc(snips_rep_45)]
deplete_aucs = [get_auc(snips_dep_10), get_auc(snips_dep_45)]
aucs = [replete_aucs, deplete_aucs]

f, ax = plot_auc_summary(aucs, colors, ylabel="Movement (AUC)")
f.suptitle("Movement — AUC Summary", fontsize=10)

if SAVE_FIGS:
    save_figure(f, "fig1_auc_movement_summary", FIGSFOLDER)

plt.tight_layout()
plt.show()

# Print summary statistics
auc_labels = [
    f"Replete + 10NaCl (n={len(snips_rep_10)})",
    f"Replete + 45NaCl (n={len(snips_rep_45)})",
    f"Deplete + 10NaCl (n={len(snips_dep_10)})",
    f"Deplete + 45NaCl (n={len(snips_dep_45)})"
]
auc_arrays = [replete_aucs[0], replete_aucs[1], deplete_aucs[0], deplete_aucs[1]]
print_auc_stats(auc_arrays, auc_labels, title="Figure 1 — Movement Summary Statistics")

## Organization

This notebook generates **Figure 1 (Movement Analysis)** only. Each subsequent figure has its own dedicated notebook:

- **figure_1_paper.ipynb**: Movement analysis (current)
- **figure_2_paper.ipynb**: Photometry (dopamine) analysis  
- **figure_3_paper.ipynb**: Neural-behavioral correlation
- **figure_4_paper.ipynb**: Transition analysis
- **figure_5_paper.ipynb**: Cluster analysis

All notebooks share common settings and functions from:
- `src/figure_config.py` — Colors, paths, parameters
- `src/figure_plotting.py` — Data extraction and plotting functions

This keeps each figure focused and manageable, while reducing code duplication.

In [None]:
# Configuration for Figure Saving
# ───────────────────────────────────────────────────────────────────────
# The SAVE_FIGS setting is loaded from figure_config.py
# Figures are saved in two formats:
#   - PDF for publication (vector format, smaller file size)
#   - PNG for presentations (raster format, high DPI for screen)
#
# All figures follow naming convention:
#   fig{number}_{description}.{pdf|png}
#
# Example:
#   fig1_heatmap_movement_replete.pdf
#   fig1_snips_movement_replete.png
# ───────────────────────────────────────────────────────────────────────

print(f"\nFigure saving is currently: {'ENABLED' if SAVE_FIGS else 'DISABLED'}")
print(f"Figure output folder: {FIGSFOLDER}")
if SAVE_FIGS:
    print("All generated figures will be saved in both PDF and PNG formats.")
else:
    print("To save figures, set SAVE_FIGS = True in src/figure_config.py")

## Figure 1 (Bonus): Movement vs Angular Velocity Comparison

Comparison of head rotation (angular velocity) with body movement to explore the relationship between these metrics across conditions.

In [None]:
# Check if angular velocity data is available
if snips_angvel is not None:
    print("✓ Angular velocity data available for comparison")
    print(f"  snips_angvel shape: {snips_angvel_smooth.shape}")
    print(f"  Comparing {snips_movement_smooth.shape[0]} movement trials with {snips_angvel_smooth.shape[0]} angular velocity trials")
else:
    print("✗ Angular velocity data not available in assembled data")
    print("  Make sure assembled_data.pickle was generated with the updated assemble_all_data.py")


In [None]:
if snips_angvel is not None:
    ### Comparison 1: Replete Condition — Movement vs Angular Velocity Heatmaps
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Replete + 10NaCl — Movement
    heatmap_data_rep_10_mov = get_heatmap_data(snips_movement_smooth, x_array, "replete", "10NaCl")
    make_heatmap(heatmap_data_rep_10_mov, axes[0, 0], vlim, inf_bar=True, cmap=custom_cmap)
    axes[0, 0].set_title("Replete + 10NaCl — Movement", fontsize=11, fontweight='bold')
    
    # Replete + 10NaCl — Angular Velocity
    heatmap_data_rep_10_av = get_heatmap_data(snips_angvel_smooth, x_array, "replete", "10NaCl")
    make_heatmap(heatmap_data_rep_10_av, axes[0, 1], vlim_angvel, inf_bar=True, cmap='RdYlBu_r')
    axes[0, 1].set_title("Replete + 10NaCl — Angular Velocity", fontsize=11, fontweight='bold')
    
    # Replete + 45NaCl — Movement
    heatmap_data_rep_45_mov = get_heatmap_data(snips_movement_smooth, x_array, "replete", "45NaCl")
    make_heatmap(heatmap_data_rep_45_mov, axes[1, 0], vlim, cmap=custom_cmap)
    axes[1, 0].set_title("Replete + 45NaCl — Movement", fontsize=11, fontweight='bold')
    
    # Replete + 45NaCl — Angular Velocity
    heatmap_data_rep_45_av = get_heatmap_data(snips_angvel_smooth, x_array, "replete", "45NaCl")
    make_heatmap(heatmap_data_rep_45_av, axes[1, 1], vlim_angvel, cmap='RdYlBu_r')
    axes[1, 1].set_title("Replete + 45NaCl — Angular Velocity", fontsize=11, fontweight='bold')
    
    fig.suptitle("Replete Condition: Movement vs Angular Velocity", fontsize=13, fontweight='bold', y=1.00)
    
    if SAVE_FIGS:
        save_figure(fig, "fig1_compare_replete_movement_vs_angvel", FIGSFOLDER)
    
    plt.tight_layout()
    plt.show()


In [None]:
if snips_angvel is not None:
    ### Comparison 2: Deplete Condition — Movement vs Angular Velocity Heatmaps
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Deplete + 10NaCl — Movement
    heatmap_data_dep_10_mov = get_heatmap_data(snips_movement_smooth, x_array, "deplete", "10NaCl")
    make_heatmap(heatmap_data_dep_10_mov, axes[0, 0], vlim, inf_bar=True, cmap=custom_cmap)
    axes[0, 0].set_title("Deplete + 10NaCl — Movement", fontsize=11, fontweight='bold')
    
    # Deplete + 10NaCl — Angular Velocity
    heatmap_data_dep_10_av = get_heatmap_data(snips_angvel_smooth, x_array, "deplete", "10NaCl")
    make_heatmap(heatmap_data_dep_10_av, axes[0, 1], vlim_angvel, inf_bar=True, cmap='RdYlBu_r')
    axes[0, 1].set_title("Deplete + 10NaCl — Angular Velocity", fontsize=11, fontweight='bold')
    
    # Deplete + 45NaCl — Movement
    heatmap_data_dep_45_mov = get_heatmap_data(snips_movement_smooth, x_array, "deplete", "45NaCl")
    make_heatmap(heatmap_data_dep_45_mov, axes[1, 0], vlim, cmap=custom_cmap)
    axes[1, 0].set_title("Deplete + 45NaCl — Movement", fontsize=11, fontweight='bold')
    
    # Deplete + 45NaCl — Angular Velocity
    heatmap_data_dep_45_av = get_heatmap_data(snips_angvel_smooth, x_array, "deplete", "45NaCl")
    make_heatmap(heatmap_data_dep_45_av, axes[1, 1], vlim_angvel, cmap='RdYlBu_r')
    axes[1, 1].set_title("Deplete + 45NaCl — Angular Velocity", fontsize=11, fontweight='bold')
    
    fig.suptitle("Deplete Condition: Movement vs Angular Velocity", fontsize=13, fontweight='bold', y=1.00)
    
    if SAVE_FIGS:
        save_figure(fig, "fig1_compare_deplete_movement_vs_angvel", FIGSFOLDER)
    
    plt.tight_layout()
    plt.show()


In [None]:
if snips_angvel is not None:
    ### Comparison 3: Time Series Comparison
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Replete + 10NaCl
    snips_rep_10_av, snips_rep_45_av = get_mean_snips(snips_angvel_smooth, x_array, "replete")
    
    ax = axes[0, 0]
    ax.plot(snips_rep_10, linewidth=2, label='10NaCl', color=colors[0])
    ax.fill_between(range(len(snips_rep_10)), 
                    snips_rep_10 - np.std(snips_rep_10), 
                    snips_rep_10 + np.std(snips_rep_10),
                    alpha=0.2, color=colors[0])
    ax.set_title('Replete + 10NaCl — Movement', fontsize=11, fontweight='bold')
    ax.set_ylabel('Movement', fontsize=10)
    ax.grid(True, alpha=0.3)
    ax.legend()
    
    ax = axes[0, 1]
    ax.plot(snips_rep_10_av, linewidth=2, label='10NaCl', color=colors[0])
    ax.fill_between(range(len(snips_rep_10_av)), 
                    snips_rep_10_av - np.std(snips_rep_10_av), 
                    snips_rep_10_av + np.std(snips_rep_10_av),
                    alpha=0.2, color=colors[0])
    ax.set_title('Replete + 10NaCl — Angular Velocity', fontsize=11, fontweight='bold')
    ax.set_ylabel('Angular Velocity (deg/frame)', fontsize=10)
    ax.grid(True, alpha=0.3)
    ax.legend()
    
    # Deplete + 45NaCl
    snips_dep_10_av, snips_dep_45_av = get_mean_snips(snips_angvel_smooth, x_array, "deplete")
    
    ax = axes[1, 0]
    ax.plot(snips_dep_45, linewidth=2, label='45NaCl', color=colors[3])
    ax.fill_between(range(len(snips_dep_45)), 
                    snips_dep_45 - np.std(snips_dep_45), 
                    snips_dep_45 + np.std(snips_dep_45),
                    alpha=0.2, color=colors[3])
    ax.set_title('Deplete + 45NaCl — Movement', fontsize=11, fontweight='bold')
    ax.set_ylabel('Movement', fontsize=10)
    ax.set_xlabel('Time bins', fontsize=10)
    ax.grid(True, alpha=0.3)
    ax.legend()
    
    ax = axes[1, 1]
    ax.plot(snips_dep_45_av, linewidth=2, label='45NaCl', color=colors[3])
    ax.fill_between(range(len(snips_dep_45_av)), 
                    snips_dep_45_av - np.std(snips_dep_45_av), 
                    snips_dep_45_av + np.std(snips_dep_45_av),
                    alpha=0.2, color=colors[3])
    ax.set_title('Deplete + 45NaCl — Angular Velocity', fontsize=11, fontweight='bold')
    ax.set_ylabel('Angular Velocity (deg/frame)', fontsize=10)
    ax.set_xlabel('Time bins', fontsize=10)
    ax.grid(True, alpha=0.3)
    ax.legend()
    
    fig.suptitle("Time Series Comparison: Movement vs Angular Velocity", fontsize=13, fontweight='bold')
    
    if SAVE_FIGS:
        save_figure(fig, "fig1_timeseries_movement_vs_angvel", FIGSFOLDER)
    
    plt.tight_layout()
    plt.show()
