## Setup

In [None]:
vaf_thres = 0.20

In [None]:
save_plots = False

In [None]:
from pathlib import Path
import re
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
from tqdm import tqdm
from Bio import AlignIO
from sgutils import (
    COLUMN_NAMES,
    data_factory, polyp_factory, SampleSet,
    get_all_queries,
    apply_filters, size_filter, coverage_filter,
    plot_tree, annotate_tree_with_targets,
    df_filter_targets, load_driver_genes,
    plot_cmpr_pars
)

In [None]:
from functools import partial
from itables import show as itables_show
ishow = partial(itables_show, column_filters = "footer")

In [None]:
import pyranges as pr
import requests

In [None]:
%config InlineBackend.figure_format = 'retina'

In [None]:
model_results = pd.concat(
    [
        pd.read_json(f"output/FAP01_model_results_vaf-{vaf_thres}.json"),
        pd.read_json(f"output/FAP03_model_results_vaf-{vaf_thres}.json"),
    ],
    axis=1,
)

In [None]:
def clone_mixture_filter(query):
    if query not in model_results.columns:
        raise ValueError(f"{query} not in model_results")
    if "_N" in query:
        return True
    else:
        return model_results[query].is_clonal_truncated

In [None]:
all_queries = get_all_queries()
tbl = data_factory("full")

In [None]:
# Load blacklist
blacklist = pd.read_table("./data/sgWGS_black_list.txt", header=None, names=["sample_id"])
blacklist_set = set(blacklist["sample_id"])

In [None]:
def get_tree(query):
    prefixes, col_prefixes, colors = polyp_factory(query, with_normal=True)
    store = SampleSet(apply_filters(tbl, vaf_thres=vaf_thres), prefixes)
    store.set_colordict(col_prefixes, colors)

    # Filter out blacklisted samples
    store.sample_df = store.sample_df.query('~Tumor_Sample_Barcode.isin(@blacklist_set)')

    # Unfiltered trees
    #store.set_alignment_type('binary')
    #tree = store.seq_nj_tree()
    #store.set_alignment_type('vaf')
    #tree = store.seq_nj_tree()

    ids = store.sample_df.Tumor_Sample_Barcode.unique()
    keep = ids[np.array([clone_mixture_filter(id) for id in ids])]
    store.sample_df = store.sample_df.query('Tumor_Sample_Barcode.isin(@keep)')
    store.reset_alignments()

    # Filtered trees
    #store.set_alignment_type('binary')
    #tree = store.seq_nj_tree()
    store.set_alignment_type('vaf')
    tree = store.seq_nj_tree()
    return tree, store

## Default Trees

In [None]:
stores = {}

### FAP03_P2

In [None]:
query = "FAP03_P2"
tree, store = get_tree(query)
stores[query] = store

In [None]:
fig, ax = plt.subplots(
    1, 1, figsize=(5, (int(np.floor(store.get_dims()[-1] / 3)) + 2)))

plot_tree(tree, store.colordict, ax)
annotate_tree_with_targets(tree, ax, store.sample_df)
if save_plots:
    plt.savefig(f"figures/{query}_tree_vaf-{vaf_thres}.pdf")
    for record in store.alignment:
        record.description = ""
    with open(f"output/{query}_align_vaf-{vaf_thres}.fasta", "w") as output_handle:
        AlignIO.write(store.alignment, output_handle, "fasta")

### FAP01_T3

In [None]:
query = "FAP01_T3"
tree, store = get_tree(query)
stores[query] = store

In [None]:
fig, ax = plt.subplots(
    1, 1, figsize=(9, (int(np.floor(store.get_dims()[-1] / 3)) + 3)))

plot_tree(tree, store.colordict, ax)
annotate_tree_with_targets(tree, ax, store.sample_df)
if save_plots:
    plt.savefig(f"figures/{query}_tree_vaf-{vaf_thres}.pdf")
    for record in store.alignment:
        record.description = ""
    with open(f"output/{query}_align_vaf-{vaf_thres}.fasta", "w") as output_handle:
        AlignIO.write(store.alignment, output_handle, "fasta")

### Remaining samples that passed filters

In [None]:
filtered = (
    model_results.T.reset_index()
    .query('~index.str.contains("_N")')
    .assign(
        sample_name=lambda df: df['index'].str.split('_').str[:2].str.join('_'),
        blacklisted=lambda df: df['index'].isin(blacklist_set),
        clonal_and_not_blacklisted=lambda df: df['is_clonal_truncated'] & ~df['blacklisted']
    )
    .groupby('sample_name')
    .agg(
        total_samples=('index', 'size'),
        clonal_samples=('clonal_and_not_blacklisted', 'sum')
    )
    .assign(
        clonal_fraction=lambda df: (df['clonal_samples'] / df['total_samples']).astype(float).round(2)
    )
    .query('clonal_fraction >= 0.5')
    .query('clonal_samples >= 5')
    .query('~sample_name.isin(["FAP03_P2", "FAP01_T3"])')
)
filtered

In [None]:
for query in tqdm(filtered.index):
    tree, store = get_tree(query)
#    tree.ladderize(reverse=True)
    stores[query] = store
    fig, ax = plt.subplots(
    1, 1, figsize=(7, (int(np.floor(store.get_dims()[-1] / 3)) + 2)))

    plot_tree(tree, store.colordict, ax)
    annotate_tree_with_targets(tree, ax, store.sample_df)
    ax.set_title(query)
    if save_plots:
        plt.savefig(f"figures/{query}_tree_vaf-{vaf_thres}.pdf")
        for record in store.alignment:
            record.description = ""
        with open(f"output/{query}_align_vaf-{vaf_thres}.fasta", "w") as output_handle:
            AlignIO.write(store.alignment, output_handle, "fasta")

### Summary

In [None]:
# Create detailed sample-level data
sample_data = (
    model_results.T.reset_index()
    .query('~index.str.contains("_N")')
    .assign(
        sample_name=lambda df: df['index'].str.split('_').str[:2].str.join('_'),
        blacklisted=lambda df: df['index'].isin(blacklist_set),
        sample_group=lambda df: df['index'].str.split('_').str[:2].str.join('_').map(lambda x: 
            'FAP03_P2' if x == 'FAP03_P2' else
            'FAP01_T3' if x == 'FAP01_T3' else
            'Filtered' if x in filtered.index else
            'Other'
        )
    )
)

# Basic overall summary
total_samples = len(sample_data)
clonal_samples = sample_data['is_clonal_truncated'].sum()
non_clonal_samples = total_samples - clonal_samples
blacklisted_samples = sample_data['blacklisted'].sum()

print("Overall summary:")
print(f"Total samples: {total_samples}")
print(f"Clonal samples: {clonal_samples}")
print(f"Non-clonal samples: {non_clonal_samples}")
print(f"Blacklisted samples: {blacklisted_samples}")

# Create cross-tabulation of clonal vs blacklisted status
sample_data['status'] = (
    sample_data['is_clonal_truncated'].map(lambda x: 'clonal' if x else 'not_clonal') + 
    ' & ' + 
    sample_data['blacklisted'].map(lambda x: 'blacklisted' if x else 'not_blacklisted')
)

print("\nCross-tabulation (clonal vs blacklisted):")
overall_crosstab = sample_data['status'].value_counts().sort_index()
display(overall_crosstab)

print("\nBreakdown by sample group:")
group_crosstab = sample_data.groupby('sample_group')['status'].value_counts().unstack(fill_value=0)
display(group_crosstab)

## Parsimony comparison for supplement


In [None]:
if save_plots:
    for query, store in stores.items():
        print(query)
        plot_cmpr_pars(store)
        plt.savefig(f"figures/suppl/{query}_cmpr_pars.pdf")
        plt.show()

## Investigating Rescued Drivers and CNVs

In [None]:
def get_gene_coordinates(hugo_symbol):
    """Get chromosome, start, and end coordinates for a gene"""
    server = "https://rest.ensembl.org"
    ext = f"/lookup/symbol/human/{hugo_symbol}?content-type=application/json"
    
    response = requests.get(server + ext)
    
    if response.status_code == 200:
        data = response.json()
        return (
            "chr" + data.get('seq_region_name'),
            data.get('start'),
            data.get('end')
        )
    return None

def get_gene_coords_df(genes):
    """One-liner version that returns DataFrame"""
    return pd.DataFrame([
        {'gene': g, **dict(zip(['Chromosome', 'Start', 'End'], 
                              get_gene_coordinates(g) or (None, None, None)))}
        for g in genes
    ])

In [None]:
gene_locations = get_gene_coords_df(load_driver_genes().to_list())

In [None]:
fn_fap1 = "./data/v2/FAP01.all_maf_v2.txt"
fn_fap3 = "./data/v2/FAP03.all_maf_v2.txt"

fn_fap1_cnv = "./data/v2/FAP01_sgWGS_cnv_v2.bed"
fn_fap3_cnv = "./data/v2/FAP03_sgWGS_cnv_v2.bed"

v2_maf = pd.concat(
    (
        pd.read_table(fn_fap3, low_memory=False),
        pd.read_table(fn_fap1, low_memory=False),
    )
).reset_index(drop=True)

v2_cnv = pd.concat(
    (
        pd.read_table(fn_fap3_cnv),
        pd.read_table(fn_fap1_cnv),
    )
).reset_index(drop=True)

In [None]:
gene_pr = pr.PyRanges(gene_locations)

def combine_cnv_mutations(custom_annot):
    return (
        custom_annot
        .assign(gene=lambda df: df.mut_str.str.split(":").str[0])
        .assign(mut=lambda df: df.mut_str.str.split(":").str[1])
        .groupby(["Tumor_Sample_Barcode", "mut", "source"])
        .apply(lambda group: ",".join(group.gene))
        .reset_index(name="genes")
        .assign(mut_str=lambda df: df.genes + ":" + df.mut)
    )

def get_cnv_overlaps(v2_cnv, samples, gene_pr):
    cnv_pr = pr.PyRanges(v2_cnv
        .query("id.isin(@samples)")
        .assign(Chromosome=lambda df: "chr" + df.chromosome.astype(str))
        .rename(columns={'start.pos': 'Start', 'end.pos': 'End'})
    )

    cnv_gene_overlaps = cnv_pr.join(gene_pr)
    return cnv_gene_overlaps

def get_cnv_annot(v2_cnv, samples, gene_pr):
    cnv_gene_overlaps = get_cnv_overlaps(v2_cnv, samples, gene_pr)

    cnv_annot = pd.concat((
        (
            cnv_gene_overlaps.as_df().query("(`CNt.adj` == 2) & (`A.adj` != 1) & (`Bf` <= 0.2)")
            .rename(columns={"id": "Tumor_Sample_Barcode"})
            .assign(mut_str=lambda df: df['gene'] + ":LOH")
            .assign(source="cnv")
            [['Tumor_Sample_Barcode', 'mut_str', 'source']]
        ),
        (
            cnv_gene_overlaps.as_df().query("(`CNt.adj` < 2)")
            .rename(columns={"id": "Tumor_Sample_Barcode"})
            .assign(mut_str=lambda df: df['gene'] + ":DEL")
            .assign(source="cnv")
            [['Tumor_Sample_Barcode', 'mut_str', 'source']]
        ),
        (
            cnv_gene_overlaps.as_df().query("(`CNt.adj` > 2)")
            .rename(columns={"id": "Tumor_Sample_Barcode"})
            .assign(mut_str=lambda df: df['gene'] + ":AMP")
            .assign(source="cnv")
            [['Tumor_Sample_Barcode', 'mut_str', 'source']]
        ),
    ))
    return cnv_annot

In [None]:
def get_longer_mutations_annot(samples):
    genes = load_driver_genes()

    missing = (
        tbl
        .query("Tumor_Sample_Barcode.isin(@samples)")
        .query("Hugo_Symbol.isin(@genes)")
        .assign(ref_len=lambda df: df[COLUMN_NAMES['ref_allele']].str.len())
        .assign(alt_len=lambda df: df[COLUMN_NAMES['alt_allele']].str.len())
        .query('(ref_len != 1) | (alt_len != 1)')
        .query(f'{COLUMN_NAMES["variant_class"]}.str.contains("Mutation") | {COLUMN_NAMES["variant_class"]}.str.contains("Frame_Shift")')
        [['Tumor_Sample_Barcode', 'Hugo_Symbol', 'ref_len', 'alt_len', COLUMN_NAMES['aa_change'], COLUMN_NAMES['variant_class'], 'VAF', COLUMN_NAMES['total_depth']]]
        .assign(mut_str=lambda df: df['Hugo_Symbol'] +":"+ df[COLUMN_NAMES['aa_change']])
        .assign(source='longer')
        [['Tumor_Sample_Barcode', 'mut_str', 'source']]

    )

    if len(missing) > 0:
        return missing
    else:
        return None

In [None]:
def filter_v2_tbl(v2_tbl, store_tbl):
    key_cols = ['Chromosome', 'Start_Position', 'End_Position', 'Hugo_Symbol']

    v2_tbl_copy = v2_tbl.copy()
    store_tbl_copy = store_tbl.copy()

    v2_tbl_copy['location_key'] = v2_tbl_copy[key_cols].apply(lambda x: tuple(x), axis=1)
    store_tbl_copy['location_key'] = store_tbl_copy[key_cols].apply(lambda x: tuple(x), axis=1)

    store_tbl_locations = set(store_tbl_copy['location_key'])
    store_tbl_full_keys = set(store_tbl_copy.apply(lambda x: (x['location_key'], x['Tumor_Sample_Barcode']), axis=1))
    mask = (
        v2_tbl_copy['location_key'].isin(store_tbl_locations) &  # Location exists in store_tbl
        ~v2_tbl_copy.apply(lambda x: (x['location_key'], x['Tumor_Sample_Barcode']), axis=1).isin(store_tbl_full_keys)  # But full combo doesn't
    )

    result = v2_tbl_copy[mask].copy()
    store_tbl_mut_str_map = store_tbl_copy.groupby('location_key')['mut_str'].first().to_dict()
    result['mut_str'] = result['location_key'].map(store_tbl_mut_str_map)
    result = result.drop('location_key', axis=1)
    return result


def get_v2_muts(v2_maf, annot_df, samples, genes):
    v2_tbl = (
        v2_maf.query("Tumor_Sample_Barcode.isin(@samples)")
        .query("Hugo_Symbol.isin(@genes)")[
            [
                "Chromosome",
                "Start_Position",
                "End_Position",
                "Hugo_Symbol",
                "aaChange",
                "Tumor_Sample_Barcode",
                "Total_allele_depth",
                "VAF",
            ]
        ]
        .sort_values(by=["Hugo_Symbol", "Total_allele_depth"], ascending=False)
    )

    store_tbl = annot_df[
        [
            "Chromosome",
            "Start_Position",
            "End_Position",
            COLUMN_NAMES["sample_id"],
            COLUMN_NAMES["gene"],
            COLUMN_NAMES["aa_change"],
            COLUMN_NAMES["total_depth"],
            "VAF",
            "mut_str",
        ]
    ].sort_values([COLUMN_NAMES["gene"], COLUMN_NAMES["total_depth"]], ascending=False)
    v2_tbl_filtered = filter_v2_tbl(v2_tbl, store_tbl)
    return v2_tbl_filtered

def add_custom_annot(tree, ax, annot_df, fontsize=8):
    source_colors = {
        'cnv': 'black',
        'store': 'black',
        'v2': 'gray',
        'longer': 'black',
    }

    def combine_mutations(group):
        return list(zip(group['mut_str'], group['source']))
    
    annot_dict = annot_df.groupby(COLUMN_NAMES['sample_id']).apply(combine_mutations).to_dict()
    
    text_labels = [text for text in ax.get_children() if isinstance(text, plt.Text)]
    for text in text_labels:
        leaf_label = text.get_text().strip()
        if leaf_label in annot_dict:
            x, y = text.get_position()
            x = x - tree.find_any(name=leaf_label).branch_length
            pixel_x, pixel_y = ax.transData.transform((x, y))
            pixel_x_offset = pixel_x + 5
            
            # Get list of (mutation, source) tuples for this sample
            mut_source_pairs = annot_dict[leaf_label]
            
            # Plot each line with appropriate color based on its source
            for i, (mut_str, source) in enumerate(mut_source_pairs):
                # Get color for this source, default to black if not found
                color = source_colors.get(source, 'black')
                
                # Convert data coordinates to pixel coordinates
                pixel_y_offset = pixel_y + 2 + (i * 15)
                x_offset, y_offset = ax.transData.inverted().transform((pixel_x_offset, pixel_y_offset))
                
                ax.text(x_offset, y_offset, mut_str, fontsize=fontsize, 
                       verticalalignment='bottom', color=color)

### Plotting Trees with additional Drivers

In [None]:
dim_dict = {
    'FAP03_P2': (5, 12),
    'FAP01_T3': (8, 12),
    'FAP01_P6': (8, 8),
    'FAP03_P1': (8, 5),
}

all_annots = []
all_samples = []
for query, store in stores.items():
    annot_df = df_filter_targets(store.sample_df)
    genes = annot_df.Hugo_Symbol.unique()
    samples = store.sample_df.Tumor_Sample_Barcode.unique()

    tree = store.seq_nj_tree()
    samples = [term.name for term in tree.get_terminals() if term.name in samples]

    v2_annot = get_v2_muts(v2_maf, annot_df, samples, genes)

    custom_annot = pd.concat((
        annot_df[['Tumor_Sample_Barcode', 'mut_str']].assign(source='store'),
        v2_annot[['Tumor_Sample_Barcode', 'mut_str']].assign(source='v2'),
        get_cnv_annot(v2_cnv, samples, gene_pr)
    ))

    longer_annot = get_longer_mutations_annot(samples)
    if longer_annot is not None:
        custom_annot = pd.concat((custom_annot, longer_annot))

    all_annots.append(custom_annot)
    all_samples.extend([sample for sample in samples if not "_N" in sample])

    fig, ax = plt.subplots(1, 1, figsize=dim_dict[query])

    plot_tree(tree, store.colordict, ax)
    for label in ax.texts:
        label.set_fontsize(16)
#    add_custom_annot(tree, ax, custom_annot)
    add_custom_annot(tree, ax, combine_cnv_mutations(custom_annot))
    
    if save_plots:
        plt.savefig(f"figures/{query}_tree_vaf-{vaf_thres}_full-annot.pdf", bbox_inches='tight')

#### More on driver rescue

In [None]:
query = "FAP01_T3"
#query = "FAP03_P1"
store = stores[query]

annot_df = df_filter_targets(store.sample_df)
genes = annot_df.Hugo_Symbol.unique()
samples = store.sample_df.query("~Tumor_Sample_Barcode.str.contains('_N')").Tumor_Sample_Barcode.unique()

v2_annot = get_v2_muts(v2_maf, annot_df, samples, genes)


In [None]:
drivers = (
    pd.concat(
        (
            annot_df[
                [
                    COLUMN_NAMES["sample_id"],
                    COLUMN_NAMES["total_depth"],
                    COLUMN_NAMES["gene"],
                    "mut_str",
                ]
            ]
            .assign(source="store")
            .rename(columns={COLUMN_NAMES["total_depth"]: "Total_allele_depth"}),
            v2_annot[
                ["Tumor_Sample_Barcode", "Total_allele_depth", "Hugo_Symbol", "mut_str"]
            ].assign(source="v2"),
        )
    )
    .sort_values(by=["Hugo_Symbol", "Total_allele_depth"], ascending=False)
    .query("Hugo_Symbol != 'CTNNB1'")
)

bins = 50
alpha = 0.7

# Calculate grid dimensions
n_samples = len(samples)
n_cols = min(4, n_samples)
n_rows = (n_samples + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))
if n_samples == 1:
    axes = [axes]
elif n_rows == 1:
    axes = axes.reshape(1, -1)

# Flatten axes for easier iteration
axes_flat = axes.flatten() if n_samples > 1 else axes

# Create color map for Hugo_Symbols, but force APC to red and SMAD4 to orange
unique_genes = drivers['Hugo_Symbol'].unique()
gene_colors = {}
for gene in unique_genes:
    if gene == "APC":
        gene_colors[gene] = "red"
    elif gene == "SMAD4":
        gene_colors[gene] = "orange"
    else:
        # Use Set1 colormap for other genes, skipping first two colors (red, orange)
        # so as not to duplicate APC/SMAD4
        # Find the index of this gene among the non-APC/SMAD4 genes
        other_genes = [g for g in unique_genes if g not in ("APC", "SMAD4")]
        color_idx = other_genes.index(gene) if gene in other_genes else 0
        # Set1 has 9 colors, skip 0 (red) and 1 (orange)
        cmap = plt.cm.Set1
        # Use color_idx+2 to skip red/orange, wrap around if needed
        color = cmap((color_idx + 2) / max(3, len(other_genes) + 2))
        gene_colors[gene] = color

for i, sample_name in enumerate(sorted(samples)):
    ax = axes_flat[i]
    
    # Subset data for this sample
    data = store.sample_df.query("Tumor_Sample_Barcode == @sample_name")
    
    title = sample_name
    title_color = 'black'

    coverage = data[COLUMN_NAMES['total_depth']]
    ax.hist(coverage, bins=bins, alpha=alpha, edgecolor='black', linewidth=0.5)
    ax.set_xlabel('Total Depth')
    ax.set_ylabel('Count')
    ax.set_title(title, color=title_color)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(left=0)

    # Add vertical lines for driver mutations in this sample
    sample_drivers = drivers.query("Tumor_Sample_Barcode == @sample_name")
    for _, row in sample_drivers.iterrows():
        gene = row['Hugo_Symbol']
        depth = row['Total_allele_depth']
        color = gene_colors[gene]
        ax.axvline(x=depth, color=color, linestyle='--', linewidth=2, alpha=0.8, label=gene)

    # Add statistics
    mean_cov = coverage.mean()
    median_cov = coverage.median()
    n_variants = len(coverage)

    stats_text = f'n={n_variants}\nMean: {mean_cov:.1f}x\nMedian: {median_cov:.1f}x'
    ax.text(0.7, 0.8, stats_text, transform=ax.transAxes, 
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
            verticalalignment='top')

    # Add legend if there are driver mutations for this sample
    if len(sample_drivers) > 0:
        ax.legend(loc='upper right', bbox_to_anchor=(0.95, 0.5))

# Hide unused subplots
for i in range(n_samples, len(axes_flat)):
    axes_flat[i].set_visible(False)

plt.tight_layout()

if save_plots:
    plt.savefig(f"figures/{query}_driver_rescue_coverage.pdf", bbox_inches='tight')

In [None]:
drivers.query("Hugo_Symbol == 'APC'").sort_values('Tumor_Sample_Barcode')

In [None]:
apc_depth_path = Path("./data/apc_depth/")

apc_depth_files = [fn for fn in (sorted(apc_depth_path.glob("*.txt"))) if fn.name.startswith(f"{query}_R") and fn.stem.replace(".apc_depth", "") in samples]
apc_depth_files

In [None]:
apc_var = (
    annot_df.query("Hugo_Symbol == 'APC'")[
        ["Start_Position", "End_Position", "Variant_Type"]
    ]
    .drop_duplicates()
)
assert len(apc_var) == 1
apc_var = apc_var.squeeze()
if apc_var.Variant_Type == "DEL":
    apc_loc = apc_var.Start_Position - 1
else:
    apc_loc = apc_var.Start_Position

# Subset drivers to only APC
apc_drivers = drivers.query("Hugo_Symbol == 'APC'")

# Set up for two columns
n_files = len(apc_depth_files)
n_cols = 2
n_rows = (n_files + n_cols - 1) // n_cols  # ceiling division

fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, 2 * n_rows), sharex=True)
axes = np.array(axes).reshape(-1, n_cols)  # Ensure axes is always 2D

for i, fn in enumerate(apc_depth_files):
    depth_df = pd.read_table(fn, names=["chr", "pos", "depth"])
    stride = max(1, len(depth_df) // 10000)  # target ~10,000 points max
    plot_df = depth_df.iloc[::stride, :]

    # Remove the .apc_depth suffix to get the sample name
    sample_name = fn.stem.replace(".apc_depth", "")

    # Grab the coverage at the APC locus, if present
    apc_cov_row = depth_df[depth_df["pos"] == apc_loc]
    if not apc_cov_row.empty:
        apc_cov = apc_cov_row["depth"].iloc[0]
        apc_cov_str = f"APC@{apc_loc}: {apc_cov}x"
    else:
        apc_cov = None
        apc_cov_str = f"APC@{apc_loc}: N/A"

    # Check if sample is in APC drivers
    driver_row = apc_drivers[apc_drivers["Tumor_Sample_Barcode"] == sample_name]
    if not driver_row.empty:
        # Get source if present
        source = driver_row["source"].iloc[0] if "source" in driver_row.columns else None
        if source == "store":
            title_color = "green"
        elif source == "v2":
            title_color = "orange"
        else:
            title_color = "gray"
        title_text = f"{sample_name} ({apc_cov_str})"
    else:
        # Not in drivers
        title_color = "red"
        title_text = f"{sample_name} (absent, {apc_cov_str})"

    row = i // n_cols
    col = i % n_cols
    ax = axes[row, col]
    ax.axvline(apc_loc, color='red', linestyle='--', linewidth=1.0, label='mutation')
    ax.plot(plot_df["pos"], plot_df["depth"], lw=0.7, color='tab:blue')
    ax.set_ylabel("Sequencing depth")
    ax.set_title(
        f"Depth across locus for {title_text}",
        color=title_color,
    )
    ax.set_ylim(bottom=0)  # Set lower ylim to 0
    # Optionally, add legend only to the first subplot
    if i == 0:
        ax.legend(loc='upper right')

# Hide unused subplots if any
for j in range(n_files, n_rows * n_cols):
    row = j // n_cols
    col = j % n_cols
    axes[row, col].set_visible(False)

# Set xlabel for bottom row
for col in range(n_cols):
    axes[-1, col].set_xlabel("Genomic position (chr5)")

plt.tight_layout()

if save_plots:
    plt.savefig(f"figures/{query}_APC_depth_panel.pdf", bbox_inches='tight')

plt.show()

### Oncoplot-style plots for Driver annotations

In [None]:
category_to_value = {
    'LOH': 1,
    'AMP': 2,
    'DEL': 3,
    'FS': 4,
    'STOP': 5,
    'MISS': 6
}

value_to_category = {
    v: k for k, v in category_to_value.items()
}
value_to_category[0] = 'neutral'

# colors for value 0, 1, 2, 3, 4, 5, 6
category_colors = ['#e0e0e0', '#aed8e6ff', '#ed2224ff', '#3a53a4ff', '#a68027ff', '#707f8fff', '#008001ff']

In [None]:
tmp = (
    pd.concat(all_annots)
#    .query("source == 'cnv'")
    .assign(gene=lambda df: df.mut_str.str.split(":").str[0])
    .assign(mut=lambda df: df.mut_str.str.split(":").str[1])
    .assign(lesion=lambda df: df.Tumor_Sample_Barcode.str.split("_").str[0:2].str.join("_"))
    .assign(
        variant_pos=lambda df: df.apply(
            lambda row: (
                int(re.search(r'p\.[A-Z](\d+)', row['mut']).group(1))
                if (
                    row['source'] != 'cnv'
                    and isinstance(row['mut'], str)
                    and re.search(r'p\.[A-Z](\d+)', row['mut'])
                ) else -1
            ),
            axis=1
        ).astype(int)
    )
)

onco_matrix = (
    tmp
    .assign(mut=lambda df: df.mut.map(lambda x: x if not "fs*" in x else "FS"))
    .assign(mut=lambda df: df.mut.map(lambda x: x if not "*" in x else "STOP"))
    .assign(mut=lambda df: df.mut.map(lambda x: x if not "p." in x else "MISS"))
    .assign(mut=lambda df: df.mut.map(category_to_value))
    .pivot_table(
        index='gene',
        columns='Tumor_Sample_Barcode',
        values='mut',
        aggfunc=lambda x: ( # prioritize STOP, FS, MISS, DEL, AMP, LOH
            x[x.isin([5,4,6,3,2,1])].sort_values(key=lambda s: s.map({5:0,4:1,6:2,3:3,2:4,1:5})).iloc[0]
            if x.isin([5,4,6,3,2,1]).any() else 0
        ),
        fill_value=0
    )
)

In [None]:
# Order the rows by number of non-zero values (most non-zeros on top)
row_order = onco_matrix.astype(bool).sum(axis=1).sort_values(ascending=False).index
onco_matrix_sorted = onco_matrix.loc[row_order]

# Desired sample group order
desired_group_order = ['FAP01_P6', 'FAP01_T3', 'FAP03_P1', 'FAP03_P2', ]

# Ensure all samples in all_samples are present as columns, filling missing ones with 0s
missing_samples = [s for s in all_samples if s not in onco_matrix_sorted.columns]
for s in missing_samples:
    onco_matrix_sorted[s] = 0

# Reorder columns: within each group, keep the order from all_samples, but order groups as desired
def group_key(sample):
    for i, prefix in enumerate(desired_group_order):
        if str(sample).startswith(prefix):
            return (i, all_samples.index(sample))
    # If not matched, put at the end
    return (len(desired_group_order), all_samples.index(sample))

ordered_samples = sorted(all_samples, key=group_key)
onco_matrix_sorted = onco_matrix_sorted[ordered_samples]

# --- Annotate APC FS and STOP variants with variant_pos ---
onco_matrix_display = onco_matrix_sorted.copy().astype(object)

annotations = pd.DataFrame('', index=onco_matrix_display.index, columns=onco_matrix_display.columns)
if 'APC' in onco_matrix_display.index:
    for col in onco_matrix_display.columns:
        cell_val = onco_matrix_display.loc['APC', col]
        # 4 = FS, 5 = STOP (see category_to_value)
        if cell_val in (4, 5):
            # Find mutations that match the displayed mutation type
            apc_rows = tmp[(tmp['gene'] == 'APC') & (tmp['Tumor_Sample_Barcode'] == col)]
            target_mut_type = "FS" if cell_val == 4 else "STOP"
            
            # Filter to mutations that would map to the same category as displayed
            matching_rows = apc_rows[
                apc_rows['mut'].apply(lambda x: 
                    (target_mut_type == "FS" and "fs*" in str(x)) or
                    (target_mut_type == "STOP" and "*" in str(x) and "fs*" not in str(x))
                )
            ]
            
            if not matching_rows.empty:
                # Use the first matching mutation's variant_pos
                variant_pos = matching_rows.iloc[0]['variant_pos']
                if pd.notnull(variant_pos) and variant_pos != -1:
                    annotations.loc['APC', col] = str(int(variant_pos))

# Identify individuals by sample column prefixes, in the new order
individual_prefixes = desired_group_order
columns = onco_matrix_display.columns

# Find the column indices where a new individual starts (except the first)
individual_starts = []
for prefix in individual_prefixes:
    indices = [i for i, col in enumerate(columns) if str(col).startswith(prefix)]
    if indices:
        individual_starts.append(indices[0])
individual_starts = sorted(set(individual_starts))
separator_indices = [i for i in individual_starts if i != 0]

# Create a colormap and norm for the heatmap
cmap = mpl.colors.ListedColormap(category_colors)
bounds = [0, 1, 2, 3, 4, 5, 6, 7]
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

fig_width = 12
fig_height = 5

fig, ax = plt.subplots(figsize=(fig_width, fig_height))
sns.heatmap(
    onco_matrix_sorted,
    cmap=cmap,
    norm=norm,
    cbar=False,
    ax=ax,
    annot=annotations,
    fmt='',
    annot_kws={"fontsize": 6, "va": "center", "ha": "center"}
)

# Remove axis labels
ax.set_xlabel("")
ax.set_ylabel("")

# Set custom x-axis tick labels: show only the suffix after the lesion name (e.g., R5_G1)
def get_sample_suffix(sample):
    # sample is e.g. FAP01_P6_R5_G1
    parts = str(sample).split("_")
    if len(parts) > 2:
        return "_".join(parts[2:])
    else:
        return str(sample)

ax.set_xticks(np.arange(len(onco_matrix_sorted.columns)) + 0.5)
ax.set_xticklabels([get_sample_suffix(s) for s in onco_matrix_sorted.columns], rotation=90)

# Add vertical separator lines between individuals
for idx in separator_indices:
    ax.axvline(idx, color='black', linewidth=2)

# --- Add lesion/individual names atop the relevant heatmap sections ---
# For each group, place the group name centered above its section
for prefix in desired_group_order:
    # Find all columns for this group
    indices = [i for i, col in enumerate(onco_matrix_sorted.columns) if str(col).startswith(prefix)]
    if indices:
        start = indices[0]
        end = indices[-1]
        center = (start + end) / 2 + 0.5  # +0.5 to align with heatmap grid
        ax.text(
            center,
            -0.2,  # slightly above the heatmap (tweak as needed)
            prefix,
            ha='center',
            va='bottom',
            fontsize=12,
            fontweight='bold',
            transform=ax.transData,
            clip_on=False
        )

# Create custom legend
present_values = set(onco_matrix.values.flatten())
handles = [
    mpl.patches.Patch(color=category_colors[i], label=label)
    for i, label in value_to_category.items() if i in present_values
]
ax.legend(
    handles=handles,
    title="Mutation Type",
    bbox_to_anchor=(1.01, 1),
    loc='upper left',
    borderaxespad=0.
)
fig.tight_layout()
if save_plots:
    fig.savefig('figures/onco_matrix.pdf', bbox_inches='tight')

## Mutation sharing plots

In [None]:
def plot_mutation_sharing(filtered_counts, ax):
    sns.barplot(x=filtered_counts.index, y=filtered_counts.values, color="steelblue", ax=ax)
    ax.set_yscale('log')
    ax.set_xlabel("Number of samples sharing mutation")
    ax.set_ylabel("Number of mutations")
    ax.set_title(f"Distribution of filtered mutation sharing across samples (n: {filtered_counts.sum()})")


In [None]:
def plot_stacked_mutation_sharing(filtered_counts, filtered_out_counts, n_samples, ax):
    x_positions = range(1, n_samples + 1)
    
    # Create stacked bars
    ax.bar(x_positions, filtered_counts.values, color="steelblue", label="Filtered mutations")
    ax.bar(x_positions, filtered_out_counts.values, 
            bottom=filtered_counts.values, color="gray", label="Filtered out mutations")
    
    ax.set_yscale('log')
    ax.set_xlabel("Number of samples sharing mutation")
    ax.set_ylabel("Number of mutations")
    ax.set_title(f"Distribution of mutation sharing across samples (n: {(filtered_counts + filtered_out_counts).sum()})")
    ax.legend()
    ax.set_xticks(x_positions)
    ax.set_xlim(0.5, n_samples + 0.5)
    ax.grid(axis='x', visible=False)


In [None]:
def plot_stacked_mutation_sharing_simple(filtered_counts, filtered_out_counts, n_samples, ax):
    x_positions = range(1, n_samples + 1)
    
    # Create stacked bars
    ax.bar(x_positions, filtered_counts.values, color="steelblue", label="Filtered mutations")
    ax.bar(x_positions, filtered_out_counts.values, 
            bottom=filtered_counts.values, color="gray", label="Filtered out mutations")
    
    ax.set_xlabel("Number of samples sharing mutation")
    ax.set_ylabel("Number of mutations")
    ax.set_title(f"Distribution of mutation sharing across samples (n: {(filtered_counts + filtered_out_counts).sum()})")
    ax.legend()
    ax.set_xticks(x_positions)
    ax.set_xticklabels([int(x) for x in x_positions])
    ax.set_xlim(0.5, n_samples + 0.5)
    ax.grid(axis='x', visible=False)


In [None]:
def plot_broken_axis_mutation_sharing(filtered_counts, filtered_out_counts, n_samples, ax1, ax2, split_point=325):
    x_positions = range(1, n_samples + 1)
    total_counts = filtered_counts + filtered_out_counts

    # Upper subplot (high values) - log scale
    ax1.bar(x_positions, filtered_counts.values, color="steelblue", label="Filtered mutations")
    ax1.bar(x_positions, filtered_out_counts.values, 
            bottom=filtered_counts.values, color="gray", label="Filtered out mutations")
    ax1.set_yscale('log')
    ax1.set_ylim(bottom=split_point)  # Focus on the high values, let matplotlib auto-scale the top
    ax1.set_ylabel("")
    ax1.legend()
    ax1.set_xticks(x_positions)
    ax1.set_xlim(0.5, n_samples + 0.5)
    ax1.grid(axis='x', visible=False)

    # Lower subplot (low values)
    ax2.bar(x_positions, filtered_counts.values, color="steelblue")
    ax2.bar(x_positions, filtered_out_counts.values, 
            bottom=filtered_counts.values, color="gray")
    ax2.set_ylim(0, split_point)  # Focus on the lower values
    ax2.set_xlabel("Number of samples sharing mutation")
    ax2.set_ylabel("Number of mutations")
    ax2.set_xticks(x_positions)
    ax2.set_xlim(0.5, n_samples + 0.5)
    ax2.grid(axis='x', visible=False)

    # Add break marks
    ax1.spines['bottom'].set_visible(False)
    ax2.spines['top'].set_visible(False)
    ax1.tick_params(labeltop=False, bottom=False, labelbottom=False)
    ax2.xaxis.tick_bottom()

    # Add diagonal lines to indicate the break
    d = 0.015  # size of diagonal lines
    kwargs = dict(transform=ax1.transAxes, color='k', clip_on=False)
    ax1.plot((-d, +d), (-d, +d), **kwargs)
    ax1.plot((1 - d, 1 + d), (-d, +d), **kwargs)

    kwargs.update(transform=ax2.transAxes)
    ax2.plot((-d, +d), (1 - d, 1 + d), **kwargs)
    ax2.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)

    ax1.set_title(f"Distribution of mutation sharing across samples (n: {total_counts.sum()})")


In [None]:
def plot_mutation_sharing_comparison(filtered_counts, filtered_out_counts, n_samples, split_point):
    # Create figure with 4 panels comparing different plotting approaches
    fig = plt.figure(figsize=(18, 12))
    gs = fig.add_gridspec(2, 2, height_ratios=[1, 1], hspace=0.3, wspace=0.2)

    # Panel 1: Standard bar plot
    ax1 = fig.add_subplot(gs[0, 0])
    plot_mutation_sharing(filtered_counts, ax1)

    # Panel 2: Simple stacked plot
    ax2 = fig.add_subplot(gs[0, 1])
    plot_stacked_mutation_sharing_simple(filtered_counts, filtered_out_counts, n_samples, ax2)

    # Panel 3: Stacked plot with log scale
    ax3 = fig.add_subplot(gs[1, 0])
    plot_stacked_mutation_sharing(filtered_counts, filtered_out_counts, n_samples, ax3)

    # Panel 4: Broken axis plot - create vertically stacked subplots with minimal space between them
    gs_broken = gs[1, 1].subgridspec(2, 1, height_ratios=[1, 1], hspace=0.05)
    ax4_top = fig.add_subplot(gs_broken[0])
    ax4_bottom = fig.add_subplot(gs_broken[1], sharex=ax4_top)
    plot_broken_axis_mutation_sharing(filtered_counts, filtered_out_counts, n_samples, ax4_top, ax4_bottom, split_point)


In [None]:
split_points = {
    'FAP03_P2': 325,
    'FAP01_T3': 2100,
    'FAP01_P6': 3300,
    'FAP03_P1': 1600,
}

for query, store in stores.items():
    print(query)

    query_str = "~Tumor_Sample_Barcode.str.contains('_N')"
    input_df = store.sample_df.query(query_str).copy()

    samples = input_df.Tumor_Sample_Barcode.unique()
    unfiltered_input_df = tbl.query('Tumor_Sample_Barcode in @samples').copy()

    n_samples = len(samples)

    # Calculate mutation sharing counts for both datasets
    def get_mutation_sharing_counts(df, n_samples):
        return (
            df
            .pivot_table(
                index=store.var_cols,
                columns=store.sample_col,
                values='VAF',
                aggfunc=lambda x: 1,
                fill_value=0
            )
            .sum(axis=1)
            .value_counts()
            .reindex(range(1, n_samples + 1), fill_value=0)
            .sort_index()
        )

    filtered_counts = get_mutation_sharing_counts(input_df, n_samples)
    unfiltered_counts = get_mutation_sharing_counts(unfiltered_input_df, n_samples)
    filtered_out_counts = unfiltered_counts - filtered_counts

#    plot_mutation_sharing_comparison(filtered_counts, filtered_out_counts, n_samples, split_points[query])
#    plt.show()
#    continue

    if query == 'FAP03_P2':
        figsize = (12.5, 6.5)
        fig = plt.figure(figsize=figsize)
        gs_broken = fig.add_gridspec(2, 1, height_ratios=[1, 1], hspace=0.05)
        ax1 = fig.add_subplot(gs_broken[0])
        ax2 = fig.add_subplot(gs_broken[1])
        plot_broken_axis_mutation_sharing(filtered_counts, filtered_out_counts, n_samples, ax1, ax2, split_points[query])
    else:
        if query == 'FAP01_T3':
            figsize = (11, 6.5)
        else:
            figsize = (5, 6.5)
        fig, ax = plt.subplots(figsize=figsize)
        plot_stacked_mutation_sharing_simple(filtered_counts, filtered_out_counts, n_samples, ax)

    if save_plots:
        plt.savefig(f'figures/mutation_sharing_{query}.pdf', bbox_inches='tight')
    plt.show()

In [None]:
query = 'FAP03_P2'
store = stores[query]
query_str = "~Tumor_Sample_Barcode.str.contains('_N')"
input_df = store.sample_df.query(query_str).copy()
samples = input_df.Tumor_Sample_Barcode.unique()
unfiltered_input_df = tbl.query('Tumor_Sample_Barcode in @samples').copy()

df_filtered = input_df.copy()
df_unfiltered = unfiltered_input_df.copy()

# Extract region from sample name, e.g., "FAP03_P2_R4_G7" -> "R4"
def extract_region(sample_name):
    m = re.search(r'(R\d+)', sample_name)
    return m.group(1) if m else 'Unknown'

df_filtered['region'] = df_filtered[store.sample_col].apply(extract_region)
df_unfiltered['region'] = df_unfiltered[store.sample_col].apply(extract_region)

# Get all regions and order them as R1, R2, ..., Unknown (if present)
def region_sort_key(region):
    m = re.match(r'R(\d+)', region)
    return (0, int(m.group(1))) if m else (1, region)

regions = sorted(df_filtered['region'].unique(), key=region_sort_key)

# Map region to color by finding the first sample in that region and using its color from store.colordict
region_to_color = {}
for region in regions:
    region_samples = df_filtered.loc[df_filtered['region'] == region, store.sample_col]
    color = "steelblue"
    for sample in region_samples:
        if sample in store.colordict:
            color = store.colordict[sample]
            break
    region_to_color[region] = color

# Prepare a DataFrame to collect counts per region
all_counts = []

for region in regions:
    # Get counts for filtered and unfiltered data
    df_filtered_region = df_filtered[df_filtered['region'] == region]
    df_unfiltered_region = df_unfiltered[df_unfiltered['region'] == region]
    
    # Calculate per-region n_samples
    region_n_samples = df_filtered_region[store.sample_col].nunique()
    
    region_filtered_counts = get_mutation_sharing_counts(df_filtered_region, region_n_samples)
    region_unfiltered_counts = get_mutation_sharing_counts(df_unfiltered_region, region_n_samples)
    
    # Calculate filtered out counts
    region_filtered_out_counts = region_unfiltered_counts - region_filtered_counts
    
    region_df = pd.DataFrame({
        'num_samples': region_filtered_counts.index,
        'filtered_count': region_filtered_counts.values,
        'filtered_out_count': region_filtered_out_counts.values,
        'unfiltered_count': region_unfiltered_counts.values,
        'region': region
    })
    all_counts.append(region_df)

plot_df = pd.concat(all_counts, ignore_index=True)

# Ensure region column is categorical with the desired order for plotting
plot_df['region'] = pd.Categorical(plot_df['region'], categories=regions, ordered=True)

# Plot: do NOT share x-axis, so each region can have its own x range
def region_stacked_barplot(data, **kwargs):
    region = data['region'].iloc[0]
    bar_color = region_to_color.get(region, "steelblue")
    
    ax = plt.gca()
    x_positions = data['num_samples']
    
    # Create stacked bars
    ax.bar(x_positions, data['filtered_count'], color=bar_color, label="Filtered mutations")
    ax.bar(x_positions, data['filtered_out_count'], 
           bottom=data['filtered_count'], color="gray", label="Filtered out mutations")
    
    # Ensure x-axis labels are integers
    ax.set_xticks(x_positions)
    ax.set_xticklabels([int(x) for x in x_positions])
    ax.grid(axis='x', visible=False)

g = sns.FacetGrid(plot_df, col="region", col_wrap=3, sharey=False, sharex=False, height=3, aspect=1.2)
g.map_dataframe(region_stacked_barplot)
g.set_axis_labels("Number of samples with mutation", "Number of mutations")
g.set_titles("Region: {col_name}")

# Add legend to the first subplot in top-right position
g.axes[3].legend(loc='upper right')

plt.tight_layout()

# Print the actual figure size
fig = g.figure
figsize_inches = fig.get_size_inches()
print(f"Actual figure size: {figsize_inches[0]:.1f} x {figsize_inches[1]:.1f} inches")

if save_plots:
    plt.savefig(f'figures/mutation_sharing_{query}-regions.pdf', bbox_inches='tight')
plt.show()
