# Baysor Segmentation Diagnostics

This notebook provides comprehensive diagnostics on Baysor segmentation outputs including:
- Summary statistics of all segmentation runs
- Cell-level QC metrics (transcript counts, gene diversity)
- Spatial visualization of segmentation results
- Unassigned transcript analysis
- Comparison across multiple segmentation parameters

## 1. Setup and Configuration

In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import os
from pathlib import Path
import warnings

warnings.filterwarnings('ignore')

print("Libraries loaded successfully!")

In [None]:
# Set the root folder to search for Baysor outputs
output_root = "/Users/christoffer/Downloads/new_spinal_cord_data_CG"

# Verify the root exists
if os.path.isdir(output_root):
    print(f"✓ Output root directory exists: {output_root}")
else:
    print(f"✗ Output root directory not found: {output_root}")
    print("Please update the output_root variable above")

## 2. Find and Scan Baysor Output Directories

In [None]:
def find_segmentation_outputs(root_dir):
    """
    Recursively search for Baysor segmentation output folders.
    Returns a list of directories containing segmentation.csv
    """
    outputs = []
    
    for root, dirs, files in os.walk(root_dir):
        # Look for segmentation.csv which indicates a Baysor output
        if 'segmentation.csv' in files:
            outputs.append(root)
    
    return sorted(outputs)

# Find all segmentation outputs
outputs = find_segmentation_outputs(output_root)
print(f"Found {len(outputs)} segmentation output(s)\n")

for i, output_dir in enumerate(outputs, 1):
    print(f"{i}. {output_dir}")

## 3. Load and Summarize All Outputs

In [None]:
def load_segmentation_data(output_dir):
    """
    Load segmentation output files from a Baysor output directory.
    Returns a dictionary with segmentation data and metadata.
    """
    data = {}
    
    # Load main segmentation file
    seg_path = os.path.join(output_dir, 'segmentation.csv')
    if os.path.isfile(seg_path):
        data['segmentation'] = pd.read_csv(seg_path)
    
    # Load cell stats if available
    stats_path = os.path.join(output_dir, 'segmentation_cell_stats.csv')
    if os.path.isfile(stats_path):
        data['cell_stats'] = pd.read_csv(stats_path)
    
    # Load counts if available
    counts_path = os.path.join(output_dir, 'segmentation_counts.tsv')
    if os.path.isfile(counts_path):
        data['counts'] = pd.read_csv(counts_path, sep='\t')
    
    return data

# Load all output data
all_data = {}
for output_dir in outputs:
    all_data[output_dir] = load_segmentation_data(output_dir)

print(f"Loaded data for {len(all_data)} segmentation run(s)")

In [None]:
def summarize_outputs(all_data):
    """
    Create a summary dataframe of all Baysor outputs with key QC metrics.
    """
    rows = []
    
    for output_dir, data in all_data.items():
        summary = {
            'output_dir': output_dir,
            'output_name': os.path.basename(output_dir),
        }
        
        # Parse segmentation file
        if 'segmentation' in data:
            seg = data['segmentation']
            summary['n_transcripts'] = len(seg)
            
            # Find cell column (various naming conventions)
            cell_cols = [col for col in seg.columns if col.lower() in ['cell', 'cell_id', 'cellid', 'cell_index']]
            if cell_cols:
                cell_col = cell_cols[0]
                # Count cells and unassigned transcripts
                cell_ids = pd.to_numeric(seg[cell_col], errors='coerce')
                assigned = cell_ids[cell_ids > 0]
                summary['n_cells'] = assigned.nunique()
                summary['n_assigned_transcripts'] = len(assigned)
                summary['n_unassigned_transcripts'] = (cell_ids <= 0).sum()
                summary['assigned_fraction'] = len(assigned) / len(cell_ids) if len(cell_ids) > 0 else 0
            else:
                summary['n_cells'] = np.nan
                summary['n_assigned_transcripts'] = np.nan
                summary['n_unassigned_transcripts'] = np.nan
                summary['assigned_fraction'] = np.nan
        
        # Parse cell stats file
        if 'cell_stats' in data:
            stats = data['cell_stats']
            summary['cell_stats_n_cells'] = len(stats)
            
            # Find transcript count column
            count_cols = [col for col in stats.columns if col.lower() in ['n_transcripts', 'n_molecules', 'n_counts', 'n_genes']]
            if count_cols and count_cols[0] != 'n_genes':
                count_col = count_cols[0]
                summary['mean_transcripts_per_cell'] = stats[count_col].mean()
                summary['median_transcripts_per_cell'] = stats[count_col].median()
                summary['min_transcripts_per_cell'] = stats[count_col].min()
                summary['max_transcripts_per_cell'] = stats[count_col].max()
        
        # Check for counts file
        counts_path = os.path.join(output_dir, 'segmentation_counts.tsv')
        summary['has_counts_matrix'] = os.path.isfile(counts_path)
        if summary['has_counts_matrix']:
            file_size_mb = os.path.getsize(counts_path) / (1024**2)
            summary['counts_matrix_size_mb'] = round(file_size_mb, 2)
        
        rows.append(summary)
    
    return pd.DataFrame(rows)

summary_df = summarize_outputs(all_data)
print("\n=== SEGMENTATION OUTPUT SUMMARY ===")
display(summary_df)

## 4. Detailed Analysis - Select an Output to Inspect

In [None]:
# Select which output to analyze in detail (index 0 for first output)
selected_idx = 0

if len(outputs) > 0:
    selected_output = outputs[selected_idx]
    selected_data = all_data[selected_output]
    print(f"Selected output: {selected_output}")
else:
    print("No segmentation outputs found!")

In [None]:
# Load the segmentation data
if 'segmentation' in selected_data:
    seg_df = selected_data['segmentation']
    print(f"Segmentation data shape: {seg_df.shape}")
    print(f"\nColumns: {list(seg_df.columns)}")
    print(f"\nFirst 5 rows:")
    display(seg_df.head())
else:
    print("No segmentation data found")

## 5. Cell-Level Statistics

In [None]:
if 'segmentation' in selected_data:
    seg_df = selected_data['segmentation']
    
    # Find cell column
    cell_cols = [col for col in seg_df.columns if col.lower() in ['cell', 'cell_id', 'cellid', 'cell_index']]
    
    if cell_cols:
        cell_col = cell_cols[0]
        cell_ids = pd.to_numeric(seg_df[cell_col], errors='coerce')
        
        # Analyze assigned cells
        assigned_mask = cell_ids > 0
        assigned_ids = cell_ids[assigned_mask]
        
        print(f"Cell Assignment Summary:")
        print(f"  Total transcripts: {len(seg_df):,}")
        print(f"  Assigned transcripts: {assigned_mask.sum():,}")
        print(f"  Unassigned transcripts: {(~assigned_mask).sum():,}")
        print(f"  % Assigned: {100 * assigned_mask.sum() / len(seg_df):.1f}%")
        print(f"  Number of cells: {assigned_ids.nunique():,}")
        
        # Per-cell statistics
        per_cell_stats = seg_df[assigned_mask].groupby(cell_col).size().reset_index(name='n_transcripts')
        
        print(f"\n  Transcripts per cell:")
        print(f"    Mean: {per_cell_stats['n_transcripts'].mean():.1f}")
        print(f"    Median: {per_cell_stats['n_transcripts'].median():.1f}")
        print(f"    Min: {per_cell_stats['n_transcripts'].min()}")
        print(f"    Max: {per_cell_stats['n_transcripts'].max()}")
        print(f"    Std: {per_cell_stats['n_transcripts'].std():.1f}")

## 6. Spatial Distribution of Cells

In [None]:
if 'segmentation' in selected_data:
    seg_df = selected_data['segmentation']
    
    # Find spatial columns
    x_cols = [col for col in seg_df.columns if col.lower() in ['x', 'global_x', 'x_coord', 'x_position']]
    y_cols = [col for col in seg_df.columns if col.lower() in ['y', 'global_y', 'y_coord', 'y_position']]
    cell_cols = [col for col in seg_df.columns if col.lower() in ['cell', 'cell_id', 'cellid', 'cell_index']]
    
    if x_cols and y_cols and cell_cols:
        x_col, y_col, cell_col = x_cols[0], y_cols[0], cell_cols[0]
        
        # Create spatial plot
        plot_df = seg_df[[x_col, y_col, cell_col]].copy()
        plot_df['cell_assigned'] = pd.to_numeric(plot_df[cell_col], errors='coerce') > 0
        
        fig = px.scatter(
            plot_df,
            x=x_col,
            y=y_col,
            color='cell_assigned',
            color_discrete_map={True: 'blue', False: 'red'},
            opacity=0.6,
            labels={'cell_assigned': 'Assigned'},
            title='Spatial Distribution of Assigned vs Unassigned Transcripts',
            height=600
        )
        fig.update_layout(hovermode='closest')
        fig.show()
    else:
        print("Spatial columns (x, y, cell) not found in segmentation file")

## 7. Transcript Count Distribution per Cell

In [None]:
if 'segmentation' in selected_data:
    seg_df = selected_data['segmentation']
    cell_cols = [col for col in seg_df.columns if col.lower() in ['cell', 'cell_id', 'cellid', 'cell_index']]
    
    if cell_cols:
        cell_col = cell_cols[0]
        cell_ids = pd.to_numeric(seg_df[cell_col], errors='coerce')
        assigned_mask = cell_ids > 0
        
        # Get transcript counts per cell
        per_cell_counts = seg_df[assigned_mask].groupby(cell_col).size()
        
        # Create histogram
        fig = go.Figure()
        fig.add_trace(go.Histogram(
            x=per_cell_counts.values,
            nbinsx=50,
            name='Transcript count',
            marker_color='rgba(55, 83, 109, 0.7)'
        ))
        
        fig.update_layout(
            title='Distribution of Transcripts per Cell',
            xaxis_title='Transcripts per Cell',
            yaxis_title='Number of Cells',
            height=400,
            showlegend=False
        )
        fig.show()

## 8. Gene Distribution Analysis

In [None]:
if 'segmentation' in selected_data:
    seg_df = selected_data['segmentation']
    
    # Find gene column
    gene_cols = [col for col in seg_df.columns if col.lower() in ['gene', 'gene_name', 'gene_id']]
    
    if gene_cols:
        gene_col = gene_cols[0]
        
        # Get top genes
        gene_counts = seg_df[gene_col].value_counts().head(20)
        
        fig = go.Figure()
        fig.add_trace(go.Bar(
            y=gene_counts.index,
            x=gene_counts.values,
            orientation='h',
            marker_color='rgba(55, 83, 109, 0.7)'
        ))
        
        fig.update_layout(
            title='Top 20 Most Abundant Genes',
            xaxis_title='Transcript Count',
            yaxis_title='Gene',
            height=500,
            showlegend=False
        )
        fig.show()
        
        print(f"\nGene Statistics:")
        print(f"  Total unique genes: {seg_df[gene_col].nunique():,}")
        print(f"  Total transcripts: {len(seg_df):,}")
        print(f"  Mean transcripts per gene: {len(seg_df) / seg_df[gene_col].nunique():.1f}")
    else:
        print("Gene column not found")

## 9. Cell Statistics Summary (if available)

In [None]:
if 'cell_stats' in selected_data:
    cell_stats_df = selected_data['cell_stats']
    print(f"Cell statistics shape: {cell_stats_df.shape}")
    print(f"\nColumns: {list(cell_stats_df.columns)}")
    print(f"\nStatistics summary:")
    display(cell_stats_df.describe())
else:
    print("No cell statistics file found for this output")

## 10. Export Summary Report

In [None]:
# Save the summary dataframe to CSV
summary_path = os.path.join(output_root, 'baysor_outputs_summary.csv')
summary_df.to_csv(summary_path, index=False)
print(f"✓ Summary saved to: {summary_path}")

# Print summary
print(f"\n=== FINAL SUMMARY ===")
print(f"Total segmentation outputs found: {len(outputs)}")
print(f"\nKey metrics:")
if 'n_transcripts' in summary_df.columns and len(summary_df) > 0:
    print(f"  Total transcripts (all runs): {summary_df['n_transcripts'].sum():,}")
if 'n_cells' in summary_df.columns:
    print(f"  Total cells (all runs): {summary_df['n_cells'].sum():,.0f}")
if 'assigned_fraction' in summary_df.columns:
    print(f"  Average assignment rate: {summary_df['assigned_fraction'].mean():.1%}")