# Lab 18: Optimization Techniques and Performance Enhancements

## Overview

In this lab, we'll explore the sophisticated optimization techniques used in Bonsai v3 to enable efficient pedigree reconstruction at scale. These techniques are essential for handling real-world genetic genealogy datasets with hundreds or thousands of individuals.

In [None]:
# Standard imports
import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from IPython.display import display, HTML, Markdown
import inspect
import importlib
import copy
import random
import math
import time
import concurrent.futures
import contextlib
from collections import defaultdict

sys.path.append(os.path.dirname(os.getcwd()))

# Cross-compatibility setup
from scripts_support.lab_cross_compatibility import setup_environment, is_jupyterlite, save_results, save_plot

# Set up environment-specific paths
DATA_DIR, RESULTS_DIR = setup_environment()

# Set visualization styles
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_context("notebook")

In [None]:
# Setup Bonsai module paths
if not is_jupyterlite():
    # In local environment, add the utils directory to system path
    utils_dir = os.getenv('PROJECT_UTILS_DIR', os.path.join(os.path.dirname(DATA_DIR), 'utils'))
    bonsaitree_dir = os.path.join(utils_dir, 'bonsaitree')
    
    # Add to path if it exists and isn't already there
    if os.path.exists(bonsaitree_dir) and bonsaitree_dir not in sys.path:
        sys.path.append(bonsaitree_dir)
        print(f"Added {bonsaitree_dir} to sys.path")
else:
    # In JupyterLite, use a simplified approach
    print("⚠️ Running in JupyterLite: Some Bonsai functionality may be limited.")
    print("This notebook is primarily designed for local execution where the Bonsai codebase is available.")

In [None]:
# Helper functions for exploring modules
def display_module_classes(module_name):
    """Display classes and their docstrings from a module"""
    try:
        # Import the module
        module = importlib.import_module(module_name)
        
        # Find all classes
        classes = inspect.getmembers(module, inspect.isclass)
        
        # Filter classes defined in this module (not imported)
        classes = [(name, cls) for name, cls in classes if cls.__module__ == module_name]
        
        # Print info for each class
        for name, cls in classes:
            print(f"\
## {name}")
            
            # Get docstring
            doc = inspect.getdoc(cls)
            if doc:
                print(f"Docstring: {doc}")
            else:
                print("No docstring available")
            
            # Get methods
            methods = inspect.getmembers(cls, inspect.isfunction)
            if methods:
                print("\
Methods:")
                for method_name, method in methods:
                    if not method_name.startswith('_'):  # Skip private methods
                        print(f"- {method_name}")
    except ImportError as e:
        print(f"Error importing module {module_name}: {e}")
    except Exception as e:
        print(f"Error processing module {module_name}: {e}")

def display_module_functions(module_name):
    """Display functions and their docstrings from a module"""
    try:
        # Import the module
        module = importlib.import_module(module_name)
        
        # Find all functions
        functions = inspect.getmembers(module, inspect.isfunction)
        
        # Filter functions defined in this module (not imported)
        functions = [(name, func) for name, func in functions if func.__module__ == module_name]
        
        # Print info for each function
        for name, func in functions:
            if name.startswith('_'):  # Skip private functions
                continue
                
            print(f"\
## {name}")
            
            # Get signature
            sig = inspect.signature(func)
            print(f"Signature: {name}{sig}")
            
            # Get docstring
            doc = inspect.getdoc(func)
            if doc:
                print(f"Docstring: {doc}")
            else:
                print("No docstring available")
    except ImportError as e:
        print(f"Error importing module {module_name}: {e}")
    except Exception as e:
        print(f"Error processing module {module_name}: {e}")

def view_source(obj):
    """Display the source code of an object (function or class)"""
    try:
        source = inspect.getsource(obj)
        display(Markdown(f"```python\
{source}\
```"))
    except Exception as e:
        print(f"Error retrieving source: {e}")

## Check Bonsai Installation

Let's verify that the Bonsai v3 module is available for import:

In [None]:
try:
    from utils.bonsaitree.bonsaitree import v3
    print("✅ Successfully imported Bonsai v3 module")
except ImportError as e:
    print(f"❌ Failed to import Bonsai v3 module: {e}")
    print("This lab requires access to the Bonsai v3 codebase.")
    print("Make sure you've properly set up your environment with the Bonsai repository.")

## Lab 18: Optimization Techniques and Performance Enhancements

In this lab, we'll explore the optimization techniques that enable Bonsai v3 to efficiently handle large-scale pedigree reconstruction problems. These techniques are essential for processing real-world genetic genealogy datasets that can include hundreds or thousands of individuals.

Key optimization areas we'll explore include:

1. **Search Space Pruning**: Reducing the number of pedigree configurations to evaluate
2. **Parallel Processing**: Leveraging multiple CPU cores for faster computation
3. **Adaptive Parameter Selection**: Optimizing algorithm parameters based on dataset characteristics
4. **Specialized Data Structures**: Using memory-efficient representations of genetic and pedigree data
5. **Early Termination and Lazy Evaluation**: Avoiding unnecessary computation

We'll implement simplified versions of these optimization techniques to understand how they work and why they're important for large-scale pedigree reconstruction.

## Part 1: Search Space Pruning

Pedigree reconstruction from genetic data involves exploring a vast combinatorial space of possible pedigree configurations. The number of possible configurations grows exponentially with the number of individuals, making exhaustive search infeasible for real-world datasets. Search space pruning is therefore essential for efficient pedigree reconstruction.

In [ ]:
# Import Bonsai search space pruning functions if available
if not is_jupyterlite():
    try:
        from utils.bonsaitree.bonsaitree.v3.connections import prune_search_space
        
        # Display the source code if available
        print("Source code for prune_search_space:")
        view_source(prune_search_space)
    except (ImportError, AttributeError) as e:
        print(f"Could not import function: {e}")
else:
    print("Cannot display source code in JupyterLite environment.")

### 1.1 Connectivity-Based Clustering

One of the most effective search space pruning techniques in Bonsai v3 is connectivity-based clustering. This approach groups individuals into clusters based on their IBD connectivity, allowing Bonsai to focus on reconstructing each cluster separately before merging them. Let's implement a simplified version of this approach:

In [ ]:
def build_connectivity_graph(id_to_shared_ibd, min_ibd_threshold=50):
    """
    Build a graph representing IBD connectivity between individuals.
    
    Args:
        id_to_shared_ibd: Dict mapping ID pairs to their IBD segments
        min_ibd_threshold: Minimum IBD amount (in cM) to consider for connectivity
        
    Returns:
        G: NetworkX graph with edges representing IBD connections
    """
    G = nx.Graph()
    
    # Add edges for each pair with IBD sharing above threshold
    for (id1, id2), segments in id_to_shared_ibd.items():
        # Calculate total IBD
        total_cm = sum(seg.get('length_cm', 0) for seg in segments)
        
        # Add edge if above threshold
        if total_cm >= min_ibd_threshold:
            G.add_edge(id1, id2, weight=total_cm)
    
    return G

def identify_related_clusters(connectivity_graph, min_size=2):
    """
    Identify clusters of related individuals based on IBD connectivity.
    
    Args:
        connectivity_graph: NetworkX graph with edges representing IBD connections
        min_size: Minimum cluster size to return
        
    Returns:
        clusters: List of sets containing individual IDs in each cluster
    """
    # Find connected components in the graph
    components = list(nx.connected_components(connectivity_graph))
    
    # Filter by minimum size
    clusters = [comp for comp in components if len(comp) >= min_size]
    
    # Sort by size (descending)
    clusters.sort(key=len, reverse=True)
    
    return clusters

def visualize_clusters(connectivity_graph, clusters, title="IBD Connectivity Clusters"):
    """
    Visualize IBD connectivity clusters.
    
    Args:
        connectivity_graph: NetworkX graph with edges representing IBD connections
        clusters: List of sets containing individual IDs in each cluster
        title: Title for the visualization
    """
    plt.figure(figsize=(12, 8))
    plt.title(title)
    
    # Position nodes using force-directed layout
    pos = nx.spring_layout(connectivity_graph, seed=42)
    
    # Generate colors for clusters
    colors = plt.cm.tab10(np.linspace(0, 1, len(clusters)))
    
    # Draw each cluster with a different color
    for i, cluster in enumerate(clusters):
        nx.draw_networkx_nodes(
            connectivity_graph, 
            pos, 
            nodelist=list(cluster),
            node_color=[colors[i]],
            node_size=200,
            alpha=0.8
        )
    
    # Draw edges with width proportional to weight
    edge_weights = [connectivity_graph[u][v]['weight'] / 500 for u, v in connectivity_graph.edges()]
    nx.draw_networkx_edges(
        connectivity_graph, 
        pos, 
        width=edge_weights, 
        alpha=0.5
    )
    
    # Draw labels
    nx.draw_networkx_labels(connectivity_graph, pos, font_size=10)
    
    plt.axis('off')
    plt.tight_layout()
    plt.show()

In [ ]:
# Demonstrate connectivity-based clustering
def generate_sample_ibd_data(num_individuals=30, num_families=3, family_size=10):
    """
    Generate sample IBD data with a family structure.
    
    Args:
        num_individuals: Total number of individuals
        num_families: Number of distinct families
        family_size: Approximate size of each family
        
    Returns:
        id_to_shared_ibd: Dict mapping ID pairs to their IBD segments
    """
    id_to_shared_ibd = {}
    
    # Ensure family_size * num_families <= num_individuals
    actual_family_size = min(family_size, num_individuals // num_families)
    
    # Create families
    families = []
    remaining_ids = list(range(1, num_individuals + 1))
    
    for i in range(num_families):
        # Determine this family's size
        if i == num_families - 1:
            # Last family gets all remaining individuals
            size = len(remaining_ids)
        else:
            # Randomize family size slightly
            size = min(actual_family_size + random.randint(-2, 2), len(remaining_ids))
            size = max(3, size)  # Ensure at least 3 members
        
        # Select members for this family
        family = random.sample(remaining_ids, size)
        families.append(family)
        
        # Remove selected IDs from remaining
        for id_val in family:
            remaining_ids.remove(id_val)
    
    # Add any remaining individuals to random families
    for id_val in remaining_ids:
        family_idx = random.randint(0, num_families - 1)
        families[family_idx].append(id_val)
    
    # Generate IBD sharing within families
    for family in families:
        # Create close relationships within the family
        for i in range(len(family)):
            for j in range(i + 1, len(family)):
                id1, id2 = family[i], family[j]
                
                # Determine relationship based on position in family
                pos_diff = abs(i - j)
                if pos_diff == 1:
                    # Adjacent positions - create parent-child or sibling relationship
                    ibd_amount = random.uniform(1700, 2600)  # Parent-child or full sibling range
                elif pos_diff == 2:
                    # Two steps away - create grandparent or aunt/uncle relationship
                    ibd_amount = random.uniform(700, 1600)  # Grandparent or avuncular range
                else:
                    # More distant relationship
                    ibd_amount = random.uniform(200, 700) / pos_diff  # More distant = less IBD
                
                # Create a segment
                segment = {
                    'chrom': 1,
                    'start_cm': 0,
                    'end_cm': ibd_amount,
                    'length_cm': ibd_amount
                }
                
                # Ensure id1 < id2 for consistent keys
                pair = (min(id1, id2), max(id1, id2))
                if pair not in id_to_shared_ibd:
                    id_to_shared_ibd[pair] = []
                id_to_shared_ibd[pair].append(segment)
    
    # Add some cross-family relationships (more distant)
    num_cross_relationships = num_families * 2
    
    for _ in range(num_cross_relationships):
        # Select two different families
        fam1_idx, fam2_idx = random.sample(range(num_families), 2)
        
        # Select one member from each family
        id1 = random.choice(families[fam1_idx])
        id2 = random.choice(families[fam2_idx])
        
        # Create a more distant relationship
        ibd_amount = random.uniform(50, 300)  # Distant cousin range
        
        # Create a segment
        segment = {
            'chrom': 1,
            'start_cm': 0,
            'end_cm': ibd_amount,
            'length_cm': ibd_amount
        }
        
        # Ensure id1 < id2 for consistent keys
        pair = (min(id1, id2), max(id1, id2))
        if pair not in id_to_shared_ibd:
            id_to_shared_ibd[pair] = []
        id_to_shared_ibd[pair].append(segment)
    
    return id_to_shared_ibd

# Set random seed for reproducibility
random.seed(42)

# Generate sample IBD data
id_to_shared_ibd = generate_sample_ibd_data(
    num_individuals=30,
    num_families=3,
    family_size=10
)

# Print summary of IBD data
print(f"Generated IBD data with {len(id_to_shared_ibd)} pairs")

# Build connectivity graph
connectivity_graph = build_connectivity_graph(id_to_shared_ibd, min_ibd_threshold=100)
print(f"Connectivity graph has {len(connectivity_graph.nodes())} nodes and {len(connectivity_graph.edges())} edges")

# Identify clusters
clusters = identify_related_clusters(connectivity_graph)
print(f"Identified {len(clusters)} clusters:")
for i, cluster in enumerate(clusters):
    print(f"Cluster {i+1}: {len(cluster)} individuals - {sorted(cluster)}")

# Visualize clusters
visualize_clusters(connectivity_graph, clusters, "Sample IBD Connectivity Clusters")

### 1.2 Demographic Constraints

Another important search space pruning technique is the use of demographic constraints to rule out impossible relationships. By leveraging age, sex, and other demographic information, Bonsai can eliminate many implausible pedigree configurations before evaluating them.

In [ ]:
def establish_demographic_constraints(id_to_info):
    """
    Establish constraints based on demographic information.
    
    Args:
        id_to_info: Dict with demographic information for individuals
        
    Returns:
        constraints: Dict of constraints for relationship types
    """
    constraints = {
        'parent_child': [],  # (parent_id, child_id) pairs that are possible
        'sibling': [],       # (id1, id2) pairs that could be siblings
        'invalid': []        # (id1, id2) pairs that cannot be directly related
    }
    
    # Get all pairs of individuals
    ids = list(id_to_info.keys())
    all_pairs = [(ids[i], ids[j]) for i in range(len(ids)) for j in range(i+1, len(ids))]
    
    for id1, id2 in all_pairs:
        info1 = id_to_info.get(id1, {})
        info2 = id_to_info.get(id2, {})
        
        # Extract demographic information
        age1 = info1.get('age')
        age2 = info2.get('age')
        sex1 = info1.get('sex')
        sex2 = info2.get('sex')
        
        # Check parent-child relationships
        if age1 is not None and age2 is not None:
            age_diff = abs(age1 - age2)
            
            if age1 > age2 and age_diff >= 15:
                # id1 could be parent of id2
                constraints['parent_child'].append((id1, id2))
            
            if age2 > age1 and age_diff >= 15:
                # id2 could be parent of id1
                constraints['parent_child'].append((id2, id1))
            
            # Check for sibling relationship
            if abs(age_diff) < 30:
                # Could be siblings
                constraints['sibling'].append((id1, id2))
            
            # Check for impossible direct relationships
            if age_diff < 12:
                # Age difference too small for parent-child
                constraints['invalid'].append(('parent_child', id1, id2))
                constraints['invalid'].append(('parent_child', id2, id1))
        
        # Check sex-based constraints
        if sex1 is not None and sex2 is not None:
            # Two males cannot have a child together
            if sex1 == 'M' and sex2 == 'M':
                constraints['invalid'].append(('parent_pair', id1, id2))
            
            # Two females cannot have a child together
            if sex1 == 'F' and sex2 == 'F':
                constraints['invalid'].append(('parent_pair', id1, id2))
    
    return constraints

def apply_demographic_constraints(clusters, id_to_info):
    """
    Apply demographic constraints to refine clusters.
    
    Args:
        clusters: List of sets containing individual IDs in each cluster
        id_to_info: Dict with demographic information for individuals
        
    Returns:
        refined_clusters: List of sets with additional demographic constraints applied
    """
    # Establish constraints
    constraints = establish_demographic_constraints(id_to_info)
    
    # Extract invalid relationships
    invalid_pairs = [(id1, id2) for rel_type, id1, id2 in constraints['invalid'] 
                    if rel_type == 'parent_child']
    
    # Create a new graph for each cluster with demographic constraints
    refined_clusters = []
    
    for cluster in clusters:
        # Skip small clusters
        if len(cluster) < 3:
            refined_clusters.append(cluster)
            continue
        
        # Create a new graph for this cluster
        G = nx.Graph()
        G.add_nodes_from(cluster)
        
        # Add edges for possible relationships
        for id1 in cluster:
            for id2 in cluster:
                if id1 == id2:
                    continue
                    
                # Skip invalid pairs
                if (id1, id2) in invalid_pairs or (id2, id1) in invalid_pairs:
                    continue
                
                # Add edge for possible relationship
                G.add_edge(id1, id2)
        
        # Find connected components after constraints
        sub_clusters = list(nx.connected_components(G))
        refined_clusters.extend(sub_clusters)
    
    return refined_clusters

In [ ]:
# Demonstrate demographic constraints
def generate_sample_demographic_info(num_individuals=30):
    """
    Generate sample demographic information for individuals.
    
    Args:
        num_individuals: Number of individuals to generate info for
        
    Returns:
        id_to_info: Dict mapping IDs to demographic information
    """
    id_to_info = {}
    
    # Define age ranges for different generations
    gen1_range = (60, 80)  # Grandparent generation
    gen2_range = (35, 55)  # Parent generation
    gen3_range = (10, 30)  # Child generation
    
    # Assign generations to individuals
    gen1_count = num_individuals // 5
    gen2_count = num_individuals // 3
    gen3_count = num_individuals - gen1_count - gen2_count
    
    generations = ([1] * gen1_count) + ([2] * gen2_count) + ([3] * gen3_count)
    random.shuffle(generations)
    
    for i in range(1, num_individuals + 1):
        gen = generations[i-1]
        
        # Assign age based on generation
        if gen == 1:
            age = random.randint(*gen1_range)
        elif gen == 2:
            age = random.randint(*gen2_range)
        else:
            age = random.randint(*gen3_range)
        
        # Assign sex
        sex = random.choice(['M', 'F'])
        
        # Store demographic information
        id_to_info[i] = {
            'id': i,
            'age': age,
            'sex': sex,
            'generation': gen
        }
    
    return id_to_info

# Generate demographic information
id_to_info = generate_sample_demographic_info(30)

# Print summary of demographic information
print("Sample demographic information:")
generations = defaultdict(list)
for id_val, info in id_to_info.items():
    generations[info['generation']].append(id_val)

for gen, ids in sorted(generations.items()):
    print(f"Generation {gen} ({len(ids)} individuals): {sorted(ids)}")
    
# Establish demographic constraints
constraints = establish_demographic_constraints(id_to_info)

# Print summary of constraints
print("\
Demographic constraints:")
print(f"Possible parent-child relationships: {len(constraints['parent_child'])}")
print(f"Possible sibling relationships: {len(constraints['sibling'])}")
print(f"Invalid relationships: {len(constraints['invalid'])}")

# Apply demographic constraints to clusters
refined_clusters = apply_demographic_constraints(clusters, id_to_info)

# Print summary of refined clusters
print(f"\
After applying demographic constraints: {len(refined_clusters)} clusters")
for i, cluster in enumerate(refined_clusters):
    print(f"Refined Cluster {i+1}: {len(cluster)} individuals - {sorted(cluster)}")

### 1.3 Combining Search Space Pruning Techniques

Now let's combine these pruning techniques to create a simplified version of Bonsai v3's search space pruning:

In [ ]:
def prune_search_space_simplified(id_to_shared_ibd, id_to_info):
    """
    Simplified implementation of Bonsai v3's search space pruning.
    
    Args:
        id_to_shared_ibd: Dict mapping ID pairs to their IBD segments
        id_to_info: Dict with demographic information for individuals
        
    Returns:
        pruned_space: Dict containing the pruned search space
    """
    pruned_space = {}
    
    # Step 1: Build connectivity graph
    connectivity_graph = build_connectivity_graph(id_to_shared_ibd, min_ibd_threshold=100)
    pruned_space['connectivity_graph'] = connectivity_graph
    
    # Step 2: Identify initial clusters
    initial_clusters = identify_related_clusters(connectivity_graph)
    pruned_space['initial_clusters'] = initial_clusters
    
    # Step 3: Apply demographic constraints
    refined_clusters = apply_demographic_constraints(initial_clusters, id_to_info)
    pruned_space['refined_clusters'] = refined_clusters
    
    # Step 4: Establish search order for clusters
    # (Prioritize larger clusters and those with more IBD connections)
    search_order = []
    
    # Calculate total IBD within each cluster
    cluster_scores = []
    for i, cluster in enumerate(refined_clusters):
        # Calculate total IBD within cluster
        total_ibd = 0
        cluster_list = list(cluster)
        for j in range(len(cluster_list)):
            for k in range(j + 1, len(cluster_list)):
                id1, id2 = cluster_list[j], cluster_list[k]
                pair = (min(id1, id2), max(id1, id2))
                if pair in id_to_shared_ibd:
                    segments = id_to_shared_ibd[pair]
                    total_ibd += sum(seg.get('length_cm', 0) for seg in segments)
        
        # Score is a combination of size and IBD density
        score = len(cluster) * 100 + total_ibd
        cluster_scores.append((i, score))
    
    # Sort by score (descending)
    cluster_scores.sort(key=lambda x: x[1], reverse=True)
    search_order = [refined_clusters[i] for i, _ in cluster_scores]
    pruned_space['search_order'] = search_order
    
    # Step 5: Generate simplified search partitions
    # (Each partition represents a subset of the search space to explore independently)
    partitions = []
    for cluster in search_order:
        # Create a partition for this cluster
        partition = {
            'individuals': list(cluster),
            'constraints': establish_demographic_constraints({id_val: id_to_info[id_val] for id_val in cluster if id_val in id_to_info})
        }
        partitions.append(partition)
    
    pruned_space['partitions'] = partitions
    
    return pruned_space

# Demonstrate combined pruning
pruned_space = prune_search_space_simplified(id_to_shared_ibd, id_to_info)

# Print summary of pruned search space
print("Pruned Search Space Summary:")
print(f"Initial clusters: {len(pruned_space['initial_clusters'])}")
print(f"Refined clusters after demographic constraints: {len(pruned_space['refined_clusters'])}")
print(f"Search partitions: {len(pruned_space['partitions'])}")

# Print search order
print("\
Search Order:")
for i, cluster in enumerate(pruned_space['search_order']):
    print(f"{i+1}. Cluster with {len(cluster)} individuals: {sorted(cluster)}")

# Calculate search space reduction
total_individuals = len(set().union(*pruned_space['initial_clusters']))
avg_partition_size = sum(len(partition['individuals']) for partition in pruned_space['partitions']) / len(pruned_space['partitions'])

# Naive search space size (exponential in number of individuals)
naive_space_size = 2 ** total_individuals

# Pruned search space size (sum of exponentials for each partition)
pruned_space_size = sum(2 ** len(partition['individuals']) for partition in pruned_space['partitions'])

# Calculate reduction
reduction_factor = naive_space_size / pruned_space_size

print(f"\
Search Space Reduction:")
print(f"Total individuals: {total_individuals}")
print(f"Average partition size: {avg_partition_size:.1f}")
print(f"Naive search space size: 2^{total_individuals} ≈ {naive_space_size:.2e}")
print(f"Pruned search space size: ≈ {pruned_space_size:.2e}")
print(f"Reduction factor: ≈ {reduction_factor:.2e}x")

## Part 2: Parallel Processing

Another key optimization technique in Bonsai v3 is parallel processing, which leverages multiple CPU cores to speed up computation. This is particularly important for operations that can be naturally parallelized, such as evaluating multiple pedigree configurations or processing different chromosomes independently.

In [ ]:
class ParallelExecutor:
    """
    Simplified implementation of Bonsai v3's parallel processing.
    """
    
    def __init__(self, max_workers=None, use_processes=False):
        """
        Initialize the parallel executor.
        
        Args:
            max_workers: Maximum number of worker threads/processes
            use_processes: Whether to use processes instead of threads
        """
        self.max_workers = max_workers or min(32, os.cpu_count() + 4)
        self.use_processes = use_processes
        self._executor = None
    
    def execute(self, tasks):
        """
        Execute a list of tasks in parallel.
        
        Args:
            tasks: List of (function, args, kwargs) tuples to execute
            
        Returns:
            List of results from all tasks
        """
        executor_cls = concurrent.futures.ProcessPoolExecutor if self.use_processes else concurrent.futures.ThreadPoolExecutor
        
        with executor_cls(max_workers=self.max_workers) as executor:
            # Submit all tasks
            futures = []
            for func, args, kwargs in tasks:
                future = executor.submit(func, *args, **kwargs)
                futures.append(future)
            
            # Collect results
            results = []
            for future in concurrent.futures.as_completed(futures):
                try:
                    result = future.result()
                    results.append(result)
                except Exception as e:
                    print(f"Task failed with error: {e}")
                    results.append(None)
            
            return results

def simulate_expensive_calculation(partition, delay=0.1):
    """
    Simulate an expensive calculation on a partition.
    
    Args:
        partition: A partition from the pruned search space
        delay: Simulated computation time in seconds
        
    Returns:
        result: Simulated calculation result
    """
    # Simulate computation time
    time.sleep(delay)
    
    # Simulated calculation: For each individual in the partition,
    # find all possible relationships with others in the partition
    individuals = partition['individuals']
    constraints = partition['constraints']
    
    # Calculate number of possible relationship configurations
    n = len(individuals)
    possible_configs = 0
    
    for i in range(n):
        for j in range(i+1, n):
            id1, id2 = individuals[i], individuals[j]
            
            # Check if relationship is allowed by constraints
            invalid_pair = False
            for rel_type, i1, i2 in constraints.get('invalid', []):
                if (i1 == id1 and i2 == id2) or (i1 == id2 and i2 == id1):
                    invalid_pair = True
                    break
            
            if not invalid_pair:
                # Assume 4 possible relationship types for each valid pair
                possible_configs += 4
    
    return {
        'partition_size': n,
        'possible_configs': possible_configs,
        'total_pairs': n * (n - 1) // 2,
        'computation_time': delay
    }

In [ ]:
# Demonstrate parallel processing
def process_all_partitions_sequential(partitions):
    """
    Process all partitions sequentially.
    
    Args:
        partitions: List of partitions to process
        
    Returns:
        results: List of results for each partition
        elapsed_time: Time taken in seconds
    """
    start_time = time.time()
    
    results = []
    for partition in partitions:
        result = simulate_expensive_calculation(partition)
        results.append(result)
    
    elapsed_time = time.time() - start_time
    
    return results, elapsed_time

def process_all_partitions_parallel(partitions, max_workers=None):
    """
    Process all partitions in parallel.
    
    Args:
        partitions: List of partitions to process
        max_workers: Maximum number of worker threads
        
    Returns:
        results: List of results for each partition
        elapsed_time: Time taken in seconds
    """
    start_time = time.time()
    
    # Create tasks
    tasks = [(simulate_expensive_calculation, (partition,), {}) for partition in partitions]
    
    # Execute tasks in parallel
    executor = ParallelExecutor(max_workers=max_workers)
    results = executor.execute(tasks)
    
    elapsed_time = time.time() - start_time
    
    return results, elapsed_time

# Get partitions from our pruned search space
partitions = pruned_space['partitions']

# Add a delay to make the speed difference more noticeable
for partition in partitions:
    partition['delay'] = 0.2 * len(partition['individuals']) / 10

# Process sequentially
print("Processing partitions sequentially...")
sequential_results, sequential_time = process_all_partitions_sequential(partitions)
print(f"Sequential processing took {sequential_time:.2f} seconds")

# Process in parallel
print("\
Processing partitions in parallel...")
parallel_results, parallel_time = process_all_partitions_parallel(partitions)
print(f"Parallel processing took {parallel_time:.2f} seconds")

# Calculate speedup
speedup = sequential_time / parallel_time
print(f"Speedup: {speedup:.2f}x")

# Compare results
print("\
Results summary:")
sequential_configs = sum(result['possible_configs'] for result in sequential_results)
parallel_configs = sum(result['possible_configs'] for result in parallel_results)
print(f"Sequential processing found {sequential_configs} possible configurations")
print(f"Parallel processing found {parallel_configs} possible configurations")

# Plot results
plt.figure(figsize=(10, 6))
plt.title("Sequential vs. Parallel Processing Time")
plt.bar(['Sequential', 'Parallel'], [sequential_time, parallel_time], color=['blue', 'green'])
plt.ylabel("Time (seconds)")
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

# Plot scaling with number of workers
def measure_scaling():
    """Measure how performance scales with number of workers"""
    worker_counts = [1, 2, 4, 8, 16]
    times = []
    
    for workers in worker_counts:
        _, elapsed_time = process_all_partitions_parallel(partitions, max_workers=workers)
        times.append(elapsed_time)
    
    return worker_counts, times

worker_counts, times = measure_scaling()

plt.figure(figsize=(10, 6))
plt.title("Scaling with Number of Workers")
plt.plot(worker_counts, times, 'o-', linewidth=2, markersize=8)
plt.xlabel("Number of Workers")
plt.ylabel("Time (seconds)")
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

## Part 3: Adaptive Parameter Selection

Bonsai v3 uses adaptive parameter selection to optimize performance based on dataset characteristics. Rather than using fixed parameters for all situations, Bonsai dynamically adjusts its parameters based on factors like dataset size, IBD density, and computational resources available.

In [ ]:
def optimize_parameters(id_to_shared_ibd, id_to_info, available_memory=8 * 1024 * 1024 * 1024, available_cores=None):
    """
    Dynamically optimize algorithm parameters based on dataset characteristics.
    
    Args:
        id_to_shared_ibd: Dict mapping ID pairs to their IBD segments
        id_to_info: Dict with demographic information for individuals
        available_memory: Available memory in bytes
        available_cores: Available CPU cores
        
    Returns:
        optimized_params: Dict of optimized parameter values
    """
    # Use all available cores if not specified
    available_cores = available_cores or os.cpu_count()
    
    # Get dataset characteristics
    num_individuals = len(set([id for pair in id_to_shared_ibd for id in pair]))
    
    # Calculate IBD density and average length
    total_pairs = num_individuals * (num_individuals - 1) // 2
    actual_pairs = len(id_to_shared_ibd)
    ibd_density = actual_pairs / total_pairs if total_pairs > 0 else 0
    
    total_length = 0
    total_segments = 0
    for segments in id_to_shared_ibd.values():
        for segment in segments:
            total_length += segment.get('length_cm', 0)
            total_segments += 1
    
    avg_ibd_length = total_length / total_segments if total_segments > 0 else 0
    
    # Initialize parameters with default values
    params = {
        'max_up': 3,             # Maximum generations to extend upward
        'n_keep': 5,             # Number of top pedigrees to keep
        'ibd_threshold': 20,     # Minimum IBD amount to consider (cM)
        'max_iterations': 100,   # Maximum iterations for optimization
        'batch_size': 10,        # Batch size for parallel processing
        'use_threading': True,   # Whether to use threading
    }
    
    # Adjust max_up based on IBD density and length
    if ibd_density > 0.5 and avg_ibd_length > 1000:
        # Dense IBD with long segments - likely close relatives
        params['max_up'] = 2
    elif ibd_density < 0.1 or avg_ibd_length < 100:
        # Sparse IBD with short segments - likely distant relatives
        params['max_up'] = 4
    
    # Adjust n_keep based on available memory
    memory_gb = available_memory / (1024 * 1024 * 1024)
    if memory_gb < 4:  # Less than 4GB
        params['n_keep'] = 3
    elif memory_gb > 16:  # More than 16GB
        params['n_keep'] = 10
    
    # Adjust batch_size based on available cores
    params['batch_size'] = min(max(available_cores, 2), 32)
    
    # Adjust ibd_threshold based on dataset size
    if num_individuals > 100:
        params['ibd_threshold'] = 30
    elif num_individuals < 20:
        params['ibd_threshold'] = 10
    
    # Decide whether to use threading or processes
    if memory_gb > 8:  # More than 8GB
        params['use_threading'] = False  # Use processes for better parallelism
    
    return params

In [ ]:
# Demonstrate adaptive parameter selection
def generate_random_ibd_data(num_individuals, ibd_density, avg_segment_length):
    """
    Generate random IBD data with given characteristics.
    
    Args:
        num_individuals: Number of individuals
        ibd_density: Density of IBD connections (0-1)
        avg_segment_length: Average segment length in cM
        
    Returns:
        id_to_shared_ibd: Dict mapping ID pairs to their IBD segments
    """
    id_to_shared_ibd = {}
    
    # Generate IDs
    ids = list(range(1, num_individuals + 1))
    
    # Calculate total possible pairs
    total_pairs = num_individuals * (num_individuals - 1) // 2
    
    # Determine number of pairs that share IBD
    num_ibd_pairs = int(total_pairs * ibd_density)
    
    # Randomly select pairs to share IBD
    all_pairs = [(i, j) for i in ids for j in ids if i < j]
    ibd_pairs = random.sample(all_pairs, min(num_ibd_pairs, len(all_pairs)))
    
    # Generate random segments for each pair
    for id1, id2 in ibd_pairs:
        # Number of segments (random between 1-5)
        num_segments = random.randint(1, 5)
        
        # Generate segments
        segments = []
        for _ in range(num_segments):
            # Random segment length with variation around avg_segment_length
            length = max(1, random.gauss(avg_segment_length, avg_segment_length / 4))
            
            segment = {
                'chrom': random.randint(1, 22),
                'start_cm': random.uniform(0, 200),
                'end_cm': 0,  # Will be set based on length
                'length_cm': length
            }
            segment['end_cm'] = segment['start_cm'] + length
            
            segments.append(segment)
        
        # Store segments
        id_to_shared_ibd[(id1, id2)] = segments
    
    return id_to_shared_ibd

# Generate different datasets with varying characteristics
small_close_dataset = generate_random_ibd_data(
    num_individuals=15,
    ibd_density=0.7,
    avg_segment_length=1500
)

medium_mixed_dataset = generate_random_ibd_data(
    num_individuals=50,
    ibd_density=0.3,
    avg_segment_length=500
)

large_distant_dataset = generate_random_ibd_data(
    num_individuals=200,
    ibd_density=0.05,
    avg_segment_length=100
)

# Generate random demographic info for each dataset
def generate_random_demographic_info(ids):
    """Generate random demographic info for a list of IDs"""
    id_to_info = {}
    for id_val in ids:
        id_to_info[id_val] = {
            'id': id_val,
            'age': random.randint(10, 80),
            'sex': random.choice(['M', 'F'])
        }
    return id_to_info

small_demo = generate_random_demographic_info(range(1, 16))
medium_demo = generate_random_demographic_info(range(1, 51))
large_demo = generate_random_demographic_info(range(1, 201))

# Optimize parameters for each dataset
small_params = optimize_parameters(small_close_dataset, small_demo)
medium_params = optimize_parameters(medium_mixed_dataset, medium_demo)
large_params = optimize_parameters(large_distant_dataset, large_demo)

# Print optimized parameters
print("Parameters for small dataset with close relatives:")
for param, value in small_params.items():
    print(f"  {param}: {value}")

print("\
Parameters for medium dataset with mixed relationships:")
for param, value in medium_params.items():
    print(f"  {param}: {value}")

print("\
Parameters for large dataset with distant relatives:")
for param, value in large_params.items():
    print(f"  {param}: {value}")

# Visualize parameter adaptations
params_to_plot = ['max_up', 'n_keep', 'ibd_threshold', 'batch_size']
dataset_labels = ['Small, Close', 'Medium, Mixed', 'Large, Distant']
param_values = [
    [small_params[p] for p in params_to_plot],
    [medium_params[p] for p in params_to_plot],
    [large_params[p] for p in params_to_plot]
]

# Create a grouped bar chart
fig, ax = plt.subplots(figsize=(12, 6))
x = np.arange(len(params_to_plot))
width = 0.25

for i, (values, label) in enumerate(zip(param_values, dataset_labels)):
    ax.bar(x + i * width, values, width, label=label)

ax.set_title('Adaptive Parameter Selection')
ax.set_ylabel('Parameter Value')
ax.set_xticks(x + width)
ax.set_xticklabels(params_to_plot)
ax.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

## Part 4: Specialized Data Structures

Bonsai v3 uses specialized data structures optimized for the specific requirements of pedigree reconstruction. These data structures are designed to minimize memory usage and computational overhead while still providing efficient access to the information needed for reconstruction.

In [ ]:
class CompactIBDStore:
    """
    Memory-efficient storage for IBD segment data.
    
    This class provides a compact representation of IBD segments,
    optimized for the specific access patterns used in Bonsai v3.
    """
    
    def __init__(self, id_to_shared_ibd):
        """
        Initialize the compact IBD store from a standard IBD representation.
        
        Args:
            id_to_shared_ibd: Dict mapping ID pairs to their IBD segments
        """
        # Convert to more efficient representation
        self.pairs = []
        self.segments = []
        self.pair_to_segments = {}
        self.pair_to_idx = {}
        
        for (id1, id2), segs in id_to_shared_ibd.items():
            pair_idx = len(self.pairs)
            self.pairs.append((id1, id2))
            self.pair_to_idx[(id1, id2)] = pair_idx
            
            # Store segment indices
            seg_indices = []
            for seg in segs:
                seg_idx = len(self.segments)
                # Store only essential fields
                compact_seg = {
                    'chrom': seg.get('chrom', 0),
                    'start_cm': seg.get('start_cm', 0),
                    'end_cm': seg.get('end_cm', 0),
                    'length_cm': seg.get('length_cm', 0)
                }
                self.segments.append(compact_seg)
                seg_indices.append(seg_idx)
            
            self.pair_to_segments[pair_idx] = seg_indices
    
    def get_shared_segments(self, id1, id2):
        """
        Get the IBD segments shared by two individuals.
        
        Args:
            id1, id2: IDs of the individuals
            
        Returns:
            List of shared IBD segments
        """
        pair = (min(id1, id2), max(id1, id2))
        pair_idx = self.pair_to_idx.get(pair)
        
        if pair_idx is None:
            return []
            
        seg_indices = self.pair_to_segments.get(pair_idx, [])
        return [self.segments[idx] for idx in seg_indices]
    
    def get_total_ibd(self, id1, id2):
        """
        Get the total IBD shared by two individuals.
        
        Args:
            id1, id2: IDs of the individuals
            
        Returns:
            Total shared IBD in centimorgans
        """
        segments = self.get_shared_segments(id1, id2)
        return sum(seg['length_cm'] for seg in segments)
    
    def memory_usage(self):
        """
        Estimate memory usage of this data structure.
        
        Returns:
            Estimated memory usage in bytes
        """
        # Rough estimate based on number of objects
        pairs_size = len(self.pairs) * 16  # Two integers per pair
        segments_size = len(self.segments) * 32  # Four floats per segment
        index_size = len(self.pair_to_segments) * 24  # Dict entry overhead
        
        return pairs_size + segments_size + index_size

class SparseRelationshipMatrix:
    """
    Efficient representation of pairwise relationships.
    
    This class provides a memory-efficient representation of pairwise
    relationships in a pedigree, using sparse matrix techniques.
    """
    
    def __init__(self):
        """
        Initialize an empty sparse relationship matrix.
        """
        self.relationships = {}  # (id1, id2) -> relationship tuple
    
    def set_relationship(self, id1, id2, relationship):
        """
        Set the relationship between two individuals.
        
        Args:
            id1, id2: IDs of the individuals
            relationship: Tuple of (up, down, num_ancs)
        """
        # Ensure id1 < id2 for consistent keys
        if id1 > id2:
            id1, id2 = id2, id1
            # Swap up and down for symmetric relationship representation
            up, down, num_ancs = relationship
            relationship = (down, up, num_ancs)
        
        self.relationships[(id1, id2)] = relationship
    
    def get_relationship(self, id1, id2):
        """
        Get the relationship between two individuals.
        
        Args:
            id1, id2: IDs of the individuals
            
        Returns:
            Relationship tuple (up, down, num_ancs) or None if not set
        """
        # Ensure id1 < id2 for consistent keys
        if id1 > id2:
            id1, id2 = id2, id1
            swap = True
        else:
            swap = False
        
        relationship = self.relationships.get((id1, id2))
        
        if relationship is not None and swap:
            # Swap up and down for symmetric relationship representation
            up, down, num_ancs = relationship
            return (down, up, num_ancs)
        
        return relationship
    
    def get_all_relationships_for(self, id_val):
        """
        Get all relationships involving a specific individual.
        
        Args:
            id_val: ID of the individual
            
        Returns:
            Dict mapping other IDs to relationship tuples
        """
        result = {}
        
        # Check for relationships where id_val is first in the pair
        for (id1, id2), relationship in self.relationships.items():
            if id1 == id_val:
                result[id2] = relationship
            elif id2 == id_val:
                # Swap up and down for symmetric relationship representation
                up, down, num_ancs = relationship
                result[id1] = (down, up, num_ancs)
        
        return result

In [ ]:
# Demonstrate specialized data structures
def compare_ibd_storage_performance(id_to_shared_ibd):
    """
    Compare performance of different IBD storage methods.
    
    Args:
        id_to_shared_ibd: Dict mapping ID pairs to their IBD segments
        
    Returns:
        comparison: Dict containing performance metrics
    """
    # Measure standard dict storage
    start_time = time.time()
    queries = []
    
    # Create a list of all pairs to query
    pairs = list(id_to_shared_ibd.keys())
    
    # Add some pairs that don't exist
    max_id = max([max(pair) for pair in pairs])
    for _ in range(100):
        id1 = random.randint(1, max_id)
        id2 = random.randint(1, max_id)
        if id1 != id2:
            queries.append((min(id1, id2), max(id1, id2)))
    
    # Add all existing pairs
    queries.extend(pairs)
    random.shuffle(queries)
    
    # Measure standard dict performance
    standard_start = time.time()
    standard_results = []
    
    for id1, id2 in queries:
        pair = (id1, id2)
        if pair in id_to_shared_ibd:
            segments = id_to_shared_ibd[pair]
            total_cm = sum(seg.get('length_cm', 0) for seg in segments)
            standard_results.append(total_cm)
        else:
            standard_results.append(0)
    
    standard_time = time.time() - standard_start
    
    # Create and measure CompactIBDStore
    compact_store = CompactIBDStore(id_to_shared_ibd)
    
    compact_start = time.time()
    compact_results = []
    
    for id1, id2 in queries:
        total_cm = compact_store.get_total_ibd(id1, id2)
        compact_results.append(total_cm)
    
    compact_time = time.time() - compact_start
    
    # Compare memory usage (rough estimate)
    standard_size = sys.getsizeof(id_to_shared_ibd)
    for pair, segments in id_to_shared_ibd.items():
        standard_size += sys.getsizeof(pair)
        standard_size += sys.getsizeof(segments)
        for segment in segments:
            standard_size += sys.getsizeof(segment)
    
    compact_size = compact_store.memory_usage()
    
    # Verify results are the same
    results_match = standard_results == compact_results
    
    return {
        'standard_time': standard_time,
        'compact_time': compact_time,
        'standard_size': standard_size,
        'compact_size': compact_size,
        'speedup': standard_time / compact_time if compact_time > 0 else float('inf'),
        'memory_reduction': standard_size / compact_size if compact_size > 0 else float('inf'),
        'results_match': results_match
    }

# Create a large test dataset
large_test_dataset = generate_random_ibd_data(
    num_individuals=1000,
    ibd_density=0.01,
    avg_segment_length=200
)

# Compare standard and compact storage
print("Comparing IBD storage methods...")
comparison = compare_ibd_storage_performance(large_test_dataset)

print(f"Standard Dict:")
print(f"  Lookup time: {comparison['standard_time']:.6f} seconds")
print(f"  Estimated memory: {comparison['standard_size'] / (1024*1024):.2f} MB")

print(f"\
CompactIBDStore:")
print(f"  Lookup time: {comparison['compact_time']:.6f} seconds")
print(f"  Estimated memory: {comparison['compact_size'] / (1024*1024):.2f} MB")

print(f"\
Performance Comparison:")
print(f"  Speedup: {comparison['speedup']:.2f}x")
print(f"  Memory reduction: {comparison['memory_reduction']:.2f}x")
print(f"  Results match: {comparison['results_match']}")

# Demonstrate SparseRelationshipMatrix
def test_sparse_relationship_matrix():
    """Test the SparseRelationshipMatrix class"""
    matrix = SparseRelationshipMatrix()
    
    # Add some relationships
    matrix.set_relationship(1, 2, (0, 1, 1))  # 1 is parent of 2
    matrix.set_relationship(3, 4, (1, 1, 2))  # 3 and 4 are full siblings
    matrix.set_relationship(5, 6, (1, 1, 1))  # 5 and 6 are half siblings
    
    # Test retrieval
    print("Relationship between 1 and 2:", matrix.get_relationship(1, 2))
    print("Relationship between 2 and 1:", matrix.get_relationship(2, 1))
    print("Relationship between 3 and 4:", matrix.get_relationship(3, 4))
    print("Relationship between 5 and 6:", matrix.get_relationship(5, 6))
    
    # Test getting all relationships for an individual
    print("\
All relationships for individual 1:", matrix.get_all_relationships_for(1))
    print("All relationships for individual 2:", matrix.get_all_relationships_for(2))

print("\
Testing SparseRelationshipMatrix:")
test_sparse_relationship_matrix()

## Part 5: Early Termination and Lazy Evaluation

Bonsai v3 uses early termination and lazy evaluation strategies to avoid unnecessary computation. These techniques allow Bonsai to quickly eliminate unpromising pedigree configurations without fully evaluating them.

In [ ]:
def evaluate_pedigree_with_early_termination(up_dct, id_to_shared_ibd, early_term_threshold=-1000.0):
    """
    Evaluate the likelihood of a pedigree with early termination.
    
    Args:
        up_dct: Up-node dictionary representing the pedigree
        id_to_shared_ibd: Dict mapping ID pairs to their IBD segments
        early_term_threshold: Threshold for early termination
        
    Returns:
        log_likelihood: Log-likelihood of the pedigree, or None if terminated early
    """
    # Initialize log-likelihood
    log_like = 0.0
    
    # Get all pairs of individuals in the pedigree
    all_ids = list(up_dct.keys())
    all_pairs = [(all_ids[i], all_ids[j]) for i in range(len(all_ids)) for j in range(i+1, len(all_ids))]
    
    # First evaluate high-IBD pairs (more informative)
    high_ibd_pairs = []
    low_ibd_pairs = []
    
    for id1, id2 in all_pairs:
        pair = (min(id1, id2), max(id1, id2))
        if pair in id_to_shared_ibd:
            total_cm = sum(seg.get('length_cm', 0) for seg in id_to_shared_ibd[pair])
            if total_cm > 100:  # High IBD threshold
                high_ibd_pairs.append((id1, id2, total_cm))
            else:
                low_ibd_pairs.append((id1, id2, total_cm))
        else:
            low_ibd_pairs.append((id1, id2, 0))
    
    # Sort high-IBD pairs by IBD amount (descending)
    high_ibd_pairs.sort(key=lambda x: x[2], reverse=True)
    
    # Evaluate high-IBD pairs first
    for id1, id2, total_cm in high_ibd_pairs:
        # Get relationship in the pedigree (simplified for demo)
        relationship = "direct"  # This would normally be determined from up_dct
        
        # Calculate likelihood for this pair based on relationship and IBD
        if relationship == "direct":
            # Direct relatives should have high IBD
            if total_cm < 700:
                pair_ll = -500  # Severe penalty for direct relatives with low IBD
            else:
                pair_ll = math.log(1 + total_cm)
        else:
            # Non-direct relatives should have lower IBD
            if total_cm > 1500:
                pair_ll = -300  # Penalty for non-direct relatives with high IBD
            else:
                pair_ll = math.log(1 + total_cm / 10)
        
        # Update total likelihood
        log_like += pair_ll
        
        # Check for early termination
        if log_like < early_term_threshold:
            print(f"Early termination after {len(high_ibd_pairs)} high-IBD pairs. Score: {log_like:.2f}")
            return None
    
    # Evaluate remaining pairs
    for id1, id2, total_cm in low_ibd_pairs:
        # Similar calculation as above, but for low-IBD pairs
        relationship = "distant"  # Simplified
        
        if relationship == "direct":
            pair_ll = -200  # Severe penalty for direct relatives with low IBD
        else:
            pair_ll = math.log(1 + total_cm / 50)
        
        # Update total likelihood
        log_like += pair_ll
        
        # Check for early termination
        if log_like < early_term_threshold:
            return None
    
    return log_like

In [ ]:
# Demonstrate early termination
def generate_pedigrees_with_varying_quality(id_to_shared_ibd, num_pedigrees=5):
    """
    Generate pedigrees with varying quality for demonstration.
    
    Args:
        id_to_shared_ibd: Dict mapping ID pairs to their IBD segments
        num_pedigrees: Number of pedigrees to generate
        
    Returns:
        pedigrees: List of pedigrees with varying quality
    """
    pedigrees = []
    
    # Get all individuals
    all_ids = set()
    for id1, id2 in id_to_shared_ibd.keys():
        all_ids.add(id1)
        all_ids.add(id2)
    
    # Create a good pedigree - preserve high-IBD relationships
    good_ped = {id_val: {} for id_val in all_ids}
    
    # Find pairs with high IBD
    high_ibd_pairs = []
    for pair, segments in id_to_shared_ibd.items():
        total_cm = sum(seg.get('length_cm', 0) for seg in segments)
        if total_cm > 1000:  # Parent-child or sibling level
            high_ibd_pairs.append((pair[0], pair[1], total_cm))
    
    # Sort by IBD amount (descending)
    high_ibd_pairs.sort(key=lambda x: x[2], reverse=True)
    
    # Add parent-child relationships for highest IBD pairs
    for id1, id2, _ in high_ibd_pairs[:len(high_ibd_pairs)//2]:
        # Make id1 parent of id2
        good_ped[id2][id1] = 1
    
    pedigrees.append(("Good", good_ped))
    
    # Create a mediocre pedigree - some relationships preserved, some not
    mediocre_ped = {id_val: {} for id_val in all_ids}
    
    # Add parent-child relationships for some high IBD pairs
    for id1, id2, _ in high_ibd_pairs[:len(high_ibd_pairs)//4]:
        # Make id1 parent of id2
        mediocre_ped[id2][id1] = 1
    
    # Add some random relationships
    for _ in range(5):
        id1, id2 = random.sample(list(all_ids), 2)
        if id1 != id2:
            mediocre_ped[id2][id1] = 1
    
    pedigrees.append(("Mediocre", mediocre_ped))
    
    # Create increasingly poor pedigrees with random relationships
    for i in range(num_pedigrees - 2):
        poor_ped = {id_val: {} for id_val in all_ids}
        
        # Add random relationships
        for _ in range(10 + i * 5):
            id1, id2 = random.sample(list(all_ids), 2)
            if id1 != id2:
                poor_ped[id2][id1] = 1
        
        pedigrees.append((f"Poor {i+1}", poor_ped))
    
    return pedigrees

# Demonstrate early termination with different pedigrees
def compare_evaluation_methods(pedigrees, id_to_shared_ibd):
    """
    Compare standard evaluation with early termination.
    
    Args:
        pedigrees: List of (name, pedigree) tuples
        id_to_shared_ibd: Dict mapping ID pairs to their IBD segments
    """
    results = []
    
    for name, pedigree in pedigrees:
        # Evaluate with standard method (no early termination)
        start_time = time.time()
        standard_ll = evaluate_pedigree_with_early_termination(pedigree, id_to_shared_ibd, float('-inf'))
        standard_time = time.time() - start_time
        
        # Evaluate with early termination
        start_time = time.time()
        early_term_ll = evaluate_pedigree_with_early_termination(pedigree, id_to_shared_ibd, -500)
        early_term_time = time.time() - start_time
        
        terminated_early = early_term_ll is None
        
        results.append({
            'name': name,
            'standard_ll': standard_ll,
            'standard_time': standard_time,
            'early_term_ll': early_term_ll,
            'early_term_time': early_term_time,
            'terminated_early': terminated_early,
            'speedup': standard_time / early_term_time if early_term_time > 0 else float('inf')
        })
    
    # Print results
    print(f"{'Pedigree':<10} | {'Standard LL':<12} | {'Standard Time':<14} | {'Early Term LL':<14} | {'Early Term Time':<15} | {'Terminated':<10} | {'Speedup':<8}")
    print("-" * 90)
    
    for result in results:
        standard_ll = f"{result['standard_ll']:.2f}" if result['standard_ll'] is not None else "None"
        early_term_ll = f"{result['early_term_ll']:.2f}" if result['early_term_ll'] is not None else "None"
        
        print(f"{result['name']:<10} | {standard_ll:<12} | {result['standard_time']:.6f} s | "
              f"{early_term_ll:<14} | {result['early_term_time']:.6f} s | "
              f"{result['terminated_early']!s:<10} | {result['speedup']:.2f}x")
    
    # Plot results
    plt.figure(figsize=(12, 6))
    plt.title("Standard vs. Early Termination Evaluation Time")
    
    names = [r['name'] for r in results]
    standard_times = [r['standard_time'] * 1000 for r in results]  # Convert to ms
    early_term_times = [r['early_term_time'] * 1000 for r in results]  # Convert to ms
    
    x = np.arange(len(names))
    width = 0.35
    
    plt.bar(x - width/2, standard_times, width, label='Standard')
    plt.bar(x + width/2, early_term_times, width, label='Early Termination')
    
    plt.xlabel('Pedigree')
    plt.ylabel('Evaluation Time (ms)')
    plt.xticks(x, names)
    plt.legend()
    
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.show()

# Generate test data
random.seed(42)
test_ibd_data = generate_random_ibd_data(
    num_individuals=20,
    ibd_density=0.3,
    avg_segment_length=500
)

# Generate pedigrees with varying quality
pedigrees = generate_pedigrees_with_varying_quality(test_ibd_data)

# Compare evaluation methods
compare_evaluation_methods(pedigrees, test_ibd_data)

## Summary

In this lab, we explored the sophisticated optimization techniques used in Bonsai v3 to enable efficient pedigree reconstruction at scale. These optimizations are crucial for handling real-world genetic genealogy datasets with hundreds or thousands of individuals.

Key optimization strategies we explored include:

1. **Search Space Pruning**: By clustering individuals based on IBD connectivity and applying demographic constraints, Bonsai v3 can dramatically reduce the number of pedigree configurations to evaluate. We saw how this pruning can reduce the search space by many orders of magnitude.

2. **Parallel Processing**: By leveraging multiple CPU cores, Bonsai v3 can significantly speed up computation for tasks that can be naturally parallelized, such as evaluating multiple pedigree configurations or processing different chromosomes independently.

3. **Adaptive Parameter Selection**: Instead of using fixed parameters for all datasets, Bonsai v3 dynamically adjusts its parameters based on dataset characteristics like size, IBD density, and average segment length. This adaptive approach ensures optimal performance across a wide range of datasets.

4. **Specialized Data Structures**: Memory-efficient data structures like CompactIBDStore and SparseRelationshipMatrix allow Bonsai v3 to handle larger datasets with less memory overhead and faster access times.

5. **Early Termination and Lazy Evaluation**: By prioritizing evaluation of high-information pairs and terminating evaluation early for unpromising configurations, Bonsai v3 can avoid unnecessary computation and focus resources on the most promising pedigree candidates.

These optimization techniques are not just implementation details but fundamental enabling technologies that make it possible to apply pedigree reconstruction to real-world genetic genealogy datasets with hundreds or thousands of individuals.

In [ ]:
# Convert this notebook to PDF using poetry
!poetry run jupyter nbconvert --to pdf Lab18_Optimization_Techniques.ipynb

# Note: PDF conversion requires LaTeX to be installed on your system
# If you encounter errors, you may need to install it:
# On Ubuntu/Debian: sudo apt-get install texlive-xetex
# On macOS with Homebrew: brew install texlive