## Setup and Imports


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import json
from pathlib import Path

from sgutils import (
    data_factory, SampleSet, size_filter, apply_filters,
    get_all_queries,
    plot_mixture_model, COLUMN_NAMES
)

# Set up plotting
plt.style.use('seaborn-v0_8')
sns.set_context('notebook')
plt.rcParams['figure.dpi'] = 150
%config InlineBackend.figure_format = 'retina'

# Analysis parameters
vaf_thres = 0.2
print(f"VAF threshold: {vaf_thres}")

# Get all lesion names
all_lesions = get_all_queries()
print(f"Found {len(all_lesions)} lesions: {all_lesions}")


## Function Definitions


In [None]:
def create_sample_set(lesion_name):
    """Create a SampleSet for a specific lesion."""
    full_data = data_factory("full")
    return SampleSet(full_data, lesion_name)

def get_lesion_data(sample_set, apply_size_filter=True, apply_vaf_filter=False, vaf_thres=0.2):
    """Get filtered data for a lesion."""
    data = sample_set.sample_df.copy()
    
    if apply_size_filter:
        data = size_filter(data)
    
    if apply_vaf_filter:
        data = apply_filters(data, vaf_thres=vaf_thres)
    
    return data

def plot_vaf_histogram(data, title, ax, bins=50, alpha=0.7):
    """Plot VAF histogram for a sample."""
    vafs = data[COLUMN_NAMES['vaf']]
    ax.hist(vafs, bins=bins, alpha=alpha, density=True, edgecolor='black', linewidth=0.5)
    ax.set_xlabel('VAF')
    ax.set_ylabel('Density')
    ax.set_title(title)
    ax.grid(True, alpha=0.3)
    
    # Add statistics
    mean_vaf = vafs.mean()
    median_vaf = vafs.median()
    n_variants = len(vafs)
    
    stats_text = f'n={n_variants}\nMean: {mean_vaf:.3f}\nMedian: {median_vaf:.3f}'
    ax.text(0.7, 0.8, stats_text, transform=ax.transAxes, 
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
            verticalalignment='top')

def plot_coverage_histogram(data, title, ax, bins=50, alpha=0.7):
    """Plot coverage histogram for a sample."""
    coverage = data[COLUMN_NAMES['total_depth']]
    ax.hist(coverage, bins=bins, alpha=alpha, edgecolor='black', linewidth=0.5)
    ax.set_xlabel('Total Depth')
    ax.set_ylabel('Count')
    ax.set_title(title)
    ax.grid(True, alpha=0.3)
    
    # Add statistics
    mean_cov = coverage.mean()
    median_cov = coverage.median()
    n_variants = len(coverage)
    
    stats_text = f'n={n_variants}\nMean: {mean_cov:.1f}x\nMedian: {median_cov:.1f}x'
    ax.text(0.7, 0.8, stats_text, transform=ax.transAxes, 
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
            verticalalignment='top')

def plot_vaf_with_model(data, model_params, title, ax):
    """Plot VAF histogram with truncated binomial mixture model overlay."""
    k = data[COLUMN_NAMES['alt_depth']].to_numpy()
    n = data[COLUMN_NAMES['total_depth']].to_numpy()
    
    # Use the truncated model from the model results
    truncated_model = model_params['truncated_model'].copy()
    truncated_model['variants_per_component'] = model_params['truncated_variants_per_component']
    
    # Plot using the existing function
    plot_mixture_model(k, n, truncated_model, ax=ax)
    
    # Update title with clonality information
    clonality_status = "Clonal" if model_params['is_clonal_truncated'] else "Non-clonal"
    dominant_vaf = model_params['dominant_vaf_truncated']
    full_title = f"{title}\n{clonality_status}, Dominant VAF: {dominant_vaf:.3f}"
    ax.set_title(full_title)

def load_model_results(vaf_thres):
    """Load model results for the given VAF threshold."""
    model_results = {}
    patients = ['FAP01', 'FAP03']
    
    for patient in patients:
        json_file = f"output/{patient}_model_results_vaf-{vaf_thres}.json"
        if Path(json_file).exists():
            with open(json_file, 'r') as f:
                patient_results = json.load(f)
                model_results.update(patient_results)
        else:
            print(f"Warning: {json_file} not found. Model overlays will be skipped.")
    
    return model_results

def get_samples_for_lesion(lesion_sample_set):
    """Get unique samples for a lesion, sorted alphabetically."""
    return sorted(lesion_sample_set.sample_df[COLUMN_NAMES['sample_id']].unique().tolist())


In [None]:
def plot_raw_vaf_for_sample(sample_name, ax, lesion_data):
    """Plot raw VAF histogram for a sample within a lesion."""
    sample_data = lesion_data[lesion_data[COLUMN_NAMES['sample_id']] == sample_name]
    if len(sample_data) > 0:
        plot_vaf_histogram(sample_data, sample_name, ax)
    else:
        ax.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax.transAxes)
        ax.set_title(sample_name)

def plot_filtered_vaf_for_sample(sample_name, ax, lesion_data, model_results=None):
    """Plot filtered VAF histogram for a sample with optional model overlay."""
    sample_data = lesion_data[lesion_data[COLUMN_NAMES['sample_id']] == sample_name]
    if len(sample_data) > 0 and model_results and sample_name in model_results:
        plot_vaf_with_model(sample_data, model_results[sample_name], sample_name, ax)
        # Color title based on clonality
        title_color = 'green' if model_results[sample_name]['is_clonal_truncated'] else 'red'
        ax.title.set_color(title_color)
    elif len(sample_data) > 0:
        plot_vaf_histogram(sample_data, sample_name, ax)
    else:
        ax.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax.transAxes)
        ax.set_title(sample_name)

def plot_coverage_for_sample(sample_name, ax, lesion_data, model_results=None):
    """Plot coverage histogram for a sample."""
    sample_data = lesion_data[lesion_data[COLUMN_NAMES['sample_id']] == sample_name]
    if len(sample_data) > 0:
        # Create base title
        title = sample_name
        title_color = 'black'
        
        # Add pass/fail annotation if model results available
        if model_results and sample_name in model_results:
            status = "Passed" if model_results[sample_name]['is_clonal_truncated'] else "Failed"
            title = f"{sample_name} - {status}"
            title_color = 'green' if status == "Passed" else 'red'
        
        plot_coverage_histogram(sample_data, title, ax)
        ax.set_title(title, color=title_color)
    else:
        ax.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax.transAxes)
        ax.set_title(sample_name)


In [None]:
def create_faceted_plot_for_lesion(samples, plot_func, lesion_name, save_filename=None, samples_per_row=5, figsize_per_subplot=(5, 4)):
    """Create a faceted plot for samples within a single lesion."""
    n_samples = len(samples)
    if n_samples == 0:
        print(f"No samples found for {lesion_name}")
        return
        
    n_rows = (n_samples + samples_per_row - 1) // samples_per_row
    n_cols = min(n_samples, samples_per_row)
    
    fig, axes = plt.subplots(n_rows, n_cols, 
                           figsize=(figsize_per_subplot[0]*n_cols, figsize_per_subplot[1]*n_rows))
    
    # Handle single subplot case
    if n_rows == 1 and n_cols == 1:
        axes = np.array([[axes]])
    elif n_rows == 1:
        axes = axes.reshape(1, -1)
    elif n_cols == 1:
        axes = axes.reshape(-1, 1)
    
    # Plot each sample
    for i, sample in enumerate(samples):
        row = i // samples_per_row
        col = i % samples_per_row
        plot_func(sample, axes[row, col])
    
    # Hide empty subplots
    for i in range(n_samples, n_rows * n_cols):
        row = i // samples_per_row
        col = i % samples_per_row
        axes[row, col].set_visible(False)
    
    # Use tight_layout first, then add suptitle with less vertical offset to reduce whitespace
    plt.tight_layout()
    plt.suptitle(lesion_name, fontsize=16, y=1.01)
    
    # Save plot if filename provided
    if save_filename:
        # Create figures directory if it doesn't exist
        Path("./figures/suppl").mkdir(parents=True, exist_ok=True)
        # Save with high DPI for publication quality
        fig.savefig(f"./figures/suppl/{save_filename}", dpi=300, bbox_inches='tight', 
                   facecolor='white', edgecolor='none')
        print(f"Saved plot: ./figures/suppl/{save_filename}")
        plt.close()
    else:
        plt.show()

def plot_lesion_raw_vaf_histograms(lesion_sample_sets, lesion_names=None, save_plots=False):
    """Plot raw VAF histograms for specified lesions."""
    lesions_to_plot = lesion_names if lesion_names is not None else list(lesion_sample_sets.keys())
    
    for lesion_name in lesions_to_plot:
        sample_set = lesion_sample_sets[lesion_name]
        lesion_data = get_lesion_data(sample_set, apply_size_filter=True, apply_vaf_filter=False)
        samples = get_samples_for_lesion(sample_set)
        
        def plot_raw_vaf(sample, ax):
            plot_raw_vaf_for_sample(sample, ax, lesion_data)
        
        save_filename = f"{lesion_name}_raw_vaf_histograms.png" if save_plots else None
        create_faceted_plot_for_lesion(samples, plot_raw_vaf, lesion_name, save_filename=save_filename)

def plot_lesion_filtered_vaf_histograms(lesion_sample_sets, model_results, vaf_thres, lesion_names=None, save_plots=False):
    """Plot filtered VAF histograms with models for specified lesions."""
    lesions_to_plot = lesion_names if lesion_names is not None else list(lesion_sample_sets.keys())
    
    for lesion_name in lesions_to_plot:
        sample_set = lesion_sample_sets[lesion_name]
        lesion_data = get_lesion_data(sample_set, apply_size_filter=True, apply_vaf_filter=True, vaf_thres=vaf_thres)
        samples = get_samples_for_lesion(sample_set)
        
        def plot_filtered_vaf(sample, ax):
            plot_filtered_vaf_for_sample(sample, ax, lesion_data, model_results)
        
        save_filename = f"{lesion_name}_filtered_vaf_histograms_vaf-{vaf_thres}.png" if save_plots else None
        create_faceted_plot_for_lesion(samples, plot_filtered_vaf, lesion_name, save_filename=save_filename)

def plot_lesion_raw_coverage_histograms(lesion_sample_sets, lesion_names=None, save_plots=False):
    """Plot raw coverage histograms for specified lesions."""
    lesions_to_plot = lesion_names if lesion_names is not None else list(lesion_sample_sets.keys())
    
    for lesion_name in lesions_to_plot:
        sample_set = lesion_sample_sets[lesion_name]
        lesion_data = get_lesion_data(sample_set, apply_size_filter=True, apply_vaf_filter=False)
        samples = get_samples_for_lesion(sample_set)
        
        def plot_raw_coverage(sample, ax):
            plot_coverage_for_sample(sample, ax, lesion_data)
        
        save_filename = f"{lesion_name}_raw_coverage_histograms.png" if save_plots else None
        create_faceted_plot_for_lesion(samples, plot_raw_coverage, lesion_name, save_filename=save_filename)

def plot_lesion_filtered_coverage_histograms(lesion_sample_sets, model_results, vaf_thres, lesion_names=None, save_plots=False):
    """Plot filtered coverage histograms for specified lesions."""
    lesions_to_plot = lesion_names if lesion_names is not None else list(lesion_sample_sets.keys())
    
    for lesion_name in lesions_to_plot:
        sample_set = lesion_sample_sets[lesion_name]
        lesion_data = get_lesion_data(sample_set, apply_size_filter=True, apply_vaf_filter=True, vaf_thres=vaf_thres)
        samples = get_samples_for_lesion(sample_set)
        
        def plot_filtered_coverage(sample, ax):
            plot_coverage_for_sample(sample, ax, lesion_data, model_results)
        
        save_filename = f"{lesion_name}_filtered_coverage_histograms_vaf-{vaf_thres}.png" if save_plots else None
        create_faceted_plot_for_lesion(samples, plot_filtered_coverage, lesion_name, save_filename=save_filename)


## Data Loading and Preparation


In [None]:
# Load model results
print("Loading model results...")
try:
    model_results = load_model_results(vaf_thres)
    print(f"Loaded model results for {len(model_results)} samples")
except Exception as e:
    print(f"Error loading model results: {e}")
    model_results = {}

# Create sample sets for each lesion
print("\nCreating sample sets for each lesion...")
lesion_sample_sets = {}
for lesion_name in tqdm(all_lesions, desc="Processing lesions"):
    try:
        sample_set = create_sample_set(lesion_name)
        # Only include lesions that have actual data
        if len(sample_set.sample_df) > 0:
            lesion_sample_sets[lesion_name] = sample_set
            samples = get_samples_for_lesion(sample_set)
            print(f"  {lesion_name}: {len(samples)} samples")
        else:
            print(f"  {lesion_name}: No data found, skipping")
    except Exception as e:
        print(f"  {lesion_name}: Error creating sample set - {e}")

print(f"\nCreated sample sets for {len(lesion_sample_sets)} lesions")
print(f"Available lesions: {list(lesion_sample_sets.keys())}")


In [None]:
# Generate raw VAF histograms for all lesions
# plot_lesion_raw_vaf_histograms(lesion_sample_sets)

# Or generate for specific lesions:
# plot_lesion_raw_vaf_histograms(lesion_sample_sets, ['FAP01_P1', 'FAP01_P2'])

# To save plots:
# plot_lesion_raw_vaf_histograms(lesion_sample_sets, save_plots=True)


In [None]:
# Generate filtered VAF histograms with models for all lesions
# plot_lesion_filtered_vaf_histograms(lesion_sample_sets, model_results, vaf_thres)

# Or generate for specific lesions:
# plot_lesion_filtered_vaf_histograms(lesion_sample_sets, model_results, vaf_thres, ['FAP01_P1', 'FAP01_P2'])

# To save plots:
# plot_lesion_filtered_vaf_histograms(lesion_sample_sets, model_results, vaf_thres, save_plots=True)


In [None]:
# Generate raw coverage histograms for all lesions
# plot_lesion_raw_coverage_histograms(lesion_sample_sets)

# Or generate for specific lesions:
# plot_lesion_raw_coverage_histograms(lesion_sample_sets, ['FAP01_P1', 'FAP01_P2'])

# To save plots:
# plot_lesion_raw_coverage_histograms(lesion_sample_sets, save_plots=True)


In [None]:
# Generate filtered coverage histograms for all lesions
# plot_lesion_filtered_coverage_histograms(lesion_sample_sets, model_results, vaf_thres)

# Or generate for specific lesions:
# plot_lesion_filtered_coverage_histograms(lesion_sample_sets, model_results, vaf_thres, ['FAP01_P1', 'FAP01_P2'])

# To save plots:
# plot_lesion_filtered_coverage_histograms(lesion_sample_sets, model_results, vaf_thres, save_plots=True)


## Save All Plots to Figures Directory

To save all plots to the `./figures/suppl/` directory, use the `save_plots=True` parameter. This will create individual plots for each lesion.


In [None]:
# Save all plots for all lesions
# Uncomment the lines below to save all plots:

print("Saving all plots...")
plot_lesion_raw_vaf_histograms(lesion_sample_sets, save_plots=True)
plot_lesion_filtered_vaf_histograms(lesion_sample_sets, model_results, vaf_thres, save_plots=True)
plot_lesion_raw_coverage_histograms(lesion_sample_sets, save_plots=True)
plot_lesion_filtered_coverage_histograms(lesion_sample_sets, model_results, vaf_thres, save_plots=True)
print("All plots saved!")

# Files will be saved as:
# ./figures/suppl/{lesion_name}_raw_vaf_histograms.png
# ./figures/suppl/{lesion_name}_filtered_vaf_histograms_vaf-0.2.png
# ./figures/suppl/{lesion_name}_raw_coverage_histograms.png
# ./figures/suppl/{lesion_name}_filtered_coverage_histograms_vaf-0.2.png
# 
# For example:
# ./figures/suppl/FAP01_P1_raw_vaf_histograms.png
# ./figures/suppl/FAP01_P1_filtered_vaf_histograms_vaf-0.2.png
# etc.
