# Explore Cell-Type Specificity in ATAC-seq Peaks

This notebook visualizes:
1. Differential peak analysis results (tau distribution, peak categories)
2. Motif baseline classifier performance
3. Cross-model specificity analysis (after DNN training)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
from pathlib import Path

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (10, 6)

CELL_TYPES = ['Cardiomyocyte', 'Coronary_EC', 'Fibroblast', 'Macrophage', 'Pericytes']
CT_COLORS = dict(zip(CELL_TYPES, sns.color_palette('Set2', len(CELL_TYPES))))

## 1. Peak Annotations

In [None]:
# Load peak annotations
annotations = pd.read_csv('data/peak_annotations.csv')
print(f'Total peaks: {len(annotations):,}')
print(f'\nCategory distribution:')
print(annotations['category'].value_counts())
annotations.head()

In [None]:
# Tau distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Histogram
axes[0].hist(annotations['tau'], bins=50, edgecolor='black', alpha=0.7)
axes[0].axvline(0.3, color='red', linestyle='--', label='Shared threshold (0.3)')
axes[0].axvline(0.6, color='green', linestyle='--', label='Specific threshold (0.6)')
axes[0].set_xlabel('Tau specificity index')
axes[0].set_ylabel('Number of peaks')
axes[0].set_title('Tau Distribution')
axes[0].legend()

# Category pie chart
cat_counts = annotations['category'].value_counts()
axes[1].pie(cat_counts.values, labels=cat_counts.index, autopct='%1.1f%%',
            colors=['#2ecc71', '#e74c3c', '#f39c12'])
axes[1].set_title('Peak Categories')

plt.tight_layout()
plt.show()

In [None]:
# Specific peaks per cell type
specific = annotations[annotations['category'] == 'specific']
ct_counts = specific['specific_celltype'].value_counts()

fig, ax = plt.subplots(figsize=(8, 5))
bars = ax.bar(ct_counts.index, ct_counts.values,
              color=[CT_COLORS.get(ct, 'gray') for ct in ct_counts.index])
ax.set_ylabel('Number of specific peaks')
ax.set_title('Cell-Type Specific Peaks')
for bar, val in zip(bars, ct_counts.values):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 100,
            f'{val:,}', ha='center', va='bottom')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

## 2. Motif Baseline Results

In [None]:
# Load motif baseline results
motif_results_path = Path('results/motif_baseline/results.json')
if motif_results_path.exists():
    with open(motif_results_path) as f:
        motif_results = json.load(f)
    
    # Per-cell-type AUC
    binary = motif_results['binary_classifiers']
    fig, ax = plt.subplots(figsize=(8, 5))
    cts = list(binary.keys())
    aucs = [binary[ct]['auc'] for ct in cts]
    bars = ax.bar(cts, aucs, color=[CT_COLORS.get(ct, 'gray') for ct in cts])
    ax.axhline(0.65, color='red', linestyle='--', label='Threshold (0.65)')
    ax.axhline(0.5, color='gray', linestyle=':', label='Random')
    ax.set_ylabel('AUC')
    ax.set_title('Motif Baseline: Binary Classification AUC')
    ax.set_ylim(0, 1)
    ax.legend()
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()
    
    # Top motifs
    multi = motif_results.get('multiclass', {})
    if 'top_motifs' in multi:
        print(f"Multi-class accuracy: {multi['accuracy']:.3f}")
        print(f"\nTop 15 motifs:")
        for name, imp in multi['top_motifs'][:15]:
            print(f'  {name}: {imp:.4f}')
else:
    print('Motif baseline results not found. Run scripts/03_motif_baseline.py first.')

## 3. Cross-Model Specificity (after DNN training)

In [None]:
# Load specificity results
spec_results_path = Path('results/specificity/results.json')
if spec_results_path.exists():
    with open(spec_results_path) as f:
        spec_results = json.load(f)
    
    # Per-model metrics
    print('Per-model test set metrics:')
    for ct in CELL_TYPES:
        if ct in spec_results:
            r = spec_results[ct]
            print(f"  {ct}: profile_r={r['profile_r']:.3f}, "
                  f"count_r={r['count_r']:.3f}, jsd={r['jsd']:.4f}")
    
    # Cross-model correlation heatmap
    if 'cross_model_correlation' in spec_results:
        corr_data = spec_results['cross_model_correlation']
        corr_matrix = np.array(corr_data['matrix'])
        ct_labels = corr_data['cell_types']
        
        fig, ax = plt.subplots(figsize=(8, 6))
        sns.heatmap(corr_matrix, xticklabels=ct_labels, yticklabels=ct_labels,
                    annot=True, fmt='.3f', cmap='RdYlBu_r', center=0.9,
                    vmin=0.5, vmax=1.0, ax=ax)
        ax.set_title('Cross-Model Prediction Correlation')
        plt.tight_layout()
        plt.show()
        
        print(f"\nMean off-diagonal correlation: {corr_data['mean_off_diagonal']:.3f}")
    
    # Specificity AUC
    if 'specificity_auc' in spec_results:
        spec_auc = spec_results['specificity_auc']
        print(f"\nSpecificity AUC:")
        for ct, auc in spec_auc.items():
            print(f"  {ct}: {auc:.3f}")
else:
    print('Specificity results not found. Run scripts/07_evaluate_specificity.py first.')

In [None]:
# Explore the logCPM data directly
logcpm = pd.read_csv('data/YoungSed_DownSample_Peak_logCPM_CellType.csv', index_col=0)
logcpm.columns = [c.replace('.', '_') for c in logcpm.columns]
print(f'Shape: {logcpm.shape}')
print(f'\nColumn means:')
print(logcpm.mean())
print(f'\nCorrelation between cell types:')
print(logcpm.corr().round(3))

In [None]:
# Pairwise scatter: check for visual separation
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
pairs = [(0,1), (0,2), (0,3), (1,2), (1,3), (2,4)]

for ax, (i, j) in zip(axes.flat, pairs):
    ct_i, ct_j = CELL_TYPES[i], CELL_TYPES[j]
    # Subsample for speed
    idx = np.random.choice(len(logcpm), min(5000, len(logcpm)), replace=False)
    ax.scatter(logcpm[ct_i].iloc[idx], logcpm[ct_j].iloc[idx],
              alpha=0.1, s=1, rasterized=True)
    ax.set_xlabel(ct_i)
    ax.set_ylabel(ct_j)
    r = logcpm[[ct_i, ct_j]].corr().iloc[0,1]
    ax.set_title(f'r = {r:.3f}')

plt.suptitle('Cell-Type logCPM Pairwise Scatter', fontsize=14)
plt.tight_layout()
plt.show()