In [22]:
import os
import sys
import math
import csv
from pathlib import Path

# plotting libs
try:
    import matplotlib.pyplot as plt
    from matplotlib import rcParams
    from matplotlib_venn import venn2
except Exception:
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "matplotlib_venn"])
    import matplotlib.pyplot as plt
    from matplotlib import rcParams
    from matplotlib_venn import venn2

rcParams.update({'figure.autolayout': True, 'font.size': 10})

In [23]:
#!/usr/bin/env python3
"""
compare_counts_per_sample_venns_with_side_mutation_counts.py

Like previous script, but displays mutation-subtype tallies for non-shared SNVs in the
white space to the left (DS-only) and right (FB-only) of each SNV subplot.

Usage:
    python compare_counts_per_sample_venns_with_side_mutation_counts.py \
        --rootA /path/to/results_ds \
        --rootB /path/to/results \
        --outdir /path/to/save/venns \
        [--include-mito]
"""


# Canonical mutation map (collapsed)
mutation_map = {
    "C>A": ["C>A", "G>T"],
    "C>G": ["C>G", "G>C"],
    "C>T": ["C>T", "G>A"],
    "T>A": ["T>A", "A>T"],
    "T>C": ["T>C", "A>G"],
    "T>G": ["T>G", "A>C"],
}


def assign_class_and_cpg(ref, alt):
    """
    Given ref and alt (strings), return (mut_class, is_cpg_bool)
    Handles ref values that start with 'CpG' or 'GpC' (in which case treat base as C or G and mark is_cpg True).
    Uses canonical mutation_map to collapse complementary changes.
    Returns (None, False) if no class matched.
    """
    # normalize
    ref_str = str(ref)
    alt_str = str(alt)
    is_cpg = False

    # check for prefixes like 'CpG' or 'GpC'
    if ref_str.startswith("CpG"):
        ref_base = "C"
        is_cpg = True
    elif ref_str.startswith("GpC"):
        ref_base = "G"
        is_cpg = True
    else:
        ref_base = ref_str

    # if ref_base longer than 1, take the first base
    if len(ref_base) > 1:
        ref_base = ref_base[0]

    mut = f"{ref_base.upper()}>{alt_str.upper()}"
    mut_class = None
    for mclass, muts in mutation_map.items():
        if mut in [m.upper() for m in muts]:
            mut_class = mclass
            break

    return mut_class, is_cpg


def variant_type_from_filename(fn):
    lower = fn.lower()
    if 'indel' in lower:
        return 'INDEL'
    if 'snv' in lower:
        return 'SNV'
    return 'SNV'


def parse_count_file(path, forced_variant_type=None):
    variants = set()
    typed_map = {}
    with open(path, 'rt') as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith('#'):
                continue
            parts = line.split()
            if len(parts) < 4:
                continue
            chrom, pos, ref, alt = parts[0], parts[1], parts[2], parts[3]
            vid = f"{chrom}:{pos}:{ref}:{alt}"
            variants.add(vid)
            vt = forced_variant_type if forced_variant_type else 'SNV'
            typed_map[vid] = vt
    return variants, typed_map


def find_donors(root):
    donors = set()
    if not os.path.isdir(root):
        return donors
    for entry in os.listdir(root):
        p = os.path.join(root, entry)
        if os.path.isdir(p):
            donors.add(entry)
    return donors


def index_files_in_donor(donor_dir):
    """
    Walk donor_dir recursively and return mapping filename -> first absolute path found.
    Only include files that look like count files (contain 'count' and end with .txt).
    """
    mapping = {}
    for dirpath, _, filenames in os.walk(donor_dir):
        for fn in filenames:
            if fn.lower().endswith('.txt') and 'count' in fn.lower():
                if fn not in mapping:
                    mapping[fn] = os.path.join(dirpath, fn)
    return mapping


def extract_grouping_from_filename(fn):
    lower = fn.lower()
    grouping = {'sample': None, 'region': 'unknown', 'clonality': 'unknown', 'var_class': 'unknown'}
    grouping['sample'] = fn.split('.')[0] if '.' in fn else os.path.splitext(fn)[0]
    if 'mito' in lower:
        grouping['region'] = 'mito'
    if 'nuclear' in lower:
        grouping['region'] = 'nuclear'
    if 'all' in lower:
        grouping['clonality'] = 'all'
    if 'clonal' in lower:
        grouping['clonality'] = 'clonal'
    if 'subclonal' in lower:
        grouping['clonality'] = 'subclonal'
    if 'snv' in lower:
        grouping['var_class'] = 'SNV'
    if 'indel' in lower:
        grouping['var_class'] = 'INDEL'
    return grouping


def make_subplot_text(ax, onlyA_cnt, onlyB_cnt, both_cnt, bd_onlyA, bd_onlyB, bd_both):
    """
    Put compact text inside/under the subplot (no long paths).
    """
    lines = [
        f"DS only: {onlyA_cnt}  (SNV:{bd_onlyA['SNV']} INDEL:{bd_onlyA['INDEL']})",
        f"FB only: {onlyB_cnt}  (SNV:{bd_onlyB['SNV']} INDEL:{bd_onlyB['INDEL']})",
        f"Both  : {both_cnt}  (SNV:{bd_both['SNV']} INDEL:{bd_both['INDEL']})"
    ]
    # place the text beneath the venn by relative coordinates inside axes
    ax.text(0.5, -0.18, "\n".join(lines), transform=ax.transAxes,
            fontsize=8, ha='center', va='top', family='monospace')


def breakdown_counts(variant_ids, tmapA, tmapB):
    snv = 0
    indel = 0
    for vid in variant_ids:
        vt = tmapA.get(vid) or tmapB.get(vid) or 'SNV'
        if vt == 'INDEL':
            indel += 1
        else:
            snv += 1
    return {'SNV': snv, 'INDEL': indel}


def tally_snv_mutation_map(variant_ids, tmapA, tmapB):
    """
    For a set of variant IDs, count only SNVs by mutation class + CpG.
    Return a dict mapping 'C>T' -> count and 'CpG_C>T' -> count etc.
    """
    tally = {}
    for vid in variant_ids:
        # parse variant id CHROM:POS:REF:ALT
        try:
            _, _, ref, alt = vid.split(':', 3)
        except ValueError:
            # skip malformed
            continue
        # only consider SNVs (typed maps may mark INDEL)
        vt = tmapA.get(vid) or tmapB.get(vid) or 'SNV'
        if vt != 'SNV':
            continue
        mut_class, is_cpg = assign_class_and_cpg(ref, alt)
        if mut_class is None:
            key = 'UNKNOWN'
            tally[key] = tally.get(key, 0) + 1
        else:
            # count both class and (if CpG) CpG-prefixed class
            key = mut_class
            tally[key] = tally.get(key, 0) + 1
            if is_cpg:
                ckey = f"CpG_{mut_class}"
                tally[ckey] = tally.get(ckey, 0) + 1
    return tally


def tally_to_compact_str(tally):
    """
    Convert tally dict to compact semicolon-separated string "C>T:5;T>C:2;CpG_C>T:1"
    Sorted by key for determinism.
    """
    if not tally:
        return ""
    parts = []
    for k in sorted(tally.keys()):
        parts.append(f"{k}:{tally[k]}")
    return ";".join(parts)


def format_tally_lines_for_side(tally, max_lines=6):
    """
    Return a multi-line string suitable for side display.
    Sort tally by count desc and show up to max_lines entries.
    """
    if not tally:
        return "(no SNV non-shared)"
    # sort by count desc
    items = sorted(tally.items(), key=lambda kv: kv[1], reverse=True)
    lines = []
    for k, v in items[:max_lines]:
        lines.append(f"{k}:{v}")
    if len(items) > max_lines:
        remaining = sum(c for _, c in items[max_lines:])
        lines.append(f"...+{remaining}")
    return "\n".join(lines)


def create_sample_figure(donor, sample, comparisons, out_png):
    """
    comparisons: list of dicts:
      { 'fname': fname, 'grouping': grouping, 'variantsA': set, 'variantsB': set, 'tmapA': dict, 'tmapB': dict }
    """
    n = len(comparisons)
    if n == 0:
        return None
    # layout: 2 cols preferred
    ncols = 2
    nrows = math.ceil(n / ncols)
    fig_h = max(4, 3 * nrows)
    fig_w = 10
    fig = plt.figure(figsize=(fig_w, fig_h))
    fig.suptitle(f"{donor} â€” {sample}", fontsize=14)

    summary_for_rows = []

    for i, comp in enumerate(comparisons):
        ax = fig.add_subplot(nrows, ncols, i + 1)
        A = comp['variantsA']
        B = comp['variantsB']
        tmapA = comp['tmapA']
        tmapB = comp['tmapB']
        # venn plot
        venn2([A, B], set_labels=("DS", "FB"), ax=ax)
        # header for this subplot from grouping, e.g., "Clonal SNV"
        clon = comp['grouping'].get('clonality', 'unknown')
        varc = comp['grouping'].get('var_class', 'unknown')
        title = f"{clon.title()} {varc}" if clon != 'unknown' else f"{varc}"
        ax.set_title(title, fontsize=11)

        inter = A & B
        onlyA = A - B
        onlyB = B - A

        bd_onlyA = breakdown_counts(onlyA, tmapA, tmapB)
        bd_onlyB = breakdown_counts(onlyB, tmapA, tmapB)
        bd_both = breakdown_counts(inter, tmapA, tmapB)

        make_subplot_text(ax, len(onlyA), len(onlyB), len(inter), bd_onlyA, bd_onlyB, bd_both)

        # mutation-class tallies for non-shared SNVs
        onlyA_mut_tally = tally_snv_mutation_map(onlyA, tmapA, tmapB)
        onlyB_mut_tally = tally_snv_mutation_map(onlyB, tmapA, tmapB)

        # side text only for SNV subplots
        if varc == 'SNV':
            left_lines = format_tally_lines_for_side(onlyA_mut_tally, max_lines=6)
            right_lines = format_tally_lines_for_side(onlyB_mut_tally, max_lines=6)
            # left of venn: using axes coordinates (<0 will place it outside left)
            ax.text(-0.32, 0.5, left_lines, transform=ax.transAxes,
                    fontsize=8, ha='left', va='center', family='monospace')
            # right of venn: x > 1 will place outside to the right
            ax.text(1.02, 0.5, right_lines, transform=ax.transAxes,
                    fontsize=8, ha='left', va='center', family='monospace')

        # store summary row (include compact semicolon strings for TSV)
        onlyA_mut_str = tally_to_compact_str(onlyA_mut_tally)
        onlyB_mut_str = tally_to_compact_str(onlyB_mut_tally)

        summary_for_rows.append({
            'donor': donor,
            'sample': sample,
            'grouping': title,
            'DS_count': len(A),
            'FB_count': len(B),
            'onlyDS': len(onlyA),
            'onlyFB': len(onlyB),
            'both': len(inter),
            'DS_SNV_only': bd_onlyA['SNV'],
            'DS_INDEL_only': bd_onlyA['INDEL'],
            'FB_SNV_only': bd_onlyB['SNV'],
            'FB_INDEL_only': bd_onlyB['INDEL'],
            'both_SNV': bd_both['SNV'],
            'both_INDEL': bd_both['INDEL'],
            'onlyDS_mut_map': onlyA_mut_str,
            'onlyFB_mut_map': onlyB_mut_str,
            'out_png': out_png
        })

    # if there are empty subplot slots, hide them
    total_slots = nrows * ncols
    for j in range(n, total_slots):
        ax_extra = fig.add_subplot(nrows, ncols, j + 1)
        ax_extra.axis('off')

    outp = Path(out_png)
    outp.parent.mkdir(parents=True, exist_ok=True)
    # make room for outside text by using tight bbox when saving (so left/right texts aren't clipped)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.savefig(outp, dpi=200, bbox_inches='tight')
    plt.close(fig)
    return summary_for_rows


def main(rootA, rootB, outdir, include_mito=False):
    rootA = os.path.abspath(rootA)
    rootB = os.path.abspath(rootB)
    outdir = os.path.abspath(outdir)
    summary_rows = []

    donorsA = find_donors(rootA)
    donorsB = find_donors(rootB)
    donors_common = sorted(list(donorsA & donorsB))
    if not donors_common:
        print("No common donors found between the two roots. Exiting.")
        return

    for donor in donors_common:
        donorA_dir = os.path.join(rootA, donor)
        donorB_dir = os.path.join(rootB, donor)
        filesA = index_files_in_donor(donorA_dir)
        filesB = index_files_in_donor(donorB_dir)
        # only consider exact filename matches present in both
        common_files = sorted(set(filesA.keys()) & set(filesB.keys()))
        if not common_files:
            continue

        # group matches by sample (sample = filename before first dot)
        grouped = {}
        for fname in common_files:
            if (not include_mito) and ('mito' in fname.lower()):
                continue
            sample = fname.split('.')[0] if '.' in fname else os.path.splitext(fname)[0]
            grouped.setdefault(sample, []).append(fname)

        # for each sample, create comparisons list and then a single figure
        for sample, fnames in grouped.items():
            comparisons = []
            for fname in sorted(fnames):
                pathA = filesA[fname]
                pathB = filesB[fname]
                vt = variant_type_from_filename(fname)
                variantsA, tmapA = parse_count_file(pathA, forced_variant_type=vt)
                variantsB, tmapB = parse_count_file(pathB, forced_variant_type=vt)
                grouping = extract_grouping_from_filename(fname)
                comparisons.append({
                    'fname': fname,
                    'grouping': grouping,
                    'variantsA': variantsA,
                    'variantsB': variantsB,
                    'tmapA': tmapA,
                    'tmapB': tmapB
                })

            if not comparisons:
                continue
            out_png = os.path.join(outdir, donor, f"{sample}.venns.png")
            summary_for_sample = create_sample_figure(donor, sample, comparisons, out_png)
            if summary_for_sample:
                summary_rows.extend(summary_for_sample)
                print(f"Wrote sample venns: {out_png}")

    # write compact summary TSV (no long paths)
    summary_path = os.path.join(outdir, "venn_summary_per_sample.tsv")
    with open(summary_path, 'w', newline='') as tsvf:
        fieldnames = [
            'donor', 'sample', 'grouping', 'DS_count', 'FB_count',
            'onlyDS', 'onlyFB', 'both',
            'DS_SNV_only', 'DS_INDEL_only', 'FB_SNV_only', 'FB_INDEL_only', 'both_SNV', 'both_INDEL',
            'onlyDS_mut_map', 'onlyFB_mut_map', 'out_png'
        ]
        writer = csv.DictWriter(tsvf, fieldnames=fieldnames, delimiter='\t')
        writer.writeheader()
        for r in summary_rows:
            writer.writerow(r)
    print(f"Summary written to {summary_path}")
    print("Done.")





In [None]:
if __name__ == "__main__":
    import argparse
    ap = argparse.ArgumentParser(description="Compare count.txt files between two roots and produce per-sample Venn figures with side mutation-type tallies.")
    ap.add_argument("--rootA", required=True, help="Path to results_ds (DeepSomatic) root")
    ap.add_argument("--rootB", required=True, help="Path to results (freebayes) root")
    ap.add_argument("--outdir", required=True, help="Directory to save per-sample venn PNGs and summary TSV")
    ap.add_argument("--include-mito", action='store_true', help="Include files with 'mito' in the filename (default: skip mito)")
    args = ap.parse_args()
    main(args.rootA, args.rootB, args.outdir, include_mito=args.include_mito)

In [24]:
rootA = os.path.abspath("/uufs/chpc.utah.edu/common/HIPAA/u1264408/u1264408/Git/SEMIColon/data/output/CellCut/results_ds")
rootB = os.path.abspath("/uufs/chpc.utah.edu/common/HIPAA/u1264408/u1264408/Git/SEMIColon/data/output/CellCut/results")
outdir = os.path.abspath("/uufs/chpc.utah.edu/common/HIPAA/u1264408/u1264408/Git/SEMIColon/data/output/CellCut/overlap")

In [25]:
main(rootA, rootB, outdir, include_mito=False)

Wrote sample venns: /uufs/chpc.utah.edu/common/HIPAA/u1264408/u1264408/Git/SEMIColon/data/output/CellCut/overlap/GB103/103_RE.venns.png
Wrote sample venns: /uufs/chpc.utah.edu/common/HIPAA/u1264408/u1264408/Git/SEMIColon/data/output/CellCut/overlap/GB103/A6__103_CE.venns.png
Wrote sample venns: /uufs/chpc.utah.edu/common/HIPAA/u1264408/u1264408/Git/SEMIColon/data/output/CellCut/overlap/GB103/B8_103_SI.venns.png
Wrote sample venns: /uufs/chpc.utah.edu/common/HIPAA/u1264408/u1264408/Git/SEMIColon/data/output/CellCut/overlap/GB103/C1_103_TR.venns.png
Wrote sample venns: /uufs/chpc.utah.edu/common/HIPAA/u1264408/u1264408/Git/SEMIColon/data/output/CellCut/overlap/GB103/C3_103_DE.venns.png
Wrote sample venns: /uufs/chpc.utah.edu/common/HIPAA/u1264408/u1264408/Git/SEMIColon/data/output/CellCut/overlap/GB115/B18_115_TC.venns.png
Wrote sample venns: /uufs/chpc.utah.edu/common/HIPAA/u1264408/u1264408/Git/SEMIColon/data/output/CellCut/overlap/GB115/B1_115_SI.venns.png
Wrote sample venns: /uufs/ch

In [None]:
main(args.rootA, args.rootB, args.outdir, include_mito=args.include_mito)