# Cross-Species Consensus Peak Pipeline (v2)

**Goal:** Build a unified cross-species ATAC-seq consensus peak set across 6 primate species,
with full provenance tracking and gene annotation.

**Pipeline overview:**
1. Lift all non-human species peaks to hg38
2. Merge lifted peaks + human peaks into a unified consensus, tracking which species contributed each peak
3. Identify **human-specific peaks** (human peaks not overlapping any non-human lifted peak)
4. Lift unified consensus back to each species genome
5. Identify **species-specific peaks** (original species peaks not covered by any liftback peak)
6. Annotate all peaks with closest gene (from species-appropriate GTF) and distance

**Outputs:**
- `unified_peak_NNNNNN` -- consensus peaks in hg38 with species-detection annotation
- `human_peak_NNNNNN` -- human-specific peaks (hg38) not liftable to any species
- `{species}_peak_NNNNNN` -- per-species peaks in native coordinates, not liftable to hg38
- `peak_annotation.tsv` -- master annotation: peak_id, type, species detected, closest gene, distance

## 1. Load Packages and Define Configuration

In [None]:
import sys
import os
import pandas as pd
import numpy as np
from pathlib import Path

# Add the src directory to the path
sys.path.insert(0, os.path.abspath(".."))

from src.cross_species import (
    cross_species_consensus_pipeline,
    merge_with_species_tracking,
    find_human_specific_peaks,
    find_species_specific_peaks,
    create_peak_annotation,
    extract_gene_bed_from_gtf,
    annotate_with_closest_gene,
    add_peak_ids,
    build_master_annotation,
    cross_map_species_specific_peaks,
    DEFAULT_GTF_FILES,
    REVERSE_CHAIN_FILES,
    CROSS_SPECIES_ROUTES,
)
from src.liftover import DEFAULT_CHAIN_DIR

print("All imports loaded successfully")

In [None]:
# =============================================================================
# Configuration -- edit paths here
# =============================================================================

PEAKS_BASE = "/cluster/project/treutlein/USERS/jjans/analysis/adult_intestine/peaks"

# Human consensus peaks (separate from non-human species)
HUMAN_BED = f"{PEAKS_BASE}/consensus_peak_calling_Human/Consensus_Peaks_Filtered_500.bed"

# Non-human species consensus peaks
SPECIES_BEDS = {
    "Bonobo":      f"{PEAKS_BASE}/consensus_peak_calling_Bonobo/Consensus_Peaks_Filtered_500.bed",
    "Chimpanzee":  f"{PEAKS_BASE}/consensus_peak_calling_Chimpanzee/Consensus_Peaks_Filtered_500.bed",
    "Gorilla":     f"{PEAKS_BASE}/consensus_peak_calling_Gorilla/Consensus_Peaks_Filtered_500.bed",
    "Macaque":     f"{PEAKS_BASE}/consensus_peak_calling_Macaque/Consensus_Peaks_Filtered_500.bed",
    "Marmoset":    f"{PEAKS_BASE}/consensus_peak_calling_Marmoset/Consensus_Peaks_Filtered_500.bed",
}

# Pre-lifted BED files (already in hg38 coords from previous liftover run)
# Set to None to re-run liftover from scratch
PRE_LIFTED_BEDS = {
    "Bonobo":      f"{PEAKS_BASE}/lifted_consensus_peaks/Consensus_Peaks_Filtered_500.hg38_Bonobo.bed",
    "Chimpanzee":  f"{PEAKS_BASE}/lifted_consensus_peaks/Consensus_Peaks_Filtered_500.hg38_Chimpanzee.bed",
    "Gorilla":     f"{PEAKS_BASE}/lifted_consensus_peaks/Consensus_Peaks_Filtered_500.hg38_Gorilla.bed",
    "Macaque":     f"{PEAKS_BASE}/lifted_consensus_peaks/Consensus_Peaks_Filtered_500.hg38_Macaque.bed",
    "Marmoset":    f"{PEAKS_BASE}/lifted_consensus_peaks/Consensus_Peaks_Filtered_500.hg38_Marmoset.bed",
}

# Output directory
OUTPUT_DIR = f"{PEAKS_BASE}/cross_species_consensus_v2"

# liftOver executable
LIFTOVER_PATH = "/cluster/project/treutlein/jjans/software/miniforge3/envs/genomes/bin/liftOver"

# Chain file directory
CHAIN_DIR = DEFAULT_CHAIN_DIR

# GTF files for gene annotation (use defaults from cross_species.py)
GTF_FILES = DEFAULT_GTF_FILES.copy()

# Per-species minimum match ratios for liftOver
# Great apes are closer to human -> higher match; more distant species -> lower
MIN_MATCH = {
    "Bonobo":      0.9,
    "Chimpanzee":  0.9,
    "Gorilla":     0.9,
    "Macaque":     0.8,
    "Marmoset":    0.6,
}

MERGE_DISTANCE = 0  # Max distance between peaks to merge (0 = must overlap)
NCPU = 16           # Parallel workers for liftover + cross-mapping

print(f"Output directory: {OUTPUT_DIR}")
print(f"Chain file dir:   {CHAIN_DIR}")
print(f"liftOver binary:  {LIFTOVER_PATH}")
print(f"Species:          {list(SPECIES_BEDS.keys())}")
print(f"Min match rates:  {MIN_MATCH}")
print(f"Pre-lifted beds:  {'Yes (skipping step 1)' if PRE_LIFTED_BEDS else 'No (will run liftover)'}")

## 2. Validate Input Files

Check that all input peak files, chain files, and GTF files exist before running the pipeline.

In [None]:
# =============================================================================
# Validate all input files exist
# =============================================================================
from src.liftover import CHAIN_FILES, get_chain_file

all_ok = True

# Check human BED
print("--- Human peaks ---")
if os.path.exists(HUMAN_BED):
    n = sum(1 for l in open(HUMAN_BED) if l.strip() and not l.startswith('#'))
    print(f"  OK  Human: {n:,} peaks")
else:
    print(f"  MISSING  {HUMAN_BED}")
    all_ok = False

# Check non-human BED files
print("\n--- Non-human species peaks ---")
for species, bed in SPECIES_BEDS.items():
    if os.path.exists(bed):
        n = sum(1 for l in open(bed) if l.strip() and not l.startswith('#'))
        print(f"  OK  {species}: {n:,} peaks")
    else:
        print(f"  MISSING  {species}: {bed}")
        all_ok = False

# Check pre-lifted BED files (if provided)
if PRE_LIFTED_BEDS:
    print("\n--- Pre-lifted BED files (hg38 coords) ---")
    for species, bed in PRE_LIFTED_BEDS.items():
        if os.path.exists(bed):
            n = sum(1 for l in open(bed) if l.strip() and not l.startswith('#'))
            print(f"  OK  {species}: {n:,} peaks")
        else:
            print(f"  MISSING  {species}: {bed}")
            all_ok = False

# Check chain files (forward: species -> hg38)
# Only needed if we're actually running liftover (no pre-lifted beds)
needs_liftover = not PRE_LIFTED_BEDS or any(
    sp not in PRE_LIFTED_BEDS or not os.path.exists(PRE_LIFTED_BEDS[sp])
    for sp in SPECIES_BEDS
)

if needs_liftover:
    print("\n--- Forward chain files (species -> hg38) ---")
    for species in SPECIES_BEDS:
        if species == "Marmoset":
            for step in ["Marmoset_step1", "Marmoset_step2"]:
                path = os.path.join(CHAIN_DIR, CHAIN_FILES[step])
                status = "OK" if os.path.exists(path) else "MISSING"
                print(f"  {status}  {step}: {CHAIN_FILES[step]}")
                if status == "MISSING":
                    all_ok = False
        else:
            path = os.path.join(CHAIN_DIR, CHAIN_FILES[species])
            status = "OK" if os.path.exists(path) else "MISSING"
            print(f"  {status}  {species}: {CHAIN_FILES[species]}")
            if status == "MISSING":
                all_ok = False
else:
    print("\n--- Forward chain files --- SKIPPED (using pre-lifted files)")

# Check reverse chain files (hg38 -> species, always needed for liftback)
print("\n--- Reverse chain files (hg38 -> species) ---")
for key, chain in REVERSE_CHAIN_FILES.items():
    path = os.path.join(CHAIN_DIR, chain)
    status = "OK" if os.path.exists(path) else "MISSING"
    print(f"  {status}  {key}: {chain}")
    if status == "MISSING":
        all_ok = False

# Check GTF files
print("\n--- GTF files for gene annotation ---")
for species, gtf in GTF_FILES.items():
    status = "OK" if os.path.exists(gtf) else "MISSING"
    print(f"  {status}  {species}: {os.path.basename(gtf)}")
    if status == "MISSING":
        all_ok = False

# Check liftOver binary
print(f"\n--- liftOver binary ---")
status = "OK" if os.path.exists(LIFTOVER_PATH) else "MISSING"
print(f"  {status}  {LIFTOVER_PATH}")
if status == "MISSING":
    all_ok = False

print(f"\n{'All inputs validated successfully' if all_ok else 'WARNING: Some inputs are missing!'}")

## 3. Run Cross-Species Consensus Pipeline

This executes all 6 steps:
1. **Lift to hg38** -- liftOver each non-human species to human genome (skipped if `PRE_LIFTED_BEDS` are provided)
2. **Merge with tracking** -- merge all species (incl. human) with bedtools, recording which species contributed
3. **Human-specific** -- human peaks not overlapping any lifted non-human peak
4. **Lift back** -- lift unified consensus back to each species' genome (uses per-species `MIN_MATCH`)
5. **Species-specific** -- original species peaks not covered by any liftback peak
6. **Annotate** -- closest gene and distance for every peak

In [None]:
# =============================================================================
# Run the full pipeline
# =============================================================================
results = cross_species_consensus_pipeline(
    species_beds=SPECIES_BEDS,
    human_bed=HUMAN_BED,
    output_dir=OUTPUT_DIR,
    chain_dir=CHAIN_DIR,
    liftover_path=LIFTOVER_PATH,
    min_match=MIN_MATCH,
    merge_distance=MERGE_DISTANCE,
    peak_prefix="unified",
    gtf_files=GTF_FILES,
    pre_lifted_beds=PRE_LIFTED_BEDS,
    verbose=True,
    ncpu=NCPU,
)

print(f"\nPipeline status: {results['status']}")
print(f"Message: {results['message']}")

## 4. Inspect Output Files

Check what was produced in each output subdirectory.

In [None]:
# =============================================================================
# List all output files with sizes
# =============================================================================
import subprocess

print("Output directory structure:")
print("=" * 70)

for dirpath, dirnames, filenames in sorted(os.walk(OUTPUT_DIR)):
    level = dirpath.replace(OUTPUT_DIR, "").count(os.sep)
    indent = "  " * level
    reldir = os.path.relpath(dirpath, OUTPUT_DIR)
    print(f"{indent}{os.path.basename(dirpath)}/")
    subindent = "  " * (level + 1)
    for f in sorted(filenames):
        fpath = os.path.join(dirpath, f)
        size = os.path.getsize(fpath)
        if size > 1e6:
            size_str = f"{size / 1e6:.1f} MB"
        else:
            size_str = f"{size / 1e3:.0f} KB"
        # Count lines (skip binary files)
        try:
            n_lines = sum(1 for _ in open(fpath, errors="replace"))
            print(f"{subindent}{f:<50s} {size_str:>10s}  ({n_lines:,} lines)")
        except Exception:
            print(f"{subindent}{f:<50s} {size_str:>10s}  (binary)")

## 5. Explore the Unified Consensus Peaks

The unified BED has columns: `chr, start, end, peak_id, species_detected`. The species_detected column is a comma-separated list.

In [None]:
# =============================================================================
# Load and explore unified consensus peaks
# =============================================================================
unified_bed = results["output_files"]["unified_consensus"]
unified_df = pd.read_csv(unified_bed, sep="\t", header=None,
                         names=["chr", "start", "end", "peak_id", "species_detected"])

print(f"Unified consensus peaks: {len(unified_df):,}")
print(f"\nFirst 10 peaks:")
print(unified_df.head(10).to_string(index=False))

# Parse species detection
all_species = ["Bonobo", "Chimpanzee", "Gorilla", "Human", "Macaque", "Marmoset"]
for sp in all_species:
    unified_df[f"detected_in_{sp}"] = unified_df["species_detected"].str.contains(sp, na=False)

# Count species per peak
unified_df["n_species"] = unified_df[[f"detected_in_{sp}" for sp in all_species]].sum(axis=1)

print(f"\nSpecies detection distribution:")
print(unified_df["n_species"].value_counts().sort_index().to_string())

# Per-species detection rate
print(f"\nPer-species detection:")
for sp in all_species:
    col = f"detected_in_{sp}"
    n = unified_df[col].sum()
    pct = n / len(unified_df) * 100
    print(f"  {sp:<15s}: {n:>10,} peaks ({pct:5.1f}%)")

## 5b. UpSet Plot â€” Species Detection Overlap

UpSet plot showing which species combinations share peaks.
Each column in the intersection matrix represents a unique combination
of species in which a peak was detected.

In [None]:
# =============================================================================
# UpSet plot: species detection overlap
# =============================================================================
from src.visualization import plot_upset

all_species = ["Human", "Bonobo", "Chimpanzee", "Gorilla", "Macaque", "Marmoset"]
species_cols = [f"detected_in_{sp}" for sp in all_species]

plot_file = os.path.join(OUTPUT_DIR, "upset_species_detection.png")

fig = plot_upset(
    unified_df,
    set_columns=species_cols,
    set_labels=all_species,
    top_n=30,
    color="steelblue",
    title="Peak Detection Across Primate Species",
    saveas=plot_file,
)

# --- print top intersections ---
upset_df = unified_df[species_cols].copy()
upset_df.columns = all_species
pattern = upset_df.astype(int).apply(tuple, axis=1)
combo_counts = pattern.value_counts().sort_values(ascending=False)

print(f"\nTop 15 intersections:")
for combo_tuple, count in combo_counts.head(15).items():
    names = [sp for sp, flag in zip(all_species, combo_tuple) if flag]
    pct = count / len(unified_df) * 100
    print(f"  {', '.join(names):<55s} {count:>8,} peaks ({pct:5.1f}%)")

## 6. Explore Human-Specific and Species-Specific Peaks

In [None]:
# =============================================================================
# Human-specific peaks
# =============================================================================
hs_bed = results["output_files"]["human_specific"]
hs_df = pd.read_csv(hs_bed, sep="\t", header=None,
                     names=["chr", "start", "end", "peak_id"])

print(f"Human-specific peaks: {len(hs_df):,}")
print(f"First 5:")
print(hs_df.head().to_string(index=False))

# =============================================================================
# Species-specific peaks
# =============================================================================
print(f"\n{'='*70}")
print(f"Species-specific peaks:")
print(f"{'='*70}")

species_specific_counts = {}
for species in SPECIES_BEDS:
    key = f"species_specific_{species}"
    if key in results["output_files"]:
        sp_bed = results["output_files"][key]
        if os.path.exists(sp_bed):
            sp_df = pd.read_csv(sp_bed, sep="\t", header=None,
                                names=["chr", "start", "end", "peak_id"])
            species_specific_counts[species] = len(sp_df)
            print(f"\n  {species}: {len(sp_df):,} specific peaks")
            print(f"  Example IDs: {', '.join(sp_df['peak_id'].head(3))}")

# Summary bar
print(f"\n{'='*70}")
print(f"Summary:")
print(f"  Unified consensus:   {len(unified_df):>10,}")
print(f"  Human-specific:      {len(hs_df):>10,}")
for sp, n in species_specific_counts.items():
    print(f"  {sp}-specific: {n:>10,}")
total = len(unified_df) + len(hs_df) + sum(species_specific_counts.values())
print(f"  {'TOTAL':>22s}: {total:>10,}")

## 7. Explore Peak Annotation File

The master annotation file contains: `peak_id, chr, start, end, peak_type, species_detected, closest_gene, distance_to_gene`.

In [None]:
# =============================================================================
# Load and explore the master peak annotation
# =============================================================================
annotation_file = results["output_files"]["annotation"]
annot_df = pd.read_csv(annotation_file, sep="\t")

print(f"Total annotated peaks: {len(annot_df):,}")
print(f"\nColumns: {list(annot_df.columns)}")
print(f"\nPeak types:")
print(annot_df["peak_type"].value_counts().to_string())

print(f"\nDistance to nearest gene (summary):")
valid_dist = annot_df[annot_df["distance_to_gene"] >= 0]["distance_to_gene"]
print(f"  Median: {valid_dist.median():,.0f} bp")
print(f"  Mean:   {valid_dist.mean():,.0f} bp")
print(f"  At TSS (0 bp): {(valid_dist == 0).sum():,}")
print(f"  < 1 kb:  {(valid_dist < 1000).sum():,}")
print(f"  < 10 kb: {(valid_dist < 10000).sum():,}")
print(f"  < 100 kb: {(valid_dist < 100000).sum():,}")

print(f"\nSample rows from each peak type:")
for pt in annot_df["peak_type"].unique():
    subset = annot_df[annot_df["peak_type"] == pt].head(3)
    print(f"\n  {pt}:")
    print(subset[["peak_id", "chr", "start", "end", "species_detected", "closest_gene", "distance_to_gene"]].to_string(index=False))

## 7b. Gene Conservation After Liftback

For each unified peak we know the closest human gene (hg38 annotation).
After lifting those same peaks back to each species' genome, we can ask:
**is the closest gene in the species genome the same gene?**

This checks whether the syntenic neighbourhood is preserved across the liftover round-trip.
Gene name matching is case-insensitive and uses the gene symbol (not Ensembl ID),
so conservation rates are a lower bound (orthologs with different names will be counted as mismatches).

In [None]:
# =============================================================================
# Gene conservation analysis: human annotation vs species liftback annotation
# =============================================================================
import subprocess, tempfile

# Load human gene annotation for unified peaks
human_annot = pd.read_csv(
    os.path.join(OUTPUT_DIR, "06_annotation", "unified_gene_annotation.tsv"),
    sep="\t",
)
print(f"Human annotation: {len(human_annot):,} unified peaks with closest gene")
n_human_ens = human_annot["closest_gene"].str.startswith("ENS").sum()
n_human_named = len(human_annot) - n_human_ens
print(f"  Human gene symbols: {n_human_named:,} ({n_human_named/len(human_annot)*100:.0f}%)")
print(f"  Human Ensembl IDs:  {n_human_ens:,} ({n_human_ens/len(human_annot)*100:.0f}%)")

# Gene BED directory (already extracted by pipeline)
gene_bed_dir = os.path.join(OUTPUT_DIR, "06_annotation", "gene_beds")
liftback_dir = os.path.join(OUTPUT_DIR, "04_lifted_back")

conservation_results = {}

for species in SPECIES_BEDS:
    liftback_file = os.path.join(liftback_dir, f"unified_consensus_{species}.bed")
    gene_bed = os.path.join(gene_bed_dir, f"{species}_genes.bed")

    if not os.path.exists(liftback_file) or not os.path.exists(gene_bed):
        print(f"  Skipping {species} - missing files")
        continue

    # Count genes in species gene BED
    sp_n_genes = sum(1 for _ in open(gene_bed))

    # Run bedtools closest on liftback peaks vs species gene BED
    with tempfile.TemporaryDirectory() as tmpdir:
        # Detect chr prefix mismatch and harmonize gene BED
        with open(liftback_file) as f:
            lb_chrom = f.readline().split('\t')[0]
        with open(gene_bed) as f:
            gene_chrom = f.readline().split('\t')[0]

        lb_has_chr = lb_chrom.startswith("chr")
        gene_has_chr = gene_chrom.startswith("chr")

        gene_bed_use = gene_bed
        if lb_has_chr and not gene_has_chr:
            gene_bed_use = os.path.join(tmpdir, "genes_chr.bed")
            with open(gene_bed) as fin, open(gene_bed_use, 'w') as fout:
                for line in fin:
                    if line.strip():
                        fout.write("chr" + line)
        elif not lb_has_chr and gene_has_chr:
            gene_bed_use = os.path.join(tmpdir, "genes_nochr.bed")
            with open(gene_bed) as fin, open(gene_bed_use, 'w') as fout:
                for line in fin:
                    if line.strip():
                        fout.write(line[3:] if line.startswith("chr") else line)

        # Sort both files
        lb_sorted = os.path.join(tmpdir, "lb_sorted.bed")
        gene_sorted = os.path.join(tmpdir, "gene_sorted.bed")
        subprocess.run(f"sort -k1,1 -k2,2n {liftback_file} > {lb_sorted}",
                       shell=True, check=True)
        subprocess.run(f"sort -k1,1 -k2,2n {gene_bed_use} > {gene_sorted}",
                       shell=True, check=True)

        # bedtools closest
        closest_out = os.path.join(tmpdir, "closest.tsv")
        subprocess.run(
            f"bedtools closest -a {lb_sorted} -b {gene_sorted} -d -t first > {closest_out}",
            shell=True, check=True,
        )

        # Detect liftback column count
        with open(lb_sorted) as f:
            n_lb_cols = len(f.readline().strip().split('\t'))

        rows = []
        with open(closest_out) as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) < n_lb_cols + 7:
                    continue
                peak_id = parts[3]
                sp_gene = parts[n_lb_cols + 3]
                sp_dist = int(parts[-1]) if parts[-1] != '.' else -1
                rows.append({"peak_id": peak_id, "sp_gene": sp_gene, "sp_distance": sp_dist})

    sp_annot = pd.DataFrame(rows)

    # Merge with human annotation
    merged = human_annot.merge(sp_annot, on="peak_id", how="inner")
    merged = merged[merged["sp_distance"] >= 0]  # drop unmatched

    # Flag Ensembl IDs
    merged["human_is_ens"] = merged["closest_gene"].str.startswith("ENS")
    merged["sp_is_ens"] = merged["sp_gene"].str.startswith("ENS")

    # Compare gene names (case-insensitive)
    merged["same_gene"] = merged["closest_gene"].str.upper() == merged["sp_gene"].str.upper()

    # Categories
    both_named = ~merged["human_is_ens"] & ~merged["sp_is_ens"]
    both_ens = merged["human_is_ens"] & merged["sp_is_ens"]
    mixed = (merged["human_is_ens"] & ~merged["sp_is_ens"]) | (~merged["human_is_ens"] & merged["sp_is_ens"])

    # Restrict to peaks where BOTH annotations are near a gene (<10kb)
    near_gene = (merged["distance_to_gene"] < 10000) & (merged["sp_distance"] < 10000)
    near_both_named = both_named & near_gene

    n_total = len(merged)
    n_both_named = both_named.sum()
    n_same_named = merged.loc[both_named, "same_gene"].sum()
    n_near_both_named = near_both_named.sum()
    n_same_near = merged.loc[near_both_named, "same_gene"].sum()

    conservation_results[species] = {
        "total_peaks": n_total,
        "sp_n_genes": sp_n_genes,
        "both_named": n_both_named,
        "same_named": n_same_named,
        "pct_named": n_same_named / n_both_named * 100 if n_both_named else 0,
        "near_both_named": n_near_both_named,
        "same_near": n_same_near,
        "pct_near": n_same_near / n_near_both_named * 100 if n_near_both_named else 0,
        "both_ens": both_ens.sum(),
        "mixed": mixed.sum(),
    }

    print(f"\n{species} ({sp_n_genes:,} gene TSSs in GTF):")
    print(f"  Peaks compared:                      {n_total:>10,}")
    print(f"  Both have gene symbol:               {n_both_named:>10,}  -> same gene: {n_same_named:,} ({n_same_named/n_both_named*100:.1f}%)" if n_both_named else "")
    print(f"  Both named + both <10kb from gene:   {n_near_both_named:>10,}  -> same gene: {n_same_near:,} ({n_same_near/n_near_both_named*100:.1f}%)" if n_near_both_named else "")
    print(f"  Both Ensembl IDs (can't compare):    {both_ens.sum():>10,}")
    print(f"  Mixed (Ensembl vs symbol):            {mixed.sum():>10,}")

# Summary table
print(f"\n{'='*80}")
print(f"GENE CONSERVATION SUMMARY")
print(f"{'='*80}")
print(f"  {'Species':<15s} {'Genes':>7s} {'Both named':>12s} {'Same':>8s} {'%':>6s}  {'Near(<10kb)':>12s} {'Same':>8s} {'%':>6s}")
print(f"  {'-'*78}")
for species, r in conservation_results.items():
    print(f"  {species:<15s} {r['sp_n_genes']:>7,} {r['both_named']:>12,} {r['same_named']:>8,} {r['pct_named']:>5.1f}%  {r['near_both_named']:>12,} {r['same_near']:>8,} {r['pct_near']:>5.1f}%")

print(f"\n  Note: 'Both named' = both human and species annotation are gene symbols (not Ensembl IDs).")
print(f"  Note: 'Near(<10kb)' = further restricted to peaks <10kb from gene on BOTH sides.")
print(f"  Peaks far from any gene in the species genome often get a different 'closest gene'")
print(f"  simply because the species GTF has fewer annotated genes, not because synteny changed.")

In [None]:
# =============================================================================
# Detailed breakdown: mismatched genes -- what happened?
# =============================================================================

# Re-run the merge for one species to inspect mismatches in detail
example_species = "Marmoset"  # most divergent, interesting to inspect

liftback_file = os.path.join(liftback_dir, f"unified_consensus_{example_species}.bed")
gene_bed = os.path.join(gene_bed_dir, f"{example_species}_genes.bed")

with tempfile.TemporaryDirectory() as tmpdir:
    with open(liftback_file) as f:
        lb_chrom = f.readline().split('\t')[0]
    with open(gene_bed) as f:
        gene_chrom = f.readline().split('\t')[0]

    lb_has_chr = lb_chrom.startswith("chr")
    gene_has_chr = gene_chrom.startswith("chr")
    gene_bed_use = gene_bed

    if lb_has_chr and not gene_has_chr:
        gene_bed_use = os.path.join(tmpdir, "genes_chr.bed")
        with open(gene_bed) as fin, open(gene_bed_use, 'w') as fout:
            for line in fin:
                if line.strip():
                    fout.write("chr" + line)
    elif not lb_has_chr and gene_has_chr:
        gene_bed_use = os.path.join(tmpdir, "genes_nochr.bed")
        with open(gene_bed) as fin, open(gene_bed_use, 'w') as fout:
            for line in fin:
                if line.strip():
                    fout.write(line[3:] if line.startswith("chr") else line)

    lb_sorted = os.path.join(tmpdir, "lb_sorted.bed")
    gene_sorted = os.path.join(tmpdir, "gene_sorted.bed")
    subprocess.run(f"sort -k1,1 -k2,2n {liftback_file} > {lb_sorted}", shell=True, check=True)
    subprocess.run(f"sort -k1,1 -k2,2n {gene_bed_use} > {gene_sorted}", shell=True, check=True)

    closest_out = os.path.join(tmpdir, "closest.tsv")
    subprocess.run(
        f"bedtools closest -a {lb_sorted} -b {gene_sorted} -d -t first > {closest_out}",
        shell=True, check=True,
    )

    with open(lb_sorted) as f:
        n_lb_cols = len(f.readline().strip().split('\t'))

    rows = []
    with open(closest_out) as f:
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) < n_lb_cols + 7:
                continue
            rows.append({
                "peak_id": parts[3],
                "sp_gene": parts[n_lb_cols + 3],
                "sp_distance": int(parts[-1]) if parts[-1] != '.' else -1,
            })

sp_annot = pd.DataFrame(rows)
detail = human_annot.merge(sp_annot, on="peak_id", how="inner")
detail = detail[detail["sp_distance"] >= 0]
detail["same_gene"] = detail["closest_gene"].str.upper() == detail["sp_gene"].str.upper()

# Categories
detail["human_is_ens"] = detail["closest_gene"].str.startswith("ENS")
detail["sp_is_ens"] = detail["sp_gene"].str.startswith("ENS")
both_named = ~detail["human_is_ens"] & ~detail["sp_is_ens"]

mismatch = detail[both_named & ~detail["same_gene"]]

print(f"=== Gene mismatch breakdown for {example_species} (gene-symbol peaks only) ===")
print(f"Both have gene symbol: {both_named.sum():,}")
print(f"Same gene:             {detail.loc[both_named, 'same_gene'].sum():,} ({detail.loc[both_named, 'same_gene'].mean()*100:.1f}%)")
print(f"Different gene:        {len(mismatch):,}")

# How many mismatches are due to large distance?
near_h = mismatch["distance_to_gene"] < 10000
near_sp = mismatch["sp_distance"] < 10000
print(f"\n  Of {len(mismatch):,} mismatches (both gene symbols, different name):")
print(f"    Both <10kb from gene:  {(near_h & near_sp).sum():>8,}  (true neighbourhood change or naming diff)")
print(f"    Human >10kb:           {(~near_h).sum():>8,}  (intergenic in human)")
print(f"    Species >10kb:         {(~near_sp).sum():>8,}  (intergenic in species, sparse GTF?)")

# Show examples of true nearby mismatches
true_mm = mismatch[near_h & near_sp]
if len(true_mm) > 0:
    print(f"\n  Examples of nearby mismatches (both <10kb):")
    sample = true_mm.sample(min(10, len(true_mm)), random_state=42)
    print(sample[["peak_id", "closest_gene", "distance_to_gene", "sp_gene", "sp_distance"]].to_string(index=False))

# =============================================================================
# Visualization
# =============================================================================
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# 1. Bar chart: conservation rate by species (named genes, <10kb)
species_names = list(conservation_results.keys())
pct_named = [conservation_results[s]["pct_named"] for s in species_names]
pct_near = [conservation_results[s]["pct_near"] for s in species_names]

x = np.arange(len(species_names))
w = 0.35
axes[0].bar(x - w/2, pct_named, w, label="Both named (all dist.)", color="steelblue")
axes[0].bar(x + w/2, pct_near, w, label="Both named + both <10kb", color="coral")
axes[0].set_xticks(x)
axes[0].set_xticklabels(species_names, rotation=30, ha="right")
axes[0].set_ylabel("% same closest gene")
axes[0].set_title("Gene conservation after liftback\n(gene-symbol peaks only)")
axes[0].legend(fontsize=8)
axes[0].set_ylim(0, 100)

# 2. Human distance vs species distance scatter for example species
# Restrict to both-named
detail_named = detail[both_named].copy()
sample_scatter = detail_named.sample(min(5000, len(detail_named)), random_state=42)
colors = ["steelblue" if s else "firebrick" for s in sample_scatter["same_gene"]]
axes[1].scatter(
    np.log10(sample_scatter["distance_to_gene"].clip(lower=1)),
    np.log10(sample_scatter["sp_distance"].clip(lower=1)),
    c=colors, alpha=0.15, s=3, rasterized=True,
)
axes[1].plot([0, 6], [0, 6], 'k--', alpha=0.5, lw=0.8)
axes[1].set_xlabel("log10(human distance to gene)")
axes[1].set_ylabel(f"log10({example_species} distance to gene)")
axes[1].set_title(f"TSS distance: human vs {example_species}\n(blue=same gene, red=different, symbol only)")
axes[1].set_xlim(-0.5, 6.5)
axes[1].set_ylim(-0.5, 6.5)

# 3. Conservation rate by distance bin (both named)
bins = [0, 100, 1000, 5000, 10000, 50000, 100000, 500000, np.inf]
labels = ["0-100", "100-1k", "1k-5k", "5k-10k", "10k-50k", "50k-100k", "100k-500k", ">500k"]
detail_named["dist_bin"] = pd.cut(detail_named["distance_to_gene"].clip(lower=0), bins=bins, labels=labels)
bin_stats = detail_named.groupby("dist_bin", observed=True)["same_gene"].agg(["mean", "count"])
axes[2].bar(range(len(bin_stats)), bin_stats["mean"] * 100, color="steelblue", edgecolor="white")
axes[2].set_xticks(range(len(bin_stats)))
axes[2].set_xticklabels(bin_stats.index, rotation=45, ha="right", fontsize=8)
axes[2].set_ylabel("% same gene")
axes[2].set_xlabel("Distance to gene in human (bp)")
axes[2].set_title(f"Conservation rate by distance ({example_species})\n(gene-symbol peaks only)")
for i, (pct, n) in enumerate(zip(bin_stats["mean"], bin_stats["count"])):
    axes[2].text(i, pct * 100 + 1, f"n={n:,}", ha="center", fontsize=6, rotation=0)

plt.tight_layout()
plot_file = os.path.join(OUTPUT_DIR, "gene_conservation_liftback.png")
plt.savefig(plot_file, dpi=150, bbox_inches="tight")
print(f"\nSaved plot: {plot_file}")
plt.show()

## Master Annotation Table

Comprehensive per-peak annotation with:
- **Coordinates** in each species (from liftback for unified, original for species-specific)
- **Nearest gene** and **distance to gene** in each species
- **Binary detection columns** (`Human_det`, `Bonobo_det`, ...) for UpSet plotting
- **Peak type**: `unified`, `human_specific`, or `{species}_specific`

In [None]:
# Load master annotation (produced by pipeline Step 7)
master_file = os.path.join(OUTPUT_DIR, "07_master_annotation", "master_annotation.tsv")

if os.path.exists(master_file):
    master = pd.read_csv(master_file, sep="\t", index_col="peak_id")
    print(f"Master annotation: {len(master):,} peaks x {len(master.columns)} columns")
    print(f"\nPeak types:\n{master['peak_type'].value_counts()}")

    # Detection summary
    det_cols = [c for c in master.columns if c.endswith("_det")]
    print(f"\nDetection summary:")
    for c in det_cols:
        print(f"  {c}: {master[c].sum():,} peaks")

    # Number of species per peak
    print(f"\nPeaks by number of species detected:")
    print(master["n_species"].value_counts().sort_index())

    # Show a few rows with key columns
    show_cols = ["peak_type", "n_species"] + det_cols
    print(f"\nSample rows:")
    display(master[show_cols].head(10))
else:
    print(f"Master annotation not found at {master_file}")
    print("Run the pipeline first (it now includes Steps 7 and 8)")

In [None]:
# UpSet-style visualization of species detection patterns
import matplotlib.pyplot as plt

if os.path.exists(master_file):
    master = pd.read_csv(master_file, sep="\t", index_col="peak_id")

    det_cols = [c for c in master.columns if c.endswith("_det")]
    species_names = [c.replace("_det", "") for c in det_cols]

    # Create detection pattern strings
    master["pattern"] = master[det_cols].apply(
        lambda row: ",".join([s for s, v in zip(species_names, row) if v == 1]),
        axis=1,
    )

    pattern_counts = master["pattern"].value_counts().head(20)

    fig, ax = plt.subplots(figsize=(12, 6))
    pattern_counts.plot(kind="barh", ax=ax, color="steelblue")
    ax.set_xlabel("Number of peaks")
    ax.set_ylabel("Species combination")
    ax.set_title("Top 20 Species Detection Patterns (for UpSet plotting)")
    ax.invert_yaxis()
    for i, v in enumerate(pattern_counts):
        ax.text(v + 50, i, f"{v:,}", va="center", fontsize=8)
    plt.tight_layout()
    plt.show()

    # Gene distance comparison across species
    gene_dist_cols = [c for c in master.columns if c.endswith("_gene_dist")]
    if gene_dist_cols:
        fig, ax = plt.subplots(figsize=(10, 5))
        for col in gene_dist_cols:
            sp = col.replace("_gene_dist", "")
            vals = master[col].dropna()
            if len(vals) > 0:
                vals_kb = vals[vals >= 0] / 1000
                ax.hist(vals_kb.clip(upper=500), bins=100, alpha=0.5, label=f"{sp} (n={len(vals):,})")
        ax.set_xlabel("Distance to nearest gene (kb)")
        ax.set_ylabel("Count")
        ax.set_title("Distance to Nearest Gene by Species")
        ax.legend()
        plt.tight_layout()
        plt.show()

## Cross-Mapping Species-Specific Peaks

Species-specific peaks are peaks that could not be lifted back from the unified hg38 consensus to the native species coordinates. Routing them through hg38 again would be pointless since they already failed that route.

**Strategy:** Instead, we use **direct inter-species chain files** from UCSC, chaining through intermediate assemblies where needed. For example:
- Bonobo(panPan2) -> Chimp(panTro5): `panPan2 -> panTro4 -> panTro5` (2 hops)
- Gorilla(gorGor4) -> Bonobo(panPan2): `gorGor4 -> gorGor5 -> panPan2` (2 hops)
- Macaque(rheMac10) -> Gorilla(gorGor4): `rheMac10 -> panTro6 -> panTro5 -> gorGor5 -> gorGor4` (4 hops)

This tells us whether a "species-specific" peak actually has a corresponding open chromatin region in another species, via an alternative liftover path that bypasses hg38.

In [None]:
# Load cross-mapping results (produced by pipeline Step 8)
cross_map_file = os.path.join(OUTPUT_DIR, "08_cross_mapping", "species_specific_cross_mapping.tsv")
cross_matrix_file = os.path.join(OUTPUT_DIR, "08_cross_mapping", "cross_mapping_matrix_pct.tsv")
cross_counts_file = os.path.join(OUTPUT_DIR, "08_cross_mapping", "cross_mapping_matrix_counts.tsv")

if os.path.exists(cross_map_file):
    cross_map = pd.read_csv(cross_map_file, sep="\t")
    print(f"Cross-mapping results: {len(cross_map)} source-target pairs\n")

    # Show summary with route info
    show_cols = ["source", "target", "source_specific", "n_hops",
                 "lifted_to_target", "overlap_target_peaks", "pct_overlap"]
    display(cross_map[show_cols])

    # Heatmap of cross-mapping percentages
    if os.path.exists(cross_matrix_file):
        import matplotlib.pyplot as plt

        matrix = pd.read_csv(cross_matrix_file, sep="\t", index_col=0)
        print(f"\nCross-mapping matrix (% of source-specific peaks overlapping target):")
        display(matrix.round(1))

        fig, axes = plt.subplots(1, 2, figsize=(16, 6))

        # Percentage heatmap
        ax = axes[0]
        im = ax.imshow(matrix.values, cmap="YlOrRd", aspect="auto")
        ax.set_xticks(range(len(matrix.columns)))
        ax.set_xticklabels(matrix.columns, rotation=45, ha="right")
        ax.set_yticks(range(len(matrix.index)))
        ax.set_yticklabels(matrix.index)
        ax.set_xlabel("Target species")
        ax.set_ylabel("Source species (species-specific peaks)")
        ax.set_title("% of Species-Specific Peaks\nOverlapping Target Peaks")
        for i in range(len(matrix.index)):
            for j in range(len(matrix.columns)):
                val = matrix.values[i, j]
                color = "white" if val > 50 else "black"
                ax.text(j, i, f"{val:.1f}%", ha="center", va="center", color=color, fontsize=9)
        plt.colorbar(im, ax=ax, label="% overlap")

        # Count heatmap
        if os.path.exists(cross_counts_file):
            counts = pd.read_csv(cross_counts_file, sep="\t", index_col=0)
            ax2 = axes[1]
            im2 = ax2.imshow(counts.values, cmap="Blues", aspect="auto")
            ax2.set_xticks(range(len(counts.columns)))
            ax2.set_xticklabels(counts.columns, rotation=45, ha="right")
            ax2.set_yticks(range(len(counts.index)))
            ax2.set_yticklabels(counts.index)
            ax2.set_xlabel("Target species")
            ax2.set_ylabel("Source species")
            ax2.set_title("Absolute Count of Overlapping Peaks")
            for i in range(len(counts.index)):
                for j in range(len(counts.columns)):
                    val = int(counts.values[i, j])
                    color = "white" if val > counts.values.max() * 0.6 else "black"
                    ax2.text(j, i, f"{val:,}", ha="center", va="center", color=color, fontsize=9)
            plt.colorbar(im2, ax=ax2, label="count")

        plt.tight_layout()
        plt.show()
else:
    print(f"Cross-mapping not found at {cross_map_file}")
    print("Run the pipeline first (it now includes Steps 7 and 8)")

## 8. Build Combined BED File

Write a single BED file with all peak categories: unified (hg38), human-specific (hg38), and species-specific (native coords). Also write per-species BED files with liftback unified + species-specific peaks.

In [None]:
# =============================================================================
# Build combined BED file (all peak categories)
# =============================================================================
combined_dir = os.path.join(OUTPUT_DIR, "07_combined")
os.makedirs(combined_dir, exist_ok=True)

combined_bed = os.path.join(combined_dir, "all_peaks_combined.bed")

with open(combined_bed, 'w') as fout:
    # Header as a comment
    fout.write("#chr\tstart\tend\tpeak_id\tcategory\tgenome_assembly\n")

    # 1. Unified peaks (hg38)
    n_unified = 0
    with open(results["output_files"]["unified_consensus"]) as fin:
        for line in fin:
            if line.strip() and not line.startswith('#'):
                parts = line.strip().split('\t')
                fout.write(f"{parts[0]}\t{parts[1]}\t{parts[2]}\t{parts[3]}\tunified\thg38\n")
                n_unified += 1

    # 2. Human-specific peaks (hg38)
    n_human_spec = 0
    with open(results["output_files"]["human_specific"]) as fin:
        for line in fin:
            if line.strip() and not line.startswith('#'):
                parts = line.strip().split('\t')
                fout.write(f"{parts[0]}\t{parts[1]}\t{parts[2]}\t{parts[3]}\thuman_specific\thg38\n")
                n_human_spec += 1

    # 3. Species-specific peaks (native coordinates)
    assembly_map = {
        "Bonobo": "panPan2", "Chimpanzee": "panTro5", "Gorilla": "gorGor4",
        "Macaque": "rheMac10", "Marmoset": "calJac1",
    }
    n_sp_spec = 0
    for species in SPECIES_BEDS:
        key = f"species_specific_{species}"
        if key in results["output_files"] and os.path.exists(results["output_files"][key]):
            assembly = assembly_map.get(species, species)
            with open(results["output_files"][key]) as fin:
                for line in fin:
                    if line.strip() and not line.startswith('#'):
                        parts = line.strip().split('\t')
                        fout.write(f"{parts[0]}\t{parts[1]}\t{parts[2]}\t{parts[3]}\t{species.lower()}_specific\t{assembly}\n")
                        n_sp_spec += 1

print(f"Combined BED written: {combined_bed}")
print(f"  Unified:          {n_unified:,}")
print(f"  Human-specific:   {n_human_spec:,}")
print(f"  Species-specific: {n_sp_spec:,}")
print(f"  Total:            {n_unified + n_human_spec + n_sp_spec:,}")

# =============================================================================
# Per-species BED files (liftback unified + species-specific, in native coords)
# =============================================================================
print(f"\nPer-species complete peak sets:")

for species in SPECIES_BEDS:
    sp_combined = os.path.join(combined_dir, f"all_peaks_{species}.bed")
    n = 0

    with open(sp_combined, 'w') as fout:
        # Liftback unified peaks
        liftback_key = f"liftback_{species}"
        if liftback_key in results["output_files"]:
            liftback_file = results["output_files"][liftback_key]
            if os.path.exists(liftback_file):
                with open(liftback_file) as fin:
                    for line in fin:
                        if line.strip() and not line.startswith('#'):
                            parts = line.strip().split('\t')
                            # liftback may have: chr, start, end, peak_id, source_coords
                            fout.write(f"{parts[0]}\t{parts[1]}\t{parts[2]}\t{parts[3]}\n")
                            n += 1

        # Species-specific peaks
        sp_key = f"species_specific_{species}"
        if sp_key in results["output_files"]:
            sp_file = results["output_files"][sp_key]
            if os.path.exists(sp_file):
                with open(sp_file) as fin:
                    for line in fin:
                        if line.strip() and not line.startswith('#'):
                            parts = line.strip().split('\t')
                            fout.write(f"{parts[0]}\t{parts[1]}\t{parts[2]}\t{parts[3]}\n")
                            n += 1

    print(f"  {species}: {n:,} peaks -> {sp_combined}")

## 9. Summary Statistics and Validation

In [None]:
# =============================================================================
# Summary statistics and validation
# =============================================================================
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

print("=" * 70)
print("VALIDATION")
print("=" * 70)

# 1. Check no duplicate peak IDs
all_ids = annot_df["peak_id"].tolist()
n_unique = len(set(all_ids))
n_total = len(all_ids)
print(f"\nPeak ID uniqueness: {n_unique:,} unique / {n_total:,} total", end="")
if n_unique == n_total:
    print(" -- OK")
else:
    dup_ids = annot_df[annot_df["peak_id"].duplicated(keep=False)]["peak_id"].unique()
    print(f" -- WARNING: {n_total - n_unique} duplicates!")
    print(f"  Duplicate IDs: {list(dup_ids[:10])}")

# Helper to extract input count from liftover result dict
def _get_liftover_counts(res):
    """Extract (input, lifted) from a liftover result dict."""
    lifted = res.get("lifted", 0)
    # liftover_peaks returns lifted + unmapped; two_step returns "original"
    if "original" in res:
        inp = res["original"]
    elif "unmapped" in res:
        inp = lifted + res["unmapped"]
    else:
        inp = lifted  # pre-lifted (no unmapped info) -- show lifted as input
    return inp, lifted

# 2. Liftover success rates
print(f"\nLiftover success rates (species -> hg38):")
print(f"  {'Species':<15s} {'Input':>10s} {'Lifted':>10s} {'Rate':>8s}  {'Note'}")
print(f"  {'-'*58}")
for species, res in results["lift_to_human"].items():
    inp, lifted = _get_liftover_counts(res)
    rate = lifted / inp * 100 if inp > 0 else 0
    note = "(pre-lifted)" if res.get("source") == "pre_lifted" else ""
    print(f"  {species:<15s} {inp:>10,} {lifted:>10,} {rate:>7.1f}%  {note}")

print(f"\nLiftback success rates (hg38 -> species):")
print(f"  {'Species':<15s} {'Input':>10s} {'Lifted':>10s} {'Rate':>8s}")
print(f"  {'-'*48}")
for species, res in results["lift_back"].items():
    inp, lifted = _get_liftover_counts(res)
    rate = lifted / inp * 100 if inp > 0 else 0
    print(f"  {species:<15s} {inp:>10,} {lifted:>10,} {rate:>7.1f}%")

# 3. Species detection distribution plot
print(f"\n{'='*70}")
print(f"SPECIES DETECTION IN UNIFIED PEAKS")
print(f"{'='*70}")

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart: peaks per n_species
n_species_counts = unified_df["n_species"].value_counts().sort_index()
axes[0].bar(n_species_counts.index, n_species_counts.values, color="steelblue", edgecolor="white")
axes[0].set_xlabel("Number of species detected")
axes[0].set_ylabel("Number of peaks")
axes[0].set_title("Distribution of species detection\n(unified consensus)")
for x, y in zip(n_species_counts.index, n_species_counts.values):
    axes[0].text(x, y, f"{y:,}", ha="center", va="bottom", fontsize=8)

# Bar chart: peak category counts
categories = {"Unified": len(unified_df), "Human-specific": len(hs_df)}
for sp, n in species_specific_counts.items():
    categories[f"{sp}-specific"] = n

cats = list(categories.keys())
vals = list(categories.values())
colors = ["steelblue"] + ["firebrick"] + ["seagreen"] * len(species_specific_counts)
axes[1].barh(cats, vals, color=colors, edgecolor="white")
axes[1].set_xlabel("Number of peaks")
axes[1].set_title("Peak counts by category")
for y_pos, v in enumerate(vals):
    axes[1].text(v, y_pos, f"  {v:,}", ha="left", va="center", fontsize=8)

plt.tight_layout()
plot_file = os.path.join(OUTPUT_DIR, "peak_summary.png")
plt.savefig(plot_file, dpi=150, bbox_inches="tight")
print(f"\nSaved summary plot: {plot_file}")
plt.show()

# 4. Save summary stats
summary_file = os.path.join(OUTPUT_DIR, "pipeline_summary.txt")
with open(summary_file, 'w') as f:
    f.write("Cross-Species Consensus Pipeline v2 -- Summary\n")
    f.write("=" * 60 + "\n\n")
    f.write(f"Unified consensus peaks: {len(unified_df):,}\n")
    f.write(f"Human-specific peaks:    {len(hs_df):,}\n")
    for sp, n in species_specific_counts.items():
        f.write(f"{sp}-specific peaks:  {n:,}\n")
    f.write(f"\nSpecies detection distribution (unified):\n")
    for n_sp, count in n_species_counts.items():
        f.write(f"  {n_sp} species: {count:,} peaks\n")
    f.write(f"\nAnnotation file: {annotation_file}\n")
    f.write(f"Combined BED:    {combined_bed}\n")

print(f"Saved summary: {summary_file}")