# Paper Figures: Figure 4B - Dopamine AUC Aligned to Sigmoidal Transitions

This notebook generates extended Figure 4 showing dopamine (photometry) response aligned to the transition point where animals shift cluster membership. Uses sigmoidal transition points calculated in `src/assemble_all_data.py`.

**Figure 4B: Transition-Aligned Dopamine Response** — Photometry AUC heatmaps and time series for each rat, centered at their individual transition points.

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import dill

# Add src to path for importing local modules
sys.path.insert(0, str(Path("../src").resolve()))

from figure_config import (
    configure_matplotlib, COLORS, HEATMAP_CMAP,
    DATAFOLDER, RESULTSFOLDER, FIGSFOLDER,
    SAVE_FIGS
)
from figure_plotting import (
    save_figure, scale_vlim_to_data, make_heatmap
)

# Configure matplotlib
configure_matplotlib()
colors = COLORS  # Use shared color palette
custom_cmap = HEATMAP_CMAP  # Use shared colormap

## Load Assembled Data
Load the complete dataset with transition points and trial-aligned indices.

In [None]:
assembled_data_path = DATAFOLDER / "assembled_data.pickle"

with open(assembled_data_path, "rb") as f:
    data = dill.load(f)

# Extract main components
x_array = data["x_array"]
snips_photo = data["snips_photo"]
snips_behav = data["snips_behav"]
fits_df = data["fits_df"]
metadata = data.get("metadata", {})

print(f"Loaded assembled data from {assembled_data_path}")
print(f"\nData structure:")
print(f"  - x_array shape: {x_array.shape}")
print(f"  - x_array has trial_aligned column: {'trial_aligned' in x_array.columns}")
print(f"  - snips_photo shape: {snips_photo.shape}")
print(f"\nDeplete + 45NaCl subset:")
subset_full = x_array.query("condition == 'deplete' & infusiontype == '45NaCl'")
subset_aligned = subset_full.dropna(subset=['trial_aligned'])
print(f"  - Total trials: {len(subset_full)}")
print(f"  - Trials with valid alignment: {len(subset_aligned)}")
print(f"  - Number of unique animals: {subset_aligned['id'].nunique()}")

## Prepare Realigned Data

In [None]:
# Get the deplete + 45NaCl animals with valid transition fits
subset_aligned = (
    x_array
    .query("condition == 'deplete' & infusiontype == '45NaCl'")
    .dropna(subset=['trial_aligned'])
    .reset_index(drop=True)
    .sort_values(['id', 'trial'])
)

# Get animal info
animals = sorted(subset_aligned['id'].unique())
print(f"Animals with both transitions and deplete+45NaCl trials: {animals}")
print(f"Number of animals: {len(animals)}")

# Show summary statistics
print(f"\nTrial counts per animal:")
for animal in animals:
    n_trials = len(subset_aligned.query("id == @animal"))
    print(f"  {animal}: {n_trials} trials")

## Figure 4B: Transition-Aligned Heatmaps

In [None]:
# Create figure with one heatmap per animal
n_animals = len(animals)
n_cols = 4
n_rows = (n_animals + n_cols - 1) // n_cols

f, axes = plt.subplots(n_rows, n_cols, figsize=(12, 2.5 * n_rows))
axes = axes.flatten()

for idx, animal in enumerate(animals):
    animal_data = subset_aligned.query("id == @animal").sort_values('trial_aligned')
    animal_indices = animal_data.index.values
    
    if len(animal_indices) == 0:
        print(f"Warning: {animal} has no valid trials")
        continue
    
    # Extract snips for this animal sorted by trial_aligned
    heatmap_data = snips_photo[animal_indices, :]
    
    # Get vmin/vmax from data
    vlim = scale_vlim_to_data(heatmap_data, percentile=98)
    
    # Create heatmap
    ax = axes[idx]
    im = ax.imshow(heatmap_data, aspect='auto', cmap=custom_cmap, vmin=vlim[0], vmax=vlim[1])
    
    # Mark the transition line (trial_aligned = 0)
    transition_idx = np.where(animal_data['trial_aligned'].values >= 0)[0][0]
    ax.axhline(transition_idx - 0.5, color='white', linestyle='--', linewidth=1.5, alpha=0.8)
    
    ax.set_xlabel('Time (bins)', fontsize=9)
    ax.set_ylabel('Trial', fontsize=9)
    ax.set_title(f'{animal}\n(transition at trial index)', fontsize=10)
    ax.set_xticks([])
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('AUC', fontsize=8)

# Remove extra subplots
for idx in range(len(animals), len(axes)):
    axes[idx].remove()

plt.suptitle('Photometry AUC Aligned to Transition Points\n(Deplete + 45NaCl, white line = transition)', 
             fontsize=12, y=1.00)
plt.tight_layout()
if SAVE_FIGS:
    save_figure(f, "fig4b_transitions_aligned_heatmaps", FIGSFOLDER)
plt.show()

## Figure 4B-2: Summary Statistics - AUC Before/After Transition

In [None]:
# Calculate mean AUC before and after transition for each animal
auc_summary = []

for animal in animals:
    animal_data = subset_aligned.query("id == @animal")
    
    before_transition = animal_data[animal_data['trial_aligned'] < 0]
    after_transition = animal_data[animal_data['trial_aligned'] >= 0]
    
    auc_summary.append({
        'animal': animal,
        'auc_before': before_transition['auc_snips'].mean(),
        'auc_after': after_transition['auc_snips'].mean(),
        'n_before': len(before_transition),
        'n_after': len(after_transition),
    })

auc_df = pd.DataFrame(auc_summary)

print("\nPhotometry AUC Before/After Transition:")
print(auc_df.to_string(index=False))
print(f"\nMean AUC before: {auc_df['auc_before'].mean():.3f} ± {auc_df['auc_before'].std():.3f}")
print(f"Mean AUC after:  {auc_df['auc_after'].mean():.3f} ± {auc_df['auc_after'].std():.3f}")

## Figure 4B-3: Dopamine Change Across Transition

In [None]:
# Plot before/after AUC comparison
f, axes = plt.subplots(1, 2, figsize=(5, 3))

# Plot 1: Individual animal comparisons
ax = axes[0]
x_pos = np.arange(len(auc_df))
width = 0.35

bars1 = ax.bar(x_pos - width/2, auc_df['auc_before'], width, label='Before Transition',
                color=colors[0], alpha=0.7, edgecolor='k', linewidth=1.5)
bars2 = ax.bar(x_pos + width/2, auc_df['auc_after'], width, label='After Transition',
                color=colors[1], alpha=0.7, edgecolor='k', linewidth=1.5)

ax.set_xlabel('Animal', fontsize=10)
ax.set_ylabel('Mean Photometry AUC', fontsize=10)
ax.set_title('Photometry AUC Before/After Transition', fontsize=11)
ax.set_xticks(x_pos)
ax.set_xticklabels(auc_df['animal'], rotation=45, ha='right', fontsize=9)
ax.legend(fontsize=9)
sns.despine(ax=ax)

# Plot 2: Change in AUC
ax = axes[1]
auc_df['auc_change'] = auc_df['auc_after'] - auc_df['auc_before']
colors_change = [colors[1] if x > 0 else colors[0] for x in auc_df['auc_change']]

ax.bar(x_pos, auc_df['auc_change'], color=colors_change, alpha=0.7, edgecolor='k', linewidth=1.5)
ax.axhline(0, color='k', linestyle='-', linewidth=0.5, alpha=0.5)
ax.set_xlabel('Animal', fontsize=10)
ax.set_ylabel('AUC Change (After - Before)', fontsize=10)
ax.set_title('Dopamine Response Change at Transition', fontsize=11)
ax.set_xticks(x_pos)
ax.set_xticklabels(auc_df['animal'], rotation=45, ha='right', fontsize=9)
sns.despine(ax=ax)

plt.tight_layout()
if SAVE_FIGS:
    save_figure(f, "fig4b_auc_before_after_transition", FIGSFOLDER)
plt.show()

## Export Results

In [None]:
# Export AUC summary
auc_df.to_csv(RESULTSFOLDER / "transition_aligned_auc_summary.csv", index=False)
print(f"Exported AUC summary to {RESULTSFOLDER / 'transition_aligned_auc_summary.csv'}")

print(f"\nFigure 4B generation complete!")
print(f"Summary:")
print(f"  - {len(animals)} animals with valid transition fits")
print(f"  - {len(subset_aligned)} total trials used for alignment")
print(f"  - Mean AUC change: {auc_df['auc_change'].mean():.3f} ± {auc_df['auc_change'].std():.3f}")