Bonsai: Robust Pedigree Reconstruction Exploration

This notebook provides a comprehensive exploration of the Bonsai algorithm

for reconstructing pedigrees from genetic data.

In [None]:
import os
import random
import json
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import networkx.algorithms.community as nx_comm
from matplotlib.colors import to_rgba
import matplotlib.patches as mpatches
import sys
import re
from collections import defaultdict
import matplotlib.patches as mpatches

sys.path.append(os.path.dirname(os.getcwd()))

# Assuming utils.bonsaitree.bonsaitree.v3 is properly installed
from utils.bonsaitree.bonsaitree.v3 import bonsai

In [None]:
#######################
# 1. Data Preparation #
#######################

def load_genetic_data(seg_file, fam_file, dict_file=None):
    """
    Load and prepare genetic data from .seg, .fam, and optional dict files.
    
    Args:
        seg_file: Path to the .seg file
        fam_file: Path to the .fam file
        dict_file: Path to the ID mapping file (optional)
    
    Returns:
        seg_df: DataFrame with segment data
        individuals: Dictionary of individual metadata
        individual_to_bonsai: Mapping from original IDs to Bonsai IDs
    """
    # Load the ID mapping if provided
    individual_to_bonsai = {}
    if dict_file and os.path.exists(dict_file):
        print(f"Loading ID mapping from {dict_file}")
        try:
            with open(dict_file, 'r') as f:
                for line in f:
                    parts = line.strip().split('\t')
                    if len(parts) == 2:
                        individual_id, bonsai_id = parts
                        individual_to_bonsai[individual_id] = int(bonsai_id)
            print(f"Loaded {len(individual_to_bonsai)} ID mappings")
        except Exception as e:
            print(f"Error loading ID mapping: {e}")
    
    # Read the seg file
    seg_df = pd.read_csv(seg_file, sep="\t", header=None)
    if len(seg_df.columns) == 9:
        seg_df.columns = ["sample1", "sample2", "chrom", "phys_start", "phys_end", 
                          "ibd_type", "gen_start", "gen_end", "gen_seg_len"]
    else:
        print(f"Warning: Unexpected number of columns in seg file: {len(seg_df.columns)}")
        print("Columns found:", seg_df.columns)
        return None, None, None
    
    # Extract unique individuals from seg file
    unique_individuals_from_seg = set(seg_df["sample1"]).union(set(seg_df["sample2"]))
    print(f"Number of unique individuals in seg file: {len(unique_individuals_from_seg)}")
    
    # Read the fam file to get individual metadata
    individuals = {}
    try:
        with open(fam_file, 'r') as file:
            fam_lines = file.readlines()
        
        # Process each line in the fam file
        for line in fam_lines:
            fields = line.strip().split()
            if len(fields) < 6:
                continue
                
            family_id = fields[0]
            individual_id = fields[1]
            
            # Skip individuals not present in the ID mapping if using a mapping
            if individual_to_bonsai and individual_id not in individual_to_bonsai:
                continue
                
            father_id = fields[2]
            mother_id = fields[3]
            sex = 'M' if fields[4] == '1' else 'F'
            
            # Extract generation using regex
            match = re.search(r'g(\d+)-', individual_id)
            generation = int(match.group(1)) if match else None
            
            # Store the individual information
            individuals[individual_id] = {
                'family_id': family_id,
                'father_id': father_id,
                'mother_id': mother_id,
                'sex': sex,
                'generation': generation
            }
        
        print(f"Loaded metadata for {len(individuals)} individuals from FAM file")
        
        # Print a sample of the individuals dictionary
        sample_keys = list(individuals.keys())[:3]
        print("\nSample of individuals data:")
        for key in sample_keys:
            print(f"{key}: {individuals[key]}")
        
        # Summary statistics
        generations = {}
        for ind_id, info in individuals.items():
            gen = info.get('generation')
            if gen:
                generations[gen] = generations.get(gen, 0) + 1
        
        print("\nIndividuals by generation:")
        for gen, count in sorted(generations.items()):
            print(f"Generation {gen}: {count} individuals")
        
    except Exception as e:
        print(f"Error loading FAM file: {e}")
        return None, None, None
    
    return seg_df, individuals, individual_to_bonsai

def create_bioinfo(individuals, individual_to_bonsai):
    """
    Create bioinfo list for Bonsai with ages assigned based on generation.
    
    Args:
        individuals: Dictionary of individual metadata
        individual_to_bonsai: Mapping of original IDs to Bonsai IDs
    
    Returns:
        bioinfo: List of dictionaries with individual metadata for Bonsai
    """
    # Check if we have generation information
    has_generation_info = any('generation' in info and info['generation'] is not None 
                             for info in individuals.values())
    
    if not has_generation_info:
        print("Warning: No generation information found in individuals data")
        return []
    
    # Get generation range
    generations = [info['generation'] for info in individuals.values() 
                  if 'generation' in info and info['generation'] is not None]
    
    if not generations:
        print("Warning: No valid generation values found")
        return []
    
    latest_generation = max(generations)
    earliest_generation = min(generations)
    print(f"Generation range: {earliest_generation} to {latest_generation}")
    
    # Assign ages based on generation
    for individual_id, info in individuals.items():
        generation = info.get('generation')
        if generation is None:
            # Skip individuals without generation info
            continue
            
        if generation == latest_generation:
            # Latest generation: ages 18-40
            info['age'] = random.randint(18, 40)
        else:
            # Earlier generations: older based on generation gap
            gen_gap = latest_generation - generation
            min_age = 25 + (gen_gap * 20)
            max_age = 40 + (gen_gap * 20)
            info['age'] = random.randint(min_age, max_age)
    
    # Create bioinfo list for Bonsai
    bioinfo = []
    for individual_id, info in individuals.items():
        if 'generation' in info and info['generation'] is not None:
            if individual_id in individual_to_bonsai:
                bonsai_id = individual_to_bonsai[individual_id]
                age = info.get('age', 30)  # Default age if not calculated
                sex = info.get('sex', 'U')  # Default sex if not available
                bioinfo.append({'genotype_id': bonsai_id, 'age': age, 'sex': sex})
    
    return bioinfo

def create_ibd_segment_list(seg_df):
    """Create an unphased IBD segment list for Bonsai from the segment dataframe."""
    unphased_ibd_seg_list = []
    
    for _, row in seg_df.iterrows():
        try:
            id1 = int(row['sample1'])
            id2 = int(row['sample2'])
            chrom = str(row['chrom'])
            start_bp = float(row['phys_start'])
            end_bp = float(row['phys_end'])
            is_full = row['ibd_type'] == 2  # Assuming IBD2 indicates "full" sharing
            len_cm = float(row['gen_seg_len'])
            
            unphased_ibd_seg_list.append([id1, id2, chrom, start_bp, end_bp, is_full, len_cm])
        except ValueError as e:
            print(f"Error processing segment: {e}")
    
    return unphased_ibd_seg_list

##########################
# 2. Community Detection #
##########################

def detect_communities(ibd_seg_list, bioinfo, resolution=1.0, min_community_size=10):
    """
    Detect communities using Louvain algorithm to divide the dataset.
    
    Args:
        ibd_seg_list: List of IBD segments
        bioinfo: List of individual metadata
        resolution: Resolution parameter for Louvain (higher = smaller communities)
        min_community_size: Minimum community size to keep
    
    Returns:
        communities: List of detected communities (sets of individual IDs)
    """
    # Create a graph from IBD segments
    G = nx.Graph()
    
    # Add nodes for all individuals in bioinfo
    genotype_ids = [info['genotype_id'] for info in bioinfo]
    G.add_nodes_from(genotype_ids)
    
    # Add edges weighted by IBD sharing
    edge_weights = defaultdict(float)
    for segment in ibd_seg_list:
        id1, id2 = segment[0], segment[1]
        cm_length = segment[6]  # Length in centiMorgans
        edge_weights[(id1, id2)] += cm_length
    
    # Add all edges to the graph
    for (id1, id2), weight in edge_weights.items():
        G.add_edge(id1, id2, weight=weight)
    
    # Find communities using Louvain
    try:
        communities = list(nx.community.louvain_communities(G, resolution=resolution, weight='weight'))
        
        # Filter out communities that are too small
        communities = [comm for comm in communities if len(comm) >= min_community_size]
        
        print(f"Detected {len(communities)} communities")
        for i, community in enumerate(communities):
            print(f"Community {i+1}: {len(community)} members")
        
        return communities
    except Exception as e:
        print(f"Error detecting communities: {e}")
        # If community detection fails, return a single community with all individuals
        print("Falling back to using all individuals as one community")
        return [set(genotype_ids)]

def filter_for_community(community, bioinfo, ibd_seg_list):
    """Filter bioinfo and IBD segments for a specific community."""
    # Filter bioinfo
    community_bioinfo = [info for info in bioinfo if info['genotype_id'] in community]
    
    # Filter IBD segments
    community_ibd = []
    for seg in ibd_seg_list:
        id1, id2 = seg[0], seg[1]
        if id1 in community and id2 in community:
            community_ibd.append(seg)
    
    return community_bioinfo, community_ibd

def visualize_communities(G, communities, output_file=None, figsize=(12, 12)):
    """Visualize communities in a graph."""
    plt.figure(figsize=figsize)
    
    # Create a colormap for communities
    colors = plt.cm.rainbow(np.linspace(0, 1, len(communities)))
    
    # Assign community colors to nodes
    node_colors = []
    for node in G.nodes():
        for i, community in enumerate(communities):
            if node in community:
                node_colors.append(colors[i])
                break
        else:
            # If node isn't in any community
            node_colors.append((0.7, 0.7, 0.7, 0.5))
    
    # Create color patches for legend
    patches = []
    for i, color in enumerate(colors):
        patches.append(mpatches.Patch(color=color, label=f'Community {i+1}'))
    
    # Apply layout - try different options depending on graph size
    if len(G.nodes()) > 500:
        print("Using sfdp layout for large graph...")
        try:
            pos = nx.nx_agraph.graphviz_layout(G, prog='sfdp')
        except:
            print("Graphviz sfdp layout failed, falling back to spring layout")
            pos = nx.spring_layout(G, k=0.3, iterations=50, seed=42)
    else:
        try:
            # For smaller graphs try neato first
            pos = nx.nx_agraph.graphviz_layout(G, prog='neato')
        except:
            print("Graphviz layout failed, falling back to spring layout")
            pos = nx.spring_layout(G, k=0.3, iterations=50, seed=42)
    
    # Draw the graph
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=50, alpha=0.8)
    nx.draw_networkx_edges(G, pos, width=0.5, alpha=0.3)
    
    plt.title("IBD Network Communities", fontsize=16)
    plt.legend(handles=patches, loc='upper right')
    plt.axis('off')
    
    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        print(f"Network visualization saved to {output_file}")
    
    plt.show()

##########################
# 3. Running Bonsai      #
##########################

def run_bonsai(bioinfo, ibd_seg_list, min_seg_len=3, restrict_connections=False, 
               verbose=True):
    """Run Bonsai with the given parameters."""
    if verbose:
        print(f"Running Bonsai with {len(bioinfo)} individuals, {len(ibd_seg_list)} segments")
        print(f"Min segment length: {min_seg_len} cM")
        print(f"Restrict connections: {restrict_connections}")
    
    try:
        up_dict_log_like_list = bonsai.build_pedigree(
            bio_info=bioinfo,
            unphased_ibd_seg_list=ibd_seg_list,
            min_seg_len=min_seg_len,
            restrict_connection_points=restrict_connections
        )
        
        if verbose:
            print(f"Bonsai generated {len(up_dict_log_like_list)} pedigrees")
        
        return up_dict_log_like_list
    except Exception as e:
        print(f"Error running Bonsai: {e}")
        return []

##########################
# 4. Pedigree Analysis   #
##########################

def visualize_pedigree(pedigree, bioinfo, output_file=None, figsize=(15, 10)):
    """Visualize a pedigree as a directed graph."""
    # Create directed graph
    G = nx.DiGraph()
    
    # Create sex lookup
    sex_lookup = {info['genotype_id']: info['sex'] for info in bioinfo}
    
    # Add nodes and edges
    for child, parents in pedigree.items():
        # Add child node
        is_real = isinstance(child, int) and child > 0
        G.add_node(child, is_real=is_real)
        
        for parent in parents:
            G.add_node(parent, is_real=isinstance(parent, int) and parent > 0)
            G.add_edge(parent, child)  # Parent to child direction
    
    # Set node colors and shapes
    node_colors = []
    node_shapes = []
    
    for node in G.nodes():
        if G.nodes[node]['is_real']:
            # Real individual: green
            node_colors.append('green')
        else:
            # Inferred ancestor: white
            node_colors.append('white')
        
        # Use shapes based on sex
        if node in sex_lookup:
            if sex_lookup[node] == 'M':
                node_shapes.append('s')  # Square for male
            elif sex_lookup[node] == 'F':
                node_shapes.append('o')  # Circle for female
            else:
                node_shapes.append('d')  # Diamond for unknown
        else:
            node_shapes.append('d')  # Diamond for inferred ancestors
    
    # Create figure
    plt.figure(figsize=figsize)
    
    # Calculate layout (hierarchical)
    try:
        pos = nx.nx_agraph.graphviz_layout(G, prog='dot')
    except:
        # Fallback to spring layout if graphviz not available
        print("Graphviz not available, using spring layout instead")
        pos = nx.spring_layout(G, scale=2)
    
    # Draw the graph
    for shape in set(node_shapes):
        nodes = [node for i, node in enumerate(G.nodes()) if node_shapes[i] == shape]
        colors = [color for i, color in enumerate(node_colors) if node_shapes[i] == shape]
        
        if shape == 's':  # Square (male)
            nx.draw_networkx_nodes(G, pos, nodelist=nodes, node_color=colors, 
                                  node_shape='s', node_size=500)
        elif shape == 'o':  # Circle (female)
            nx.draw_networkx_nodes(G, pos, nodelist=nodes, node_color=colors, 
                                  node_shape='o', node_size=500)
        else:  # Diamond (unknown)
            nx.draw_networkx_nodes(G, pos, nodelist=nodes, node_color=colors, 
                                  node_shape='d', node_size=500)
    
    # Draw edges
    nx.draw_networkx_edges(G, pos, arrows=True)
    
    # Draw labels
    nx.draw_networkx_labels(G, pos)
    
    # Add legend
    legend_elements = [
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='green', markersize=10, label='Real (female)'),
        plt.Line2D([0], [0], marker='s', color='w', markerfacecolor='green', markersize=10, label='Real (male)'),
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='white', markersize=10, label='Inferred (female)'),
        plt.Line2D([0], [0], marker='s', color='w', markerfacecolor='white', markersize=10, label='Inferred (male)'),
        plt.Line2D([0], [0], marker='d', color='w', markerfacecolor='white', markersize=10, label='Unknown sex')
    ]
    plt.legend(handles=legend_elements, loc='upper right')
    
    plt.title("Reconstructed Pedigree", fontsize=16)
    plt.axis('off')
    
    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        print(f"Pedigree visualization saved to {output_file}")
    
    plt.show()

def pedigree_statistics(pedigree, bioinfo):
    """Calculate and display statistics about a pedigree."""
    # Count individuals by type
    real_individuals = [node for node in pedigree.keys() if isinstance(node, int) and node > 0]
    inferred_ancestors = [node for node in pedigree.keys() if isinstance(node, int) and node < 0]
    
    print(f"Pedigree has {len(real_individuals)} real individuals")
    print(f"Pedigree has {len(inferred_ancestors)} inferred ancestors")
    
    # Analyze pedigree depth (maximum generations)
    max_depth = 0
    for node in real_individuals:
        depth = 0
        current = node
        visited = set()  # To avoid infinite loops in case of cycles
        while current in pedigree and pedigree[current] and current not in visited:
            visited.add(current)
            # Get first parent
            parent = list(pedigree[current].keys())[0]
            current = parent
            depth += 1
        max_depth = max(max_depth, depth)
    
    print(f"Maximum depth in the pedigree: {max_depth} generations")
    
    # Find sibling groups
    siblings = {}
    for child, parents in pedigree.items():
        if not parents:  # Skip nodes with no parents
            continue
        
        # Convert parents to a hashable key
        parent_key = frozenset(parents.keys())
        if parent_key not in siblings:
            siblings[parent_key] = []
        siblings[parent_key].append(child)
    
    # Count sibling groups with at least 2 siblings
    sibling_groups = [children for children in siblings.values() if len(children) >= 2]
    print(f"Found {len(sibling_groups)} sibling groups")
    
    real_sibling_groups = []
    for group in sibling_groups:
        real_siblings = [s for s in group if isinstance(s, int) and s > 0]
        if len(real_siblings) >= 2:
            real_sibling_groups.append(real_siblings)
    
    print(f"Found {len(real_sibling_groups)} real sibling groups")
    
    return {
        "real_individuals": len(real_individuals),
        "inferred_ancestors": len(inferred_ancestors),
        "max_depth": max_depth,
        "sibling_groups": len(sibling_groups),
        "real_sibling_groups": len(real_sibling_groups)
    }

###########################
# 5. Main Analysis Flow   #
###########################

def run_bonsai_analysis(seg_file, fam_file, dict_file=None, output_dir="bonsai_results", 
                        community_index=0, resolution=1.0):
    """
    Run a complete Bonsai analysis pipeline.
    
    Args:
        seg_file: Path to the .seg file
        fam_file: Path to the .fam file
        dict_file: Path to the ID mapping file (optional)
        output_dir: Directory to save results
        community_index: Index of community to analyze (0-based)
        resolution: Resolution parameter for community detection
    """
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    print("1. Loading genetic data...")
    seg_df, individuals, individual_to_bonsai = load_genetic_data(seg_file, fam_file, dict_file)
    if seg_df is None or individuals is None:
        print("Error loading data. Exiting.")
        return None
    
    print("\n2. Creating bioinfo...")
    bioinfo = create_bioinfo(individuals, individual_to_bonsai)
    print(f"Created bioinfo for {len(bioinfo)} individuals")
    
    if len(bioinfo) == 0:
        print("Error: No valid bioinfo created. Check your FAM file format.")
        return None
    
    # Sample of the bioinfo data
    print("\nSample of bioinfo data:")
    for i in range(min(3, len(bioinfo))):
        print(bioinfo[i])
    
    print("\n3. Creating IBD segment list...")
    ibd_seg_list = create_ibd_segment_list(seg_df)
    print(f"Created {len(ibd_seg_list)} IBD segments")
    
    # Sample of IBD segment list
    print("\nSample of IBD segments:")
    for i in range(min(3, len(ibd_seg_list))):
        print(ibd_seg_list[i])
    
    print("\n4. Detecting communities...")
    communities = detect_communities(ibd_seg_list, bioinfo, resolution=resolution)
    
    if not communities:
        print("No communities detected. Using all individuals as one community.")
        genotype_ids = [info['genotype_id'] for info in bioinfo]
        communities = [set(genotype_ids)]
    
    # Visualize communities
    print("\n5. Visualizing communities...")
    # Create graph for visualization
    G = nx.Graph()
    for segment in ibd_seg_list:
        id1, id2 = segment[0], segment[1]
        cm_length = segment[6]  # Length in centiMorgans
        if G.has_edge(id1, id2):
            G[id1][id2]['weight'] += cm_length
        else:
            G.add_edge(id1, id2, weight=cm_length)

    visualize_communities(G, communities, 
                         output_file=os.path.join(output_dir, "communities.png"))
    
    # Select a community to analyze
    if community_index >= len(communities):
        print(f"Warning: Community index {community_index} out of range. Using first community.")
        community_index = 0
    
    selected_community = communities[community_index]
    print(f"\n6. Selected community {community_index + 1} with {len(selected_community)} individuals")
    
    # Filter data for the selected community
    community_bioinfo, community_ibd = filter_for_community(selected_community, bioinfo, ibd_seg_list)
    
    print("\n7. Running Bonsai on selected community...")
    up_dict_log_like_list = run_bonsai(community_bioinfo, community_ibd)
    
    if up_dict_log_like_list:
        # Analyze the best pedigree
        best_pedigree = up_dict_log_like_list[0][0]
        best_likelihood = up_dict_log_like_list[0][1]
        
        print(f"\n8. Analyzing best pedigree (Log-likelihood: {best_likelihood:.2f})...")
        stats = pedigree_statistics(best_pedigree, community_bioinfo)
        
        # Visualize the pedigree
        visualize_pedigree(best_pedigree, community_bioinfo, 
                          output_file=os.path.join(output_dir, "best_pedigree.png"))
        
        # Try variations for comparison
        print("\n9. Trying different parameters...")
        
        # With higher min_seg_len
        print("\nRunning with min_seg_len=7...")
        up_dict_log_like_list_7cm = run_bonsai(
            community_bioinfo, community_ibd, min_seg_len=7, verbose=False
        )
        
        # With restricted connections
        print("\nRunning with restricted connections...")
        up_dict_log_like_list_restricted = run_bonsai(
            community_bioinfo, community_ibd, restrict_connections=True, verbose=False
        )
        
        print("\nResults comparison:")
        print(f"Default settings: {len(up_dict_log_like_list)} pedigrees")
        print(f"Min segment 7cM: {len(up_dict_log_like_list_7cm)} pedigrees")
        print(f"Restricted connections: {len(up_dict_log_like_list_restricted)} pedigrees")
        
        # Save the best pedigree
        pedigree_file = os.path.join(output_dir, "best_pedigree.json")
        with open(pedigree_file, 'w') as f:
            # Convert any non-string keys to strings for JSON
            str_pedigree = {}
            for key, value in best_pedigree.items():
                str_key = str(key)
                str_value = {}
                for sub_key, sub_value in value.items():
                    str_value[str(sub_key)] = sub_value
                str_pedigree[str_key] = str_value
            
            json.dump(str_pedigree, f, indent=2)
            print(f"\nSaved best pedigree to {pedigree_file}")
        
        return {
            "bioinfo": bioinfo,
            "community_bioinfo": community_bioinfo,
            "ibd_seg_list": ibd_seg_list,
            "community_ibd": community_ibd,
            "communities": communities,
            "best_pedigree": best_pedigree,
            "best_likelihood": best_likelihood
        }
    else:
        print("No pedigrees were generated by Bonsai.")
        return None


In [13]:
# Run the analysis
if __name__ == "__main__":
    # Example usage
    seg_file = "../data/class_data/ped_sim_run2.seg"
    fam_file = "../data/class_data/ped_sim_run2-everyone.fam"
    dict_file = "../data/class_data/ped_sim_run2.seg_dict.txt"
    
    # Run the analysis on the first community
    results = run_bonsai_analysis(seg_file, fam_file, dict_file, community_index=0)

1. Loading genetic data...
Loading ID mapping from ../data/class_data/ped_sim_run2.seg_dict.txt
Loaded 520 ID mappings
Number of unique individuals in seg file: 520
Loaded metadata for 520 individuals from FAM file

Sample of individuals data:
FAM1_g1-b1-s1: {'family_id': 'FAM1', 'father_id': '0', 'mother_id': '0', 'sex': 'F', 'generation': 1}
FAM1_g1-b1-i1: {'family_id': 'FAM1', 'father_id': '0', 'mother_id': '0', 'sex': 'M', 'generation': 1}
FAM1_g2-b1-s1: {'family_id': 'FAM1', 'father_id': '0', 'mother_id': '0', 'sex': 'F', 'generation': 2}

Individuals by generation:
Generation 1: 20 individuals
Generation 2: 40 individuals
Generation 3: 80 individuals
Generation 4: 120 individuals
Generation 5: 160 individuals
Generation 6: 100 individuals

2. Creating bioinfo...
Generation range: 1 to 6
Created bioinfo for 520 individuals

Sample of bioinfo data:
{'genotype_id': 1000, 'age': 131, 'sex': 'F'}
{'genotype_id': 1001, 'age': 126, 'sex': 'M'}
{'genotype_id': 1002, 'age': 111, 'sex': 'F

  log_term = np.log(1 - np.exp(-np.exp(log_mu_amt)))
  x = np.asarray((x - loc)/scale, dtype=dtyp)
  x = np.asarray((x - loc)/scale, dtype=dtyp)


KeyboardInterrupt: 

In [None]:
!poetry run jupyter nbconvert --to pdf Lab15_Explore_Bonsai.ipynb