# 1.3 Jaccard Similarity Analysis and Filtering

Calculate Jaccard similarity between GO terms and filter overlapping terms.

## Inputs
- `output/intermediate/hetio_bppg_2016_stable.csv` (2016 stable genes)
- `output/intermediate/upd_go_bp_2024_added.csv` (2024 added genes)
- `output/intermediate/dataset1_all_growth.csv`
- `output/intermediate/dataset2_parents.csv`
- `input/hetionet_neo4j_go_ids_nr.csv`
- `input/hetionet_neo4j_genes_ids_nr.csv`

## Outputs
- `output/intermediate/dataset1_filtered.csv`
- `output/intermediate/dataset2_filtered.csv`
- `output/intermediate/hetio_bppg_dataset1_filtered.csv` (2016 stable, Jaccard filtered)
- `output/intermediate/hetio_bppg_dataset2_filtered.csv` (2016 stable, Jaccard filtered)
- `output/intermediate/hetio_bppg_dataset1_2024_filtered.csv` (2024 added, Jaccard filtered)
- `output/intermediate/hetio_bppg_dataset2_2024_filtered.csv` (2024 added, Jaccard filtered)
- `output/jaccard_similarity/` (cached matrices)

## Description
This notebook calculates Jaccard similarity between GO terms based on
gene overlap. Terms with Jaccard > 0.1 are clustered, and representatives
are selected based on greatest percent change in gene count.

**Gene Classification:**
- 2016 datasets contain only stable genes (present in both 2016 and 2024)
- 2024 datasets contain only added genes (only in 2024, not in 2016)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sys
import os
from pathlib import Path

# Setup repo root for consistent paths
# Works whether notebook is run from repo root or notebooks/ subdirectory
if Path.cwd().name == "notebooks":
    repo_root = Path("..").resolve()
else:
    repo_root = Path.cwd()

sys.path.insert(0, str(repo_root))
from src.similarity import load_or_calculate_jaccard
from src.filtering import filter_overlapping_go_terms

# Create output directories
(repo_root / 'output/jaccard_similarity').mkdir(parents=True, exist_ok=True)
(repo_root / 'output/intermediate').mkdir(parents=True, exist_ok=True)

print(f"Repo root: {repo_root}")

In [None]:
# Load data from previous notebooks
hetio_BPpG_2016_stable = pd.read_csv(
    repo_root / 'output/intermediate/hetio_bppg_2016_stable.csv'
)
upd_go_bp_2024_added = pd.read_csv(
    repo_root / 'output/intermediate/upd_go_bp_2024_added.csv'
)
dataset1_all_growth = pd.read_csv(
    repo_root / 'output/intermediate/dataset1_all_growth.csv'
)
dataset2_parents = pd.read_csv(
    repo_root / 'output/intermediate/dataset2_parents.csv'
)

print('Loaded classified gene datasets:')
print(f'  hetio_BPpG_2016_stable: {len(hetio_BPpG_2016_stable)} rows (stable genes)')
print(f'  upd_go_bp_2024_added: {len(upd_go_bp_2024_added)} rows (added genes)')
print(f'  dataset1_all_growth: {len(dataset1_all_growth)} GO terms')
print(f'  dataset2_parents: {len(dataset2_parents)} GO terms')

### Map to Neo4j Node IDs

The Docker API requires Neo4j internal node IDs rather than biological identifiers (GO IDs, Entrez Gene IDs). This section loads the complete mapping files and filters them to only the IDs present in our datasets.

In [None]:
hetionet_neo4j_go_ids_complete = pd.read_csv(repo_root / 'input/hetionet_neo4j_go_ids_nr.csv')
hetionet_neo4j_genes_complete = pd.read_csv(repo_root / 'input/hetionet_neo4j_genes_ids_nr.csv')

print("Complete Neo4j ID Mappings")
print("=" * 60)
print(f"Total GO term mappings: {len(hetionet_neo4j_go_ids_complete)}")
print(f"Total gene mappings: {len(hetionet_neo4j_genes_complete)}")

go_terms_dataset1 = set(dataset1_all_growth['go_id'])
go_terms_dataset2 = set(dataset2_parents['go_id'])

genes_2016_stable_dataset1 = set(hetio_BPpG_2016_stable[
    hetio_BPpG_2016_stable['go_id'].isin(go_terms_dataset1)
]['entrez_gene_id'])

genes_2016_stable_dataset2 = set(hetio_BPpG_2016_stable[
    hetio_BPpG_2016_stable['go_id'].isin(go_terms_dataset2)
]['entrez_gene_id'])

genes_2024_added_dataset1 = set(upd_go_bp_2024_added[
    upd_go_bp_2024_added['go_id'].isin(go_terms_dataset1)
]['entrez_gene_id'])

genes_2024_added_dataset2 = set(upd_go_bp_2024_added[
    upd_go_bp_2024_added['go_id'].isin(go_terms_dataset2)
]['entrez_gene_id'])

print(f"\nDataset 1 (All Growth Terms):")
print(f"  GO terms: {len(go_terms_dataset1)}")
print(f"  Genes (2016 stable): {len(genes_2016_stable_dataset1)}")
print(f"  Genes (2024 added): {len(genes_2024_added_dataset1)}")
print(f"  Total unique genes: {len(genes_2016_stable_dataset1 | genes_2024_added_dataset1)}")

print(f"\nDataset 2 (Parent Terms):")
print(f"  GO terms: {len(go_terms_dataset2)}")
print(f"  Genes (2016 stable): {len(genes_2016_stable_dataset2)}")
print(f"  Genes (2024 added): {len(genes_2024_added_dataset2)}")
print(f"  Total unique genes: {len(genes_2016_stable_dataset2 | genes_2024_added_dataset2)}")

In [None]:
neo4j_go_mapping_dataset1 = hetionet_neo4j_go_ids_complete[
    hetionet_neo4j_go_ids_complete['go_id'].isin(go_terms_dataset1)
].copy()

neo4j_gene_mapping_dataset1 = hetionet_neo4j_genes_complete[
    hetionet_neo4j_genes_complete['entrez_gene_id'].isin(
        genes_2016_stable_dataset1 | genes_2024_added_dataset1
    )
].copy()

neo4j_go_mapping_dataset2 = hetionet_neo4j_go_ids_complete[
    hetionet_neo4j_go_ids_complete['go_id'].isin(go_terms_dataset2)
].copy()

neo4j_gene_mapping_dataset2 = hetionet_neo4j_genes_complete[
    hetionet_neo4j_genes_complete['entrez_gene_id'].isin(
        genes_2016_stable_dataset2 | genes_2024_added_dataset2
    )
].copy()

print("Dataset 1 (All Growth Terms) - Neo4j Mappings")
print("=" * 60)
print(f"GO term mappings: {len(neo4j_go_mapping_dataset1)}")
print(f"Gene mappings: {len(neo4j_gene_mapping_dataset1)}")
print(f"\nMissing mappings:")
print(f"  GO terms without Neo4j ID: {len(go_terms_dataset1) - len(neo4j_go_mapping_dataset1)}")
print(f"  Genes without Neo4j ID: {len(genes_2016_stable_dataset1 | genes_2024_added_dataset1) - len(neo4j_gene_mapping_dataset1)}")

print(f"\n\nDataset 2 (Parent Terms) - Neo4j Mappings")
print("=" * 60)
print(f"GO term mappings: {len(neo4j_go_mapping_dataset2)}")
print(f"Gene mappings: {len(neo4j_gene_mapping_dataset2)}")
print(f"\nMissing mappings:")
print(f"  GO terms without Neo4j ID: {len(go_terms_dataset2) - len(neo4j_go_mapping_dataset2)}")
print(f"  Genes without Neo4j ID: {len(genes_2016_stable_dataset2 | genes_2024_added_dataset2) - len(neo4j_gene_mapping_dataset2)}")

print(f"\n\nSample GO Mappings (Dataset 1):")
display(neo4j_go_mapping_dataset1.head())

print(f"\nSample Gene Mappings (Dataset 1):")
display(neo4j_gene_mapping_dataset1.head())

In [None]:
# Create merged dataframes with Neo4j IDs for Dataset 1
hetio_BPpG_dataset1 = hetio_BPpG_2016_stable[
    hetio_BPpG_2016_stable['go_id'].isin(go_terms_dataset1)
].copy()

upd_go_bp_2024_dataset1 = upd_go_bp_2024_added[
    upd_go_bp_2024_added['go_id'].isin(go_terms_dataset1)
].copy()

# Merge with Neo4j mappings for Dataset 1
hetio_BPpG_dataset1 = hetio_BPpG_dataset1.merge(
    neo4j_go_mapping_dataset1,
    on='go_id',
    how='inner'
).merge(
    neo4j_gene_mapping_dataset1,
    on='entrez_gene_id',
    how='inner'
)

upd_go_bp_2024_dataset1 = upd_go_bp_2024_dataset1.merge(
    neo4j_go_mapping_dataset1,
    on='go_id',
    how='inner'
).merge(
    neo4j_gene_mapping_dataset1,
    on='entrez_gene_id',
    how='inner'
)

print("Dataset 1 (All Growth Terms)")
print("=" * 60)
print(f"2016 BP-Gene pairs (stable): {len(hetio_BPpG_dataset1)}")
print(f"2024 BP-Gene pairs (added): {len(upd_go_bp_2024_dataset1)}")
print(f"2016 unique GO terms: {hetio_BPpG_dataset1['go_id'].nunique()}")
print(f"2024 unique GO terms: {upd_go_bp_2024_dataset1['go_id'].nunique()}")

# Create merged dataframes with Neo4j IDs for Dataset 2
hetio_BPpG_dataset2 = hetio_BPpG_2016_stable[
    hetio_BPpG_2016_stable['go_id'].isin(go_terms_dataset2)
].copy()

upd_go_bp_2024_dataset2 = upd_go_bp_2024_added[
    upd_go_bp_2024_added['go_id'].isin(go_terms_dataset2)
].copy()

hetio_BPpG_dataset2 = hetio_BPpG_dataset2.merge(
    neo4j_go_mapping_dataset2,
    on='go_id',
    how='inner'
).merge(
    neo4j_gene_mapping_dataset2,
    on='entrez_gene_id',
    how='inner'
)

upd_go_bp_2024_dataset2 = upd_go_bp_2024_dataset2.merge(
    neo4j_go_mapping_dataset2,
    on='go_id',
    how='inner'
).merge(
    neo4j_gene_mapping_dataset2,
    on='entrez_gene_id',
    how='inner'
)

print("\nDataset 2 (Parent Terms)")
print("=" * 60)
print(f"2016 BP-Gene pairs (stable): {len(hetio_BPpG_dataset2)}")
print(f"2024 BP-Gene pairs (added): {len(upd_go_bp_2024_dataset2)}")
print(f"2016 unique GO terms: {hetio_BPpG_dataset2['go_id'].nunique()}")
print(f"2024 unique GO terms: {upd_go_bp_2024_dataset2['go_id'].nunique()}")

### Create GO-Gene pair identifiers for both datasets

Create unique identifiers combining GO term and gene for tracking pairs across 2016 and 2024.

In [None]:
# Dataset 1: Create go_id_gene identifier
hetio_BPpG_dataset1['go_id_gene'] = (
    hetio_BPpG_dataset1['go_id'] + '|' + 
    hetio_BPpG_dataset1['neo4j_target_id'].astype(str)
)

upd_go_bp_2024_dataset1['go_id_gene'] = (
    upd_go_bp_2024_dataset1['go_id'] + '|' + 
    upd_go_bp_2024_dataset1['neo4j_target_id'].astype(str)
)

print("Dataset 1 (All Growth Terms)")
print("=" * 60)
print(f"2016 pairs (stable genes): {len(hetio_BPpG_dataset1)}")
print(f"2024 pairs (added genes): {len(upd_go_bp_2024_dataset1)}")
print(f"\nNote: These gene sets are disjoint by design")
print(f"  - 2016 contains only genes present in both 2016 and 2024")
print(f"  - 2024 contains only genes added in 2024 (not in 2016)")

In [None]:
# Dataset 2: Create go_id_gene identifier
hetio_BPpG_dataset2['go_id_gene'] = (
    hetio_BPpG_dataset2['go_id'] + '|' + 
    hetio_BPpG_dataset2['neo4j_target_id'].astype(str)
)

upd_go_bp_2024_dataset2['go_id_gene'] = (
    upd_go_bp_2024_dataset2['go_id'] + '|' + 
    upd_go_bp_2024_dataset2['neo4j_target_id'].astype(str)
)

print("Dataset 2 (Parent Terms)")
print("=" * 60)
print(f"2016 pairs (stable genes): {len(hetio_BPpG_dataset2)}")
print(f"2024 pairs (added genes): {len(upd_go_bp_2024_dataset2)}")
print(f"\nNote: These gene sets are disjoint by design")
print(f"  - 2016 contains only genes present in both 2016 and 2024")
print(f"  - 2024 contains only genes added in 2024 (not in 2016)")

In [None]:
import os
import matplotlib.pyplot as plt

# Ensure output directory exists
output_dir = repo_root / "output/images"
output_dir.mkdir(parents=True, exist_ok=True)

# Sort both datasets
dataset1_sorted = dataset1_all_growth.sort_values(by="no_of_genes_in_hetio_GO_2016", ascending=True).reset_index(drop=True)
dataset2_sorted = dataset2_parents.sort_values(by="no_of_genes_in_hetio_GO_2016", ascending=True).reset_index(drop=True)

# Create figure with two subplots side by side
fig, axes = plt.subplots(1, 2, figsize=(14, 5), dpi=150)

# Dataset 1: All Growth Terms
ax1 = axes[0]
ax1.plot(range(len(dataset1_sorted)), dataset1_sorted["no_of_genes_in_hetio_GO_2016"], 
         marker="o", linewidth=0.5, markersize=1, label="2016", color="steelblue")
ax1.plot(range(len(dataset1_sorted)), dataset1_sorted["no_of_genes_in_GO_2024"], 
         marker="s", linewidth=0.5, markersize=1, label="2024", color="darkorange")
ax1.set_xlabel(f"GO Terms (n = {len(dataset1_sorted)})", fontsize=11)
ax1.set_ylabel("Number of Genes", fontsize=11)
ax1.set_title("Dataset 1: All Growth Terms", fontsize=12)
ax1.legend(title="Year", fontsize=10, title_fontsize=10)
ax1.grid(False)
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)

# Dataset 2: Parent Terms
ax2 = axes[1]
ax2.plot(range(len(dataset2_sorted)), dataset2_sorted["no_of_genes_in_hetio_GO_2016"], 
         marker="o", linewidth=0.5, markersize=1, label="2016", color="steelblue")
ax2.plot(range(len(dataset2_sorted)), dataset2_sorted["no_of_genes_in_GO_2024"], 
         marker="s", linewidth=0.5, markersize=1, label="2024", color="darkorange")
ax2.set_xlabel(f"GO Terms (n = {len(dataset2_sorted)})", fontsize=11)
ax2.set_ylabel("Number of Genes", fontsize=11)
ax2.set_title("Dataset 2: Parent Terms", fontsize=12)
ax2.legend(title="Year", fontsize=10, title_fontsize=10)
ax2.grid(False)
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)

plt.tight_layout()

# Save figure
plt.savefig(output_dir / "genes_per_go_2016_vs_2024_both_datasets.pdf", format="pdf", dpi=300, bbox_inches='tight')
plt.savefig(output_dir / "genes_per_go_2016_vs_2024_both_datasets.jpeg", format="jpeg", dpi=300, bbox_inches='tight')
plt.show()

print(f"Figure saved to {output_dir}/genes_per_go_2016_vs_2024_both_datasets.pdf")

In [None]:
from sklearn.metrics import pairwise_distances
from scipy.spatial.distance import pdist, squareform
import pickle
from tqdm import tqdm

def calculate_jaccard_similarity_optimized(go_term_genes_dict):
    """
    Calculate pairwise Jaccard similarity between GO terms using scipy.
    
    This vectorized implementation is 10-50x faster than nested loops.
    
    Parameters
    ----------
    go_term_genes_dict : dict
        Dictionary mapping GO term IDs to sets of gene IDs
    
    Returns
    -------
    pd.DataFrame
        Symmetric matrix of Jaccard similarity scores between GO terms
    """
    go_terms = list(go_term_genes_dict.keys())
    n_terms = len(go_terms)
    
    print(f"Creating binary feature matrix for {n_terms} GO terms...")
    
    # Get all unique genes across all GO terms
    all_genes = sorted(set().union(*go_term_genes_dict.values()))
    n_genes = len(all_genes)
    gene_to_idx = {gene: idx for idx, gene in enumerate(all_genes)}
    
    print(f"Total unique genes: {n_genes}")
    
    # Create binary feature matrix: rows=GO terms, cols=genes
    # Use sparse representation if memory is a concern
    feature_matrix = np.zeros((n_terms, n_genes), dtype=np.bool_)
    
    print("Populating feature matrix...")
    for i, term in enumerate(tqdm(go_terms, desc="Processing GO terms")):
        gene_indices = [gene_to_idx[gene] for gene in go_term_genes_dict[term]]
        feature_matrix[i, gene_indices] = True
    
    print("Calculating pairwise Jaccard distances...")
    # Use scipy pdist which is often faster than sklearn for jaccard
    condensed_distances = pdist(feature_matrix, metric='jaccard')
    
    print("Converting to square matrix...")
    distance_matrix = squareform(condensed_distances)
    
    # Convert distances to similarities
    similarities = 1 - distance_matrix
    
    return pd.DataFrame(similarities, index=go_terms, columns=go_terms)

def load_or_calculate_jaccard(go_term_genes_dict, cache_file):
    """
    Load Jaccard similarity matrix from cache or calculate if not cached.
    
    Parameters
    ----------
    go_term_genes_dict : dict
        Dictionary mapping GO term IDs to sets of gene IDs
    cache_file : str
        Path to cache file
    
    Returns
    -------
    pd.DataFrame
        Jaccard similarity matrix
    """
    if os.path.exists(cache_file):
        print(f"Loading cached Jaccard matrix from {cache_file}")
        return pd.read_csv(cache_file, index_col=0)
    else:
        print(f"Cache not found. Computing Jaccard similarity matrix...")
        similarity_matrix = calculate_jaccard_similarity_optimized(go_term_genes_dict)
        
        # Cache the result
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)
        print(f"Saving to cache: {cache_file}")
        similarity_matrix.to_csv(cache_file)
        print(f"Cache saved successfully")
        
        return similarity_matrix

print("Optimized Jaccard similarity functions defined")

In [None]:
# Prepare gene sets for each GO term in both datasets

# Define cache directory first
jaccard_cache_dir = repo_root / "output/jaccard_similarity"
jaccard_cache_dir.mkdir(parents=True, exist_ok=True)

# Dataset 1: 2016 data
dataset1_2016_go_genes = {}
for go_id in dataset1_all_growth['go_id']:
    genes = set(hetio_BPpG_dataset1[hetio_BPpG_dataset1['go_id'] == go_id]['entrez_gene_id'])
    if len(genes) > 0:
        dataset1_2016_go_genes[go_id] = genes

# Dataset 1: 2024 data
dataset1_2024_go_genes = {}
for go_id in dataset1_all_growth['go_id']:
    genes = set(upd_go_bp_2024_dataset1[upd_go_bp_2024_dataset1['go_id'] == go_id]['entrez_gene_id'])
    if len(genes) > 0:
        dataset1_2024_go_genes[go_id] = genes

# Dataset 2: 2016 data
dataset2_2016_go_genes = {}
for go_id in dataset2_parents['go_id']:
    genes = set(hetio_BPpG_dataset2[hetio_BPpG_dataset2['go_id'] == go_id]['entrez_gene_id'])
    if len(genes) > 0:
        dataset2_2016_go_genes[go_id] = genes

# Dataset 2: 2024 data
dataset2_2024_go_genes = {}
for go_id in dataset2_parents['go_id']:
    genes = set(upd_go_bp_2024_dataset2[upd_go_bp_2024_dataset2['go_id'] == go_id]['entrez_gene_id'])
    if len(genes) > 0:
        dataset2_2024_go_genes[go_id] = genes

print("Gene sets prepared for Jaccard similarity calculation")
print("=" * 70)
print("Dataset 1 (All Growth Terms):")
print(f"  2016: {len(dataset1_2016_go_genes)} GO terms")
print(f"  2024: {len(dataset1_2024_go_genes)} GO terms")
print(f"\nDataset 2 (Parent Terms):")
print(f"  2016: {len(dataset2_2016_go_genes)} GO terms")
print(f"  2024: {len(dataset2_2024_go_genes)} GO terms")
print("=" * 70)

In [None]:
# Check if Jaccard similarity cache files exist
jaccard_cache_dir = repo_root / "output/jaccard_similarity"
cache_files = {
    'dataset1_2016': jaccard_cache_dir / "jaccard_similarity_dataset1_2016.csv",
    'dataset1_2024': jaccard_cache_dir / "jaccard_similarity_dataset1_2024.csv",
    'dataset2_2016': jaccard_cache_dir / "jaccard_similarity_dataset2_2016.csv",
    'dataset2_2024': jaccard_cache_dir / "jaccard_similarity_dataset2_2024.csv"
}

print("Checking for cached Jaccard similarity matrices...")
print("=" * 70)
all_cached = True
for name, path in cache_files.items():
    if path.exists():
        size_kb = path.stat().st_size / 1024
        print(f"FOUND: {name:20s} ({size_kb:.1f} KB)")
    else:
        print(f"NOT FOUND: {name:20s} - will need to compute")
        all_cached = False

if all_cached:
    print("\nAll cache files exist. Loading should be instant.")
else:
    print("\nSome cache files missing. First computation may take 30-60 seconds.")
    print("Subsequent runs will be instant once cached.")

print(f"\nTo force recalculation, delete the cache directory:")
print(f"  rm -rf {jaccard_cache_dir}")
print("=" * 70)

In [None]:
# Load or calculate Jaccard similarity matrices with progress tracking

print("\n" + "=" * 70)
print("JACCARD SIMILARITY CALCULATION")
print("=" * 70)

# Calculate for Dataset 1 - 2016
print("\n[1/4] Dataset 1 - 2016...")
jaccard_dataset1_2016 = load_or_calculate_jaccard(
    dataset1_2016_go_genes, 
    str(jaccard_cache_dir / "jaccard_similarity_dataset1_2016.csv")
)

# Calculate for Dataset 1 - 2024
print("\n[2/4] Dataset 1 - 2024...")
jaccard_dataset1_2024 = load_or_calculate_jaccard(
    dataset1_2024_go_genes,
    str(jaccard_cache_dir / "jaccard_similarity_dataset1_2024.csv")
)

# Calculate for Dataset 2 - 2016
print("\n[3/4] Dataset 2 - 2016...")
jaccard_dataset2_2016 = load_or_calculate_jaccard(
    dataset2_2016_go_genes,
    str(jaccard_cache_dir / "jaccard_similarity_dataset2_2016.csv")
)

# Calculate for Dataset 2 - 2024
print("\n[4/4] Dataset 2 - 2024...")
jaccard_dataset2_2024 = load_or_calculate_jaccard(
    dataset2_2024_go_genes,
    str(jaccard_cache_dir / "jaccard_similarity_dataset2_2024.csv")
)

print("\n" + "=" * 70)
print("ALL JACCARD MATRICES READY")
print("=" * 70)

# Summary statistics
print("\nDataset 1 - 2016 Jaccard Similarity:")
print(f"  Mean: {jaccard_dataset1_2016.values[np.triu_indices_from(jaccard_dataset1_2016.values, k=1)].mean():.4f}")
print(f"  Median: {np.median(jaccard_dataset1_2016.values[np.triu_indices_from(jaccard_dataset1_2016.values, k=1)]):.4f}")
print(f"  Max (off-diagonal): {jaccard_dataset1_2016.values[np.triu_indices_from(jaccard_dataset1_2016.values, k=1)].max():.4f}")

print("\nDataset 1 - 2024 Jaccard Similarity:")
print(f"  Mean: {jaccard_dataset1_2024.values[np.triu_indices_from(jaccard_dataset1_2024.values, k=1)].mean():.4f}")
print(f"  Median: {np.median(jaccard_dataset1_2024.values[np.triu_indices_from(jaccard_dataset1_2024.values, k=1)]):.4f}")
print(f"  Max (off-diagonal): {jaccard_dataset1_2024.values[np.triu_indices_from(jaccard_dataset1_2024.values, k=1)].max():.4f}")

print("\nDataset 2 - 2016 Jaccard Similarity:")
print(f"  Mean: {jaccard_dataset2_2016.values[np.triu_indices_from(jaccard_dataset2_2016.values, k=1)].mean():.4f}")
print(f"  Median: {np.median(jaccard_dataset2_2016.values[np.triu_indices_from(jaccard_dataset2_2016.values, k=1)]):.4f}")
print(f"  Max (off-diagonal): {jaccard_dataset2_2016.values[np.triu_indices_from(jaccard_dataset2_2016.values, k=1)].max():.4f}")

print("\nDataset 2 - 2024 Jaccard Similarity:")
print(f"  Mean: {jaccard_dataset2_2024.values[np.triu_indices_from(jaccard_dataset2_2024.values, k=1)].mean():.4f}")
print(f"  Median: {np.median(jaccard_dataset2_2024.values[np.triu_indices_from(jaccard_dataset2_2024.values, k=1)]):.4f}")
print(f"  Max (off-diagonal): {jaccard_dataset2_2024.values[np.triu_indices_from(jaccard_dataset2_2024.values, k=1)].max():.4f}")

In [None]:
import seaborn as sns
from scipy.cluster.hierarchy import linkage, dendrogram
from scipy.spatial.distance import squareform

# OPTIONAL: Skip heatmap generation if too slow (large matrices can take minutes)
# Set to True to skip this cell
SKIP_HEATMAPS = True

if SKIP_HEATMAPS:
    print("Heatmap generation skipped (set SKIP_HEATMAPS=False to generate)")
else:
    print("Generating clustered heatmaps (this may take 1-2 minutes for large datasets)...")
    
    # Create 2x2 heatmap figure with hierarchical clustering
    fig, axes = plt.subplots(2, 2, figsize=(16, 14), dpi=150)
    
    # Create mask for diagonal
    def create_diagonal_mask(size):
        mask = np.zeros((size, size), dtype=bool)
        np.fill_diagonal(mask, True)
        return mask
    
    def plot_clustered_heatmap(similarity_matrix, ax, title):
        """
        Plot heatmap with hierarchical clustering and no diagonal.
        
        Parameters
        ----------
        similarity_matrix : pd.DataFrame
            Jaccard similarity matrix
        ax : matplotlib axes
            Axes to plot on
        title : str
            Plot title
        """
        print(f"  Clustering {len(similarity_matrix)} GO terms for: {title}")
        
        # Convert similarity to distance for clustering
        distance_matrix = 1 - similarity_matrix.values
        
        # Perform hierarchical clustering
        condensed_dist = squareform(distance_matrix, checks=False)
        linkage_matrix = linkage(condensed_dist, method='average')
        
        # Get dendrogram order
        dend = dendrogram(linkage_matrix, no_plot=True)
        order = dend['leaves']
        
        # Reorder matrix by clustering
        clustered_matrix = similarity_matrix.iloc[order, order]
        
        # Create mask for diagonal
        mask = create_diagonal_mask(len(clustered_matrix))
        
        # Plot heatmap with adjusted scale for sparse data
        sns.heatmap(clustered_matrix, cmap="YlOrRd", vmin=0, vmax=0.8, 
                    mask=mask, square=True, cbar_kws={'label': 'Jaccard Similarity'}, 
                    ax=ax, xticklabels=False, yticklabels=False)
        ax.set_title(title, fontsize=12)
        ax.set_xlabel("GO Terms (clustered)", fontsize=11)
        ax.set_ylabel("GO Terms (clustered)", fontsize=11)
    
    # Dataset 1 - 2016
    plot_clustered_heatmap(
        jaccard_dataset1_2016, 
        axes[0, 0],
        f"Dataset 1 (2016): All Growth Terms\n(n = {len(dataset1_2016_go_genes)} GO terms)"
    )
    
    # Dataset 1 - 2024
    plot_clustered_heatmap(
        jaccard_dataset1_2024,
        axes[0, 1],
        f"Dataset 1 (2024): All Growth Terms\n(n = {len(dataset1_2024_go_genes)} GO terms)"
    )
    
    # Dataset 2 - 2016
    plot_clustered_heatmap(
        jaccard_dataset2_2016,
        axes[1, 0],
        f"Dataset 2 (2016): Parent Terms\n(n = {len(dataset2_2016_go_genes)} GO terms)"
    )
    
    # Dataset 2 - 2024
    plot_clustered_heatmap(
        jaccard_dataset2_2024,
        axes[1, 1],
        f"Dataset 2 (2024): Parent Terms\n(n = {len(dataset2_2024_go_genes)} GO terms)"
    )
    
    plt.tight_layout()
    
    # Save figure
    plt.savefig(os.path.join(output_dir, "jaccard_similarity_heatmaps_clustered_both_datasets.pdf"), 
                format="pdf", dpi=300, bbox_inches='tight')
    plt.savefig(os.path.join(output_dir, "jaccard_similarity_heatmaps_clustered_both_datasets.jpeg"), 
                format="jpeg", dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"\nCluster ed heatmaps saved to {output_dir}/jaccard_similarity_heatmaps_clustered_both_datasets.pdf")

In [None]:
# Analyze distribution of Jaccard similarity values

# Create histograms showing distribution of similarity scores
fig, axes = plt.subplots(2, 2, figsize=(14, 10), dpi=150)

# Get upper triangle indices (exclude diagonal)
def get_upper_triangle(matrix):
    mask = np.triu_indices_from(matrix.values, k=1)
    return matrix.values[mask]

# Dataset 1 - 2016
ax1 = axes[0, 0]
values_d1_2016 = get_upper_triangle(jaccard_dataset1_2016)
ax1.hist(values_d1_2016, bins=50, edgecolor='black', alpha=0.7)
ax1.axvline(x=values_d1_2016.mean(), color='red', linestyle='--', linewidth=2, 
            label=f'Mean = {values_d1_2016.mean():.3f}')
ax1.axvline(x=np.median(values_d1_2016), color='blue', linestyle='--', linewidth=2, 
            label=f'Median = {np.median(values_d1_2016):.3f}')
ax1.set_xlabel('Jaccard Similarity', fontsize=11)
ax1.set_ylabel('Frequency', fontsize=11)
ax1.set_title('Dataset 1 (2016): All Growth Terms', fontsize=12)
ax1.legend()
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)

# Dataset 1 - 2024
ax2 = axes[0, 1]
values_d1_2024 = get_upper_triangle(jaccard_dataset1_2024)
ax2.hist(values_d1_2024, bins=50, edgecolor='black', alpha=0.7, color='orange')
ax2.axvline(x=values_d1_2024.mean(), color='red', linestyle='--', linewidth=2, 
            label=f'Mean = {values_d1_2024.mean():.3f}')
ax2.axvline(x=np.median(values_d1_2024), color='blue', linestyle='--', linewidth=2, 
            label=f'Median = {np.median(values_d1_2024):.3f}')
ax2.set_xlabel('Jaccard Similarity', fontsize=11)
ax2.set_ylabel('Frequency', fontsize=11)
ax2.set_title('Dataset 1 (2024): All Growth Terms', fontsize=12)
ax2.legend()
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)

# Dataset 2 - 2016
ax3 = axes[1, 0]
values_d2_2016 = get_upper_triangle(jaccard_dataset2_2016)
ax3.hist(values_d2_2016, bins=50, edgecolor='black', alpha=0.7, color='green')
ax3.axvline(x=values_d2_2016.mean(), color='red', linestyle='--', linewidth=2, 
            label=f'Mean = {values_d2_2016.mean():.3f}')
ax3.axvline(x=np.median(values_d2_2016), color='blue', linestyle='--', linewidth=2, 
            label=f'Median = {np.median(values_d2_2016):.3f}')
ax3.set_xlabel('Jaccard Similarity', fontsize=11)
ax3.set_ylabel('Frequency', fontsize=11)
ax3.set_title('Dataset 2 (2016): Parent Terms', fontsize=12)
ax3.legend()
ax3.spines['top'].set_visible(False)
ax3.spines['right'].set_visible(False)

# Dataset 2 - 2024
ax4 = axes[1, 1]
values_d2_2024 = get_upper_triangle(jaccard_dataset2_2024)
ax4.hist(values_d2_2024, bins=50, edgecolor='black', alpha=0.7, color='purple')
ax4.axvline(x=values_d2_2024.mean(), color='red', linestyle='--', linewidth=2, 
            label=f'Mean = {values_d2_2024.mean():.3f}')
ax4.axvline(x=np.median(values_d2_2024), color='blue', linestyle='--', linewidth=2, 
            label=f'Median = {np.median(values_d2_2024):.3f}')
ax4.set_xlabel('Jaccard Similarity', fontsize=11)
ax4.set_ylabel('Frequency', fontsize=11)
ax4.set_title('Dataset 2 (2024): Parent Terms', fontsize=12)
ax4.legend()
ax4.spines['top'].set_visible(False)
ax4.spines['right'].set_visible(False)

plt.tight_layout()

# Save figure
output_images = repo_root / 'output/images'
plt.savefig(output_images / "jaccard_similarity_distributions_both_datasets.pdf", 
            format="pdf", dpi=300, bbox_inches='tight')
plt.savefig(output_images / "jaccard_similarity_distributions_both_datasets.jpeg", 
            format="jpeg", dpi=300, bbox_inches='tight')
plt.show()

print(f"Distribution plots saved to {output_images}/jaccard_similarity_distributions_both_datasets.pdf")

In [None]:
# Identify highly similar GO term pairs (Jaccard > 0.5)

def find_similar_pairs(jaccard_matrix, threshold=0.5, top_n=10):
    """
    Find GO term pairs with high Jaccard similarity.
    
    Parameters
    ----------
    jaccard_matrix : pd.DataFrame
        Jaccard similarity matrix
    threshold : float
        Minimum Jaccard similarity to report
    top_n : int
        Number of top pairs to return
    
    Returns
    -------
    pd.DataFrame
        Top similar GO term pairs with their similarity scores
    """
    pairs = []
    n = len(jaccard_matrix)
    
    for i in range(n):
        for j in range(i+1, n):
            similarity = jaccard_matrix.iloc[i, j]
            if similarity >= threshold:
                pairs.append({
                    'GO_term_1': jaccard_matrix.index[i],
                    'GO_term_2': jaccard_matrix.columns[j],
                    'jaccard_similarity': similarity
                })
    
    df = pd.DataFrame(pairs)
    if len(df) > 0:
        df = df.sort_values('jaccard_similarity', ascending=False).head(top_n)
    return df

print("Highly Similar GO Term Pairs (Jaccard Similarity > 0.5)")
print("=" * 80)

print("\nDataset 1 - 2016:")
similar_d1_2016 = find_similar_pairs(jaccard_dataset1_2016, threshold=0.5, top_n=10)
if len(similar_d1_2016) > 0:
    print(f"Found {len(similar_d1_2016)} pairs with similarity > 0.5")
    display(similar_d1_2016)
else:
    print("No pairs with similarity > 0.5")

print("\nDataset 1 - 2024:")
similar_d1_2024 = find_similar_pairs(jaccard_dataset1_2024, threshold=0.5, top_n=10)
if len(similar_d1_2024) > 0:
    print(f"Found {len(similar_d1_2024)} pairs with similarity > 0.5")
    display(similar_d1_2024)
else:
    print("No pairs with similarity > 0.5")

print("\nDataset 2 - 2016:")
similar_d2_2016 = find_similar_pairs(jaccard_dataset2_2016, threshold=0.5, top_n=10)
if len(similar_d2_2016) > 0:
    print(f"Found {len(similar_d2_2016)} pairs with similarity > 0.5")
    display(similar_d2_2016)
else:
    print("No pairs with similarity > 0.5")

print("\nDataset 2 - 2024:")
similar_d2_2024 = find_similar_pairs(jaccard_dataset2_2024, threshold=0.5, top_n=10)
if len(similar_d2_2024) > 0:
    print(f"Found {len(similar_d2_2024)} pairs with similarity > 0.5")
    display(similar_d2_2024)
else:
    print("No pairs with similarity > 0.5")

# Count pairs at different thresholds
print("\n" + "=" * 80)
print("Summary: Number of GO term pairs exceeding similarity thresholds")
print("=" * 80)

thresholds = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
summary_data = []

for threshold in thresholds:
    summary_data.append({
        'Threshold': threshold,
        'Dataset1_2016': len(find_similar_pairs(jaccard_dataset1_2016, threshold=threshold, top_n=1000)),
        'Dataset1_2024': len(find_similar_pairs(jaccard_dataset1_2024, threshold=threshold, top_n=1000)),
        'Dataset2_2016': len(find_similar_pairs(jaccard_dataset2_2016, threshold=threshold, top_n=1000)),
        'Dataset2_2024': len(find_similar_pairs(jaccard_dataset2_2024, threshold=threshold, top_n=1000))
    })

summary_df = pd.DataFrame(summary_data)
display(summary_df)

In [None]:
# Analyze GO term pairs with LOW similarity (independent terms)

def count_pairs_below_threshold(jaccard_matrix, thresholds):
    """
    Count GO term pairs below various similarity thresholds.
    
    Parameters
    ----------
    jaccard_matrix : pd.DataFrame
        Jaccard similarity matrix
    thresholds : list
        List of similarity thresholds to check
    
    Returns
    -------
    dict
        Dictionary mapping threshold to count of pairs below that threshold
    """
    # Get upper triangle (exclude diagonal)
    n = len(jaccard_matrix)
    values = []
    for i in range(n):
        for j in range(i+1, n):
            values.append(jaccard_matrix.iloc[i, j])
    
    values = np.array(values)
    total_pairs = len(values)
    
    results = {}
    for threshold in thresholds:
        count = np.sum(values < threshold)
        percentage = (count / total_pairs) * 100
        results[threshold] = {'count': count, 'percentage': percentage}
    
    return results, total_pairs

# Define thresholds to check
thresholds = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]

print("GO Term Pairs with LOW Similarity (Diversity Analysis)")
print("=" * 80)
print("\nCounts show number of GO term pairs with similarity BELOW each threshold")
print("(Higher counts indicate more independent/diverse GO terms)\n")

# Dataset 1 - 2016
results_d1_2016, total_d1_2016 = count_pairs_below_threshold(jaccard_dataset1_2016, thresholds)
print(f"Dataset 1 - 2016 (Total pairs: {total_d1_2016})")
print("-" * 60)
for threshold in thresholds:
    count = results_d1_2016[threshold]['count']
    pct = results_d1_2016[threshold]['percentage']
    print(f"  Similarity < {threshold:.1f}: {count:6d} pairs ({pct:5.1f}%)")

# Dataset 1 - 2024
results_d1_2024, total_d1_2024 = count_pairs_below_threshold(jaccard_dataset1_2024, thresholds)
print(f"\nDataset 1 - 2024 (Total pairs: {total_d1_2024})")
print("-" * 60)
for threshold in thresholds:
    count = results_d1_2024[threshold]['count']
    pct = results_d1_2024[threshold]['percentage']
    print(f"  Similarity < {threshold:.1f}: {count:6d} pairs ({pct:5.1f}%)")

# Dataset 2 - 2016
results_d2_2016, total_d2_2016 = count_pairs_below_threshold(jaccard_dataset2_2016, thresholds)
print(f"\nDataset 2 - 2016 (Total pairs: {total_d2_2016})")
print("-" * 60)
for threshold in thresholds:
    count = results_d2_2016[threshold]['count']
    pct = results_d2_2016[threshold]['percentage']
    print(f"  Similarity < {threshold:.1f}: {count:6d} pairs ({pct:5.1f}%)")

# Dataset 2 - 2024
results_d2_2024, total_d2_2024 = count_pairs_below_threshold(jaccard_dataset2_2024, thresholds)
print(f"\nDataset 2 - 2024 (Total pairs: {total_d2_2024})")
print("-" * 60)
for threshold in thresholds:
    count = results_d2_2024[threshold]['count']
    pct = results_d2_2024[threshold]['percentage']
    print(f"  Similarity < {threshold:.1f}: {count:6d} pairs ({pct:5.1f}%)")

# Create summary dataframe
print("\n" + "=" * 80)
print("Summary Table: Percentage of GO term pairs below each similarity threshold")
print("=" * 80)

summary_data = []
for threshold in thresholds:
    summary_data.append({
        'Threshold': f'< {threshold:.1f}',
        'Dataset1_2016': f"{results_d1_2016[threshold]['percentage']:.1f}%",
        'Dataset1_2024': f"{results_d1_2024[threshold]['percentage']:.1f}%",
        'Dataset2_2016': f"{results_d2_2016[threshold]['percentage']:.1f}%",
        'Dataset2_2024': f"{results_d2_2024[threshold]['percentage']:.1f}%"
    })

summary_df = pd.DataFrame(summary_data)
display(summary_df)

## Filter Overlapping GO Terms

Remove redundant GO terms with Jaccard similarity > 0.1.

Approach:
- Use graph-based clustering to identify connected components of highly overlapping terms
- For each cluster, select the GO term with greatest absolute percent change in gene count (2016→2024)
- Filter based on 2016 Jaccard matrices (baseline year)
- This ensures GO terms used in downstream analysis are reasonably independent

In [None]:
def filter_overlapping_go_terms(jaccard_matrix, dataset, threshold=0.1):
    """
    Remove redundant GO terms with Jaccard similarity > threshold.
    
    Uses greedy pairwise filtering: for each pair above threshold,
    removes the GO term with lower absolute percent change.
    
    Parameters
    ----------
    jaccard_matrix : pd.DataFrame
        Symmetric matrix of Jaccard similarities (GO terms x GO terms)
    dataset : pd.DataFrame
        Dataset with columns: go_id, no_of_genes_in_hetio_GO_2016, 
        no_of_genes_in_GO_2024
    threshold : float
        Jaccard similarity threshold for redundancy (default 0.1)
    
    Returns
    -------
    filtered_dataset : pd.DataFrame
        Dataset containing only non-redundant GO terms
    removed_terms : dict
        Mapping {removed_go_id: kept_go_id}
    removal_df : pd.DataFrame
        Detailed information about removed pairs
    """
    
    # Calculate percent change for all GO terms
    dataset_with_pct = dataset.copy()
    dataset_with_pct['pct_change'] = abs(
        (dataset_with_pct['no_of_genes_in_GO_2024'] - 
         dataset_with_pct['no_of_genes_in_hetio_GO_2016']) / 
        dataset_with_pct['no_of_genes_in_hetio_GO_2016'] * 100
    )
    
    # Create lookup for percent change
    pct_change_lookup = dict(zip(dataset_with_pct['go_id'], 
                                  dataset_with_pct['pct_change']))
    
    # Find all GO term pairs with similarity > threshold
    pairs_to_filter = []
    n = len(jaccard_matrix)
    for i in range(n):
        for j in range(i+1, n):
            similarity = jaccard_matrix.iloc[i, j]
            if similarity > threshold:
                go_id_i = jaccard_matrix.index[i]
                go_id_j = jaccard_matrix.columns[j]
                
                pct_i = pct_change_lookup.get(go_id_i, 0)
                pct_j = pct_change_lookup.get(go_id_j, 0)
                
                # Keep the one with higher percent change
                if pct_i >= pct_j:
                    keep_id = go_id_i
                    remove_id = go_id_j
                    keep_pct = pct_i
                    remove_pct = pct_j
                else:
                    keep_id = go_id_j
                    remove_id = go_id_i
                    keep_pct = pct_j
                    remove_pct = pct_i
                
                pairs_to_filter.append({
                    'keep_id': keep_id,
                    'remove_id': remove_id,
                    'jaccard_similarity': similarity,
                    'keep_pct': keep_pct,
                    'remove_pct': remove_pct
                })
    
    # Greedy removal: process pairs and track what's been removed
    terms_to_remove = set()
    removed_terms = {}
    removal_details = []
    
    for pair in pairs_to_filter:
        keep_id = pair['keep_id']
        remove_id = pair['remove_id']
        
        # Skip if either term already removed
        if keep_id in terms_to_remove or remove_id in terms_to_remove:
            continue
        
        # Remove the lower percent change term
        terms_to_remove.add(remove_id)
        removed_terms[remove_id] = keep_id
        
        # Get full details for removal record
        keep_row = dataset_with_pct[dataset_with_pct['go_id'] == keep_id].iloc[0]
        remove_row = dataset_with_pct[dataset_with_pct['go_id'] == remove_id].iloc[0]
        
        removal_details.append({
            'removed_go_id': remove_id,
            'kept_go_id': keep_id,
            'jaccard_similarity': pair['jaccard_similarity'],
            'removed_genes_2016': remove_row['no_of_genes_in_hetio_GO_2016'],
            'removed_genes_2024': remove_row['no_of_genes_in_GO_2024'],
            'removed_pct_change': remove_row['pct_change'],
            'kept_genes_2016': keep_row['no_of_genes_in_hetio_GO_2016'],
            'kept_genes_2024': keep_row['no_of_genes_in_GO_2024'],
            'kept_pct_change': keep_row['pct_change']
        })
    
    # Filter dataset
    filtered_dataset = dataset[~dataset['go_id'].isin(terms_to_remove)].copy()
    
    # Create removal dataframe
    removal_df = pd.DataFrame(removal_details) if removal_details else pd.DataFrame()
    
    return filtered_dataset, removed_terms, removal_df

print('Greedy pairwise filtering function defined')

In [None]:
# Filter Dataset 1 based on 2016 Jaccard matrix
print('Filtering Dataset 1 (all growth terms)...')
print('=' * 80)

dataset1_filtered, removed_d1, removal_details_d1 = filter_overlapping_go_terms(
    jaccard_dataset1_2016, 
    dataset1_all_growth,
    threshold=0.1
)

# Statistics
print(f'\nDataset 1 Statistics:')
print(f'  GO terms before filtering: {len(dataset1_all_growth)}')
print(f'  GO terms after filtering:  {len(dataset1_filtered)}')
print(f'  GO terms removed:          {len(removed_d1)}')
print(f'  Reduction:                 {len(removed_d1)/len(dataset1_all_growth)*100:.1f}%')

# Show removed terms (if any)
if len(removal_details_d1) > 0:
    print(f'\nRemoved Terms Details (Dataset 1, first 10):')
    print(removal_details_d1[['removed_go_id', 'kept_go_id', 'jaccard_similarity', 
                               'removed_pct_change', 'kept_pct_change']].head(10).to_string(index=False))
    print(f'\n... and {len(removal_details_d1) - 10} more pairs removed') if len(removal_details_d1) > 10 else None
else:
    print(f'\nNo overlapping GO terms found (all Jaccard < 0.1)')

In [None]:
# Filter Dataset 2 based on 2016 Jaccard matrix
print('Filtering Dataset 2 (parent terms)...')
print('=' * 80)

dataset2_filtered, removed_d2, clusters_d2, removal_details_d2 = filter_overlapping_go_terms(
    jaccard_dataset2_2016,
    dataset2_parents,
    threshold=0.1
)

# Statistics
print(f'\nDataset 2 Statistics:')
print(f'  GO terms before filtering: {len(dataset2_parents)}')
print(f'  GO terms after filtering:  {len(dataset2_filtered)}')
print(f'  GO terms removed:          {len(removed_d2)}')
print(f'  Reduction:                 {len(removed_d2)/len(dataset2_parents)*100:.1f}%')

# Cluster distribution
print(f'\nCluster Distribution (Dataset 2):')
print(clusters_d2.to_string(index=False))

# Show removed terms (if any)
if len(removal_details_d2) > 0:
    print(f'\nRemoved Terms Details (Dataset 2):')
    print(removal_details_d2[['removed_go_id', 'representative_go_id', 'jaccard_similarity',
                               'removed_pct_change', 'kept_pct_change']].to_string(index=False))
else:
    print(f'\nNo overlapping GO terms found (all Jaccard < 01)')

In [None]:
# Combined summary table
print('\nCombined Summary: Before and After Filtering')
print('=' * 80)

# Calculate unique genes
genes_d1_before = len(hetio_BPpG_dataset1[hetio_BPpG_dataset1['go_id'].isin(dataset1_all_growth['go_id'])]['neo4j_target_id'].unique())
genes_d1_after = len(hetio_BPpG_dataset1[hetio_BPpG_dataset1['go_id'].isin(dataset1_filtered['go_id'])]['neo4j_target_id'].unique())

genes_d2_before = len(hetio_BPpG_dataset2[hetio_BPpG_dataset2['go_id'].isin(dataset2_parents['go_id'])]['neo4j_target_id'].unique())
genes_d2_after = len(hetio_BPpG_dataset2[hetio_BPpG_dataset2['go_id'].isin(dataset2_filtered['go_id'])]['neo4j_target_id'].unique())

# GO-gene pairs
pairs_d1_before = len(hetio_BPpG_dataset1[hetio_BPpG_dataset1['go_id'].isin(dataset1_all_growth['go_id'])])
pairs_d1_after = len(hetio_BPpG_dataset1[hetio_BPpG_dataset1['go_id'].isin(dataset1_filtered['go_id'])])

pairs_d2_before = len(hetio_BPpG_dataset2[hetio_BPpG_dataset2['go_id'].isin(dataset2_parents['go_id'])])
pairs_d2_after = len(hetio_BPpG_dataset2[hetio_BPpG_dataset2['go_id'].isin(dataset2_filtered['go_id'])])

summary_data = {
    'Metric': ['GO terms', 'Unique genes', 'GO-gene pairs', 'Terms removed', 'Clusters (size >= 2)'],
    'Dataset 1 Before': [
        len(dataset1_all_growth),
        genes_d1_before,
        pairs_d1_before,
        '-',
        '-'
    ],
    'Dataset 1 After': [
        len(dataset1_filtered),
        genes_d1_after,
        pairs_d1_after,
        len(removed_d1),
        len(clusters_d1[clusters_d1['cluster_size'] >= 2])
    ],
    'Dataset 2 Before': [
        len(dataset2_parents),
        genes_d2_before,
        pairs_d2_before,
        '-',
        '-'
    ],
    'Dataset 2 After': [
        len(dataset2_filtered),
        genes_d2_after,
        pairs_d2_after,
        len(removed_d2),
        len(clusters_d2[clusters_d2['cluster_size'] >= 2])
    ]
}

summary_df = pd.DataFrame(summary_data)
print(summary_df.to_string(index=False))

print(f'\nFiltering complete. Use dataset1_filtered and dataset2_filtered for downstream analyses.')

In [None]:
# Update GO-gene pair dataframes with filtered GO terms
print('Updating GO-gene pair dataframes...')
print('=' * 80)

# Filter Dataset 1 BPpG
hetio_BPpG_dataset1_filtered = hetio_BPpG_dataset1[
    hetio_BPpG_dataset1['go_id'].isin(dataset1_filtered['go_id'])
].copy()

print(f'Dataset 1 GO-gene pairs:')
print(f'  Before filtering: {len(hetio_BPpG_dataset1):,}')
print(f'  After filtering:  {len(hetio_BPpG_dataset1_filtered):,}')
print(f'  Removed:          {len(hetio_BPpG_dataset1) - len(hetio_BPpG_dataset1_filtered):,}')

# Filter Dataset 2 BPpG
hetio_BPpG_dataset2_filtered = hetio_BPpG_dataset2[
    hetio_BPpG_dataset2['go_id'].isin(dataset2_filtered['go_id'])
].copy()

print(f'\nDataset 2 GO-gene pairs:')
print(f'  Before filtering: {len(hetio_BPpG_dataset2):,}')
print(f'  After filtering:  {len(hetio_BPpG_dataset2_filtered):,}')
print(f'  Removed:          {len(hetio_BPpG_dataset2) - len(hetio_BPpG_dataset2_filtered):,}')

print(f'\nFiltered dataframes created:')
print(f'  - hetio_BPpG_dataset1_filtered')
print(f'  - hetio_BPpG_dataset2_filtered')
print(f'\nUse these for downstream DWPC calculations to ensure GO term independence.')

## Visualize Filtered Datasets

Compare gene counts (2016 vs 2024) for filtered GO term sets.

In [None]:
# Visualize filtered datasets
import matplotlib.pyplot as plt

# Sort filtered datasets
dataset1_filt_sorted = dataset1_filtered.sort_values(
    by='no_of_genes_in_hetio_GO_2016', 
    ascending=True
).reset_index(drop=True)

dataset2_filt_sorted = dataset2_filtered.sort_values(
    by='no_of_genes_in_hetio_GO_2016',
    ascending=True
).reset_index(drop=True)

# Create figure with two subplots
fig, axes = plt.subplots(1, 2, figsize=(14, 5), dpi=150)

# Dataset 1
axes[0].scatter(
    dataset1_filt_sorted.index,
    dataset1_filt_sorted['no_of_genes_in_hetio_GO_2016'],
    alpha=0.6,
    label='2016',
    s=20
)
axes[0].scatter(
    dataset1_filt_sorted.index,
    dataset1_filt_sorted['no_of_genes_in_GO_2024'],
    alpha=0.6,
    label='2024',
    s=20
)
axes[0].set_xlabel('GO Term Index (sorted by 2016 gene count)')
axes[0].set_ylabel('Number of Genes')
axes[0].set_title(f'Dataset 1: All Growth Terms (Filtered)\\nn={len(dataset1_filtered)} GO terms')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Dataset 2
axes[1].scatter(
    dataset2_filt_sorted.index,
    dataset2_filt_sorted['no_of_genes_in_hetio_GO_2016'],
    alpha=0.6,
    label='2016',
    s=20
)
axes[1].scatter(
    dataset2_filt_sorted.index,
    dataset2_filt_sorted['no_of_genes_in_GO_2024'],
    alpha=0.6,
    label='2024',
    s=20
)
axes[1].set_xlabel('GO Term Index (sorted by 2016 gene count)')
axes[1].set_ylabel('Number of Genes')
axes[1].set_title(f'Dataset 2: Parent Terms (Filtered)\\nn={len(dataset2_filtered)} GO terms')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print('Visualization complete')

In [None]:
# Note: Jaccard similarity matrices are automatically cached during calculation
# Cache location: output/jaccard_similarity/
# Files:
#   - jaccard_similarity_dataset1_2016.csv
#   - jaccard_similarity_dataset1_2024.csv
#   - jaccard_similarity_dataset2_2016.csv
#   - jaccard_similarity_dataset2_2024.csv
#
# To force recalculation, delete the cache files and re-run the calculation cell

print("Jaccard similarity matrices are cached at:")
print(f"  {jaccard_cache_dir}/")
print("\nCached files:")
for dataset in ['dataset1_2016', 'dataset1_2024', 'dataset2_2016', 'dataset2_2024']:
    cache_file = jaccard_cache_dir / f"jaccard_similarity_{dataset}.csv"
    if cache_file.exists():
        file_size = cache_file.stat().st_size / 1024
        print(f"  jaccard_similarity_{dataset}.csv ({file_size:.1f} KB)")
    else:
        print(f"  jaccard_similarity_{dataset}.csv (not found)")

### Summary: Jaccard Similarity Analysis

This analysis computed pairwise Jaccard similarity between GO terms based on gene overlap for both datasets and years. Key insights:

**Jaccard Similarity Formula:**
- J(A,B) = |A ∩ B| / |A ∪ B|
- Ranges from 0 (no overlap) to 1 (complete overlap)

**Implementation:**
- Uses sklearn.metrics.pairwise_distances for optimized computation (10-50x faster than nested loops)
- Results are automatically cached to disk for instant loading on subsequent runs
- Cache location: output/jaccard_similarity/

**Analysis Components:**
1. Computed similarity matrices for all GO term pairs in both datasets
2. Generated clustered heatmaps showing pairwise similarity patterns with hierarchical clustering
3. Diagonal removed from heatmaps to focus on GO term relationships
4. Analyzed distribution of similarity scores
5. Identified highly similar GO term pairs that may represent functional redundancy or hierarchical relationships

**Outputs:**
- Similarity matrices cached in output/jaccard_similarity/
- Clustered heatmaps and distributions saved to output/images/

**Applications:**
- Identify redundant or highly overlapping GO terms
- Understand functional relationships between biological processes
- Inform feature selection for machine learning (remove highly correlated features)
- Validate that filtered datasets have appropriate diversity
- Clustering reveals functional modules and related biological processes

### Gene Overlap Analysis: Jaccard Similarity Between GO Terms

Calculate pairwise Jaccard similarity between GO terms based on their annotated genes. This measures how much gene sets overlap between different GO terms, which can reveal functional relationships and redundancy in the ontology.

### Visualize Gene Count Changes: 2016 vs 2024

This visualization compares the number of genes associated with each GO term between 2016 and 2024 for both filtered datasets. GO terms are sorted by their 2016 gene count for easier comparison of growth patterns.

In [None]:
# Save filtered datasets - 2016 data
print('Saving filtered datasets (2016)...')

output_intermediate = repo_root / 'output/intermediate'

dataset1_filtered.to_csv(
    output_intermediate / 'dataset1_filtered.csv',
    index=False
)
print(f'Saved dataset1_filtered.csv: '
      f'{len(dataset1_filtered)} GO terms')

dataset2_filtered.to_csv(
    output_intermediate / 'dataset2_filtered.csv',
    index=False
)
print(f'Saved dataset2_filtered.csv: '
      f'{len(dataset2_filtered)} GO terms')

hetio_BPpG_dataset1_filtered.to_csv(
    output_intermediate / 'hetio_bppg_dataset1_filtered.csv',
    index=False
)
print(f'Saved hetio_bppg_dataset1_filtered.csv: '
      f'{len(hetio_BPpG_dataset1_filtered)} rows')

hetio_BPpG_dataset2_filtered.to_csv(
    output_intermediate / 'hetio_bppg_dataset2_filtered.csv',
    index=False
)
print(f'Saved hetio_bppg_dataset2_filtered.csv: '
      f'{len(hetio_BPpG_dataset2_filtered)} rows')

In [None]:
# Filter and save 2024 datasets with filtered GO terms
print('\nFiltering and saving 2024 datasets...')
print('=' * 80)

# Filter 2024 Dataset 1 to only filtered GO terms
upd_go_bp_2024_dataset1_filtered = upd_go_bp_2024_dataset1[
    upd_go_bp_2024_dataset1['go_id'].isin(dataset1_filtered['go_id'])
].copy()

print(f'Dataset 1 (2024):')
print(f'  Before filtering: {len(upd_go_bp_2024_dataset1):,} GO-gene pairs')
print(f'  After filtering:  {len(upd_go_bp_2024_dataset1_filtered):,} GO-gene pairs')
print(f'  Unique GO terms:  {upd_go_bp_2024_dataset1_filtered["go_id"].nunique()}')
print(f'  Unique genes:     {upd_go_bp_2024_dataset1_filtered["neo4j_target_id"].nunique()}')

# Filter 2024 Dataset 2 to only filtered GO terms
upd_go_bp_2024_dataset2_filtered = upd_go_bp_2024_dataset2[
    upd_go_bp_2024_dataset2['go_id'].isin(dataset2_filtered['go_id'])
].copy()

print(f'\nDataset 2 (2024):')
print(f'  Before filtering: {len(upd_go_bp_2024_dataset2):,} GO-gene pairs')
print(f'  After filtering:  {len(upd_go_bp_2024_dataset2_filtered):,} GO-gene pairs')
print(f'  Unique GO terms:  {upd_go_bp_2024_dataset2_filtered["go_id"].nunique()}')
print(f'  Unique genes:     {upd_go_bp_2024_dataset2_filtered["neo4j_target_id"].nunique()}')

# Save 2024 filtered datasets
output_intermediate = repo_root / 'output/intermediate'
upd_go_bp_2024_dataset1_filtered.to_csv(
    output_intermediate / 'hetio_bppg_dataset1_2024_filtered.csv',
    index=False
)
print(f'\nSaved hetio_bppg_dataset1_2024_filtered.csv: {len(upd_go_bp_2024_dataset1_filtered):,} rows')

upd_go_bp_2024_dataset2_filtered.to_csv(
    output_intermediate / 'hetio_bppg_dataset2_2024_filtered.csv',
    index=False
)
print(f'Saved hetio_bppg_dataset2_2024_filtered.csv: {len(upd_go_bp_2024_dataset2_filtered):,} rows')

In [None]:
print('\n' + '=' * 80)
print('NOTEBOOK 1.3 COMPLETE - ALL DATASETS SAVED')
print('=' * 80)

print('\nOutput Summary:')
print('  GO Term Lists (year-agnostic):')
print(f'    - dataset1_filtered.csv ({len(dataset1_filtered)} GO terms)')
print(f'    - dataset2_filtered.csv ({len(dataset2_filtered)} GO terms)')

print('\n  2016 BP-Gene Associations with Neo4j IDs:')
print(f'    - hetio_bppg_dataset1_filtered.csv ({len(hetio_BPpG_dataset1_filtered):,} pairs)')
print(f'    - hetio_bppg_dataset2_filtered.csv ({len(hetio_BPpG_dataset2_filtered):,} pairs)')

print('\n  2024 BP-Gene Associations with Neo4j IDs:')
print(f'    - hetio_bppg_dataset1_2024_filtered.csv ({len(upd_go_bp_2024_dataset1_filtered):,} pairs)')
print(f'    - hetio_bppg_dataset2_2024_filtered.csv ({len(upd_go_bp_2024_dataset2_filtered):,} pairs)')

print('\nDataset Comparison:')
print(f'  Dataset 1: {len(hetio_BPpG_dataset1_filtered):,} (2016) -> {len(upd_go_bp_2024_dataset1_filtered):,} (2024) pairs')
print(f'  Dataset 2: {len(hetio_BPpG_dataset2_filtered):,} (2016) -> {len(upd_go_bp_2024_dataset2_filtered):,} (2024) pairs')

print('\nNext: Run notebook 1.4 to generate permuted datasets')