# 06 â€” Cross-Species Differential Accessibility Analysis

DESeq2-based analysis of cross-species ATAC-seq pseudobulk quantification:
1. Load & merge quantification across 6 species
2. PCA of cell types across species
3. DESeq2: find human-specific peaks
4. Heatmap clustering of differentially accessible peaks

In [None]:
import os, re, warnings
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from pydeseq2.dds import DeseqDataSet
from pydeseq2.ds import DeseqStats

warnings.filterwarnings('ignore', category=UserWarning, module='numexpr')
warnings.filterwarnings('ignore', category=FutureWarning)

sns.set_style('whitegrid')
sns.set_context('notebook', font_scale=1.1)

BASE = '/cluster/project/treutlein/USERS/jjans/analysis/adult_intestine/peaks'
QUANT_DIR = f'{BASE}/cross_species_consensus_v2/08_quantification'
OUT_DIR = f'{BASE}/cross_species_consensus_v2/09_deseq2'
os.makedirs(OUT_DIR, exist_ok=True)

SPECIES = ['Human', 'Bonobo', 'Chimpanzee', 'Gorilla', 'Macaque', 'Marmoset']
FRAG_THRESH = 1e6  # minimum total fragments per cell type

print('Setup complete')
print(f'Output directory: {OUT_DIR}')

## 1. Load & merge quantification data

In [None]:
# Load and merge all species, applying fragment threshold
all_df = None
for species in SPECIES:
    df = pd.read_feather(f'{QUANT_DIR}/quantification_{species}.feather')
    df.index = df['index']
    df.drop('index', axis=1, inplace=True)
    df.columns = [re.sub(r'\.fragments\.tsv', '', x) for x in df.columns]
    df.columns = [f'{species}:{x}' for x in df.columns]

    # Filter cell types by fragment threshold
    cell_type_counts = df.sum()
    selected = list(cell_type_counts[cell_type_counts > FRAG_THRESH].index)
    df = df[selected].copy()
    print(f'  {species}: {len(selected)} cell types kept')

    if all_df is None:
        all_df = df
    else:
        all_df = pd.merge(df, all_df, left_index=True, right_index=True)

print(f'\nMerged matrix: {all_df.shape[0]:,} peaks x {all_df.shape[1]} samples')
print(f'Non-zero fraction: {(all_df > 0).mean().mean():.3f}')

In [None]:
# Parse species and cell type from column names
sample_info = pd.DataFrame({
    'sample': all_df.columns,
    'species': [c.split(':')[0] for c in all_df.columns],
    'cell_type': [c.split(':')[1] for c in all_df.columns],
})
sample_info['is_human'] = (sample_info['species'] == 'Human').astype(str)
sample_info.index = sample_info['sample']

print(f'Samples per species:')
print(sample_info['species'].value_counts().to_string())
print(f'\nShared cell types across all species:')
from functools import reduce
per_species_ct = sample_info.groupby('species')['cell_type'].apply(set)
shared_ct = reduce(lambda a, b: a & b, per_species_ct)
print(f'  {len(shared_ct)} cell types: {sorted(shared_ct)}')

## 2. PCA of cell types across species

In [None]:
# Normalize: log2(CPM + 1) for PCA
cpm = all_df.div(all_df.sum(axis=0), axis=1) * 1e6
log_cpm = np.log2(cpm + 1)

# Filter to top variable peaks for PCA (top 20K by variance)
peak_var = log_cpm.var(axis=1)
top_peaks = peak_var.nlargest(20_000).index
X = log_cpm.loc[top_peaks].T  # samples x peaks

# Standardize and run PCA
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
pca = PCA(n_components=10)
pcs = pca.fit_transform(X_scaled)

pca_df = pd.DataFrame(
    pcs[:, :5],
    columns=[f'PC{i+1}' for i in range(5)],
    index=X.index
)
pca_df['species'] = sample_info.loc[pca_df.index, 'species'].values
pca_df['cell_type'] = sample_info.loc[pca_df.index, 'cell_type'].values

var_explained = pca.explained_variance_ratio_ * 100
print('Variance explained:')
for i in range(5):
    print(f'  PC{i+1}: {var_explained[i]:.1f}%')

In [None]:
# PCA plot: color by species, shape by cell type
species_colors = {
    'Human': '#E41A1C', 'Bonobo': '#377EB8', 'Chimpanzee': '#4DAF4A',
    'Gorilla': '#984EA3', 'Macaque': '#FF7F00', 'Marmoset': '#A65628',
}

# Assign a distinct marker to each cell type
_MARKERS = ['o', 's', 'D', '^', 'v', '<', '>', 'p', 'P', '*',
            'h', 'H', 'X', 'd', '8', '+', 'x', '1', '2', '3',
            '4', '|', '_', 'o', 's', 'D', '^', 'v', '<', '>', 'p', 'P', '*']
unique_ct = sorted(pca_df['cell_type'].unique())
ct_markers = dict(zip(unique_ct, _MARKERS[:len(unique_ct)]))

fig, axes = plt.subplots(1, 2, figsize=(20, 8))

# Plot each species x cell_type combination
for sp in SPECIES:
    for ct in unique_ct:
        mask = (pca_df['species'] == sp) & (pca_df['cell_type'] == ct)
        if mask.sum() == 0:
            continue
        axes[0].scatter(
            pca_df.loc[mask, 'PC1'], pca_df.loc[mask, 'PC2'],
            c=species_colors[sp], marker=ct_markers[ct],
            s=80, alpha=0.8, edgecolors='white', linewidth=0.5
        )
        axes[1].scatter(
            pca_df.loc[mask, 'PC1'], pca_df.loc[mask, 'PC3'],
            c=species_colors[sp], marker=ct_markers[ct],
            s=80, alpha=0.8, edgecolors='white', linewidth=0.5
        )

axes[0].set_xlabel(f'PC1 ({var_explained[0]:.1f}%)')
axes[0].set_ylabel(f'PC2 ({var_explained[1]:.1f}%)')
axes[0].set_title('PCA: PC1 vs PC2')

axes[1].set_xlabel(f'PC1 ({var_explained[0]:.1f}%)')
axes[1].set_ylabel(f'PC3 ({var_explained[2]:.1f}%)')
axes[1].set_title('PCA: PC1 vs PC3')

# Species color legend (left panel)
from matplotlib.patches import Patch
species_handles = [Patch(facecolor=c, label=sp) for sp, c in species_colors.items()]
axes[0].legend(handles=species_handles, title='Species',
               bbox_to_anchor=(1.02, 1), loc='upper left')

# Cell type marker legend (right panel)
from matplotlib.lines import Line2D
ct_handles = [Line2D([0], [0], marker=ct_markers[ct], color='grey', linestyle='None',
                     markersize=7, label=ct.replace('_', ' '))
              for ct in unique_ct]
axes[1].legend(handles=ct_handles, title='Cell type',
               bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=6, ncol=1)

plt.tight_layout()
fig.savefig(f'{OUT_DIR}/pca_species.pdf', bbox_inches='tight', dpi=150)
fig.savefig(f'{OUT_DIR}/pca_species.png', bbox_inches='tight', dpi=150)
plt.show()
print(f'Saved to {OUT_DIR}/pca_species.pdf')

In [None]:
# PCA plot: color by cell type
unique_ct = sorted(pca_df['cell_type'].unique())
ct_palette = dict(zip(unique_ct, sns.color_palette('husl', len(unique_ct))))

fig, axes = plt.subplots(1, 2, figsize=(22, 8))

# Markers for species
species_markers = {
    'Human': 'o', 'Bonobo': 's', 'Chimpanzee': 'D',
    'Gorilla': '^', 'Macaque': 'v', 'Marmoset': 'P',
}

for sp in SPECIES:
    mask_sp = pca_df['species'] == sp
    for ct in unique_ct:
        mask = mask_sp & (pca_df['cell_type'] == ct)
        if mask.sum() == 0:
            continue
        label = ct if sp == SPECIES[0] else None
        axes[0].scatter(
            pca_df.loc[mask, 'PC1'], pca_df.loc[mask, 'PC2'],
            c=[ct_palette[ct]], marker=species_markers[sp],
            s=80, alpha=0.8, edgecolors='white', linewidth=0.5, label=label
        )
        axes[1].scatter(
            pca_df.loc[mask, 'PC1'], pca_df.loc[mask, 'PC3'],
            c=[ct_palette[ct]], marker=species_markers[sp],
            s=80, alpha=0.8, edgecolors='white', linewidth=0.5
        )

axes[0].set_xlabel(f'PC1 ({var_explained[0]:.1f}%)')
axes[0].set_ylabel(f'PC2 ({var_explained[1]:.1f}%)')
axes[0].set_title('PCA colored by cell type')
axes[0].legend(title='Cell type', bbox_to_anchor=(1.02, 1), loc='upper left',
               fontsize=7, ncol=2)

axes[1].set_xlabel(f'PC1 ({var_explained[0]:.1f}%)')
axes[1].set_ylabel(f'PC3 ({var_explained[2]:.1f}%)')
axes[1].set_title('PCA colored by cell type')

# Add species marker legend
from matplotlib.lines import Line2D
marker_handles = [Line2D([0], [0], marker=m, color='grey', linestyle='None',
                         markersize=8, label=sp)
                  for sp, m in species_markers.items()]
axes[1].legend(handles=marker_handles, title='Species',
               bbox_to_anchor=(1.02, 1), loc='upper left')

plt.tight_layout()
fig.savefig(f'{OUT_DIR}/pca_celltype.pdf', bbox_inches='tight', dpi=150)
fig.savefig(f'{OUT_DIR}/pca_celltype.png', bbox_inches='tight', dpi=150)
plt.show()
print(f'Saved to {OUT_DIR}/pca_celltype.pdf')

## 3. DESeq2: Human-specific differential accessibility

Compare Human vs all other species to find peaks with human-specific accessibility.

In [None]:
# Pre-filter peaks: keep those with at least 10 counts across all samples
min_total_counts = 10
peak_sums = all_df.sum(axis=1)
keep_peaks = peak_sums[peak_sums >= min_total_counts].index
counts_filt = all_df.loc[keep_peaks]
print(f'Peaks after filtering (>= {min_total_counts} total counts): '
      f'{counts_filt.shape[0]:,} / {all_df.shape[0]:,}')

# Prepare count matrix: samples x peaks (DESeq2 convention)
counts_T = counts_filt.T.copy()
counts_T = counts_T.astype(int)

# Metadata
metadata = sample_info[['species', 'cell_type', 'is_human']].copy()
metadata = metadata.loc[counts_T.index]

print(f'Count matrix: {counts_T.shape[0]} samples x {counts_T.shape[1]:,} peaks')
print(f'Condition: is_human (True: {(metadata["is_human"]=="True").sum()}, '
      f'False: {(metadata["is_human"]=="False").sum()})')

In [None]:
%%time
# Run DESeq2
dds = DeseqDataSet(
    counts=counts_T,
    metadata=metadata,
    design='~is_human',
    ref_level=['is_human', 'False'],
    quiet=False,
    n_cpus=8,
)

dds.deseq2()
print('\nDESeq2 fitting complete')
print(f'Size factors range: {dds.obsm["size_factors"].min():.3f} - '
      f'{dds.obsm["size_factors"].max():.3f}')

In [None]:
%%time
# Extract results: Human vs non-Human
stat_res = DeseqStats(
    dds,
    contrast=['is_human', 'True', 'False'],
    alpha=0.05,
    n_cpus=8,
)
stat_res.summary()

results = stat_res.results_df.copy()
results = results.dropna(subset=['padj'])

print(f'\nResults with valid padj: {len(results):,}')
print(f'Significant (padj < 0.05): {(results["padj"] < 0.05).sum():,}')
print(f'  Up in Human (LFC > 1): {((results["padj"] < 0.05) & (results["log2FoldChange"] > 1)).sum():,}')
print(f'  Down in Human (LFC < -1): {((results["padj"] < 0.05) & (results["log2FoldChange"] < -1)).sum():,}')

In [None]:
# Save full results
results.to_csv(f'{OUT_DIR}/deseq2_human_vs_nonhuman.csv')
print(f'Full results saved to {OUT_DIR}/deseq2_human_vs_nonhuman.csv')

# Save significant human-specific peaks
sig_up = results[(results['padj'] < 0.05) & (results['log2FoldChange'] > 1)].sort_values('padj')
sig_down = results[(results['padj'] < 0.05) & (results['log2FoldChange'] < -1)].sort_values('padj')

sig_up.to_csv(f'{OUT_DIR}/human_specific_peaks_up.csv')
sig_down.to_csv(f'{OUT_DIR}/human_specific_peaks_down.csv')

print(f'Human-gained peaks: {len(sig_up):,} (saved)')
print(f'Human-lost peaks: {len(sig_down):,} (saved)')
print(f'\nTop 10 human-gained peaks:')
sig_up.head(10)

In [None]:
# Volcano plot
fig, ax = plt.subplots(figsize=(10, 8))

lfc = results['log2FoldChange']
neg_log_p = -np.log10(results['padj'])

# Classify points
sig_mask = results['padj'] < 0.05
up_mask = sig_mask & (lfc > 1)
down_mask = sig_mask & (lfc < -1)
ns_mask = ~(up_mask | down_mask)

ax.scatter(lfc[ns_mask], neg_log_p[ns_mask], c='lightgrey', s=3, alpha=0.5, label='NS', rasterized=True)
ax.scatter(lfc[down_mask], neg_log_p[down_mask], c='#377EB8', s=5, alpha=0.6,
           label=f'Down in Human ({down_mask.sum():,})', rasterized=True)
ax.scatter(lfc[up_mask], neg_log_p[up_mask], c='#E41A1C', s=5, alpha=0.6,
           label=f'Up in Human ({up_mask.sum():,})', rasterized=True)

ax.axhline(-np.log10(0.05), ls='--', c='grey', lw=0.8)
ax.axvline(1, ls='--', c='grey', lw=0.8)
ax.axvline(-1, ls='--', c='grey', lw=0.8)

ax.set_xlabel('log2 Fold Change (Human / non-Human)')
ax.set_ylabel('-log10(adjusted p-value)')
ax.set_title('Differential Accessibility: Human vs Other Primates')
ax.legend(loc='upper right')

plt.tight_layout()
fig.savefig(f'{OUT_DIR}/volcano_human_vs_nonhuman.pdf', bbox_inches='tight', dpi=150)
fig.savefig(f'{OUT_DIR}/volcano_human_vs_nonhuman.png', bbox_inches='tight', dpi=150)
plt.show()
print(f'Saved to {OUT_DIR}/volcano_human_vs_nonhuman.pdf')

## 4. Heatmap: clustering of cell types by differential peaks

In [None]:
# Select top significant peaks for heatmap (top 500 by padj with |LFC| > 1)
sig_all = results[(results['padj'] < 0.05) & (results['log2FoldChange'].abs() > 1)]
top_sig = sig_all.nsmallest(min(500, len(sig_all)), 'padj')
print(f'Peaks for heatmap: {len(top_sig)}')

# Prepare log2-CPM matrix for these peaks
heatmap_data = log_cpm.loc[top_sig.index]

# Z-score across samples (columns) for each peak
heatmap_z = heatmap_data.subtract(heatmap_data.mean(axis=1), axis=0).div(
    heatmap_data.std(axis=1), axis=0
)
# Clip extreme z-scores for visualization
heatmap_z = heatmap_z.clip(-3, 3)

# Sort peaks by LFC (human-gained at top, human-lost at bottom)
peak_order = top_sig.sort_values('log2FoldChange', ascending=False).index
heatmap_z = heatmap_z.loc[peak_order]

print(f'Heatmap matrix: {heatmap_z.shape}')
print(f'Human-gained: {(top_sig["log2FoldChange"] > 0).sum()}, '
      f'Human-lost: {(top_sig["log2FoldChange"] < 0).sum()}')

In [None]:
# Clustered heatmap with species & cell type annotation
# Column annotations
col_species = sample_info.loc[heatmap_z.columns, 'species']
col_ct = sample_info.loc[heatmap_z.columns, 'cell_type']

species_lut = species_colors
ct_lut = ct_palette

col_colors = pd.DataFrame({
    'Species': col_species.map(species_lut),
    'Cell type': col_ct.map(ct_lut),
})

# Row annotation: direction of change
direction = top_sig.loc[peak_order, 'log2FoldChange'].apply(
    lambda x: '#E41A1C' if x > 0 else '#377EB8'
)
row_colors = pd.Series(direction.values, index=peak_order, name='Direction')

g = sns.clustermap(
    heatmap_z,
    col_colors=col_colors,
    row_colors=row_colors,
    row_cluster=False,  # keep sorted by LFC
    col_cluster=True,   # cluster samples
    cmap='RdBu_r',
    center=0,
    vmin=-3, vmax=3,
    figsize=(18, 12),
    xticklabels=True,
    yticklabels=False,
    cbar_kws={'label': 'Z-score (log2 CPM)'},
    dendrogram_ratio=(0.1, 0.1),
    colors_ratio=0.03,
)
g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), rotation=90, fontsize=6)
g.fig.suptitle('Top Differentially Accessible Peaks: Human vs Other Primates',
               y=1.02, fontsize=14)

# Add legends
from matplotlib.patches import Patch
species_handles = [Patch(facecolor=c, label=s) for s, c in species_lut.items()]
dir_handles = [Patch(facecolor='#E41A1C', label='Up in Human'),
               Patch(facecolor='#377EB8', label='Down in Human')]

leg1 = g.fig.legend(handles=species_handles, title='Species',
                    loc='upper right', bbox_to_anchor=(0.98, 0.95), fontsize=8)
leg2 = g.fig.legend(handles=dir_handles, title='Direction',
                    loc='upper right', bbox_to_anchor=(0.98, 0.75), fontsize=8)
g.fig.add_artist(leg1)

g.fig.savefig(f'{OUT_DIR}/heatmap_human_specific.pdf', bbox_inches='tight', dpi=150)
g.fig.savefig(f'{OUT_DIR}/heatmap_human_specific.png', bbox_inches='tight', dpi=150)
plt.show()
print(f'Saved to {OUT_DIR}/heatmap_human_specific.pdf')

In [None]:
# Heatmap: sample-to-sample correlation (all samples)
# Use top variable peaks for a more informative correlation
top_var_data = log_cpm.loc[top_peaks]
corr = top_var_data.corr(method='pearson')

g2 = sns.clustermap(
    corr,
    col_colors=col_colors,
    row_colors=col_colors,
    cmap='RdYlBu_r',
    vmin=0, vmax=1,
    figsize=(16, 14),
    xticklabels=True,
    yticklabels=True,
    cbar_kws={'label': 'Pearson correlation'},
    dendrogram_ratio=0.1,
    colors_ratio=0.02,
)
g2.ax_heatmap.set_xticklabels(g2.ax_heatmap.get_xticklabels(), rotation=90, fontsize=5)
g2.ax_heatmap.set_yticklabels(g2.ax_heatmap.get_yticklabels(), fontsize=5)
g2.fig.suptitle('Sample Correlation Heatmap (top 20K variable peaks)', y=1.01, fontsize=14)

g2.fig.savefig(f'{OUT_DIR}/heatmap_correlation.pdf', bbox_inches='tight', dpi=150)
g2.fig.savefig(f'{OUT_DIR}/heatmap_correlation.png', bbox_inches='tight', dpi=150)
plt.show()
print(f'Saved to {OUT_DIR}/heatmap_correlation.pdf')

## 5. Summary

In [None]:
print('Analysis complete')
print('=' * 60)
print(f'Input: {all_df.shape[0]:,} peaks x {all_df.shape[1]} samples')
print(f'Species: {", ".join(SPECIES)}')
print(f'Fragment threshold: {FRAG_THRESH:,.0f}')
print()
print(f'DESeq2 results (Human vs non-Human):')
print(f'  Tested peaks: {len(results):,}')
print(f'  Significant (padj < 0.05, |LFC| > 1):')
print(f'    Human-gained: {len(sig_up):,}')
print(f'    Human-lost: {len(sig_down):,}')
print()
print(f'Output files in {OUT_DIR}/:')
for f in sorted(os.listdir(OUT_DIR)):
    size = os.path.getsize(f'{OUT_DIR}/{f}')
    print(f'  {f} ({size/1e6:.1f} MB)' if size > 1e6 else f'  {f} ({size/1e3:.0f} KB)')