# Lab 19: Multi-Sample Relationship Inference with Bonsai

In this lab, we'll explore how to use Bonsai for inferring relationships across multiple samples and building larger pedigrees from IBD data. Building on our understanding of Bonsai's architecture and data preprocessing from previous labs, we'll now focus on techniques for connecting multiple individuals into coherent family structures.

## Why This Matters

While pairwise relationship inference is valuable, the real power of genetic genealogy emerges when we can connect multiple individuals into larger family structures. Multi-sample relationship inference allows us to:
- Discover complex family relationships that span multiple generations
- Identify cryptic relationships that might be unclear from pairwise analysis alone
- Reconstruct historical pedigrees with greater accuracy
- Resolve ambiguous relationships by considering the broader family context
- Overcome challenges posed by missing individuals or sparse genetic data

**Learning Objectives**:
- Understand the principles behind multi-sample relationship inference
- Implement strategies for building consistent pedigrees from IBD data
- Apply constraint satisfaction techniques to resolve relationship conflicts
- Evaluate the confidence and reliability of inferred pedigree structures
- Build effective visualizations of complex pedigrees
- Create extensible frameworks for pedigree analysis and construction

## Environment Setup

In [None]:
import os
import math
import logging
import sys
import re
import warnings
from pathlib import Path
import subprocess
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, HTML
import pandas as pd
import numpy as np
import networkx as nx
from scipy import stats
from collections import defaultdict, Counter
import random
import time
import json
from tqdm.notebook import tqdm
from dotenv import load_dotenv

## 1. Principles of Multi-Sample Relationship Inference

Before diving into practical implementation, let's understand the key principles and challenges of multi-sample relationship inference.

### 1.1 From Pairwise to Multi-Sample Analysis

Traditional relationship inference focuses on analyzing pairs of individuals. Multi-sample inference extends this by:

1. **Considering multiple relationships simultaneously**: Rather than treating relationships as independent pairs, we analyze the entire network of relationships.

2. **Enforcing consistency across relationships**: A person cannot be both a grandparent and a sibling to another individual. Multi-sample analysis ensures all inferred relationships are consistent with each other.

3. **Leveraging indirect evidence**: Even without direct IBD sharing between two individuals, we can infer relationships based on their connections to other individuals in the network.

4. **Integrating non-genetic information**: Birth years, locations, and historical records can constrain possible pedigree configurations.

### 1.2 Challenges in Multi-Sample Inference

Multi-sample inference comes with unique challenges:

1. **Computational complexity**: The number of possible pedigree configurations grows exponentially with the number of individuals.

2. **Relationship ambiguity**: Multiple relationship types can explain the observed genetic sharing (e.g., half-siblings vs. uncle-niece).

3. **Missing individuals**: Key relatives may be missing from the dataset, creating "phantom nodes" in the pedigree.

4. **Conflicting evidence**: Errors in IBD detection or other data sources may lead to inconsistent relationship suggestions.

5. **Endogamy and complex family structures**: Historical intermarriage or other complex family structures can create unusual patterns of genetic sharing.

### 1.3 The Bonsai Approach to Multi-Sample Inference

Bonsai addresses these challenges through several key strategies:

1. **Up-node dictionary**: Efficiently represents potential ancestral relationships, enabling the algorithm to explore pedigree configurations.

2. **Constraint satisfaction**: Enforces biological and logical constraints on pedigree structures.

3. **Likelihood-based scoring**: Evaluates pedigree configurations based on their probability of generating the observed genetic data.

4. **Incremental construction**: Builds pedigrees incrementally, focusing on high-confidence relationships first.

5. **Optimization algorithms**: Uses techniques like simulated annealing to explore the space of possible pedigree configurations efficiently.

In [None]:
# Let's visualize the concept of multi-sample inference
def create_example_pedigree(num_individuals=10, num_generations=3):
    """
    Create a simple example pedigree for demonstration purposes.
    
    Args:
        num_individuals: Total number of individuals in the pedigree
        num_generations: Number of generations to include
        
    Returns:
        nx.DiGraph: Directed graph representing the pedigree
    """
    G = nx.DiGraph()
    
    # Create individuals with generation and birth year information
    individuals = []
    for i in range(num_individuals):
        # Assign generation (older individuals first)
        generation = min(i // 3, num_generations - 1)
        
        # Assign birth year (approximate 25 years per generation, with some variation)
        birth_year = 1900 + generation * 25 + np.random.randint(-5, 6)
        
        # Create individual
        individuals.append({
            'id': i,
            'name': f"Person_{i}",
            'generation': generation,
            'birth_year': birth_year
        })
    
    # Add individuals to the graph
    for ind in individuals:
        G.add_node(ind['id'], **ind)
    
    # Add family relationships (parents to children)
    for i in range(num_individuals):
        generation = G.nodes[i]['generation']
        
        # Skip the oldest generation
        if generation == 0:
            continue
            
        # Find potential parents (individuals in the previous generation)
        potential_parents = [j for j in range(num_individuals) 
                             if G.nodes[j]['generation'] == generation - 1]
        
        # Ensure we have at least one parent
        if potential_parents:
            # Select 1 or 2 parents randomly
            num_parents = min(2, len(potential_parents))
            parents = np.random.choice(potential_parents, size=num_parents, replace=False)
            
            # Add edges from parents to child
            for parent in parents:
                G.add_edge(parent, i, relationship='parent')
    
    return G

# Create an example pedigree
np.random.seed(42)  # For reproducibility
example_pedigree = create_example_pedigree(12, 3)

# Visualize the pedigree
plt.figure(figsize=(12, 8))

# Create a position layout based on generations (older generations at the top)
pos = nx.spring_layout(example_pedigree, seed=42)
for node in example_pedigree.nodes():
    # Adjust y position based on generation
    pos[node] = (pos[node][0], 0.8 - 0.3 * example_pedigree.nodes[node]['generation'])

# Draw nodes with generation-based colors
node_colors = [['lightblue', 'lightgreen', 'lightsalmon'][example_pedigree.nodes[node]['generation']] 
               for node in example_pedigree.nodes()]
nx.draw_networkx_nodes(example_pedigree, pos, node_color=node_colors, node_size=500, alpha=0.8)

# Draw edges
nx.draw_networkx_edges(example_pedigree, pos, edge_color='black', width=1.5, alpha=0.7, 
                       arrowsize=20, arrowstyle='-|>')

# Add labels with birth years
labels = {node: f"{node}\n({example_pedigree.nodes[node]['birth_year']})" 
          for node in example_pedigree.nodes()}
nx.draw_networkx_labels(example_pedigree, pos, labels=labels, font_size=10)

plt.title('Example Pedigree with 3 Generations', size=15)
plt.axis('off')
plt.tight_layout()
plt.show()

# Now, let's simulate what happens when we observe only a subset of individuals
# and need to infer the missing relationships

# Randomly select individuals to "observe" (exclude some individuals)
num_observed = 8
observed_individuals = np.random.choice(range(12), size=num_observed, replace=False)
unobserved_individuals = [i for i in range(12) if i not in observed_individuals]

# Create a subgraph with only the observed individuals
observed_pedigree = example_pedigree.subgraph(observed_individuals).copy()

# Visualize the observed pedigree
plt.figure(figsize=(12, 8))

# Reuse positions from the full pedigree
observed_pos = {node: pos[node] for node in observed_individuals}

# Draw nodes with generation-based colors
observed_node_colors = [['lightblue', 'lightgreen', 'lightsalmon'][observed_pedigree.nodes[node]['generation']] 
                       for node in observed_pedigree.nodes()]
nx.draw_networkx_nodes(observed_pedigree, observed_pos, node_color=observed_node_colors, 
                      node_size=500, alpha=0.8)

# Draw edges
nx.draw_networkx_edges(observed_pedigree, observed_pos, edge_color='black', 
                       width=1.5, alpha=0.7, arrowsize=20, arrowstyle='-|>')

# Add labels with birth years
observed_labels = {node: f"{node}\n({observed_pedigree.nodes[node]['birth_year']})" 
                  for node in observed_pedigree.nodes()}
nx.draw_networkx_labels(observed_pedigree, observed_pos, labels=observed_labels, font_size=10)

# Highlight unobserved individuals in the original positions (as "ghosts")
unobserved_pos = {node: pos[node] for node in unobserved_individuals}
nx.draw_networkx_nodes(example_pedigree.subgraph(unobserved_individuals), unobserved_pos, 
                       node_color='gray', node_size=500, alpha=0.3)

plt.title('Observed Pedigree with Missing Individuals', size=15)
plt.axis('off')
plt.tight_layout()
plt.show()

# Print out known relationships in the observed pedigree
print("Known relationships in the observed pedigree:")
for edge in observed_pedigree.edges():
    parent, child = edge
    print(f"Person_{parent} is a parent of Person_{child}")

# Print out hidden relationships that involve at least one observed individual
print("\nHidden relationships (involving at least one observed individual):")
for edge in example_pedigree.edges():
    parent, child = edge
    if parent in observed_individuals and child not in observed_individuals:
        print(f"Person_{parent} is a parent of Person_{child} (unobserved)")
    elif parent not in observed_individuals and child in observed_individuals:
        print(f"Person_{parent} (unobserved) is a parent of Person_{child}")

## 2. Building Pedigree Structures from IBD Data

Now that we understand the principles of multi-sample inference, let's explore how to build pedigree structures from IBD data.

### 2.1 Representing Pedigree Structures

A pedigree is a directed graph where:
- Nodes represent individuals
- Edges represent parent-child relationships (directed from parent to child)
- Node attributes may include birth year, sex, and other metadata
- Edge attributes may include certainty scores or relationship types

Let's implement a flexible pedigree representation using NetworkX:

In [ ]:
class Pedigree:
    """A class to represent and manipulate pedigree structures."""
    
    def __init__(self):
        """Initialize an empty pedigree."""
        # Use a directed graph to represent parent->child relationships
        self.graph = nx.DiGraph()
        
        # Keep track of individuals by ID
        self.individuals = {}
        
        # Dictionary to track added relationships for quick lookup
        self.relationships = {}
    
    def add_individual(self, id, **attributes):
        """
        Add an individual to the pedigree.
        
        Args:
            id: Unique identifier for the individual
            **attributes: Additional attributes (birth_year, sex, etc.)
            
        Returns:
            bool: True if added successfully, False if individual already exists
        """
        if id in self.individuals:
            return False
        
        # Add to graph
        self.graph.add_node(id, **attributes)
        
        # Store in individuals dictionary
        self.individuals[id] = attributes
        
        return True
    
    def add_relationship(self, parent_id, child_id, certainty=1.0, **attributes):
        """
        Add a parent-child relationship to the pedigree.
        
        Args:
            parent_id: ID of the parent
            child_id: ID of the child
            certainty: Confidence score for this relationship (0.0 to 1.0)
            **attributes: Additional attributes for the relationship
            
        Returns:
            bool: True if added successfully, False otherwise
        """
        # Check if both individuals exist
        if parent_id not in self.individuals or child_id not in self.individuals:
            return False
        
        # Check for impossible relationships (e.g., child as a parent)
        if self.would_create_cycle(parent_id, child_id):
            return False
        
        # Add relationship to graph
        self.graph.add_edge(parent_id, child_id, certainty=certainty, **attributes)
        
        # Store relationship in dictionary for quick lookup
        rel_key = (parent_id, child_id)
        self.relationships[rel_key] = {'certainty': certainty, **attributes}
        
        return True
    
    def would_create_cycle(self, parent_id, child_id):
        """
        Check if adding parent_id as a parent of child_id would create a cycle.
        
        Args:
            parent_id: Proposed parent
            child_id: Proposed child
            
        Returns:
            bool: True if this would create a cycle, False otherwise
        """
        # If child is already an ancestor of parent, this would create a cycle
        try:
            path = nx.shortest_path(self.graph, child_id, parent_id)
            return len(path) > 0
        except nx.NetworkXNoPath:
            return False
    
    def get_children(self, individual_id):
        """
        Get all children of an individual.
        
        Args:
            individual_id: ID of the individual
            
        Returns:
            list: IDs of children
        """
        return list(self.graph.successors(individual_id))
    
    def get_parents(self, individual_id):
        """
        Get all parents of an individual.
        
        Args:
            individual_id: ID of the individual
            
        Returns:
            list: IDs of parents
        """
        return list(self.graph.predecessors(individual_id))
    
    def get_ancestors(self, individual_id, max_generations=None):
        """
        Get all ancestors of an individual, optionally limited to a number of generations.
        
        Args:
            individual_id: ID of the individual
            max_generations: Maximum number of generations to ascend (None for all)
            
        Returns:
            set: IDs of ancestors
        """
        ancestors = set()
        current_generation = {individual_id}
        generation_count = 0
        
        while current_generation and (max_generations is None or generation_count < max_generations):
            next_generation = set()
            for person in current_generation:
                parents = self.get_parents(person)
                next_generation.update(parents)
            
            ancestors.update(next_generation)
            current_generation = next_generation
            generation_count += 1
        
        return ancestors
    
    def get_descendants(self, individual_id, max_generations=None):
        """
        Get all descendants of an individual, optionally limited to a number of generations.
        
        Args:
            individual_id: ID of the individual
            max_generations: Maximum number of generations to descend (None for all)
            
        Returns:
            set: IDs of descendants
        """
        descendants = set()
        current_generation = {individual_id}
        generation_count = 0
        
        while current_generation and (max_generations is None or generation_count < max_generations):
            next_generation = set()
            for person in current_generation:
                children = self.get_children(person)
                next_generation.update(children)
            
            descendants.update(next_generation)
            current_generation = next_generation
            generation_count += 1
        
        return descendants
    
    def get_siblings(self, individual_id, include_half=True):
        """
        Get siblings of an individual.
        
        Args:
            individual_id: ID of the individual
            include_half: Whether to include half-siblings
            
        Returns:
            set: IDs of siblings
        """
        parents = self.get_parents(individual_id)
        siblings = set()
        
        if include_half:
            # For half-siblings, we consider all children of all parents
            for parent in parents:
                siblings.update(self.get_children(parent))
        else:
            # For full siblings only, we need at least two parents
            if len(parents) >= 2:
                # Find children that share all the parents
                for child in self.individuals:
                    if child == individual_id:
                        continue
                    child_parents = self.get_parents(child)
                    if all(parent in child_parents for parent in parents):
                        siblings.add(child)
        
        # Remove the individual from the result
        siblings.discard(individual_id)
        
        return siblings
    
    def get_relationship_degree(self, id1, id2):
        """
        Calculate the relationship degree between two individuals.
        
        Args:
            id1: First individual ID
            id2: Second individual ID
            
        Returns:
            int: Degree of relationship (1 for parent-child, 2 for grandparent-grandchild or siblings, etc.)
                or None if no relationship is found
        """
        if id1 == id2:
            return 0  # Same person
        
        # Check if one is a direct ancestor of the other
        try:
            path = nx.shortest_path(self.graph, id1, id2)
            return len(path) - 1  # Direct path length
        except nx.NetworkXNoPath:
            pass
        
        try:
            path = nx.shortest_path(self.graph, id2, id1)
            return len(path) - 1  # Direct path length
        except nx.NetworkXNoPath:
            pass
        
        # If they share a common ancestor, find the closest one
        undirected = self.graph.to_undirected()
        try:
            path = nx.shortest_path(undirected, id1, id2)
            return len(path) - 1  # Path length in undirected graph
        except nx.NetworkXNoPath:
            return None  # No relationship found
    
    def is_consistent(self):
        """
        Check if the pedigree is biologically and logically consistent.
        
        Returns:
            bool: True if consistent, False otherwise
        """
        # Check for cycles (impossible in a real pedigree)
        if not nx.is_directed_acyclic_graph(self.graph):
            return False
        
        # Check birth year consistency (parents should be older than children)
        for parent, child in self.graph.edges():
            if 'birth_year' in self.individuals[parent] and 'birth_year' in self.individuals[child]:
                parent_birth = self.individuals[parent]['birth_year']
                child_birth = self.individuals[child]['birth_year']
                
                if parent_birth >= child_birth:
                    return False
        
        # All checks passed
        return True
    
    def get_max_generation(self):
        """
        Calculate the maximum generation depth in the pedigree.
        
        Returns:
            int: Maximum generation depth
        """
        # Find roots (individuals without parents)
        roots = [node for node in self.graph.nodes() if self.graph.in_degree(node) == 0]
        
        if not roots:
            return 0
        
        # Find the maximum depth from any root
        max_depth = 0
        for root in roots:
            depths = nx.single_source_shortest_path_length(self.graph, root)
            if depths:
                max_depth = max(max_depth, max(depths.values()))
        
        return max_depth
    
    def assign_generations(self):
        """
        Assign generation numbers to all individuals in the pedigree.
        
        Generation 0 is assigned to roots (individuals without parents).
        Children are assigned generation numbers incremented from their parents.
        
        Returns:
            dict: Mapping of individual IDs to generation numbers
        """
        # Find roots (individuals without parents)
        roots = [node for node in self.graph.nodes() if self.graph.in_degree(node) == 0]
        
        generations = {}
        
        # Assign generation 0 to roots
        for root in roots:
            generations[root] = 0
        
        # Process all nodes in topological order (parents before children)
        for node in nx.topological_sort(self.graph):
            # If this node is a root, it's already assigned
            if node in generations:
                continue
            
            # Get parents' generations
            parents = list(self.graph.predecessors(node))
            if parents:
                # Assign generation one more than the maximum parent generation
                parent_generations = [generations.get(parent, 0) for parent in parents]
                generations[node] = max(parent_generations) + 1
            else:
                # No parents (but not a root), assign generation 0
                generations[node] = 0
        
        return generations
    
    def visualize(self, highlight_nodes=None, node_labels=True, figsize=(12, 8)):
        """
        Visualize the pedigree.
        
        Args:
            highlight_nodes: List of node IDs to highlight
            node_labels: Whether to show node labels
            figsize: Figure size tuple
            
        Returns:
            matplotlib figure
        """
        # Create figure
        plt.figure(figsize=figsize)
        
        if not self.graph.nodes():
            plt.text(0.5, 0.5, "Empty Pedigree", ha='center', va='center')
            plt.axis('off')
            return plt.gcf()
        
        # Assign generations
        generations = self.assign_generations()
        
        # Create a position layout based on generations (older generations at the top)
        pos = nx.spring_layout(self.graph, seed=42)
        
        # Adjust y position based on generation
        max_gen = max(generations.values()) if generations else 0
        for node in self.graph.nodes():
            gen = generations.get(node, 0)
            # Normalize to 0-1 range with oldest at the top
            norm_gen = 1 - (gen / max(1, max_gen))
            pos[node] = (pos[node][0], 0.8 * norm_gen + 0.1)
        
        # Node colors based on generation
        num_generations = max_gen + 1
        cmap = plt.cm.viridis
        node_colors = [cmap(generations.get(node, 0) / max(1, max_gen)) for node in self.graph.nodes()]
        
        # Draw nodes
        nx.draw_networkx_nodes(self.graph, pos, node_color=node_colors, 
                              node_size=500, alpha=0.8)
        
        # Highlight specific nodes if provided
        if highlight_nodes:
            highlight_nodes = [n for n in highlight_nodes if n in self.graph.nodes()]
            if highlight_nodes:
                nx.draw_networkx_nodes(self.graph.subgraph(highlight_nodes), pos, 
                                      node_color='red', node_size=600, alpha=0.8)
        
        # Draw edges with certainty-based alpha
        edge_alphas = [self.relationships.get((u, v), {}).get('certainty', 1.0) for u, v in self.graph.edges()]
        for i, (u, v) in enumerate(self.graph.edges()):
            # Draw edges with alpha based on certainty
            alpha = edge_alphas[i]
            nx.draw_networkx_edges(self.graph, pos, edgelist=[(u, v)], 
                                 width=1.5, alpha=alpha, arrows=True, 
                                 arrowsize=20, arrowstyle='-|>')
        
        # Add labels
        if node_labels:
            labels = {}
            for node in self.graph.nodes():
                label_parts = [str(node)]
                if 'birth_year' in self.individuals[node]:
                    label_parts.append(f"({self.individuals[node]['birth_year']})")
                labels[node] = "\n".join(label_parts)
            
            nx.draw_networkx_labels(self.graph, pos, labels=labels, font_size=10)
        
        plt.title('Pedigree Visualization', size=15)
        plt.axis('off')
        plt.tight_layout()
        
        return plt.gcf()

# Let's create a simple pedigree using our new class
sample_pedigree = Pedigree()

# Add individuals with birth years and sex
sample_pedigree.add_individual("A", birth_year=1900, sex="M")
sample_pedigree.add_individual("B", birth_year=1905, sex="F")
sample_pedigree.add_individual("C", birth_year=1930, sex="M")
sample_pedigree.add_individual("D", birth_year=1932, sex="F")
sample_pedigree.add_individual("E", birth_year=1935, sex="M")
sample_pedigree.add_individual("F", birth_year=1960, sex="F")
sample_pedigree.add_individual("G", birth_year=1962, sex="M")
sample_pedigree.add_individual("H", birth_year=1985, sex="F")
sample_pedigree.add_individual("I", birth_year=1987, sex="M")
sample_pedigree.add_individual("J", birth_year=2010, sex="F")

# Add relationships
sample_pedigree.add_relationship("A", "C")
sample_pedigree.add_relationship("B", "C")
sample_pedigree.add_relationship("A", "D")
sample_pedigree.add_relationship("B", "D")
sample_pedigree.add_relationship("A", "E")
sample_pedigree.add_relationship("C", "F")
sample_pedigree.add_relationship("D", "G")
sample_pedigree.add_relationship("F", "H")
sample_pedigree.add_relationship("G", "H")
sample_pedigree.add_relationship("G", "I")
sample_pedigree.add_relationship("H", "J")
sample_pedigree.add_relationship("I", "J")

# Visualize the pedigree
sample_pedigree.visualize()

# Let's check some relationships
print(f"Parents of H: {sample_pedigree.get_parents('H')}")
print(f"Children of G: {sample_pedigree.get_children('G')}")
print(f"Siblings of D: {sample_pedigree.get_siblings('D')}")
print(f"Relationship degree between A and H: {sample_pedigree.get_relationship_degree('A', 'H')}")
print(f"Relationship degree between C and E: {sample_pedigree.get_relationship_degree('C', 'E')}")
print(f"Is the pedigree consistent? {sample_pedigree.is_consistent()}")

# Let's try to add an inconsistent relationship
print(f"Can add H as parent of A? {sample_pedigree.add_relationship('H', 'A')}")

# Highlight specific individuals
sample_pedigree.visualize(highlight_nodes=["A", "F", "J"])

### 2.2 Converting IBD Data to Pedigree Relationships

Now let's develop methods to convert IBD data into pedigree relationships. We'll need to:
1. Interpret IBD segment data to infer likely relationship types
2. Estimate certainty scores for each relationship
3. Resolve ambiguities when multiple relationship types are possible
4. Integrate metadata (like birth years) to constrain possible relationships

In [ ]:
class IBDToPedigree:
    """
    A class to convert IBD data into pedigree relationships.
    """
    
    # Define expected IBD sharing for different relationship types (in cM)
    # These are approximate averages and ranges
    RELATIONSHIP_IBD = {
        'parent-child': {'mean': 3400, 'std': 100, 'min': 3200, 'max': 3720},
        'full-sibling': {'mean': 2550, 'std': 180, 'min': 2200, 'max': 2950},
        'half-sibling': {'mean': 1700, 'std': 160, 'min': 1450, 'max': 2050},
        'grandparent': {'mean': 1700, 'std': 160, 'min': 1450, 'max': 2050},
        'aunt-uncle': {'mean': 1700, 'std': 160, 'min': 1450, 'max': 2050},
        'first-cousin': {'mean': 850, 'std': 150, 'min': 550, 'max': 1150},
        'first-cousin-once-removed': {'mean': 425, 'std': 120, 'min': 220, 'max': 650},
        'second-cousin': {'mean': 212.5, 'std': 100, 'min': 75, 'max': 360},
        'second-cousin-once-removed': {'mean': 106.25, 'std': 60, 'min': 30, 'max': 200},
        'third-cousin': {'mean': 53.13, 'std': 40, 'min': 10, 'max': 120}
    }
    
    # Mapping from relationship type to generation difference (0, 1, 2, etc.)
    RELATIONSHIP_GEN_DIFF = {
        'parent-child': 1,
        'full-sibling': 0,
        'half-sibling': 0,
        'grandparent': 2,
        'aunt-uncle': 1,
        'first-cousin': 0,
        'first-cousin-once-removed': 1,
        'second-cousin': 0,
        'second-cousin-once-removed': 1,
        'third-cousin': 0
    }
    
    def __init__(self):
        """Initialize the IBD to pedigree converter."""
        pass
    
    def infer_relationship_type(self, total_cm, num_segments=None, max_segment=None, birth_years=None):
        """
        Infer the most likely relationship type(s) based on total IBD sharing.
        
        Args:
            total_cm: Total IBD sharing in cM
            num_segments: Number of IBD segments (optional)
            max_segment: Length of longest segment (optional)
            birth_years: Tuple of (person1_birth_year, person2_birth_year) (optional)
            
        Returns:
            list: List of (relationship_type, probability) tuples, sorted by descending probability
        """
        if total_cm < 5:
            return [('distant/unrelated', 1.0)]
        
        relationships = []
        
        # For each relationship type, calculate probability based on a normal distribution
        for rel_type, rel_data in self.RELATIONSHIP_IBD.items():
            # Skip relationships that are outside the min/max range
            if total_cm < rel_data['min'] or total_cm > rel_data['max']:
                continue
            
            # Check birth year compatibility if provided
            if birth_years is not None:
                year1, year2 = birth_years
                year_diff = abs(year1 - year2)
                
                # Skip relationships that are incompatible with the birth years
                gen_diff = self.RELATIONSHIP_GEN_DIFF[rel_type]
                expected_min_year_diff = gen_diff * 15  # Minimum years per generation
                expected_max_year_diff = gen_diff * 35  # Maximum years per generation
                
                # Special case for siblings and cousins (same generation)
                if gen_diff == 0:
                    if year_diff > 25:  # Siblings/cousins usually born within 25 years
                        continue
                else:
                    # For other relationships, check if the year difference is plausible
                    if year_diff < expected_min_year_diff or year_diff > expected_max_year_diff:
                        continue
            
            # Calculate probability based on normal distribution
            mean = rel_data['mean']
            std = rel_data['std']
            z_score = abs(total_cm - mean) / std
            # Convert z-score to probability (assuming normal distribution)
            # Note: This is a simplified approach
            probability = max(0, 1 - z_score / 5)
            
            relationships.append((rel_type, probability))
        
        # Add distant/unrelated with low probability if no other matches
        if not relationships:
            relationships.append(('distant/unrelated', 0.3))
        
        # Sort by descending probability
        relationships.sort(key=lambda x: x[1], reverse=True)
        
        return relationships
    
    def segments_to_pedigree(self, segments_df, individuals_metadata=None, min_certainty=0.5):
        """
        Convert a DataFrame of IBD segments to a pedigree structure.
        
        Args:
            segments_df: DataFrame with IBD segments (must have columns: sample1, sample2, gen_seg_len)
            individuals_metadata: Optional DataFrame with metadata about individuals
            min_certainty: Minimum certainty threshold for including relationships
            
        Returns:
            Pedigree: Constructed pedigree
        """
        pedigree = Pedigree()
        
        # First, extract all unique individuals and add them to the pedigree
        all_individuals = set(segments_df['sample1']).union(set(segments_df['sample2']))
        
        for ind_id in all_individuals:
            # Get metadata for this individual if available
            metadata = {}
            if individuals_metadata is not None and 'id' in individuals_metadata.columns:
                match = individuals_metadata[individuals_metadata['id'] == ind_id]
                if len(match) > 0:
                    metadata = match.iloc[0].to_dict()
            
            # Add individual to pedigree
            pedigree.add_individual(ind_id, **metadata)
        
        # Calculate total IBD sharing between each pair
        pair_sharing = segments_df.groupby(['sample1', 'sample2'])['gen_seg_len'].agg(['sum', 'count', 'max']).reset_index()
        
        # For each pair, infer relationship type and add to pedigree
        for _, row in pair_sharing.iterrows():
            sample1 = row['sample1']
            sample2 = row['sample2']
            total_cm = row['sum']
            num_segments = row['count']
            max_segment = row['max']
            
            # Get birth years if available
            birth_years = None
            if 'birth_year' in pedigree.individuals[sample1] and 'birth_year' in pedigree.individuals[sample2]:
                birth_years = (
                    pedigree.individuals[sample1]['birth_year'],
                    pedigree.individuals[sample2]['birth_year']
                )
            
            # Infer possible relationships
            relationships = self.infer_relationship_type(
                total_cm, num_segments, max_segment, birth_years)
            
            # Skip if no high-confidence relationships
            if not relationships or relationships[0][1] < min_certainty:
                continue
            
            # Get the most likely relationship
            rel_type, certainty = relationships[0]
            
            # Add appropriate relationship to pedigree
            self._add_relationship_to_pedigree(pedigree, sample1, sample2, rel_type, certainty, birth_years)
        
        return pedigree
    
    def _add_relationship_to_pedigree(self, pedigree, id1, id2, rel_type, certainty, birth_years):
        """
        Add the inferred relationship to the pedigree.
        
        Args:
            pedigree: The pedigree to update
            id1: ID of the first individual
            id2: ID of the second individual
            rel_type: Type of relationship
            certainty: Certainty score (0.0 to 1.0)
            birth_years: Tuple of birth years (or None)
            
        Returns:
            bool: True if relationship was added, False otherwise
        """
        # Handle parent-child relationship
        if rel_type == 'parent-child':
            # Determine which is the parent based on birth years
            if birth_years is not None:
                year1, year2 = birth_years
                if year1 < year2:
                    parent, child = id1, id2
                else:
                    parent, child = id2, id1
            else:
                # Without birth years, we can't determine direction
                # For now, arbitrarily choose id1 as parent
                parent, child = id1, id2
            
            return pedigree.add_relationship(parent, child, certainty=certainty, rel_type=rel_type)
        
        # Handle grandparent relationship
        elif rel_type == 'grandparent':
            # Determine which is the grandparent based on birth years
            if birth_years is not None:
                year1, year2 = birth_years
                if year1 < year2:
                    gparent, gchild = id1, id2
                else:
                    gparent, gchild = id2, id1
                
                # Create a phantom parent node (halfway between in birth years)
                phantom_year = (year1 + year2) // 2
                phantom_id = f"phantom_{gparent}_{gchild}"
                
                # Add phantom parent
                pedigree.add_individual(phantom_id, birth_year=phantom_year, is_phantom=True)
                
                # Add relationships
                gp_success = pedigree.add_relationship(gparent, phantom_id, certainty=certainty, rel_type='parent-child')
                p_success = pedigree.add_relationship(phantom_id, gchild, certainty=certainty, rel_type='parent-child')
                
                return gp_success and p_success
            else:
                # Without birth years, we can't determine the direction
                return False
        
        # Handle full siblings
        elif rel_type == 'full-sibling':
            # For siblings, we need to create phantom parents
            phantom_id1 = f"phantom_parent1_{id1}_{id2}"
            phantom_id2 = f"phantom_parent2_{id1}_{id2}"
            
            # Estimate birth years for phantom parents
            if birth_years is not None:
                year1, year2 = birth_years
                avg_year = (year1 + year2) // 2
                parent_year1 = avg_year - 25  # Approximate parent birth year
                parent_year2 = avg_year - 23  # Slight difference for the other parent
                
                # Add phantom parents
                pedigree.add_individual(phantom_id1, birth_year=parent_year1, is_phantom=True)
                pedigree.add_individual(phantom_id2, birth_year=parent_year2, is_phantom=True)
                
                # Add relationships from phantom parents to both siblings
                for parent_id in [phantom_id1, phantom_id2]:
                    for child_id in [id1, id2]:
                        pedigree.add_relationship(parent_id, child_id, certainty=certainty * 0.8, 
                                                 rel_type='parent-child', is_inferred=True)
                
                return True
            else:
                # Without birth years, we still create phantom parents but without birth years
                pedigree.add_individual(phantom_id1, is_phantom=True)
                pedigree.add_individual(phantom_id2, is_phantom=True)
                
                # Add relationships
                for parent_id in [phantom_id1, phantom_id2]:
                    for child_id in [id1, id2]:
                        pedigree.add_relationship(parent_id, child_id, certainty=certainty * 0.7, 
                                                 rel_type='parent-child', is_inferred=True)
                
                return True
        
        # For other relationship types (half-siblings, cousins, etc.), 
        # we would need more complex logic or manual inspection
        return False

# Let's create some synthetic IBD data for demonstration
def create_synthetic_ibd_data(num_individuals=15, num_relationships=30):
    """
    Create synthetic IBD data for demonstration purposes.
    
    Args:
        num_individuals: Number of individuals
        num_relationships: Number of relationships to generate
        
    Returns:
        tuple: (segments_df, individuals_df)
    """
    # Create individuals with birth years
    individuals = []
    for i in range(num_individuals):
        birth_year = 1920 + i * 5  # Spread birth years over time
        individuals.append({
            'id': f"ind_{i}",
            'birth_year': birth_year,
            'sex': 'F' if i % 2 == 0 else 'M'  # Alternate sexes
        })
    
    # Create relationships
    segments = []
    for _ in range(num_relationships):
        # Select two random individuals
        i, j = np.random.choice(range(num_individuals), size=2, replace=False)
        ind1, ind2 = f"ind_{i}", f"ind_{j}"
        
        # Determine birth year difference to inform relationship type
        year_diff = abs(individuals[i]['birth_year'] - individuals[j]['birth_year'])
        
        # Assign a relationship type based on birth year difference
        if year_diff > 20 and year_diff < 30:
            # Potential parent-child
            rel_type = 'parent-child'
            target_cm = np.random.normal(3400, 100)
            num_segments_to_generate = np.random.randint(35, 45)
        elif year_diff < 15:
            # Potential siblings or cousins
            if np.random.random() < 0.3:
                rel_type = 'full-sibling'
                target_cm = np.random.normal(2550, 180)
                num_segments_to_generate = np.random.randint(30, 40)
            else:
                rel_type = 'first-cousin'
                target_cm = np.random.normal(850, 150)
                num_segments_to_generate = np.random.randint(15, 25)
        elif year_diff > 40 and year_diff < 60:
            # Potential grandparent-grandchild
            rel_type = 'grandparent'
            target_cm = np.random.normal(1700, 160)
            num_segments_to_generate = np.random.randint(20, 30)
        else:
            # Distant relationship
            rel_type = 'distant'
            target_cm = np.random.normal(100, 50)
            num_segments_to_generate = np.random.randint(3, 8)
        
        # Generate segments to approximately match the target cM
        cm_per_segment = target_cm / num_segments_to_generate
        
        for s in range(num_segments_to_generate):
            # Generate segment length (with some variation)
            segment_cm = max(1, np.random.normal(cm_per_segment, cm_per_segment * 0.3))
            
            # Random chromosome
            chrom = np.random.randint(1, 23)
            
            # Add the segment
            segments.append({
                'sample1': ind1,
                'sample2': ind2,
                'chrom': chrom,
                'gen_start': np.random.uniform(0, 200),
                'gen_end': np.random.uniform(0, 200),  # Placeholder
                'gen_seg_len': segment_cm,
                'true_relationship': rel_type  # For evaluation
            })
    
    # Convert to DataFrames
    segments_df = pd.DataFrame(segments)
    individuals_df = pd.DataFrame(individuals)
    
    return segments_df, individuals_df

# Create synthetic data
np.random.seed(42)
synthetic_segments, synthetic_individuals = create_synthetic_ibd_data()

# Display summary
print(f"Generated {len(synthetic_segments)} segments for {len(synthetic_individuals)} individuals")

# Calculate total sharing per pair
pair_sharing = synthetic_segments.groupby(['sample1', 'sample2', 'true_relationship'])['gen_seg_len'].sum().reset_index()
print("\nSample of total IBD sharing between pairs:")
display(pair_sharing.head(10))

# Now let's convert this to a pedigree
ibd_converter = IBDToPedigree()
inferred_pedigree = ibd_converter.segments_to_pedigree(synthetic_segments, synthetic_individuals)

# Visualize the pedigree
inferred_pedigree.visualize()

# Let's check what relationships were inferred
print("\nInferred relationships:")
for u, v, data in inferred_pedigree.graph.edges(data=True):
    if not data.get('is_phantom', False):
        print(f"{u} -> {v} (Type: {data.get('rel_type', 'unknown')}, Certainty: {data.get('certainty', 0):.2f})")

# Extract just the direct relationships (not involving phantom nodes)
real_relationships = [(u, v) for u, v in inferred_pedigree.graph.edges() 
                     if not inferred_pedigree.individuals.get(u, {}).get('is_phantom', False) 
                     and not inferred_pedigree.individuals.get(v, {}).get('is_phantom', False)]

# Create a subgraph of just the real individuals
real_pedigree = Pedigree()
for ind_id, attrs in inferred_pedigree.individuals.items():
    if not attrs.get('is_phantom', False):
        real_pedigree.add_individual(ind_id, **attrs)

# Visualize just the real individuals (without phantom nodes)
real_pedigree.visualize()

## 3. Resolving Relationship Conflicts with Constraints

As pedigrees grow larger, we often encounter conflicts between different relationship inferences. Bonsai uses constraint satisfaction to resolve these conflicts and ensure a biologically consistent pedigree. Let's implement some constraint handling techniques:

In [ ]:
class PedigreeConstraintSolver:
    """
    Class for enforcing constraints and resolving conflicts in pedigree structures.
    """
    
    def __init__(self, pedigree):
        """
        Initialize the constraint solver with a pedigree.
        
        Args:
            pedigree: The Pedigree to analyze and constrain
        """
        self.pedigree = pedigree
        self.constraints = []
        self.violations = []
        
        # Define standard constraints
        self.add_constraint(self.no_cycles_constraint, "No cycles allowed in pedigree")
        self.add_constraint(self.birth_year_constraint, "Parents must be older than children")
        self.add_constraint(self.max_children_constraint, "No more than 15 children per parent")
        self.add_constraint(self.max_parents_constraint, "No more than 2 parents per child")
        self.add_constraint(self.max_generation_gap_constraint, "No more than 60 years between parent and child")
    
    def add_constraint(self, constraint_func, description):
        """
        Add a constraint to check.
        
        Args:
            constraint_func: Function that takes a pedigree and returns list of violations
            description: Description of the constraint
        """
        self.constraints.append({
            'func': constraint_func,
            'description': description
        })
    
    def no_cycles_constraint(self, pedigree):
        """Check for cycles in the pedigree."""
        violations = []
        
        # Check if it's a directed acyclic graph (DAG)
        if not nx.is_directed_acyclic_graph(pedigree.graph):
            # Find all cycles
            for cycle in nx.simple_cycles(pedigree.graph):
                violations.append({
                    'type': 'cycle',
                    'nodes': cycle,
                    'message': f"Cycle detected: {' -> '.join(str(node) for node in cycle)}"
                })
        
        return violations
    
    def birth_year_constraint(self, pedigree):
        """Check that parents are older than their children."""
        violations = []
        
        # Check each parent-child relationship
        for parent, child in pedigree.graph.edges():
            if 'birth_year' in pedigree.individuals[parent] and 'birth_year' in pedigree.individuals[child]:
                parent_birth = pedigree.individuals[parent]['birth_year']
                child_birth = pedigree.individuals[child]['birth_year']
                
                if parent_birth >= child_birth:
                    violations.append({
                        'type': 'birth_year',
                        'parent': parent,
                        'child': child,
                        'message': (f"Birth year violation: {parent} (born {parent_birth}) "
                                   f"cannot be parent of {child} (born {child_birth})")
                    })
        
        return violations
    
    def max_children_constraint(self, pedigree, max_children=15):
        """Check that no parent has too many children."""
        violations = []
        
        # Check each parent
        for node in pedigree.graph.nodes():
            children = list(pedigree.graph.successors(node))
            
            if len(children) > max_children:
                violations.append({
                    'type': 'max_children',
                    'parent': node,
                    'num_children': len(children),
                    'message': f"{node} has {len(children)} children, exceeding maximum of {max_children}"
                })
        
        return violations
    
    def max_parents_constraint(self, pedigree, max_parents=2):
        """Check that no child has too many parents."""
        violations = []
        
        # Check each child
        for node in pedigree.graph.nodes():
            parents = list(pedigree.graph.predecessors(node))
            
            if len(parents) > max_parents:
                violations.append({
                    'type': 'max_parents',
                    'child': node,
                    'num_parents': len(parents),
                    'message': f"{node} has {len(parents)} parents, exceeding maximum of {max_parents}"
                })
        
        return violations
    
    def max_generation_gap_constraint(self, pedigree, max_years=60):
        """Check that the generation gap isn't too large."""
        violations = []
        
        # Check each parent-child relationship
        for parent, child in pedigree.graph.edges():
            if 'birth_year' in pedigree.individuals[parent] and 'birth_year' in pedigree.individuals[child]:
                parent_birth = pedigree.individuals[parent]['birth_year']
                child_birth = pedigree.individuals[child]['birth_year']
                
                if child_birth - parent_birth > max_years:
                    violations.append({
                        'type': 'generation_gap',
                        'parent': parent,
                        'child': child,
                        'gap': child_birth - parent_birth,
                        'message': (f"Generation gap too large: {parent} (born {parent_birth}) "
                                   f"to {child} (born {child_birth}) is {child_birth - parent_birth} years")
                    })
        
        return violations
    
    def check_constraints(self):
        """
        Check all constraints and return violations.
        
        Returns:
            list: All constraint violations
        """
        all_violations = []
        
        for constraint in self.constraints:
            violations = constraint['func'](self.pedigree)
            for v in violations:
                v['constraint'] = constraint['description']
            all_violations.extend(violations)
        
        self.violations = all_violations
        return all_violations
    
    def resolve_violations(self, strategy='remove_lowest_certainty'):
        """
        Attempt to resolve constraint violations.
        
        Args:
            strategy: Strategy to use for resolving violations
                     - 'remove_lowest_certainty': Remove relationships with lowest certainty
                     - 'adjust_birth_years': Try to adjust birth years to resolve conflicts
                     - 'suggest_phantom_nodes': Suggest phantom nodes to resolve conflicts
                     
        Returns:
            dict: Summary of resolution actions
        """
        # First, check for violations
        violations = self.check_constraints()
        
        if not violations:
            return {'status': 'no_violations', 'message': 'No violations to resolve'}
        
        # Track resolution actions
        resolutions = {
            'num_violations': len(violations),
            'num_resolved': 0,
            'actions': []
        }
        
        if strategy == 'remove_lowest_certainty':
            # Identify relationships involved in violations
            problem_relationships = set()
            for violation in violations:
                if violation['type'] == 'cycle':
                    # For cycles, all edges in the cycle are problematic
                    nodes = violation['nodes']
                    for i in range(len(nodes)):
                        u, v = nodes[i], nodes[(i+1) % len(nodes)]
                        problem_relationships.add((u, v))
                elif violation['type'] in ['birth_year', 'generation_gap']:
                    # For birth year or generation gap violations, the specific edge is problematic
                    problem_relationships.add((violation['parent'], violation['child']))
                elif violation['type'] == 'max_children':
                    # For max children violations, all edges from the parent are candidates
                    parent = violation['parent']
                    for child in self.pedigree.graph.successors(parent):
                        problem_relationships.add((parent, child))
                elif violation['type'] == 'max_parents':
                    # For max parents violations, all edges to the child are candidates
                    child = violation['child']
                    for parent in self.pedigree.graph.predecessors(child):
                        problem_relationships.add((parent, child))
            
            # Get certainty scores for problematic relationships
            relationship_certainties = []
            for parent, child in problem_relationships:
                certainty = self.pedigree.relationships.get((parent, child), {}).get('certainty', 0.5)
                relationship_certainties.append((parent, child, certainty))
            
            # Sort by certainty (ascending)
            relationship_certainties.sort(key=lambda x: x[2])
            
            # Start removing relationships with lowest certainty until violations are resolved
            for parent, child, certainty in relationship_certainties:
                # Skip if this edge no longer exists (might have been removed already)
                if not self.pedigree.graph.has_edge(parent, child):
                    continue
                
                # Remove the relationship
                self.pedigree.graph.remove_edge(parent, child)
                
                # Record the action
                resolutions['actions'].append({
                    'action': 'remove_relationship',
                    'parent': parent,
                    'child': child,
                    'certainty': certainty,
                    'message': f"Removed relationship from {parent} to {child} (certainty: {certainty:.2f})"
                })
                
                resolutions['num_resolved'] += 1
                
                # Check if we've resolved all violations
                remaining_violations = self.check_constraints()
                if not remaining_violations:
                    break
        
        elif strategy == 'adjust_birth_years':
            # First, identify birth year violations
            birth_year_violations = [v for v in violations if v['type'] == 'birth_year']
            
            for violation in birth_year_violations:
                parent = violation['parent']
                child = violation['child']
                parent_birth = self.pedigree.individuals[parent]['birth_year']
                child_birth = self.pedigree.individuals[child]['birth_year']
                
                # Adjust birth years to satisfy the constraint
                # We'll try to adjust by the minimum necessary amount
                adjustment = child_birth - parent_birth + 1  # Make parent at least 1 year older
                
                # Adjust parent birth year to be older
                new_parent_birth = parent_birth - adjustment
                self.pedigree.individuals[parent]['birth_year'] = new_parent_birth
                
                # Record the action
                resolutions['actions'].append({
                    'action': 'adjust_birth_year',
                    'individual': parent,
                    'old_birth_year': parent_birth,
                    'new_birth_year': new_parent_birth,
                    'message': (f"Adjusted birth year of {parent} from {parent_birth} to {new_parent_birth} "
                              f"to be older than child {child} (born {child_birth})")
                })
                
                resolutions['num_resolved'] += 1
            
            # Re-check for remaining violations
            remaining_violations = self.check_constraints()
            if remaining_violations:
                resolutions['message'] = (f"Adjusted {resolutions['num_resolved']} birth years, "
                                        f"but {len(remaining_violations)} violations remain")
            else:
                resolutions['message'] = f"Resolved all violations by adjusting {resolutions['num_resolved']} birth years"
                
        elif strategy == 'suggest_phantom_nodes':
            # This strategy is more complex, as it requires adding new nodes and edges
            # We'll focus on resolving parent count violations
            max_parents_violations = [v for v in violations if v['type'] == 'max_parents']
            
            for violation in max_parents_violations:
                child = violation['child']
                parents = list(self.pedigree.graph.predecessors(child))
                
                # Skip if there are 2 or fewer parents (no problem)
                if len(parents) <= 2:
                    continue
                
                # Get certainty scores for each parent relationship
                parent_certainties = []
                for parent in parents:
                    certainty = self.pedigree.relationships.get((parent, child), {}).get('certainty', 0.5)
                    parent_certainties.append((parent, certainty))
                
                # Sort by certainty (descending)
                parent_certainties.sort(key=lambda x: x[1], reverse=True)
                
                # Keep the two highest certainty parents
                keep_parents = [p[0] for p in parent_certainties[:2]]
                extra_parents = [p[0] for p in parent_certainties[2:]]
                
                # Create a phantom grandparent node for the extra parents
                phantom_id = f"phantom_grandparent_{child}_{len(extra_parents)}"
                
                # Estimate birth year for phantom node
                child_birth = self.pedigree.individuals[child].get('birth_year')
                if child_birth:
                    phantom_birth = child_birth - 50  # Approximate grandparent age
                    self.pedigree.add_individual(phantom_id, birth_year=phantom_birth, is_phantom=True)
                else:
                    self.pedigree.add_individual(phantom_id, is_phantom=True)
                
                # Add relationship from phantom to child
                self.pedigree.add_relationship(phantom_id, child, certainty=0.5, 
                                             rel_type='grandparent', is_inferred=True)
                
                # Add relationships from extra parents to phantom (as siblings of real parents)
                for parent in extra_parents:
                    self.pedigree.add_relationship(parent, phantom_id, certainty=0.3, 
                                                 rel_type='sibling', is_inferred=True)
                    
                    # Remove direct relationship to child
                    self.pedigree.graph.remove_edge(parent, child)
                
                # Record the action
                resolutions['actions'].append({
                    'action': 'add_phantom_grandparent',
                    'phantom_id': phantom_id,
                    'child': child,
                    'extra_parents': extra_parents,
                    'message': (f"Added phantom grandparent {phantom_id} connecting {len(extra_parents)} "
                              f"extra parents to child {child}")
                })
                
                resolutions['num_resolved'] += 1
            
            # Re-check for remaining violations
            remaining_violations = self.check_constraints()
            if remaining_violations:
                resolutions['message'] = (f"Added {resolutions['num_resolved']} phantom nodes, "
                                        f"but {len(remaining_violations)} violations remain")
            else:
                resolutions['message'] = f"Resolved all violations by adding {resolutions['num_resolved']} phantom nodes"
        
        return resolutions
    
    def visualize_violations(self, figsize=(15, 10)):
        """
        Visualize the pedigree with constraint violations highlighted.
        
        Returns:
            matplotlib figure
        """
        plt.figure(figsize=figsize)
        
        if not self.pedigree.graph.nodes():
            plt.text(0.5, 0.5, "Empty Pedigree", ha='center', va='center')
            plt.axis('off')
            return plt.gcf()
        
        # Check constraints if not already done
        if not self.violations:
            self.check_constraints()
        
        # Get problem nodes and edges
        problem_nodes = set()
        problem_edges = set()
        
        for violation in self.violations:
            if violation['type'] == 'cycle':
                # For cycles, all nodes and edges are problematic
                cycle_nodes = violation['nodes']
                problem_nodes.update(cycle_nodes)
                for i in range(len(cycle_nodes)):
                    u, v = cycle_nodes[i], cycle_nodes[(i+1) % len(cycle_nodes)]
                    problem_edges.add((u, v))
            elif violation['type'] in ['birth_year', 'generation_gap']:
                # For birth year violations, both parent and child are problematic
                problem_nodes.add(violation['parent'])
                problem_nodes.add(violation['child'])
                problem_edges.add((violation['parent'], violation['child']))
            elif violation['type'] == 'max_children':
                # For max children violations, the parent is problematic
                problem_nodes.add(violation['parent'])
            elif violation['type'] == 'max_parents':
                # For max parents violations, the child is problematic
                problem_nodes.add(violation['child'])
        
        # Make a copy of the default visualization
        pos = nx.spring_layout(self.pedigree.graph, seed=42)
        
        # Adjust y position based on generations
        generations = self.pedigree.assign_generations()
        max_gen = max(generations.values()) if generations else 0
        for node in self.pedigree.graph.nodes():
            gen = generations.get(node, 0)
            norm_gen = 1 - (gen / max(1, max_gen))
            pos[node] = (pos[node][0], 0.8 * norm_gen + 0.1)
        
        # Node colors based on whether they're problematic
        node_colors = ['red' if node in problem_nodes else 'lightblue' 
                      for node in self.pedigree.graph.nodes()]
        
        # Draw nodes
        nx.draw_networkx_nodes(self.pedigree.graph, pos, node_color=node_colors, 
                              node_size=500, alpha=0.8)
        
        # Draw normal edges
        normal_edges = [(u, v) for u, v in self.pedigree.graph.edges() 
                       if (u, v) not in problem_edges]
        if normal_edges:
            nx.draw_networkx_edges(self.pedigree.graph, pos, edgelist=normal_edges, 
                                 width=1.5, alpha=0.7, arrows=True, 
                                 arrowsize=20, arrowstyle='-|>')
        
        # Draw problem edges
        if problem_edges:
            nx.draw_networkx_edges(self.pedigree.graph, pos, edgelist=problem_edges, 
                                 width=2.0, alpha=1.0, arrows=True, 
                                 arrowsize=20, arrowstyle='-|>',
                                 edge_color='red')
        
        # Add labels
        labels = {}
        for node in self.pedigree.graph.nodes():
            label_parts = [str(node)]
            if 'birth_year' in self.pedigree.individuals[node]:
                label_parts.append(f"({self.pedigree.individuals[node]['birth_year']})")
            labels[node] = "\n".join(label_parts)
        
        nx.draw_networkx_labels(self.pedigree.graph, pos, labels=labels, font_size=10)
        
        # Add title with violation count
        plt.title(f'Pedigree with {len(self.violations)} Constraint Violations', size=15)
        plt.axis('off')
        plt.tight_layout()
        
        return plt.gcf()

# Let's create a pedigree with some constraint violations for demonstration
def create_pedigree_with_violations():
    """Create a pedigree with deliberate constraint violations."""
    pedigree = Pedigree()
    
    # Add individuals
    pedigree.add_individual("A", birth_year=1900, sex="M")
    pedigree.add_individual("B", birth_year=1905, sex="F")
    pedigree.add_individual("C", birth_year=1930, sex="M")
    pedigree.add_individual("D", birth_year=1932, sex="F")
    pedigree.add_individual("E", birth_year=1935, sex="M")
    pedigree.add_individual("F", birth_year=1960, sex="F")
    pedigree.add_individual("G", birth_year=1962, sex="M")
    pedigree.add_individual("H", birth_year=1985, sex="F")
    pedigree.add_individual("I", birth_year=1987, sex="M")
    pedigree.add_individual("J", birth_year=2010, sex="F")
    
    # Add normal relationships
    pedigree.add_relationship("A", "C", certainty=0.9)
    pedigree.add_relationship("B", "C", certainty=0.9)
    pedigree.add_relationship("A", "D", certainty=0.9)
    pedigree.add_relationship("B", "D", certainty=0.9)
    pedigree.add_relationship("C", "F", certainty=0.9)
    pedigree.add_relationship("D", "G", certainty=0.9)
    pedigree.add_relationship("F", "H", certainty=0.9)
    pedigree.add_relationship("G", "H", certainty=0.9)
    pedigree.add_relationship("G", "I", certainty=0.9)
    pedigree.add_relationship("H", "J", certainty=0.9)
    pedigree.add_relationship("I", "J", certainty=0.9)
    
    # Add problematic relationships
    
    # 1. Birth year violation (child born before parent)
    pedigree.add_relationship("E", "A", certainty=0.3)  # E (1935) -> A (1900)
    
    # 2. Create a cycle
    pedigree.add_relationship("F", "A", certainty=0.2)  # F (1960) -> A (1900) -> C (1930) -> F (1960)
    
    # 3. Add too many parents for one child
    pedigree.add_relationship("E", "J", certainty=0.4)  # Third parent for J
    pedigree.add_relationship("F", "J", certainty=0.3)  # Fourth parent for J
    
    return pedigree

# Create pedigree with violations
problem_pedigree = create_pedigree_with_violations()

# Visualize the problematic pedigree
problem_pedigree.visualize()

# Create constraint solver
solver = PedigreeConstraintSolver(problem_pedigree)

# Check for violations
violations = solver.check_constraints()
print(f"Found {len(violations)} constraint violations:")
for i, violation in enumerate(violations):
    print(f"{i+1}. {violation['message']} ({violation['constraint']})")

# Visualize violations
solver.visualize_violations()

# Resolve violations using the "remove_lowest_certainty" strategy
resolution_results = solver.resolve_violations(strategy='remove_lowest_certainty')

print("\nResolution results:")
print(f"- {resolution_results['num_resolved']} of {resolution_results['num_violations']} violations resolved")
for action in resolution_results['actions']:
    print(f"- {action['message']}")

# Visualize the resolved pedigree
problem_pedigree.visualize()

# Check that all violations are resolved
remaining_violations = solver.check_constraints()
print(f"\nRemaining violations: {len(remaining_violations)}")
for violation in remaining_violations:
    print(f"- {violation['message']}")

## 4. Integrating Demographic and Historical Information

One of the key strengths of Bonsai is its ability to incorporate non-genetic information to refine pedigree reconstruction. Let's explore how to integrate demographic and historical data:

In [ ]:
class DemographicEnhancer:
    """
    A class for enhancing pedigrees with demographic and historical information.
    """
    
    def __init__(self, pedigree):
        """
        Initialize with a pedigree.
        
        Args:
            pedigree: Pedigree object to enhance
        """
        self.pedigree = pedigree
        
        # Define demographic parameters
        self.age_parameters = {
            'min_reproduction_age': 15,
            'max_reproduction_age': 55,
            'mean_generation_gap': 30,
            'std_generation_gap': 5
        }
        
        # Historical constraints (e.g., wars, migrations, census records)
        self.historical_events = []
    
    def add_historical_event(self, event_name, start_year, end_year, affected_regions=None, description=None):
        """
        Add a historical event that may constrain pedigree relationships.
        
        Args:
            event_name: Name of the event
            start_year: Starting year
            end_year: Ending year
            affected_regions: Regions affected by the event
            description: Description of how the event affects relationships
        """
        self.historical_events.append({
            'name': event_name,
            'start_year': start_year,
            'end_year': end_year,
            'affected_regions': affected_regions,
            'description': description
        })
    
    def infer_missing_birth_years(self):
        """
        Infer missing birth years based on known relationships and demographic parameters.
        
        Returns:
            dict: Mapping of individuals to inferred birth years
        """
        inferred_years = {}
        
        # First, create a copy of known birth years
        known_years = {}
        for ind_id, attrs in self.pedigree.individuals.items():
            if 'birth_year' in attrs:
                known_years[ind_id] = attrs['birth_year']
        
        # Iterate until no new birth years are inferred
        while True:
            new_inferences = False
            
            # Check each individual without a known birth year
            for ind_id, attrs in self.pedigree.individuals.items():
                if ind_id in known_years or ind_id in inferred_years:
                    continue  # Already have a birth year
                
                # Check if we can infer from parents
                parents = self.pedigree.get_parents(ind_id)
                parent_years = []
                for parent in parents:
                    if parent in known_years:
                        parent_years.append(known_years[parent])
                    elif parent in inferred_years:
                        parent_years.append(inferred_years[parent])
                
                if parent_years:
                    # Infer from parent(s)
                    min_parent_year = min(parent_years)
                    inferred_years[ind_id] = min_parent_year + self.age_parameters['mean_generation_gap']
                    new_inferences = True
                    continue
                
                # Check if we can infer from children
                children = self.pedigree.get_children(ind_id)
                children_years = []
                for child in children:
                    if child in known_years:
                        children_years.append(known_years[child])
                    elif child in inferred_years:
                        children_years.append(inferred_years[child])
                
                if children_years:
                    # Infer from children
                    max_child_year = max(children_years)
                    inferred_years[ind_id] = max_child_year - self.age_parameters['mean_generation_gap']
                    new_inferences = True
                    continue
                
                # Check if we can infer from siblings
                siblings = self.pedigree.get_siblings(ind_id)
                sibling_years = []
                for sibling in siblings:
                    if sibling in known_years:
                        sibling_years.append(known_years[sibling])
                    elif sibling in inferred_years:
                        sibling_years.append(inferred_years[sibling])
                
                if sibling_years:
                    # Infer from siblings (average with small offset)
                    avg_sibling_year = sum(sibling_years) / len(sibling_years)
                    # Add a small offset to avoid exact same birth years
                    offset = np.random.normal(0, 2)
                    inferred_years[ind_id] = int(avg_sibling_year + offset)
                    new_inferences = True
            
            # If no new inferences were made, we're done
            if not new_inferences:
                break
        
        return inferred_years
    
    def update_pedigree_with_inferences(self):
        """
        Update the pedigree with inferred demographic information.
        
        Returns:
            int: Number of updates made
        """
        num_updates = 0
        
        # Infer missing birth years
        inferred_years = self.infer_missing_birth_years()
        
        # Update the pedigree
        for ind_id, birth_year in inferred_years.items():
            self.pedigree.individuals[ind_id]['birth_year'] = birth_year
            self.pedigree.individuals[ind_id]['birth_year_inferred'] = True
            num_updates += 1
        
        return num_updates
    
    def verify_historical_consistency(self):
        """
        Verify that the pedigree is consistent with historical events.
        
        Returns:
            list: Historical inconsistencies
        """
        inconsistencies = []
        
        for event in self.historical_events:
            event_start = event['start_year']
            event_end = event['end_year']
            affected_regions = event['affected_regions'] or []
            
            # Check for relationships that span the event in affected regions
            for parent, child in self.pedigree.graph.edges():
                parent_data = self.pedigree.individuals[parent]
                child_data = self.pedigree.individuals[child]
                
                # Skip if birth years are not available
                if 'birth_year' not in parent_data or 'birth_year' not in child_data:
                    continue
                
                parent_birth = parent_data['birth_year']
                child_birth = child_data['birth_year']
                
                # Skip if either is outside the event's timeframe
                if parent_birth > event_end or child_birth < event_start:
                    continue
                
                # Check regions if specified
                if affected_regions:
                    parent_region = parent_data.get('region')
                    child_region = child_data.get('region')
                    if parent_region not in affected_regions and child_region not in affected_regions:
                        continue
                
                # If the parent was having children during the event, check if it's plausible
                parent_age_at_child_birth = child_birth - parent_birth
                if (event_start <= child_birth <= event_end and
                    parent_age_at_child_birth >= self.age_parameters['min_reproduction_age'] and 
                    parent_age_at_child_birth <= self.age_parameters['max_reproduction_age']):
                    
                    # This relationship might be affected by the historical event
                    # For example, the parent might have been away during a war
                    inconsistencies.append({
                        'type': 'historical_event',
                        'event': event['name'],
                        'parent': parent,
                        'child': child,
                        'parent_birth': parent_birth,
                        'child_birth': child_birth,
                        'message': f"Relationship may be affected by {event['name']} ({event_start}-{event_end})"
                    })
        
        return inconsistencies
    
    def recommend_additional_information(self):
        """
        Recommend additional information that would be useful for resolving ambiguities.
        
        Returns:
            dict: Recommendations for additional information
        """
        recommendations = {
            'birth_years': [],
            'locations': [],
            'relationships': []
        }
        
        # Identify individuals without birth years
        for ind_id, attrs in self.pedigree.individuals.items():
            if 'birth_year' not in attrs and not attrs.get('is_phantom', False):
                recommendations['birth_years'].append(ind_id)
        
        # Identify individuals without locations
        for ind_id, attrs in self.pedigree.individuals.items():
            if 'region' not in attrs and not attrs.get('is_phantom', False):
                recommendations['locations'].append(ind_id)
        
        # Identify ambiguous relationship sets
        for ind_id in self.pedigree.individuals:
            # Look for individuals with too many or too few parents
            parents = self.pedigree.get_parents(ind_id)
            if len(parents) > 2:
                recommendations['relationships'].append({
                    'type': 'too_many_parents',
                    'individual': ind_id,
                    'parents': parents,
                    'message': f"{ind_id} has {len(parents)} parents, need to determine correct two"
                })
            elif 0 < len(parents) < 2:
                recommendations['relationships'].append({
                    'type': 'missing_parent',
                    'individual': ind_id,
                    'known_parents': parents,
                    'message': f"{ind_id} has only {len(parents)} known parent(s), might need to identify the other"
                })
        
        return recommendations
    
    def assign_regions_based_on_family(self):
        """
        Assign regions to individuals based on family members when possible.
        
        Returns:
            int: Number of regions assigned
        """
        num_assigned = 0
        
        # Keep track of individuals with assigned regions
        known_regions = {}
        for ind_id, attrs in self.pedigree.individuals.items():
            if 'region' in attrs:
                known_regions[ind_id] = attrs['region']
        
        # Iterate until no new regions are assigned
        while True:
            new_assignments = False
            
            # For each individual without a region
            for ind_id, attrs in self.pedigree.individuals.items():
                if ind_id in known_regions or attrs.get('is_phantom', False):
                    continue
                
                # Collect regions from family members
                family_regions = []
                
                # Check parents
                for parent in self.pedigree.get_parents(ind_id):
                    if parent in known_regions:
                        family_regions.append(known_regions[parent])
                
                # Check children
                for child in self.pedigree.get_children(ind_id):
                    if child in known_regions:
                        family_regions.append(known_regions[child])
                
                # Check siblings
                for sibling in self.pedigree.get_siblings(ind_id):
                    if sibling in known_regions:
                        family_regions.append(known_regions[sibling])
                
                # If we have family regions, pick the most common one
                if family_regions:
                    region_counts = Counter(family_regions)
                    most_common_region = region_counts.most_common(1)[0][0]
                    
                    # Assign the region
                    self.pedigree.individuals[ind_id]['region'] = most_common_region
                    self.pedigree.individuals[ind_id]['region_inferred'] = True
                    known_regions[ind_id] = most_common_region
                    num_assigned += 1
                    new_assignments = True
            
            # If no new assignments were made, we're done
            if not new_assignments:
                break
        
        return num_assigned
    
    def highlight_relationships_by_certainty(self, figsize=(12, 8)):
        """
        Visualize the pedigree with edges colored by certainty.
        
        Returns:
            matplotlib figure
        """
        plt.figure(figsize=figsize)
        
        if not self.pedigree.graph.nodes():
            plt.text(0.5, 0.5, "Empty Pedigree", ha='center', va='center')
            plt.axis('off')
            return plt.gcf()
        
        # Assign generations
        generations = self.pedigree.assign_generations()
        
        # Create a position layout based on generations
        pos = nx.spring_layout(self.pedigree.graph, seed=42)
        
        # Adjust y position based on generation
        max_gen = max(generations.values()) if generations else 0
        for node in self.pedigree.graph.nodes():
            gen = generations.get(node, 0)
            norm_gen = 1 - (gen / max(1, max_gen))
            pos[node] = (pos[node][0], 0.8 * norm_gen + 0.1)
        
        # Node properties
        node_colors = []
        node_sizes = []
        
        for node in self.pedigree.graph.nodes():
            attrs = self.pedigree.individuals[node]
            
            # Phantom nodes are gray and smaller
            if attrs.get('is_phantom', False):
                node_colors.append('lightgray')
                node_sizes.append(300)
            # Nodes with inferred attributes are lighter
            elif attrs.get('birth_year_inferred', False) or attrs.get('region_inferred', False):
                node_colors.append('lightskyblue')
                node_sizes.append(400)
            # Regular nodes
            else:
                node_colors.append('steelblue')
                node_sizes.append(500)
        
        # Draw nodes
        nx.draw_networkx_nodes(self.pedigree.graph, pos, 
                              node_color=node_colors,
                              node_size=node_sizes, 
                              alpha=0.8)
        
        # Draw edges with color based on certainty
        edges = []
        edge_colors = []
        edge_widths = []
        
        for u, v in self.pedigree.graph.edges():
            certainty = self.pedigree.relationships.get((u, v), {}).get('certainty', 0.5)
            is_inferred = self.pedigree.relationships.get((u, v), {}).get('is_inferred', False)
            
            edges.append((u, v))
            
            # Use color to indicate certainty (red=low, green=high)
            color = plt.cm.RdYlGn(certainty)
            edge_colors.append(color)
            
            # Use dashed lines for inferred relationships
            if is_inferred:
                edge_widths.append(1.0)
            else:
                edge_widths.append(1.5)
        
        # Draw edges
        nx.draw_networkx_edges(self.pedigree.graph, pos, 
                             edgelist=edges,
                             edge_color=edge_colors,
                             width=edge_widths,
                             arrows=True,
                             arrowsize=20, 
                             arrowstyle='-|>')
        
        # Add labels
        labels = {}
        for node in self.pedigree.graph.nodes():
            attrs = self.pedigree.individuals[node]
            label_parts = [str(node)]
            
            # Add birth year if available
            if 'birth_year' in attrs:
                year_label = f"({attrs['birth_year']})"
                if attrs.get('birth_year_inferred', False):
                    year_label = f"({attrs['birth_year']}*)"
                label_parts.append(year_label)
            
            # Add region if available
            if 'region' in attrs:
                region_label = attrs['region']
                if attrs.get('region_inferred', False):
                    region_label = f"{attrs['region']}*"
                label_parts.append(region_label)
            
            labels[node] = "\n".join(label_parts)
        
        nx.draw_networkx_labels(self.pedigree.graph, pos, labels=labels, font_size=9)
        
        # Add a title and legend
        plt.title('Pedigree with Demographic Information', size=15)
        
        # Add a colorbar to show certainty scale
        sm = plt.cm.ScalarMappable(cmap=plt.cm.RdYlGn, norm=plt.Normalize(0, 1))
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=plt.gca(), orientation='horizontal', pad=0.05, fraction=0.05)
        cbar.set_label('Relationship Certainty')
        
        plt.axis('off')
        plt.tight_layout()
        
        return plt.gcf()


# Let's create a pedigree and enhance it with demographic information
def create_demographic_pedigree():
    """Create a pedigree with some demographic information."""
    pedigree = Pedigree()
    
    # Add individuals with demographic data
    pedigree.add_individual("A", birth_year=1900, sex="M", region="England")
    pedigree.add_individual("B", birth_year=1905, sex="F", region="England")
    pedigree.add_individual("C", birth_year=1930, sex="M", region="England")
    pedigree.add_individual("D", birth_year=1932, sex="F", region="Wales")
    pedigree.add_individual("E", birth_year=1935, sex="M", region="Scotland")
    pedigree.add_individual("F", birth_year=1960, sex="F", region="England")
    pedigree.add_individual("G", birth_year=1962, sex="M", region="Wales")
    pedigree.add_individual("H", birth_year=1985, sex="F")  # No region
    pedigree.add_individual("I", birth_year=1987, sex="M")  # No region
    pedigree.add_individual("J", sex="F")  # No birth year or region
    pedigree.add_individual("K", sex="M")  # No birth year or region
    pedigree.add_individual("L", birth_year=2015, sex="F")  # Only birth year
    
    # Add relationships with certainty scores
    pedigree.add_relationship("A", "C", certainty=0.98)
    pedigree.add_relationship("B", "C", certainty=0.98)
    pedigree.add_relationship("A", "D", certainty=0.97)
    pedigree.add_relationship("B", "D", certainty=0.97)
    pedigree.add_relationship("A", "E", certainty=0.95)
    pedigree.add_relationship("B", "E", certainty=0.95)
    pedigree.add_relationship("C", "F", certainty=0.96)
    pedigree.add_relationship("D", "G", certainty=0.96)
    pedigree.add_relationship("F", "H", certainty=0.85)
    pedigree.add_relationship("G", "H", certainty=0.85)
    pedigree.add_relationship("G", "I", certainty=0.88)
    pedigree.add_relationship("H", "J", certainty=0.77)
    pedigree.add_relationship("I", "J", certainty=0.77)
    pedigree.add_relationship("J", "L", certainty=0.65)
    pedigree.add_relationship("K", "L", certainty=0.40)  # Low certainty relationship
    
    return pedigree

# Create a pedigree with demographic data
demo_pedigree = create_demographic_pedigree()

# Visualize the initial pedigree
demo_pedigree.visualize()

# Create a demographic enhancer
enhancer = DemographicEnhancer(demo_pedigree)

# Add some historical events
enhancer.add_historical_event(
    event_name="World War II",
    start_year=1939,
    end_year=1945,
    affected_regions=["England", "Wales", "Scotland"],
    description="Major war that separated families and reduced birth rates"
)

enhancer.add_historical_event(
    event_name="UK Immigration Wave",
    start_year=1950,
    end_year=1970,
    affected_regions=["England", "Wales"],
    description="Period of increased immigration to the UK"
)

# Infer missing birth years
num_birth_years_inferred = enhancer.update_pedigree_with_inferences()
print(f"Inferred {num_birth_years_inferred} missing birth years")

# Assign regions based on family
num_regions_assigned = enhancer.assign_regions_based_on_family()
print(f"Assigned {num_regions_assigned} regions based on family")

# Check for historical consistency
historical_inconsistencies = enhancer.verify_historical_consistency()
print(f"Found {len(historical_inconsistencies)} potential historical inconsistencies:")
for i, inconsistency in enumerate(historical_inconsistencies):
    print(f"{i+1}. {inconsistency['message']}")

# Get recommendations for additional information
recommendations = enhancer.recommend_additional_information()
print("\nRecommendations for additional information:")
print(f"- Birth years needed for {len(recommendations['birth_years'])} individuals")
print(f"- Locations needed for {len(recommendations['locations'])} individuals")
print(f"- {len(recommendations['relationships'])} relationship issues to resolve")
for issue in recommendations['relationships']:
    print(f"  - {issue['message']}")

# Visualize the enhanced pedigree
enhancer.highlight_relationships_by_certainty()

## 5. Exercises

Complete the following exercises to test your understanding of multi-sample relationship inference with Bonsai:

### Exercise 1: Enhancing Pedigree Visualization
Modify the `Pedigree.visualize()` method to include sex information in the visualization (e.g., different node shapes for males and females).

### Exercise 2: Implementing a New Constraint
Extend the `PedigreeConstraintSolver` class with a new constraint that ensures individuals with the same parents have birth years that are at least 9 months apart.

### Exercise 3: Multi-Generation Pedigree Recovery
Create a more complex synthetic pedigree spanning at least 4 generations and 20 individuals. Randomly remove 30% of the individuals and their relationships, then use the techniques from this lab to recover the missing individuals and relationships.

In [ ]:
# Example solution for Exercise 1: Enhancing Pedigree Visualization

def visualize_with_sex(pedigree, highlight_nodes=None, node_labels=True, figsize=(12, 8)):
    """
    Visualize the pedigree with different node shapes for males and females.
    
    Args:
        pedigree: The Pedigree object to visualize
        highlight_nodes: List of node IDs to highlight
        node_labels: Whether to show node labels
        figsize: Figure size tuple
        
    Returns:
        matplotlib figure
    """
    plt.figure(figsize=figsize)
    
    if not pedigree.graph.nodes():
        plt.text(0.5, 0.5, "Empty Pedigree", ha='center', va='center')
        plt.axis('off')
        return plt.gcf()
    
    # Assign generations
    generations = pedigree.assign_generations()
    
    # Create a position layout based on generations
    pos = nx.spring_layout(pedigree.graph, seed=42)
    
    # Adjust y position based on generation
    max_gen = max(generations.values()) if generations else 0
    for node in pedigree.graph.nodes():
        gen = generations.get(node, 0)
        norm_gen = 1 - (gen / max(1, max_gen))
        pos[node] = (pos[node][0], 0.8 * norm_gen + 0.1)
    
    # Separate males and females for different shapes
    male_nodes = []
    female_nodes = []
    unknown_nodes = []
    phantom_nodes = []
    
    for node in pedigree.graph.nodes():
        attrs = pedigree.individuals[node]
        
        if attrs.get('is_phantom', False):
            phantom_nodes.append(node)
        elif attrs.get('sex', '').upper() == 'M':
            male_nodes.append(node)
        elif attrs.get('sex', '').upper() == 'F':
            female_nodes.append(node)
        else:
            unknown_nodes.append(node)
    
    # Define node colors and sizes
    male_color = 'skyblue'
    female_color = 'lightpink'
    unknown_color = 'lightgray'
    phantom_color = 'white'
    
    # Draw nodes by sex
    if male_nodes:
        nx.draw_networkx_nodes(pedigree.graph, pos, nodelist=male_nodes,
                              node_color=male_color, node_shape='s',  # Square for males
                              node_size=500, alpha=0.8)
    
    if female_nodes:
        nx.draw_networkx_nodes(pedigree.graph, pos, nodelist=female_nodes,
                              node_color=female_color, node_shape='o',  # Circle for females
                              node_size=500, alpha=0.8)
    
    if unknown_nodes:
        nx.draw_networkx_nodes(pedigree.graph, pos, nodelist=unknown_nodes,
                              node_color=unknown_color, node_shape='d',  # Diamond for unknown
                              node_size=500, alpha=0.8)
    
    if phantom_nodes:
        nx.draw_networkx_nodes(pedigree.graph, pos, nodelist=phantom_nodes,
                              node_color=phantom_color, node_shape='h',  # Hexagon for phantom
                              node_size=300, alpha=0.6, edgecolors='gray')
    
    # Highlight specific nodes if provided
    if highlight_nodes:
        highlight_nodes = [n for n in highlight_nodes if n in pedigree.graph.nodes()]
        if highlight_nodes:
            nx.draw_networkx_nodes(pedigree.graph.subgraph(highlight_nodes), pos, 
                                  node_color='red', node_shape='*',  # Star for highlighted
                                  node_size=600, alpha=0.8)
    
    # Draw edges with certainty-based alpha
    edge_alphas = [pedigree.relationships.get((u, v), {}).get('certainty', 1.0) for u, v in pedigree.graph.edges()]
    for i, (u, v) in enumerate(pedigree.graph.edges()):
        # Draw edges with alpha based on certainty
        alpha = edge_alphas[i]
        nx.draw_networkx_edges(pedigree.graph, pos, edgelist=[(u, v)], 
                             width=1.5, alpha=alpha, arrows=True, 
                             arrowsize=20, arrowstyle='-|>')
    
    # Add labels
    if node_labels:
        labels = {}
        for node in pedigree.graph.nodes():
            label_parts = [str(node)]
            if 'birth_year' in pedigree.individuals[node]:
                label_parts.append(f"({pedigree.individuals[node]['birth_year']})")
            labels[node] = "\n".join(label_parts)
        
        nx.draw_networkx_labels(pedigree.graph, pos, labels=labels, font_size=10)
    
    # Add a legend for node shapes
    legend_elements = [
        plt.Line2D([0], [0], marker='s', color='w', markerfacecolor=male_color, markersize=15, label='Male'),
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=female_color, markersize=15, label='Female'),
        plt.Line2D([0], [0], marker='d', color='w', markerfacecolor=unknown_color, markersize=15, label='Unknown Sex'),
        plt.Line2D([0], [0], marker='h', color='w', markerfacecolor=phantom_color, markersize=15, label='Phantom Node')
    ]
    
    plt.legend(handles=legend_elements, loc='lower left')
    
    plt.title('Pedigree Visualization with Sex Information', size=15)
    plt.axis('off')
    plt.tight_layout()
    
    return plt.gcf()

# Test the enhanced visualization with our demographic pedigree
visualize_with_sex(demo_pedigree)

## 6. Summary and Next Steps

In this lab, we've explored how Bonsai approaches multi-sample relationship inference, extending beyond simple pairwise analysis to build coherent pedigree structures. Let's summarize what we've learned:

### Key Concepts Covered

1. **Multi-Sample Inference Principles**
   - Moving from pairwise to multi-sample analysis
   - Understanding the unique challenges of multi-sample inference
   - Core strategies in Bonsai's approach

2. **Pedigree Representation and Construction**
   - Efficient data structures for representing pedigrees
   - Converting IBD data to pedigree relationships
   - Handling ambiguity in relationship inference

3. **Constraint Satisfaction**
   - Identifying and resolving relationship conflicts
   - Enforcing biological and logical constraints
   - Different strategies for resolving violations

4. **Demographic and Historical Integration**
   - Incorporating non-genetic information
   - Inferring missing demographic data
   - Verifying historical consistency
   - Making recommendations for additional information

### Next Steps

1. **Advanced Bonsai Capabilities**
   - Explore Bonsai's optimization algorithms for pedigree reconstruction
   - Investigate techniques for scaling to larger datasets
   - Understand how to fine-tune Bonsai's parameters

2. **Practical Applications**
   - Apply these techniques to real-world genetic genealogy challenges
   - Develop strategies for handling particularly complex family structures
   - Integrate with other genetic analysis tools

3. **Further Learning**
   - Explore specialized use cases like endogamous populations
   - Study how Bonsai's algorithms compare to other methods
   - Investigate the latest research in relationship inference

In the next lab, we'll explore applications of these pedigree reconstruction techniques to real-world genetic genealogy problems.