# Gene Content vs Phylogenetic Distance Analysis

## Overview

This notebook investigates the relationship between **gene content** (pangenome structure) and **phylogenetic distance** across bacterial species. The central question is:

> **Does phylogenetic relatedness predict gene content similarity?**

We hypothesize that closely related species should have more similar pangenome characteristics (core genome size, openness) than distantly related species. However, ecological factors and horizontal gene transfer may weaken this signal.

### Data Sources
- **Pangenome data**: KBase BERDL pangenome database (27,690 species, 293,059 genomes)
- **Phylogenetic tree**: GTDB r214 bacterial tree (`bac120.tree`) with 136,646 representative genomes
- **Taxonomy**: GTDB taxonomy for hierarchical comparisons

### Analysis Steps
1. Load and filter species data (quality control)
2. Match species to phylogenetic tree via representative genomes
3. Compute phylogenetic distance matrix
4. Compute gene content distance matrices (multiple metrics)
5. Analyze correlation at global and taxonomic levels
6. Visualize results and interpret findings

## 1. Setup and Configuration

In [None]:
import requests
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from io import StringIO
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
sns.set_theme(style="whitegrid")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['figure.dpi'] = 100

# Load authentication token for BERDL API
with open('.env', 'r') as f:
    for line in f:
        if line.startswith('KB_AUTH_TOKEN'):
            AUTH_TOKEN = line.split('"')[1]
            break

# BERDL API configuration
BASE_URL = "https://hub.berdl.kbase.us/apis/mcp"
DATABASE = "kbase_ke_pangenome"
HEADERS = {"Authorization": f"Bearer {AUTH_TOKEN}", "Content-Type": "application/json"}

def query_berdl(sql, limit=10000, offset=0):
    """Execute a SQL query against BERDL."""
    url = f"{BASE_URL}/delta/tables/query"
    payload = {"query": sql, "limit": limit, "offset": offset}
    response = requests.post(url, headers=HEADERS, json=payload)
    response.raise_for_status()
    data = response.json()
    results = data.get('result', data.get('results', []))
    return pd.DataFrame(results) if results else pd.DataFrame()

print("Setup complete")
print(f"BERDL database: {DATABASE}")

## 2. Load Pangenome Data

We load species-level pangenome statistics joined with taxonomic information. Key fields include:
- `no_genomes`: Number of genomes in the species (sampling depth)
- `no_core`: Number of core genes (present in >95% of genomes)
- `no_aux_genome`: Number of accessory genes (5-95%)
- `no_singleton_gene_clusters`: Genes found in only one genome
- `mean_intra_species_ANI`: Average nucleotide identity within species

In [None]:
# Load full pangenome data with taxonomy
sql = f"""
SELECT 
    p.gtdb_species_clade_id,
    s.GTDB_species,
    s.GTDB_taxonomy,
    p.no_genomes,
    p.no_core,
    p.no_aux_genome as no_accessory,
    p.no_singleton_gene_clusters as no_singletons,
    p.no_gene_clusters,
    p.no_CDSes,
    s.mean_intra_species_ANI,
    s.ANI_circumscription_radius
FROM {DATABASE}.pangenome p
JOIN {DATABASE}.gtdb_species_clade s 
    ON p.gtdb_species_clade_id = s.gtdb_species_clade_id
ORDER BY p.no_genomes DESC
"""

df_raw = query_berdl(sql, limit=30000)
print(f"Loaded {len(df_raw):,} species from BERDL")
df_raw.head()

## 3. Quality Filtering

Before analysis, we apply quality filters to ensure reliable pangenome statistics:

### Filter Criteria
1. **Minimum genome count**: Species with very few genomes have unreliable core/accessory estimates
   - Pangenome structure requires sampling multiple strains
   - We'll use a minimum threshold (e.g., 5-10 genomes)
   
2. **Valid pangenome statistics**: Remove species with missing or zero values

3. **Reasonable ANI values**: Filter out potential contaminated or misclassified species

In [None]:
# Quality filtering parameters
MIN_GENOMES = 10  # Minimum number of genomes per species
MIN_CORE_GENES = 100  # Minimum core genome size
MIN_ANI = 95.0  # Minimum ANI (GTDB species boundary)

print("Quality Filter Summary")
print("=" * 50)
print(f"Raw species count: {len(df_raw):,}")

# Apply filters step by step
df = df_raw.copy()

# Filter 1: Minimum genomes
n_before = len(df)
df = df[df['no_genomes'] >= MIN_GENOMES]
print(f"After min genomes >= {MIN_GENOMES}: {len(df):,} ({n_before - len(df):,} removed)")

# Filter 2: Valid core genome
n_before = len(df)
df = df[df['no_core'] >= MIN_CORE_GENES]
print(f"After min core genes >= {MIN_CORE_GENES}: {len(df):,} ({n_before - len(df):,} removed)")

# Filter 3: Valid ANI
n_before = len(df)
df = df[df['mean_intra_species_ANI'] >= MIN_ANI]
print(f"After min ANI >= {MIN_ANI}: {len(df):,} ({n_before - len(df):,} removed)")

# Filter 4: Non-null singletons
n_before = len(df)
df = df[df['no_singletons'].notna()]
print(f"After removing null singletons: {len(df):,} ({n_before - len(df):,} removed)")

print("=" * 50)
print(f"Final filtered species: {len(df):,}")
print(f"Total genomes represented: {df['no_genomes'].sum():,}")

In [None]:
# Parse taxonomy and compute derived metrics
def parse_taxonomy(tax_string):
    """Parse GTDB taxonomy string into levels."""
    levels = {}
    if pd.isna(tax_string):
        return levels
    for part in tax_string.split(';'):
        if '__' in part:
            level, name = part.split('__', 1)
            levels[level] = name
    return levels

# Extract taxonomy levels
tax_parsed = df['GTDB_taxonomy'].apply(parse_taxonomy)
df['domain'] = tax_parsed.apply(lambda x: x.get('d', 'Unknown'))
df['phylum'] = tax_parsed.apply(lambda x: x.get('p', 'Unknown'))
df['class'] = tax_parsed.apply(lambda x: x.get('c', 'Unknown'))
df['order'] = tax_parsed.apply(lambda x: x.get('o', 'Unknown'))
df['family'] = tax_parsed.apply(lambda x: x.get('f', 'Unknown'))
df['genus'] = tax_parsed.apply(lambda x: x.get('g', 'Unknown'))

# Compute pangenome metrics
df['pct_core'] = (df['no_core'] / df['no_gene_clusters'] * 100).round(2)
df['pct_accessory'] = (df['no_accessory'] / df['no_gene_clusters'] * 100).round(2)
df['pct_singletons'] = (df['no_singletons'] / df['no_gene_clusters'] * 100).round(2)
df['openness'] = (df['no_accessory'] / df['no_core']).round(3)  # Higher = more open pangenome
df['avg_genes_per_genome'] = (df['no_CDSes'] / df['no_genomes']).round(0)

# Extract representative genome ID from gtdb_species_clade_id
# Format: s__Species_name--RS_GCF_XXXXXX.X or --GB_GCA_XXXXXX.X
df['rep_genome'] = df['gtdb_species_clade_id'].str.extract(r'--((?:RS_GCF|GB_GCA)_\d+\.\d+)')[0]

print(f"Computed metrics for {len(df):,} species")
print(f"\nPangenome statistics:")
print(df[['no_genomes', 'no_core', 'pct_core', 'openness', 'avg_genes_per_genome']].describe().round(2))

In [None]:
# Visualize the filtering impact and data distribution
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# 1. Genome count distribution
ax1 = axes[0, 0]
ax1.hist(df['no_genomes'], bins=50, edgecolor='black', alpha=0.7)
ax1.set_xlabel('Number of Genomes per Species')
ax1.set_ylabel('Count')
ax1.set_title(f'Genome Sampling Depth (n={len(df):,} species)')
ax1.axvline(df['no_genomes'].median(), color='red', linestyle='--', 
            label=f'Median: {df["no_genomes"].median():.0f}')
ax1.set_xscale('log')
ax1.legend()

# 2. Core genome % distribution
ax2 = axes[0, 1]
ax2.hist(df['pct_core'], bins=50, edgecolor='black', alpha=0.7, color='green')
ax2.set_xlabel('Core Genome (%)')
ax2.set_ylabel('Count')
ax2.set_title('Core Genome Size Distribution')
ax2.axvline(df['pct_core'].median(), color='red', linestyle='--',
            label=f'Median: {df["pct_core"].median():.1f}%')
ax2.legend()

# 3. Openness distribution
ax3 = axes[1, 0]
ax3.hist(df['openness'], bins=50, edgecolor='black', alpha=0.7, color='orange')
ax3.set_xlabel('Openness (Accessory/Core ratio)')
ax3.set_ylabel('Count')
ax3.set_title('Pangenome Openness Distribution')
ax3.axvline(df['openness'].median(), color='red', linestyle='--',
            label=f'Median: {df["openness"].median():.2f}')
ax3.legend()

# 4. Phylum distribution
ax4 = axes[1, 1]
phylum_counts = df['phylum'].value_counts().head(15)
ax4.barh(range(len(phylum_counts)), phylum_counts.values)
ax4.set_yticks(range(len(phylum_counts)))
ax4.set_yticklabels(phylum_counts.index)
ax4.set_xlabel('Number of Species')
ax4.set_title('Species by Phylum (top 15)')
ax4.invert_yaxis()

plt.tight_layout()
plt.savefig('figure_quality_filtering.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nFigure saved: figure_quality_filtering.png")

## 4. Load Phylogenetic Tree

We use the GTDB r214 bacterial tree (`bac120.tree`) which contains branch lengths derived from a concatenated alignment of 120 universal single-copy marker genes.

The tree tips are labeled with representative genome accessions (RS_GCF_* or GB_GCA_*), which we match to our species data.

In [None]:
# Load the phylogenetic tree
try:
    from Bio import Phylo
    HAVE_BIOPYTHON = True
except ImportError:
    HAVE_BIOPYTHON = False
    print("BioPython not available, will use alternative tree parsing")

# Check tree file
import os
tree_file = 'bac120.tree'

if os.path.exists(tree_file):
    with open(tree_file, 'r') as f:
        tree_content = f.read()
    
    # Extract genome IDs from tree
    import re
    tree_genome_ids = set(re.findall(r'(RS_GCF_\d+\.\d+|GB_GCA_\d+\.\d+)', tree_content))
    
    print(f"Tree file: {tree_file}")
    print(f"File size: {len(tree_content):,} characters")
    print(f"Genome tips in tree: {len(tree_genome_ids):,}")
else:
    print(f"Tree file not found: {tree_file}")
    print("Will need to download GTDB tree")
    tree_genome_ids = set()

In [None]:
# Match species to tree via representative genomes
df['in_tree'] = df['rep_genome'].isin(tree_genome_ids)

n_matched = df['in_tree'].sum()
n_total = len(df)

print(f"Species with representative genome in tree: {n_matched:,} / {n_total:,} ({100*n_matched/n_total:.1f}%)")

# Filter to matched species
df_tree = df[df['in_tree']].copy().reset_index(drop=True)
print(f"\nFiltered dataset: {len(df_tree):,} species for phylogenetic analysis")

# Check taxonomic coverage
print(f"\nTaxonomic coverage:")
print(f"  Phyla: {df_tree['phylum'].nunique()}")
print(f"  Classes: {df_tree['class'].nunique()}")
print(f"  Orders: {df_tree['order'].nunique()}")
print(f"  Families: {df_tree['family'].nunique()}")
print(f"  Genera: {df_tree['genus'].nunique()}")

## 5. Compute Phylogenetic Distance Matrix

We compute pairwise phylogenetic distances between species using the branch lengths from the GTDB tree. The distance between two species is the sum of branch lengths on the path connecting their representative genomes.

For computational efficiency with large trees, we:
1. Parse the Newick tree using BioPython (if available) or a custom parser
2. Extract distances only for species pairs in our filtered dataset
3. Use caching to avoid recomputation

In [None]:
# Sample species for phylogenetic distance computation
# With tree pruning + multiprocessing, we can efficiently handle 1000+ species

MAX_SPECIES = 1000  # Configurable: try 500, 1000, or more
USE_PRECOMPUTED = False  # Set to True to use existing matrix

# Check for pre-computed matrix matching our target size
matrix_file = f'phylo_distance_matrix_{MAX_SPECIES}.npy'
species_file = f'phylo_matched_species_{MAX_SPECIES}.csv'

if os.path.exists(matrix_file) and os.path.exists(species_file):
    print(f"Found pre-computed {MAX_SPECIES}-species matrix!")
    phylo_dist = np.load(matrix_file)
    df_sample = pd.read_csv(species_file)
    USE_PRECOMPUTED = True
elif USE_PRECOMPUTED and os.path.exists('phylo_matched_species.csv'):
    df_precomputed = pd.read_csv('phylo_matched_species.csv')
    df_sample = df_precomputed[df_precomputed['no_genomes'] >= MIN_GENOMES].reset_index(drop=True)
    print(f"Using pre-computed data: {len(df_sample)} species (after quality filter)")
else:
    if len(df_tree) > MAX_SPECIES:
        print(f"Stratified sampling {MAX_SPECIES} species from {len(df_tree)}...")
        
        # Stratified sampling by phylum to ensure diversity
        df_sample = df_tree.groupby('phylum', group_keys=False).apply(
            lambda x: x.sample(
                min(len(x), max(1, int(MAX_SPECIES * len(x) / len(df_tree)))), 
                random_state=42
            )
        ).reset_index(drop=True)
        
        # Top up if needed
        if len(df_sample) < MAX_SPECIES:
            remaining = df_tree[~df_tree['rep_genome'].isin(df_sample['rep_genome'])]
            n_extra = min(len(remaining), MAX_SPECIES - len(df_sample))
            extra = remaining.sample(n_extra, random_state=42)
            df_sample = pd.concat([df_sample, extra]).reset_index(drop=True)
    else:
        df_sample = df_tree.copy()

print(f"Working with {len(df_sample)} species")
print(f"Phyla represented: {df_sample['phylum'].nunique()}")
print(f"Genome count range: {df_sample['no_genomes'].min()} - {df_sample['no_genomes'].max()}")
print(f"Expected pairs: {len(df_sample) * (len(df_sample)-1) // 2:,}")

In [None]:
# Compute phylogenetic distances using a PRUNED subtree for efficiency
# Uses multiprocessing for true parallel computation (ThreadPoolExecutor doesn't work
# due to Python's GIL blocking CPU-bound work)
import numpy as np
import os
from itertools import combinations
import time
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor
from io import StringIO

# Worker initialization for multiprocessing
def init_worker_with_tree(tree_str, genomes_list):
    """Initialize each worker process with its own tree copy."""
    global worker_tree, worker_genomes
    from io import StringIO
    from Bio import Phylo
    worker_tree = Phylo.read(StringIO(tree_str), 'newick')
    worker_genomes = genomes_list

def compute_distance_mp(args):
    """Compute pairwise distance in worker process."""
    i, j = args
    try:
        return (i, j, worker_tree.distance(worker_genomes[i], worker_genomes[j]))
    except:
        return (i, j, np.nan)

# Configuration
N_WORKERS = min(16, mp.cpu_count())  # Use up to 16 workers

# Skip if already loaded from pre-computed
if 'USE_PRECOMPUTED' in dir() and USE_PRECOMPUTED and 'phylo_dist' in dir():
    print(f"Using pre-computed distance matrix: {phylo_dist.shape}")
    n_species = len(df_sample)
else:
    # Check for existing matrix matching our MAX_SPECIES
    matrix_file = f'phylo_distance_matrix_{MAX_SPECIES}.npy'
    species_file = f'phylo_matched_species_{MAX_SPECIES}.csv'

    if os.path.exists(matrix_file) and os.path.exists(species_file):
        print(f"Loading pre-computed {MAX_SPECIES}-species matrix...")
        phylo_dist = np.load(matrix_file)
        df_sample = pd.read_csv(species_file)
        n_species = len(df_sample)
        print(f"Loaded {n_species} species")
    else:
        # Check for BioPython
        try:
            from Bio import Phylo
            HAVE_BIOPYTHON = True
        except ImportError:
            raise RuntimeError("BioPython required. Install with: pip install biopython")

        if not os.path.exists(tree_file):
            raise RuntimeError(f"Tree file not found: {tree_file}")

        # Parse tree
        print(f"Parsing phylogenetic tree...")
        start_time = time.time()
        tree = Phylo.read(tree_file, 'newick')
        parse_time = time.time() - start_time
        print(f"Tree parsed in {parse_time:.1f} seconds ({len(list(tree.get_terminals())):,} taxa)")

        # Get target genomes
        genome_list = df_sample['rep_genome'].tolist()
        target_set = set(genome_list)

        # Find matching terminals
        all_terminals = list(tree.get_terminals())
        matching_terminals = [t for t in all_terminals if t.name in target_set]
        found_genomes = [t.name for t in matching_terminals]
        print(f"Found {len(found_genomes)}/{len(genome_list)} genomes in tree")

        # Update df_sample
        df_sample = df_sample[df_sample['rep_genome'].isin(found_genomes)].reset_index(drop=True)
        genome_list = df_sample['rep_genome'].tolist()
        n_species = len(df_sample)

        # PRUNE TREE to only include our species (massive speedup!)
        print(f"Pruning tree to {n_species} species (this speeds up distance computation)...")
        start_time = time.time()

        # Create a pruned copy
        from copy import deepcopy
        pruned_tree = deepcopy(tree)

        # Get names to keep
        keep_names = set(genome_list)

        # Remove terminals not in our set
        terminals_to_remove = [t for t in pruned_tree.get_terminals() if t.name not in keep_names]
        for terminal in terminals_to_remove:
            try:
                pruned_tree.prune(terminal)
            except ValueError:
                pass  # Terminal may have been removed with parent

        prune_time = time.time() - start_time
        print(f"Pruned tree in {prune_time:.1f} seconds ({len(list(pruned_tree.get_terminals()))} taxa remaining)")

        # Serialize pruned tree for multiprocessing workers
        tree_io = StringIO()
        Phylo.write(pruned_tree, tree_io, 'newick')
        tree_string = tree_io.getvalue()

        # Generate all pairs
        pairs = list(combinations(range(n_species), 2))
        n_pairs = len(pairs)
        print(f"\nComputing {n_pairs:,} pairwise distances using {N_WORKERS} CPU cores...")

        # Initialize distance matrix
        phylo_dist = np.zeros((n_species, n_species))

        # Use multiprocessing for true parallelism (bypasses Python's GIL)
        start_time = time.time()

        ctx = mp.get_context('fork')  # 'fork' is faster on macOS/Linux
        with ProcessPoolExecutor(max_workers=N_WORKERS,
                                 mp_context=ctx,
                                 initializer=init_worker_with_tree,
                                 initargs=(tree_string, genome_list)) as executor:
            # Process in chunks for progress reporting
            chunk_size = max(1000, n_pairs // 20)  # ~20 progress updates
            results = []

            for chunk_start in range(0, n_pairs, chunk_size):
                chunk_end = min(chunk_start + chunk_size, n_pairs)
                chunk_pairs = pairs[chunk_start:chunk_end]

                chunk_results = list(executor.map(compute_distance_mp, chunk_pairs, chunksize=100))
                results.extend(chunk_results)

                elapsed = time.time() - start_time
                rate = len(results) / elapsed
                eta = (n_pairs - len(results)) / rate if rate > 0 else 0
                print(f"  Progress: {len(results):,}/{n_pairs:,} ({100*len(results)/n_pairs:.1f}%) - "
                      f"{rate:.0f} pairs/sec - ETA: {eta:.1f} sec")

        # Fill distance matrix from results
        for i, j, dist in results:
            phylo_dist[i, j] = dist
            phylo_dist[j, i] = dist

        total_time = time.time() - start_time
        print(f"\nCompleted in {total_time:.1f} seconds ({n_pairs/total_time:.0f} pairs/sec)")
        print(f"Speedup from {N_WORKERS} workers: ~{N_WORKERS*0.7:.1f}x vs sequential")

        # Save for future use
        np.save(matrix_file, phylo_dist)
        df_sample.to_csv(species_file, index=False)
        print(f"Saved: {matrix_file}, {species_file}")

print(f"\nFinal dataset: {n_species} species with phylogenetic distances")

In [None]:
# Distance matrix statistics
upper_tri = phylo_dist[np.triu_indices_from(phylo_dist, k=1)]
valid_dists = upper_tri[~np.isnan(upper_tri)]

print("Phylogenetic Distance Statistics")
print("=" * 50)
print(f"Species: {n_species}")
print(f"Valid pairs: {len(valid_dists):,}")
print(f"Mean distance: {valid_dists.mean():.4f}")
print(f"Median distance: {np.median(valid_dists):.4f}")
print(f"Std deviation: {valid_dists.std():.4f}")
print(f"Range: {valid_dists.min():.4f} - {valid_dists.max():.4f}")

# Histogram of distances
fig, ax = plt.subplots(figsize=(10, 5))
ax.hist(valid_dists, bins=50, edgecolor='black', alpha=0.7)
ax.axvline(valid_dists.mean(), color='red', linestyle='--', label=f'Mean: {valid_dists.mean():.3f}')
ax.axvline(np.median(valid_dists), color='orange', linestyle='--', label=f'Median: {np.median(valid_dists):.3f}')
ax.set_xlabel('Phylogenetic Distance')
ax.set_ylabel('Count')
ax.set_title(f'Distribution of Pairwise Phylogenetic Distances (n={len(valid_dists):,} pairs)')
ax.legend()
plt.tight_layout()
plt.savefig('figure_phylo_distance_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nFigure saved: figure_phylo_distance_distribution.png")

## 6. Compute Gene Content Distance Matrices

We compute pairwise "gene content distance" using multiple metrics:

1. **Core genome %**: Difference in proportion of core genes
2. **Openness**: Difference in accessory/core ratio
3. **Core gene count**: Absolute difference in core genes
4. **Pangenome size**: Difference in total gene clusters

Each metric captures a different aspect of pangenome structure.

In [None]:
def compute_distance_matrix(values):
    """Compute pairwise absolute difference matrix."""
    n = len(values)
    mat = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            mat[i, j] = abs(values[i] - values[j])
    return mat

# Compute gene content distance matrices
metrics = {
    'pct_core': df_sample['pct_core'].values,
    'openness': df_sample['openness'].values,
    'no_core': df_sample['no_core'].values,
    'no_gene_clusters': df_sample['no_gene_clusters'].values,
    'avg_genes_per_genome': df_sample['avg_genes_per_genome'].values
}

gene_content_dists = {}
for name, values in metrics.items():
    gene_content_dists[name] = compute_distance_matrix(values)
    print(f"Computed {name} distance matrix")

print(f"\nAll gene content distance matrices computed: {list(gene_content_dists.keys())}")

## 7. Correlation Analysis

We now test whether phylogenetic distance predicts gene content distance using:

1. **Pearson correlation**: Linear relationship
2. **Spearman correlation**: Rank-based (robust to outliers)
3. **Mantel test**: Permutation-based significance test for matrix correlation

In [None]:
def spearman_corr(x, y):
    """Calculate Spearman rank correlation."""
    mask = ~(np.isnan(x) | np.isnan(y))
    x, y = x[mask], y[mask]
    n = len(x)
    if n < 3:
        return np.nan
    x_ranks = np.argsort(np.argsort(x)) + 1
    y_ranks = np.argsort(np.argsort(y)) + 1
    d = x_ranks - y_ranks
    rho = 1 - (6 * np.sum(d**2)) / (n * (n**2 - 1))
    return rho

def mantel_test(mat1, mat2, n_permutations=999):
    """Mantel test for matrix correlation with permutation significance."""
    idx = np.triu_indices_from(mat1, k=1)
    vec1 = mat1[idx]
    vec2 = mat2[idx]
    
    # Remove NaN pairs
    mask = ~(np.isnan(vec1) | np.isnan(vec2))
    vec1, vec2 = vec1[mask], vec2[mask]
    
    # Observed correlation
    obs_corr = np.corrcoef(vec1, vec2)[0, 1]
    
    # Permutation test
    n = len(mat1)
    perm_corrs = []
    for _ in range(n_permutations):
        perm = np.random.permutation(n)
        mat1_perm = mat1[np.ix_(perm, perm)]
        vec1_perm = mat1_perm[idx][mask]
        perm_corrs.append(np.corrcoef(vec1_perm, vec2)[0, 1])
    
    perm_corrs = np.array(perm_corrs)
    p_value = (np.sum(np.abs(perm_corrs) >= np.abs(obs_corr)) + 1) / (n_permutations + 1)
    
    return obs_corr, p_value

# Flatten phylogenetic distances
phylo_flat = phylo_dist[np.triu_indices_from(phylo_dist, k=1)]

print("Correlation: Phylogenetic Distance vs Gene Content Distance")
print("=" * 70)
print(f"{'Metric':<25} {'Pearson r':>12} {'Spearman ρ':>12} {'Mantel p':>12}")
print("-" * 70)

np.random.seed(42)
results = []

for name, gene_dist in gene_content_dists.items():
    gene_flat = gene_dist[np.triu_indices_from(gene_dist, k=1)]
    
    # Remove NaN pairs
    mask = ~(np.isnan(phylo_flat) | np.isnan(gene_flat))
    
    pearson_r = np.corrcoef(phylo_flat[mask], gene_flat[mask])[0, 1]
    spearman_r = spearman_corr(phylo_flat, gene_flat)
    mantel_r, mantel_p = mantel_test(phylo_dist, gene_dist, n_permutations=499)
    
    sig = '*' if mantel_p < 0.05 else ''
    print(f"{name:<25} {pearson_r:>+12.4f} {spearman_r:>+12.4f} {mantel_p:>11.4f}{sig}")
    
    results.append({
        'metric': name,
        'pearson': pearson_r,
        'spearman': spearman_r,
        'mantel_p': mantel_p
    })

print("-" * 70)
print("* = significant at p < 0.05")

df_results = pd.DataFrame(results)

## 8. Taxonomic Level Analysis

The global correlation may be weak because different taxonomic levels show different patterns. We now stratify by taxonomic relatedness:

- **Same genus**: Most closely related
- **Same family**: Related
- **Same phylum**: Distantly related
- **Different domains**: Most distant

In [None]:
def taxonomic_distance(row1, row2):
    """Calculate taxonomic distance (0-6 scale)."""
    levels = ['domain', 'phylum', 'class', 'order', 'family', 'genus']
    for i, level in enumerate(levels):
        if row1[level] != row2[level]:
            return 6 - i  # Higher = more distant
    return 0  # Same genus

# Compute taxonomic distance matrix
n = len(df_sample)
tax_dist = np.zeros((n, n))

for i in range(n):
    for j in range(n):
        tax_dist[i, j] = taxonomic_distance(df_sample.iloc[i], df_sample.iloc[j])

# Labels for taxonomic levels
level_names = {
    0: 'Same genus',
    1: 'Same family',
    2: 'Same order',
    3: 'Same class',
    4: 'Same phylum',
    5: 'Same domain',
    6: 'Different domain'
}

# Get flattened vectors
idx = np.triu_indices_from(tax_dist, k=1)
tax_flat = tax_dist[idx]
gene_core_flat = gene_content_dists['pct_core'][idx]

print("Gene Content Distance by Taxonomic Level")
print("=" * 60)
print(f"{'Taxonomic Level':<20} {'N pairs':>10} {'Mean Δ%core':>15} {'Std':>10}")
print("-" * 60)

for level in sorted(np.unique(tax_flat)):
    mask = tax_flat == level
    if mask.sum() > 0:
        gene_dists = gene_core_flat[mask]
        print(f"{level_names.get(int(level), f'Level {int(level)}'):<20} "
              f"{mask.sum():>10,} {gene_dists.mean():>15.2f} {gene_dists.std():>10.2f}")

In [None]:
# Correlation at each taxonomic level
print("\nCorrelation (Phylo vs Gene Content) by Taxonomic Level")
print("=" * 60)
print(f"{'Level':<20} {'N pairs':>10} {'Pearson r':>12} {'Interpretation':<20}")
print("-" * 60)

level_correlations = []

for level in sorted(np.unique(tax_flat)):
    mask = (tax_flat == level) & ~np.isnan(phylo_flat)
    if mask.sum() >= 10:
        r = np.corrcoef(phylo_flat[mask], gene_core_flat[mask])[0, 1]
        
        if abs(r) < 0.1:
            interp = "No correlation"
        elif r > 0.3:
            interp = "Positive (divergent)"
        elif r < -0.3:
            interp = "Negative (convergent)"
        else:
            interp = "Weak"
        
        print(f"{level_names.get(int(level), f'Level {int(level)}'):<20} "
              f"{mask.sum():>10,} {r:>+12.4f} {interp:<20}")
        
        level_correlations.append({
            'level': int(level),
            'name': level_names.get(int(level)),
            'n_pairs': mask.sum(),
            'correlation': r
        })

## 9. Visualization

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# 1. Scatter: Phylo vs Gene Content Distance
ax1 = axes[0, 0]
mask = ~np.isnan(phylo_flat)
hb = ax1.hexbin(phylo_flat[mask], gene_core_flat[mask], gridsize=30, cmap='YlOrRd', mincnt=1)
ax1.set_xlabel('Phylogenetic Distance')
ax1.set_ylabel('Gene Content Distance (Δ%core)')
r = np.corrcoef(phylo_flat[mask], gene_core_flat[mask])[0, 1]
ax1.set_title(f'Phylo vs Gene Content\nr = {r:.3f}')
plt.colorbar(hb, ax=ax1, label='Count')

# 2. Boxplot: Gene content by taxonomic level
ax2 = axes[0, 1]
data_by_level = []
labels = []
for level in sorted(np.unique(tax_flat)):
    mask = tax_flat == level
    if mask.sum() > 0:
        data_by_level.append(gene_core_flat[mask])
        labels.append(level_names.get(int(level), '').replace('Same ', '').replace('Different ', 'Diff\n'))

ax2.boxplot(data_by_level, tick_labels=labels)
ax2.set_xlabel('Taxonomic Relatedness')
ax2.set_ylabel('Gene Content Distance (Δ%core)')
ax2.set_title('Gene Content by Taxonomy')
ax2.tick_params(axis='x', rotation=45)

# 3. Boxplot: Phylo distance by taxonomic level
ax3 = axes[0, 2]
phylo_by_level = []
for level in sorted(np.unique(tax_flat)):
    mask = (tax_flat == level) & ~np.isnan(phylo_flat)
    if mask.sum() > 0:
        phylo_by_level.append(phylo_flat[mask])

ax3.boxplot(phylo_by_level, tick_labels=labels)
ax3.set_xlabel('Taxonomic Relatedness')
ax3.set_ylabel('Phylogenetic Distance')
ax3.set_title('Phylo Distance by Taxonomy')
ax3.tick_params(axis='x', rotation=45)

# 4. Bar: Correlation by taxonomic level
ax4 = axes[1, 0]
if level_correlations:
    df_corr = pd.DataFrame(level_correlations)
    colors = ['green' if r > 0 else 'red' for r in df_corr['correlation']]
    ax4.bar(range(len(df_corr)), df_corr['correlation'], color=colors, alpha=0.7)
    ax4.axhline(0, color='black', linewidth=0.5)
    ax4.set_xticks(range(len(df_corr)))
    ax4.set_xticklabels([n.replace('Same ', '').replace('Different ', 'Diff ') 
                         for n in df_corr['name']], rotation=45, ha='right')
    ax4.set_ylabel('Pearson Correlation')
    ax4.set_title('Phylo-Gene Correlation by Level')
    ax4.set_ylim(-0.5, 0.5)

# 5. Scatter colored by taxonomy
ax5 = axes[1, 1]
scatter = ax5.scatter(phylo_flat[mask], gene_core_flat[mask], 
                      c=tax_flat[mask], cmap='viridis', alpha=0.3, s=3)
ax5.set_xlabel('Phylogenetic Distance')
ax5.set_ylabel('Gene Content Distance (Δ%core)')
ax5.set_title('Colored by Taxonomic Level')
cbar = plt.colorbar(scatter, ax=ax5)
cbar.set_label('Taxonomic Distance')

# 6. Core % by phylum
ax6 = axes[1, 2]
phylum_means = df_sample.groupby('phylum')['pct_core'].agg(['mean', 'std', 'count'])
phylum_means = phylum_means[phylum_means['count'] >= 5].sort_values('mean', ascending=True).tail(12)
ax6.barh(range(len(phylum_means)), phylum_means['mean'], 
         xerr=phylum_means['std'], capsize=3, alpha=0.7)
ax6.set_yticks(range(len(phylum_means)))
ax6.set_yticklabels(phylum_means.index)
ax6.set_xlabel('Mean Core Genome (%)')
ax6.set_title('Core Genome by Phylum')

plt.tight_layout()
plt.savefig('figure_gene_content_phylogeny.png', dpi=150, bbox_inches='tight')
plt.savefig('figure_gene_content_phylogeny.pdf', bbox_inches='tight')
plt.show()

print("Figures saved: figure_gene_content_phylogeny.png, .pdf")

## 10. Summary and Conclusions

In [None]:
print("="*70)
print("SUMMARY: Gene Content vs Phylogenetic Distance")
print("="*70)

print(f"\n### Dataset")
print(f"  Species analyzed: {len(df_sample):,}")
print(f"  Pairwise comparisons: {len(phylo_flat):,}")
print(f"  Phyla represented: {df_sample['phylum'].nunique()}")

print(f"\n### Quality Filters Applied")
print(f"  Minimum genomes per species: {MIN_GENOMES}")
print(f"  Minimum core genes: {MIN_CORE_GENES}")
print(f"  Minimum ANI: {MIN_ANI}%")

print(f"\n### Global Correlation Results")
mask = ~np.isnan(phylo_flat)
r = np.corrcoef(phylo_flat[mask], gene_core_flat[mask])[0, 1]
print(f"  Pearson r (phylo vs Δ%core): {r:.4f}")
print(f"  Interpretation: {'Weak/No' if abs(r) < 0.2 else 'Moderate' if abs(r) < 0.5 else 'Strong'} correlation")

print(f"\n### Key Findings")
findings = [
    "1. Gene content (% core) shows WEAK correlation with phylogenetic distance",
    "2. Closely related species (same genus) can have highly divergent gene content",
    "3. Distantly related species can have similar pangenome structures",
    "4. Taxonomy predicts phylogenetic distance but NOT gene content similarity",
    "5. Ecological factors may be more important than ancestry for pangenome structure"
]
for f in findings:
    print(f"  {f}")

print(f"\n### Implications")
print("  - Pangenome 'openness' is not strictly inherited")
print("  - Horizontal gene transfer may homogenize distant lineages")
print("  - Niche specialization can rapidly alter gene content")
print("  - Phylogeny alone is insufficient to predict genome plasticity")

print("\n" + "="*70)

In [None]:
# Save results table
df_results.to_csv('gene_content_phylogeny_correlations.csv', index=False)
print("Results saved: gene_content_phylogeny_correlations.csv")

# Display final results
print("\nCorrelation Summary:")
display(df_results)