In [1]:
import ipywidgets as widgets
from IPython.display import display, clear_output
from skimage import io
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import json
from datetime import datetime

# =============================================================================
# LOAD DATA
# =============================================================================
print("Loading data...")
df = pd.read_csv('pvn_cell_analysis.csv')
masks = io.imread('/home/lq53/mir_repos/histology_for_minji/masks_cellpose.tif')
dapi_rgb = io.imread('PVN, MD, RE image61_R 91_Merged_ch00_SV.tif')
th_rgb = io.imread('PVN, MD, RE image61_R 91_Merged_ch01_SV.tif')
oxy_rgb = io.imread('PVN, MD, RE image61_R 91_Merged_ch02_SV.tif')
cfos_rgb = io.imread('PVN, MD, RE image61_R 91_Merged_ch03_SV.tif')
roi_img = io.imread('ch00-1_new.png')

dapi = dapi_rgb[..., 2]
th = th_rgb[..., 1]
oxy = oxy_rgb[..., 0]
cfos = cfos_rgb[..., 0]

# Downsampled for display
scale_display = 4
dapi_small = dapi[::scale_display, ::scale_display]
masks_small = masks[::scale_display, ::scale_display]

# =============================================================================
# EXTRACT ROI REGIONS
# =============================================================================
print("Extracting ROI regions...")
r, g, b = roi_img[..., 0], roi_img[..., 1], roi_img[..., 2]
high, low = 200, 100

mask_red = (r > high) & (g < low) & (b < low)
mask_green = (g > high) & (r < low) & (b < low)
mask_blue = (b > high) & (r < low) & (g < low)

print(f"  Red region:   {mask_red.sum():,} pixels")
print(f"  Green region: {mask_green.sum():,} pixels")
print(f"  Blue region:  {mask_blue.sum():,} pixels")

# =============================================================================
# ASSIGN CELLS TO REGIONS
# =============================================================================
print("Assigning cells to regions...")
scale_y = roi_img.shape[0] / dapi.shape[0]
scale_x = roi_img.shape[1] / dapi.shape[1]

def assign_region(row):
    y_roi = int(row['y'] * scale_y)
    x_roi = int(row['x'] * scale_x)
    if 0 <= y_roi < roi_img.shape[0] and 0 <= x_roi < roi_img.shape[1]:
        if mask_red[y_roi, x_roi]: return 'PVN'      # Red = PVN
        if mask_green[y_roi, x_roi]: return 'MD'     # Green = MD
        if mask_blue[y_roi, x_roi]: return 'RE'      # Blue = RE
    return 'outside'

df['region'] = df.apply(assign_region, axis=1)
print("Cells per region:")
print(df['region'].value_counts())

# =============================================================================
# COLORS AND STATE
# =============================================================================
MARKER_COLORS = {'TH': '#00FF00', 'Oxy': '#FF00FF', 'cFos': '#00FFFF'}
REGION_COLORS = {'PVN': '#FF4444', 'MD': '#44FF44', 'RE': '#4444FF', 'outside': '#888888', 'All': '#FFFFFF'}
REGION_ORDER = ['PVN', 'MD', 'RE']

current_state = {}

# =============================================================================
# HELPER FUNCTIONS
# =============================================================================
def get_column(marker, use_max):
    suffix = '_max' if use_max else '_mean'
    return f'{marker}{suffix}'

def show_threshold_examples(marker, thresh, n_examples, use_max, region_filter):
    """Show example cells near threshold boundary"""
    col = get_column(marker, use_max)
    img_dict = {'TH': th, 'Oxy': oxy, 'cFos': cfos}
    img = img_dict[marker]
    color = MARKER_COLORS[marker]
    
    # Filter by region
    if region_filter == 'All':
        df_filt = df[df['region'] != 'outside']
    else:
        df_filt = df[df['region'] == region_filter]
    
    if len(df_filt) == 0:
        print(f"No cells in {region_filter} region")
        return None
    
    n_per_row = n_examples // 2
    
    above = df_filt[df_filt[col] > thresh].nsmallest(n_per_row, col)
    below = df_filt[df_filt[col] <= thresh].nlargest(n_per_row, col)
    
    # Handle case where we don't have enough examples
    n_above = len(above)
    n_below = len(below)
    
    if n_above == 0 and n_below == 0:
        print(f"No cells near threshold in {region_filter}")
        return None
    
    fig, axes = plt.subplots(2, n_per_row, figsize=(2.5 * n_per_row, 5))
    metric_label = "max" if use_max else "mean"
    region_label = region_filter if region_filter != 'All' else 'All regions'
    fig.suptitle(f'{marker} threshold = {thresh:.0f} ({metric_label}) - {region_label}', fontsize=14)
    
    pad = 50
    
    # Positive examples
    for i in range(n_per_row):
        ax = axes[0, i]
        if i < n_above:
            cell = above.iloc[i]
            cy, cx = int(cell['y']), int(cell['x'])
            cell_id = int(cell['cell_id'])
            
            y1, y2 = max(0, cy - pad), min(img.shape[0], cy + pad)
            x1, x2 = max(0, cx - pad), min(img.shape[1], cx + pad)
            
            crop_img = img[y1:y2, x1:x2]
            crop_mask = masks[y1:y2, x1:x2]
            
            ax.imshow(crop_img, cmap='gray', vmin=0, vmax=150)
            cell_mask = (crop_mask == cell_id)
            if cell_mask.any():
                ax.contour(cell_mask, colors=[color], linewidths=2, levels=[0.5])
            
            ax.set_title(f"âœ“ {cell[col]:.0f}\n({cell['region']})", color='green', fontsize=9, fontweight='bold')
        else:
            ax.axis('off')
            ax.set_visible(False)
        ax.axis('off')
    axes[0, 0].set_ylabel('POSITIVE', fontsize=12, color='green', fontweight='bold')
    
    # Negative examples
    for i in range(n_per_row):
        ax = axes[1, i]
        if i < n_below:
            cell = below.iloc[i]
            cy, cx = int(cell['y']), int(cell['x'])
            cell_id = int(cell['cell_id'])
            
            y1, y2 = max(0, cy - pad), min(img.shape[0], cy + pad)
            x1, x2 = max(0, cx - pad), min(img.shape[1], cx + pad)
            
            crop_img = img[y1:y2, x1:x2]
            crop_mask = masks[y1:y2, x1:x2]
            
            ax.imshow(crop_img, cmap='gray', vmin=0, vmax=150)
            cell_mask = (crop_mask == cell_id)
            if cell_mask.any():
                ax.contour(cell_mask, colors=['gray'], linewidths=1.5, linestyles='--', levels=[0.5])
            
            ax.set_title(f"âœ— {cell[col]:.0f}\n({cell['region']})", color='gray', fontsize=9)
        else:
            ax.axis('off')
            ax.set_visible(False)
        ax.axis('off')
    axes[1, 0].set_ylabel('NEGATIVE', fontsize=12, color='gray')
    
    plt.tight_layout()
    return fig

# =============================================================================
# MAIN UPDATE FUNCTION
# =============================================================================
def update_display(th_thresh, oxy_thresh, cfos_thresh, show_marker, n_examples, 
                   th_use_max, oxy_use_max, cfos_use_max, region_filter):
    clear_output(wait=True)
    
    # Store current state
    current_state['th_thresh'] = th_thresh
    current_state['oxy_thresh'] = oxy_thresh
    current_state['cfos_thresh'] = cfos_thresh
    current_state['th_use_max'] = th_use_max
    current_state['oxy_use_max'] = oxy_use_max
    current_state['cfos_use_max'] = cfos_use_max
    
    # Get appropriate columns
    th_col = get_column('TH', th_use_max)
    oxy_col = get_column('Oxy', oxy_use_max)
    cfos_col = get_column('cFos', cfos_use_max)
    
    # Exclude 'outside' cells for analysis
    df_analysis = df[df['region'] != 'outside'].copy()
    
    # Apply thresholds
    df_analysis['TH_pos'] = df_analysis[th_col] > th_thresh
    df_analysis['Oxy_pos'] = df_analysis[oxy_col] > oxy_thresh
    df_analysis['cFos_pos'] = df_analysis[cfos_col] > cfos_thresh
    df_analysis['TH_cFos'] = df_analysis['TH_pos'] & df_analysis['cFos_pos']
    df_analysis['Oxy_cFos'] = df_analysis['Oxy_pos'] & df_analysis['cFos_pos']
    
    # =========================================================================
    # FIGURE 1: Histograms by region (3 markers x 3 regions + All)
    # =========================================================================
    fig1, axes1 = plt.subplots(3, 4, figsize=(18, 12))
    fig1.suptitle('Intensity Distributions by Region', fontsize=16, fontweight='bold')
    
    markers = ['TH', 'Oxy', 'cFos']
    cols = [th_col, oxy_col, cfos_col]
    threshs = [th_thresh, oxy_thresh, cfos_thresh]
    use_maxs = [th_use_max, oxy_use_max, cfos_use_max]
    
    for row, (marker, col, thresh, use_max) in enumerate(zip(markers, cols, threshs, use_maxs)):
        color = MARKER_COLORS[marker]
        metric = "max" if use_max else "mean"
        
        # Per-region histograms
        for col_idx, region in enumerate(REGION_ORDER):
            ax = axes1[row, col_idx]
            df_region = df_analysis[df_analysis['region'] == region]
            
            if len(df_region) > 0:
                ax.hist(df_region[col], bins=50, alpha=0.7, color=color, edgecolor='black')
                ax.axvline(thresh, color='white', lw=3)
                ax.axvline(thresh, color='black', lw=1.5, linestyle='--')
                
                n_pos = (df_region[col] > thresh).sum()
                pct = n_pos / len(df_region) * 100
                ax.set_title(f'{region}: {n_pos}/{len(df_region)} ({pct:.1f}%)', 
                           fontsize=11, color=REGION_COLORS[region])
            else:
                ax.set_title(f'{region}: 0 cells', fontsize=11, color='gray')
            
            ax.set_facecolor('#1a1a1a')
            ax.set_xlim(0, 255 if use_max else 150)
            
            if col_idx == 0:
                ax.set_ylabel(f'{marker} ({metric})', fontsize=12, fontweight='bold')
        
        # All regions combined
        ax = axes1[row, 3]
        ax.hist(df_analysis[col], bins=50, alpha=0.7, color=color, edgecolor='black')
        ax.axvline(thresh, color='white', lw=3)
        ax.axvline(thresh, color='black', lw=1.5, linestyle='--')
        
        n_pos = (df_analysis[col] > thresh).sum()
        pct = n_pos / len(df_analysis) * 100
        ax.set_title(f'ALL: {n_pos}/{len(df_analysis)} ({pct:.1f}%)', fontsize=11, fontweight='bold')
        ax.set_facecolor('#1a1a1a')
        ax.set_xlim(0, 255 if use_max else 150)
    
    plt.tight_layout()
    plt.show()
    
    # =========================================================================
    # FIGURE 2: Spatial maps - cells colored by positivity within each region
    # =========================================================================
    fig2, axes2 = plt.subplots(2, 3, figsize=(18, 12))
    fig2.suptitle('Spatial Distribution of Positive Cells', fontsize=16, fontweight='bold')
    
    # Row 1: Each marker separately
    for col_idx, (marker, col, thresh) in enumerate(zip(markers, cols, threshs)):
        ax = axes2[0, col_idx]
        ax.imshow(dapi_small, cmap='gray', vmin=0, vmax=100)
        
        pos_ids = set(df_analysis[df_analysis[col] > thresh]['cell_id'].astype(int))
        overlay = np.isin(masks_small, list(pos_ids))
        if overlay.any():
            ax.contour(overlay, colors=[MARKER_COLORS[marker]], linewidths=0.5, levels=[0.5])
        
        n_pos = len(pos_ids)
        ax.set_title(f'{marker}+ cells (n={n_pos})', color=MARKER_COLORS[marker], fontsize=14, fontweight='bold')
        ax.axis('off')
    
    # Row 2: Colocalization maps
    # TH+/cFos+
    ax = axes2[1, 0]
    ax.imshow(dapi_small, cmap='gray', vmin=0, vmax=100)
    coloc_ids = set(df_analysis[df_analysis['TH_cFos']]['cell_id'].astype(int))
    overlay = np.isin(masks_small, list(coloc_ids))
    if overlay.any():
        ax.contour(overlay, colors=['yellow'], linewidths=0.8, levels=[0.5])
    ax.set_title(f'TH+/cFos+ (n={len(coloc_ids)})', color='yellow', fontsize=14, fontweight='bold')
    ax.axis('off')
    
    # Oxy+/cFos+
    ax = axes2[1, 1]
    ax.imshow(dapi_small, cmap='gray', vmin=0, vmax=100)
    coloc_ids = set(df_analysis[df_analysis['Oxy_cFos']]['cell_id'].astype(int))
    overlay = np.isin(masks_small, list(coloc_ids))
    if overlay.any():
        ax.contour(overlay, colors=['orange'], linewidths=0.8, levels=[0.5])
    ax.set_title(f'Oxy+/cFos+ (n={len(coloc_ids)})', color='orange', fontsize=14, fontweight='bold')
    ax.axis('off')
    
    # Region overlay
    ax = axes2[1, 2]
    ax.imshow(dapi_small, cmap='gray', vmin=0, vmax=100)
    for region, color in [('PVN', 'red'), ('MD', 'lime'), ('RE', 'dodgerblue')]:
        region_ids = set(df_analysis[df_analysis['region'] == region]['cell_id'].astype(int))
        overlay = np.isin(masks_small, list(region_ids))
        if overlay.any():
            ax.contour(overlay, colors=[color], linewidths=0.3, levels=[0.5], alpha=0.7)
    ax.set_title('Regions: Red=PVN, Green=MD, Blue=RE', fontsize=12)
    ax.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # =========================================================================
    # FIGURE 3: Example cells near threshold
    # =========================================================================
    show_threshold_examples(show_marker, threshs[markers.index(show_marker)], 
                           n_examples, use_maxs[markers.index(show_marker)], region_filter)
    plt.show()
    
    # =========================================================================
    # SUMMARY TABLE
    # =========================================================================
    print("\n" + "="*80)
    print("SUMMARY BY REGION")
    print("="*80)
    
    # Header
    print(f"{'Region':<10} {'Total':>8} â”‚ {'TH+':>8} {'%':>6} â”‚ {'Oxy+':>8} {'%':>6} â”‚ {'cFos+':>8} {'%':>6} â”‚ {'TH/cFos':>8} {'Oxy/cFos':>8}")
    print("-"*80)
    
    for region in REGION_ORDER + ['ALL']:
        if region == 'ALL':
            df_r = df_analysis
        else:
            df_r = df_analysis[df_analysis['region'] == region]
        
        n_total = len(df_r)
        if n_total == 0:
            continue
            
        n_th = df_r['TH_pos'].sum()
        n_oxy = df_r['Oxy_pos'].sum()
        n_cfos = df_r['cFos_pos'].sum()
        n_th_cfos = df_r['TH_cFos'].sum()
        n_oxy_cfos = df_r['Oxy_cFos'].sum()
        
        pct_th = n_th / n_total * 100
        pct_oxy = n_oxy / n_total * 100
        pct_cfos = n_cfos / n_total * 100
        
        print(f"{region:<10} {n_total:>8} â”‚ {n_th:>8} {pct_th:>5.1f}% â”‚ {n_oxy:>8} {pct_oxy:>5.1f}% â”‚ {n_cfos:>8} {pct_cfos:>5.1f}% â”‚ {n_th_cfos:>8} {n_oxy_cfos:>8}")
    
    print("="*80)
    
    # Activation rates
    print("\nACTIVATION RATES (% of marker+ cells that are cFos+):")
    print("-"*60)
    for region in REGION_ORDER + ['ALL']:
        if region == 'ALL':
            df_r = df_analysis
        else:
            df_r = df_analysis[df_analysis['region'] == region]
        
        n_th = df_r['TH_pos'].sum()
        n_oxy = df_r['Oxy_pos'].sum()
        n_th_cfos = df_r['TH_cFos'].sum()
        n_oxy_cfos = df_r['Oxy_cFos'].sum()
        
        th_act = (n_th_cfos / n_th * 100) if n_th > 0 else 0
        oxy_act = (n_oxy_cfos / n_oxy * 100) if n_oxy > 0 else 0
        
        print(f"  {region:<8}: TH activation = {th_act:5.1f}% ({n_th_cfos}/{n_th}), Oxy activation = {oxy_act:5.1f}% ({n_oxy_cfos}/{n_oxy})")
    
    print("="*80)

# =============================================================================
# SAVE FUNCTION
# =============================================================================
def save_results(b):
    """Save thresholds and export final results"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    th_col = get_column('TH', current_state['th_use_max'])
    oxy_col = get_column('Oxy', current_state['oxy_use_max'])
    cfos_col = get_column('cFos', current_state['cfos_use_max'])
    
    # Apply thresholds to all data
    df_export = df.copy()
    df_export['TH_positive'] = df_export[th_col] > current_state['th_thresh']
    df_export['Oxy_positive'] = df_export[oxy_col] > current_state['oxy_thresh']
    df_export['cFos_positive'] = df_export[cfos_col] > current_state['cfos_thresh']
    df_export['TH_cFos'] = df_export['TH_positive'] & df_export['cFos_positive']
    df_export['Oxy_cFos'] = df_export['Oxy_positive'] & df_export['cFos_positive']
    df_export['TH_Oxy'] = df_export['TH_positive'] & df_export['Oxy_positive']
    df_export['triple'] = df_export['TH_positive'] & df_export['Oxy_positive'] & df_export['cFos_positive']
    
    # Exclude outside cells for summary
    df_analysis = df_export[df_export['region'] != 'outside']
    
    # Build results dict
    results_by_region = {}
    for region in REGION_ORDER + ['ALL']:
        if region == 'ALL':
            df_r = df_analysis
        else:
            df_r = df_analysis[df_analysis['region'] == region]
        
        results_by_region[region] = {
            'total_cells': len(df_r),
            'TH_positive': int(df_r['TH_positive'].sum()),
            'Oxy_positive': int(df_r['Oxy_positive'].sum()),
            'cFos_positive': int(df_r['cFos_positive'].sum()),
            'TH_cFos': int(df_r['TH_cFos'].sum()),
            'Oxy_cFos': int(df_r['Oxy_cFos'].sum()),
        }
    
    thresholds = {
        'timestamp': timestamp,
        'TH': {'threshold': current_state['th_thresh'], 'use_max': current_state['th_use_max']},
        'Oxy': {'threshold': current_state['oxy_thresh'], 'use_max': current_state['oxy_use_max']},
        'cFos': {'threshold': current_state['cfos_thresh'], 'use_max': current_state['cfos_use_max']},
        'results_by_region': results_by_region
    }
    
    # Save JSON
    json_file = f'thresholds_with_regions_{timestamp}.json'
    with open(json_file, 'w') as f:
        json.dump(thresholds, f, indent=2)
    
    # Save CSV
    csv_file = f'pvn_analysis_with_regions_{timestamp}.csv'
    df_export.to_csv(csv_file, index=False)
    
    # Create summary figure
    fig, axes = plt.subplots(3, 4, figsize=(20, 15))
    fig.suptitle(f'Final Analysis by Region - {timestamp}', fontsize=16, fontweight='bold')
    
    markers = ['TH', 'Oxy', 'cFos']
    cols = [th_col, oxy_col, cfos_col]
    threshs = [current_state['th_thresh'], current_state['oxy_thresh'], current_state['cfos_thresh']]
    
    for row, (marker, col, thresh) in enumerate(zip(markers, cols, threshs)):
        color = MARKER_COLORS[marker]
        
        for col_idx, region in enumerate(REGION_ORDER):
            ax = axes[row, col_idx]
            df_region = df_analysis[df_analysis['region'] == region]
            
            if len(df_region) > 0:
                ax.hist(df_region[col], bins=50, alpha=0.7, color=color, edgecolor='black')
                ax.axvline(thresh, color='red', lw=2, ls='--')
                
                n_pos = (df_region[col] > thresh).sum()
                pct = n_pos / len(df_region) * 100
                ax.set_title(f'{region} {marker}: {n_pos} ({pct:.1f}%)', fontsize=10)
            ax.set_facecolor('#f0f0f0')
        
        # All combined
        ax = axes[row, 3]
        ax.hist(df_analysis[col], bins=50, alpha=0.7, color=color, edgecolor='black')
        ax.axvline(thresh, color='red', lw=2, ls='--')
        n_pos = (df_analysis[col] > thresh).sum()
        pct = n_pos / len(df_analysis) * 100
        ax.set_title(f'ALL {marker}: {n_pos} ({pct:.1f}%)', fontsize=10, fontweight='bold')
        ax.set_facecolor('#f0f0f0')
    
    plt.tight_layout()
    
    fig_file = f'pvn_analysis_with_regions_{timestamp}.png'
    plt.savefig(fig_file, dpi=300, bbox_inches='tight')
    plt.show()
    
    print("\n" + "="*60)
    print("âœ“ SAVED FILES:")
    print(f"  Thresholds: {json_file}")
    print(f"  Data:       {csv_file}")
    print(f"  Figure:     {fig_file}")
    print("="*60)

# =============================================================================
# WIDGETS
# =============================================================================
th_slider = widgets.IntSlider(value=80, min=30, max=200, step=5, description='TH:')
oxy_slider = widgets.IntSlider(value=70, min=30, max=200, step=5, description='Oxy:')
cfos_slider = widgets.IntSlider(value=70, min=30, max=200, step=5, description='cFos:')

th_max_toggle = widgets.Checkbox(value=False, description='use max')
oxy_max_toggle = widgets.Checkbox(value=False, description='use max')
cfos_max_toggle = widgets.Checkbox(value=True, description='use max')

marker_dropdown = widgets.Dropdown(options=['TH', 'Oxy', 'cFos'], value='cFos', description='Examples:')
region_dropdown = widgets.Dropdown(options=['All', 'PVN', 'MD', 'RE'], value='All', description='Region:')
n_examples_slider = widgets.IntSlider(value=10, min=4, max=20, step=2, description='# shown:')

save_button = widgets.Button(description='ðŸ’¾ Save Results', button_style='success')
save_button.on_click(save_results)

out = widgets.interactive_output(update_display, {
    'th_thresh': th_slider, 'oxy_thresh': oxy_slider, 'cfos_thresh': cfos_slider,
    'show_marker': marker_dropdown, 'n_examples': n_examples_slider,
    'th_use_max': th_max_toggle, 'oxy_use_max': oxy_max_toggle, 'cfos_use_max': cfos_max_toggle,
    'region_filter': region_dropdown
})

# Layout
controls = widgets.VBox([
    widgets.HTML("<h3>ðŸ”¬ Threshold Tuning with Regions</h3>"),
    widgets.HTML("<b>Thresholds:</b>"),
    widgets.HBox([th_slider, th_max_toggle]),
    widgets.HBox([oxy_slider, oxy_max_toggle]),
    widgets.HBox([cfos_slider, cfos_max_toggle]),
    widgets.HTML("<hr>"),
    widgets.HTML("<b>Example cells:</b>"),
    marker_dropdown, 
    region_dropdown,
    n_examples_slider,
    widgets.HTML("<hr>"),
    save_button
])

print("\nâœ“ Ready! Run the cell below to display the interactive widget.")
print("="*60)

Loading data...
Extracting ROI regions...
  Red region:   12,012 pixels
  Green region: 6,791 pixels
  Blue region:  15,763 pixels
Assigning cells to regions...
Cells per region:
region
outside    19862
RE          2295
PVN         1632
MD           923
Name: count, dtype: int64



âœ“ Ready! Run the cell below to display the interactive widget.


In [2]:
display(widgets.HBox([controls, out]))

HBox(children=(VBox(children=(HTML(value='<h3>ðŸ”¬ Threshold Tuning with Regions</h3>'), HTML(value='<b>Thresholdâ€¦