# Lab 5: Age-Based Relationship Modeling

## Overview

This lab explores how Bonsai v3 incorporates age information to enhance relationship inference. Age is a critical demographic constraint that can help disambiguate between genetically similar relationships. We'll examine:

1. The mathematical models behind age difference distributions for various relationships
2. How Bonsai collects and applies age-based priors
3. How age information is combined with genetic evidence for improved relationship inference
4. Working with the age-modeling components in the Bonsai v3 codebase

In [None]:
# Standard imports
import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import inspect
import importlib
import json
from IPython.display import display, HTML, Markdown
import warnings
warnings.filterwarnings('ignore')

sys.path.append(os.path.dirname(os.getcwd()))

# Cross-compatibility setup
from scripts_support.lab_cross_compatibility import setup_environment, is_jupyterlite, save_results, save_plot

# Set up environment-specific paths
DATA_DIR, RESULTS_DIR = setup_environment()

# Set visualization styles
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_context("notebook")

In [None]:
# Setup Bonsai module paths
if not is_jupyterlite():
    # In local environment, add the utils directory to system path
    utils_dir = os.getenv('PROJECT_UTILS_DIR', os.path.join(os.path.dirname(DATA_DIR), 'utils'))
    bonsaitree_dir = os.path.join(utils_dir, 'bonsaitree')
    
    # Add to path if it exists and isn't already there
    if os.path.exists(bonsaitree_dir) and bonsaitree_dir not in sys.path:
        sys.path.append(bonsaitree_dir)
        print(f"Added {bonsaitree_dir} to sys.path")
else:
    # In JupyterLite, use a simplified approach
    print("⚠️ Running in JupyterLite: Some Bonsai functionality may be limited.")
    print("This notebook is primarily designed for local execution where the Bonsai codebase is available.")

In [None]:
# Helper functions for exploring modules
def display_module_classes(module_name):
    """Display classes and their docstrings from a module"""
    try:
        # Import the module
        module = importlib.import_module(module_name)
        
        # Find all classes
        classes = inspect.getmembers(module, inspect.isclass)
        
        # Filter classes defined in this module (not imported)
        classes = [(name, cls) for name, cls in classes if cls.__module__ == module_name]
        
        # Print info for each class
        for name, cls in classes:
            print(f"\n## {name}")
            
            # Get docstring
            doc = inspect.getdoc(cls)
            if doc:
                print(f"Docstring: {doc}")
            else:
                print("No docstring available")
            
            # Get methods
            methods = inspect.getmembers(cls, inspect.isfunction)
            if methods:
                print("\nMethods:")
                for method_name, method in methods:
                    if not method_name.startswith('_'):  # Skip private methods
                        print(f"- {method_name}")
    except ImportError as e:
        print(f"Error importing module {module_name}: {e}")
    except Exception as e:
        print(f"Error processing module {module_name}: {e}")

def display_module_functions(module_name):
    """Display functions and their docstrings from a module"""
    try:
        # Import the module
        module = importlib.import_module(module_name)
        
        # Find all functions
        functions = inspect.getmembers(module, inspect.isfunction)
        
        # Filter functions defined in this module (not imported)
        functions = [(name, func) for name, func in functions if func.__module__ == module_name]
        
        # Print info for each function
        for name, func in functions:
            if name.startswith('_'):  # Skip private functions
                continue
                
            print(f"\n## {name}")
            
            # Get signature
            sig = inspect.signature(func)
            print(f"Signature: {name}{sig}")
            
            # Get docstring
            doc = inspect.getdoc(func)
            if doc:
                print(f"Docstring: {doc}")
            else:
                print("No docstring available")
    except ImportError as e:
        print(f"Error importing module {module_name}: {e}")
    except Exception as e:
        print(f"Error processing module {module_name}: {e}")

def view_function_source(module_name, function_name):
    """Display the source code of a function"""
    try:
        # Import the module
        module = importlib.import_module(module_name)
        
        # Get the function
        func = getattr(module, function_name)
        
        # Get the source code
        source = inspect.getsource(func)
        
        # Print the source code
        from IPython.display import display, Markdown
        display(Markdown(f"```python\n{source}\n```"))
    except ImportError as e:
        print(f"Error importing module {module_name}: {e}")
    except AttributeError:
        print(f"Function {function_name} not found in module {module_name}")
    except Exception as e:
        print(f"Error processing function {function_name}: {e}")

## Importing Bonsai Modules

Let's start by importing the relevant Bonsai v3 modules for age-based relationship modeling:

In [None]:
try:
    # Import Bonsai v3 modules
    from utils.bonsaitree.bonsaitree.v3 import likelihoods
    from utils.bonsaitree.bonsaitree.v3 import moments
    from utils.bonsaitree.bonsaitree.v3 import constants
    
    print("✅ Successfully imported Bonsai v3 modules")
except ImportError as e:
    print(f"❌ Failed to import Bonsai v3 modules: {e}")
    print("This lab requires access to the Bonsai v3 codebase.")

## Age Difference Distributions in Relationships

Different relationship types have characteristic age difference distributions. For example:
- Parent-child relationships typically have age differences of ~20-40 years
- Siblings typically have small age differences (~0-10 years)
- Cousins can have variable age differences but often cluster around certain ranges

Let's examine how Bonsai models these age distributions:

In [None]:
# Examine the likelihoods module's functions related to age modeling
try:
    view_function_source('utils.bonsaitree.bonsaitree.v3.likelihoods', 'get_age_mean_std_for_rel_tuple')
    view_function_source('utils.bonsaitree.bonsaitree.v3.likelihoods', 'get_age_log_like')
except Exception as e:
    print(f"Could not display the age-related functions: {e}")
    print("\nThese functions compute the expected age difference and likelihood of observed age differences")
    print("for different relationship types.")

### Loading Age Moments

Bonsai stores statistical moments (mean and standard deviation) for age differences in various relationships. These are used to compute relationship likelihoods based on age information. Let's examine these moments:

In [None]:
def load_and_display_age_moments():
    """Load and display the age moments from Bonsai"""
    try:
        # Try to use the Bonsai function to load moments
        age_moments = moments.load_age_moments()
        
        # Convert to a DataFrame for better display
        moment_data = []
        for rel_tuple, (mean, std) in age_moments.items():
            # Skip None values
            if rel_tuple is None:
                continue
                
            # Parse the relationship tuple
            if len(rel_tuple) == 3:
                up, down, num_ancs = rel_tuple
            else:  # Handle other formats if they exist
                up, down, num_ancs = rel_tuple[0], rel_tuple[1], 1
                
            # Add relationship description
            rel_name = describe_relationship(up, down, num_ancs)
            
            moment_data.append({
                'Relationship': rel_name,
                'Up Generations': up,
                'Down Generations': down,
                'Common Ancestors': num_ancs,
                'Mean Age Difference': mean,
                'Std Dev Age Difference': std
            })
        
        # Create DataFrame and sort
        df = pd.DataFrame(moment_data)
        df = df.sort_values(['Up Generations', 'Down Generations', 'Common Ancestors'])
        
        return df
    except Exception as e:
        print(f"Could not load age moments: {e}")
        print("\nAttempting to load a sample of age moments for demonstration...")
        
        # Create a sample of age moments based on common relationships
        sample_moments = {
            (0, 1, 1): (30, 5),    # Parent-Child
            (1, 1, 2): (0, 5),     # Full Siblings
            (1, 1, 1): (0, 6),     # Half Siblings
            (0, 2, 1): (60, 10),   # Grandparent-Grandchild
            (1, 2, 1): (30, 10),   # Uncle/Aunt-Niece/Nephew
            (2, 2, 2): (0, 15),    # First Cousins
            (2, 2, 1): (0, 16),    # Half First Cousins
            (3, 3, 2): (0, 20),    # Second Cousins
        }
        
        # Convert to a DataFrame
        moment_data = []
        for rel_tuple, (mean, std) in sample_moments.items():
            up, down, num_ancs = rel_tuple
            rel_name = describe_relationship(up, down, num_ancs)
            
            moment_data.append({
                'Relationship': rel_name,
                'Up Generations': up,
                'Down Generations': down,
                'Common Ancestors': num_ancs,
                'Mean Age Difference': mean,
                'Std Dev Age Difference': std
            })
        
        # Create DataFrame and sort
        df = pd.DataFrame(moment_data)
        df = df.sort_values(['Up Generations', 'Down Generations', 'Common Ancestors'])
        
        return df

def describe_relationship(up, down, num_ancs):
    """Convert relationship tuple parameters to a human-readable description"""
    if up == 0 and down == 1 and num_ancs == 1:
        return "Parent-Child"
    elif up == 1 and down == 1 and num_ancs == 2:
        return "Full Siblings"
    elif up == 1 and down == 1 and num_ancs == 1:
        return "Half Siblings"
    elif up == 0 and down == 2 and num_ancs == 1:
        return "Grandparent-Grandchild"
    elif up == 1 and down == 2 and num_ancs == 1:
        return "Uncle/Aunt-Niece/Nephew"
    elif up == 2 and down == 2 and num_ancs == 2:
        return "First Cousins"
    elif up == 2 and down == 2 and num_ancs == 1:
        return "Half First Cousins"
    elif up == 3 and down == 3 and num_ancs == 2:
        return "Second Cousins"
    elif up == 3 and down == 3 and num_ancs == 1:
        return "Half Second Cousins"
    elif up == 4 and down == 4 and num_ancs == 2:
        return "Third Cousins"
    else:
        return f"Relationship({up}, {down}, {num_ancs})"

# Load and display age moments
age_moments_df = load_and_display_age_moments()
display(age_moments_df)

### Visualizing Age Difference Distributions

Now let's visualize the age difference distributions for various relationships. We'll plot the normal distributions based on the means and standard deviations from Bonsai's age moments model:

In [None]:
def plot_age_difference_distributions(moments_df):
    """Plot age difference distributions for different relationships"""
    # Define which relationships to include
    relationships_to_include = [
        "Parent-Child",
        "Full Siblings",
        "Half Siblings",
        "Grandparent-Grandchild",
        "Uncle/Aunt-Niece/Nephew",
        "First Cousins"
    ]
    
    # Filter the DataFrame for these relationships
    filtered_df = moments_df[moments_df['Relationship'].isin(relationships_to_include)]
    
    # Create the plot
    plt.figure(figsize=(12, 8))
    
    # Define a range of age differences
    age_diffs = np.linspace(-80, 80, 1000)
    
    # Plot each relationship
    from scipy.stats import norm
    for _, row in filtered_df.iterrows():
        mean = row['Mean Age Difference']
        std = row['Std Dev Age Difference']
        rel_name = row['Relationship']
        
        # Calculate normal distribution values
        pdf_values = norm.pdf(age_diffs, mean, std)
        
        # Plot the distribution
        plt.plot(age_diffs, pdf_values, label=rel_name, linewidth=2)
    
    plt.xlabel('Age Difference (Person 1 - Person 2)', fontsize=12)
    plt.ylabel('Probability Density', fontsize=12)
    plt.title('Age Difference Distributions by Relationship Type', fontsize=14)
    plt.legend(fontsize=10)
    plt.grid(alpha=0.3)
    plt.tight_layout()
    
    plt.show()

# Plot the age difference distributions
plot_age_difference_distributions(age_moments_df)

### Examining How Age Differences Scale with Relationship Distance

Let's visualize how the variability in age differences (standard deviation) increases with more distant relationships:

In [None]:
def plot_std_vs_relationship_distance(moments_df):
    """Plot how std dev of age differences scales with relationship distance"""
    # Add a relationship distance column
    moments_df['Relationship Distance'] = moments_df['Up Generations'] + moments_df['Down Generations']
    
    # Create the plot
    plt.figure(figsize=(10, 6))
    
    # Plot points with different colors based on common ancestors
    for num_ancs in [1, 2]:
        subset = moments_df[moments_df['Common Ancestors'] == num_ancs]
        label = f"{num_ancs} Common Ancestor{'s' if num_ancs > 1 else ''}"
        plt.scatter(
            subset['Relationship Distance'], 
            subset['Std Dev Age Difference'],
            label=label,
            s=80,
            alpha=0.7
        )
    
    # Add annotations for some key relationships
    for _, row in moments_df.iterrows():
        if row['Relationship'] in ["Parent-Child", "Full Siblings", "First Cousins", "Second Cousins"]:
            plt.annotate(
                row['Relationship'],
                (row['Relationship Distance'], row['Std Dev Age Difference']),
                xytext=(5, 5),
                textcoords='offset points',
                fontsize=9
            )
    
    # Fit and plot a trend line
    from scipy.stats import linregress
    x = moments_df['Relationship Distance']
    y = moments_df['Std Dev Age Difference']
    slope, intercept, r_value, p_value, std_err = linregress(x, y)
    plt.plot(x, intercept + slope*x, 'r--', label=f'Trend (r={r_value:.2f})')
    
    plt.xlabel('Relationship Distance (Up + Down Generations)', fontsize=12)
    plt.ylabel('Standard Deviation of Age Difference', fontsize=12)
    plt.title('Variability in Age Differences vs. Relationship Distance', fontsize=14)
    plt.legend(fontsize=10)
    plt.grid(alpha=0.3)
    plt.tight_layout()
    
    plt.show()

# Plot std dev vs relationship distance
plot_std_vs_relationship_distance(age_moments_df)

## Age-Based Likelihood in Bonsai

Now let's examine how Bonsai uses age information in its likelihood computations. The `get_age_log_like` function computes the log-likelihood of an observed age difference given a relationship type:

In [None]:
def compute_age_likelihoods(age1, age2):
    """Compute age-based likelihoods for various relationships"""
    # Define relationships to test
    relationships = [
        ((0, 1, 1), "Parent-Child"),
        ((1, 1, 2), "Full Siblings"),
        ((1, 1, 1), "Half Siblings"),
        ((0, 2, 1), "Grandparent-Grandchild"),
        ((1, 2, 1), "Uncle/Aunt-Niece/Nephew"),
        ((2, 2, 2), "First Cousins"),
        ((2, 2, 1), "Half First Cousins"),
        ((3, 3, 2), "Second Cousins"),
        (None, "Unrelated")
    ]
    
    try:
        # Use Bonsai's function to compute log-likelihoods
        results = []
        for rel_tuple, rel_name in relationships:
            ll = likelihoods.get_age_log_like(age1, age2, rel_tuple)
            results.append({
                'Relationship': rel_name,
                'Log-Likelihood': ll,
                'Likelihood': np.exp(ll)
            })
    except Exception as e:
        print(f"Could not compute age likelihoods using Bonsai: {e}")
        print("Using simplified model instead...")
        
        # Use our age_moments_df to compute approximate log-likelihoods
        from scipy.stats import norm
        results = []
        age_diff = age1 - age2
        
        for rel_tuple, rel_name in relationships:
            if rel_tuple is None:
                # For unrelated, use a wide distribution centered at 0
                mean, std = 0, 30
            else:
                # Find the relationship in our DataFrame
                match = age_moments_df[
                    (age_moments_df['Up Generations'] == rel_tuple[0]) & 
                    (age_moments_df['Down Generations'] == rel_tuple[1]) & 
                    (age_moments_df['Common Ancestors'] == rel_tuple[2])
                ]
                
                if len(match) > 0:
                    mean = match['Mean Age Difference'].values[0]
                    std = match['Std Dev Age Difference'].values[0]
                else:
                    # If no match, use defaults based on relationship structure
                    mean = 30 * (rel_tuple[1] - rel_tuple[0])
                    std = 10 * (rel_tuple[0] + rel_tuple[1])
            
            # Compute log-likelihood
            ll = norm.logpdf(age_diff, mean, std)
            
            results.append({
                'Relationship': rel_name,
                'Log-Likelihood': ll,
                'Likelihood': np.exp(ll)
            })
    
    # Create DataFrame and sort by likelihood
    results_df = pd.DataFrame(results)
    results_df = results_df.sort_values('Likelihood', ascending=False)
    
    return results_df

# Define two individuals with specific ages
person1_age = 70
person2_age = 40

print(f"Person 1 Age: {person1_age}")
print(f"Person 2 Age: {person2_age}")
print(f"Age Difference: {person1_age - person2_age}")
print("\nComputing age-based relationship likelihoods...\n")

# Compute and display likelihoods
age_ll_df = compute_age_likelihoods(person1_age, person2_age)
display(age_ll_df)

### Visualizing Age Likelihood Heatmap

Now let's create a heatmap showing how the likelihood of different relationships varies with different age combinations:

In [None]:
def plot_age_likelihood_heatmap(rel_tuple, rel_name):
    """Create a heatmap of age-based relationship likelihoods"""
    # Define age ranges to explore
    ages = np.arange(20, 81, 5)  # Ages from 20 to 80 in steps of 5
    
    # Create a grid of likelihoods
    likelihoods = np.zeros((len(ages), len(ages)))
    
    try:
        # Compute likelihoods using Bonsai
        for i, age1 in enumerate(ages):
            for j, age2 in enumerate(ages):
                ll = likelihoods.get_age_log_like(age1, age2, rel_tuple)
                likelihoods[i, j] = np.exp(ll) if not np.isneginf(ll) else 0
    except Exception as e:
        print(f"Could not compute age likelihoods using Bonsai: {e}")
        print("Using simplified model instead...")
        
        # Use our age_moments_df to compute approximate log-likelihoods
        from scipy.stats import norm
        
        if rel_tuple is None:
            # For unrelated, use a wide distribution centered at 0
            mean, std = 0, 30
        else:
            # Find the relationship in our DataFrame
            match = age_moments_df[
                (age_moments_df['Up Generations'] == rel_tuple[0]) & 
                (age_moments_df['Down Generations'] == rel_tuple[1]) & 
                (age_moments_df['Common Ancestors'] == rel_tuple[2])
            ]
            
            if len(match) > 0:
                mean = match['Mean Age Difference'].values[0]
                std = match['Std Dev Age Difference'].values[0]
            else:
                # If no match, use defaults based on relationship structure
                mean = 30 * (rel_tuple[1] - rel_tuple[0])
                std = 10 * (rel_tuple[0] + rel_tuple[1])
        
        for i, age1 in enumerate(ages):
            for j, age2 in enumerate(ages):
                age_diff = age1 - age2
                ll = norm.logpdf(age_diff, mean, std)
                likelihoods[i, j] = np.exp(ll)
    
    # Normalize to emphasize patterns
    likelihoods = likelihoods / np.max(likelihoods)
    
    # Create the heatmap
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        likelihoods,
        xticklabels=ages,
        yticklabels=ages,
        cmap='viridis',
        annot=False,
        cbar_kws={'label': 'Normalized Likelihood'}
    )
    
    plt.xlabel('Person 2 Age', fontsize=12)
    plt.ylabel('Person 1 Age', fontsize=12)
    plt.title(f'Age-Based Likelihood Heatmap for {rel_name} Relationship', fontsize=14)
    
    # Add diagonal line for reference (where ages are equal)
    diagonal_indices = np.array([(i, i) for i in range(len(ages))])
    plt.plot(diagonal_indices[:, 1] + 0.5, diagonal_indices[:, 0] + 0.5, 'r--', linewidth=1, alpha=0.7)
    
    plt.tight_layout()
    plt.show()

# Plot heatmaps for parent-child and sibling relationships
plot_age_likelihood_heatmap((0, 1, 1), "Parent-Child")
plot_age_likelihood_heatmap((1, 1, 2), "Full Siblings")

## Combining Genetic and Age Evidence in PwLogLike

Bonsai v3's powerful feature is its ability to combine genetic and age evidence for more accurate relationship inference. Let's examine how the `PwLogLike` class integrates these components:

In [None]:
try:
    # Look at the core methods in PwLogLike for age-based likelihood
    view_function_source('utils.bonsaitree.bonsaitree.v3.likelihoods', 'PwLogLike.get_pw_age_ll')
    view_function_source('utils.bonsaitree.bonsaitree.v3.likelihoods', 'PwLogLike.get_pw_ll')
except Exception as e:
    print(f"Could not display the PwLogLike methods: {e}")
    print("\nThe get_pw_age_ll method computes the age component of relationship likelihood.")
    print("The get_pw_ll method combines genetic and age components for overall likelihood.")

## Demonstrating Age-Based Relationship Disambiguation

A key benefit of incorporating age information is the ability to disambiguate between genetically similar relationships. For example, half siblings, grandparent-grandchild, and avuncular (aunt/uncle-niece/nephew) relationships have similar genetic signatures but very different age patterns.

Let's create a simulation to demonstrate this:

In [None]:
def simulate_disambiguation_case():
    """Simulate a case where age helps disambiguate relationships"""
    # Define similar genetic signatures for three relationship types
    similar_relationships = [
        ((1, 1, 1), "Half Siblings"),          # Half siblings
        ((0, 2, 1), "Grandparent-Grandchild"),  # Grandparent-grandchild
        ((1, 2, 1), "Avuncular")                # Uncle/aunt-niece/nephew
    ]
    
    # Create synthetic genetic likelihood scores (similar values)
    genetic_ll = {
        "Half Siblings": -15.2,
        "Grandparent-Grandchild": -15.5,
        "Avuncular": -15.3
    }
    
    # Create two simulated age scenarios
    scenarios = [
        {"name": "Similar Ages (40 and 35)", "age1": 40, "age2": 35},
        {"name": "Large Age Gap (70 and 30)", "age1": 70, "age2": 30}
    ]
    
    results = []
    
    for scenario in scenarios:
        age1, age2 = scenario["age1"], scenario["age2"]
        
        for rel_tuple, rel_name in similar_relationships:
            # Compute age likelihood
            try:
                age_ll = likelihoods.get_age_log_like(age1, age2, rel_tuple)
            except Exception:
                # Fallback to simplified model
                match = age_moments_df[
                    (age_moments_df['Up Generations'] == rel_tuple[0]) & 
                    (age_moments_df['Down Generations'] == rel_tuple[1]) & 
                    (age_moments_df['Common Ancestors'] == rel_tuple[2])
                ]
                
                if len(match) > 0:
                    mean = match['Mean Age Difference'].values[0]
                    std = match['Std Dev Age Difference'].values[0]
                else:
                    mean = 30 * (rel_tuple[1] - rel_tuple[0])
                    std = 10 * (rel_tuple[0] + rel_tuple[1])
                
                from scipy.stats import norm
                age_diff = age1 - age2
                age_ll = norm.logpdf(age_diff, mean, std)
            
            # Compute combined likelihood
            combined_ll = genetic_ll[rel_name] + age_ll
            
            results.append({
                'Scenario': scenario["name"],
                'Relationship': rel_name,
                'Age 1': age1,
                'Age 2': age2,
                'Genetic LL': genetic_ll[rel_name],
                'Age LL': age_ll,
                'Combined LL': combined_ll,
                'Relative Likelihood': np.exp(combined_ll)
            })
    
    # Convert to DataFrame
    results_df = pd.DataFrame(results)
    
    # Normalize likelihoods within each scenario
    for scenario in scenarios:
        mask = results_df['Scenario'] == scenario["name"]
        max_ll = results_df.loc[mask, 'Combined LL'].max()
        results_df.loc[mask, 'Normalized LL'] = results_df.loc[mask, 'Combined LL'] - max_ll
        results_df.loc[mask, 'Probability'] = np.exp(results_df.loc[mask, 'Normalized LL'])
        
        # Ensure probabilities sum to 1
        total_prob = results_df.loc[mask, 'Probability'].sum()
        results_df.loc[mask, 'Probability'] = results_df.loc[mask, 'Probability'] / total_prob
    
    return results_df

# Simulate and display results
disambiguation_df = simulate_disambiguation_case()
display(disambiguation_df[['Scenario', 'Relationship', 'Age 1', 'Age 2', 'Genetic LL', 'Age LL', 'Combined LL', 'Probability']])

In [None]:
def plot_disambiguation_results(results_df):
    """Plot the disambiguation results"""
    plt.figure(figsize=(14, 6))
    
    # Convert probabilities to percentages for display
    results_df['Probability %'] = results_df['Probability'] * 100
    
    # Set up the plot
    scenarios = results_df['Scenario'].unique()
    relationships = results_df['Relationship'].unique()
    
    # Set positions for grouped bars
    x = np.arange(len(scenarios))
    width = 0.25
    multiplier = 0
    
    # Create bar groups
    for relationship in relationships:
        subset = results_df[results_df['Relationship'] == relationship]
        offset = width * multiplier
        rects = plt.bar(x + offset, subset['Probability %'], width, label=relationship)
        
        # Add value labels on bars
        for rect, prob in zip(rects, subset['Probability %']):
            height = rect.get_height()
            plt.text(rect.get_x() + rect.get_width()/2., height + 1,
                     f'{prob:.1f}%', ha='center', va='bottom', fontsize=9)
            
        multiplier += 1
    
    # Add labels and legend
    plt.ylabel('Probability (%)', fontsize=12)
    plt.title('Relationship Probabilities with Different Age Scenarios', fontsize=14)
    plt.xticks(x + width, scenarios, fontsize=10)
    plt.legend(title='Relationship', fontsize=10)
    plt.ylim(0, 100)
    plt.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Plot the disambiguation results
plot_disambiguation_results(disambiguation_df)

## Exploring the Age Moments Data Format

Bonsai v3 stores age-related data in formatted files. Let's examine the structure of these files to understand how the system works:

In [None]:
def display_age_moments_file_structure():
    """Display the structure of the age moments file used by Bonsai"""
    try:
        # Try to load directly from Bonsai's constants
        file_path = constants.AGE_MOMENT_FP
        
        if os.path.exists(file_path):
            # Load and display the first few entries
            with open(file_path, 'r') as f:
                data = json.load(f)
            
            print(f"Age moments file path: {file_path}")
            print(f"Contains data for {len(data)} relationship types")
            print("\nFirst few entries (sample):\n")
            
            # Display a few sample entries
            sample_count = 0
            for key, value in data.items():
                if sample_count < 5:
                    print(f"Relationship: {key}")
                    print(f"  Mean: {value[0]}")
                    print(f"  Std Dev: {value[1]}")
                    print()
                    sample_count += 1
            
            # Display the structure overall
            print("\nFile Structure:")
            print("  - JSON file mapping relationship tuples to [mean, std_dev] arrays")
            print("  - Keys are stringified tuples representing (up, down, num_ancs)")
            print("  - Values are arrays with [mean_age_diff, std_dev_age_diff]")
            
            return data
        else:
            print(f"Age moments file not found at path: {file_path}")
            print("\nCreating a simplified example of the expected structure:")
    except Exception as e:
        print(f"Error examining age moments file: {e}")
        print("\nCreating a simplified example of the expected structure:")
    
    # Create an example structure
    example_data = {
        "(0, 1, 1)": [30.0, 5.0],   # Parent-Child
        "(1, 1, 2)": [0.0, 5.0],    # Full Siblings
        "(1, 1, 1)": [0.0, 6.0],    # Half Siblings
        "(0, 2, 1)": [60.0, 10.0],  # Grandparent-Grandchild
        "(1, 2, 1)": [30.0, 10.0]   # Uncle/Aunt-Niece/Nephew
    }
    
    print("Example data structure:")
    for key, value in example_data.items():
        print(f"Relationship: {key}")
        print(f"  Mean: {value[0]}")
        print(f"  Std Dev: {value[1]}")
        print()
    
    print("\nFile Structure:")
    print("  - JSON file mapping relationship tuples to [mean, std_dev] arrays")
    print("  - Keys are stringified tuples representing (up, down, num_ancs)")
    print("  - Values are arrays with [mean_age_diff, std_dev_age_diff]")
    
    return example_data

# Display the age moments file structure
age_moments_data = display_age_moments_file_structure()

## Building Your Own Age Moments Model

Bonsai's age models are based on empirical data, but you might want to create your own model with different parameters. Let's see how to create a custom age moments model:

In [None]:
def create_custom_age_model():
    """Create a custom age moments model based on user specifications"""
    # Define parameters for the model
    parent_child_mean = 30  # Mean parent-child age difference
    parent_child_std = 5    # Std dev of parent-child age difference
    sibling_std = 4         # Std dev of sibling age differences
    generational_multiple = 28  # How much to multiply for each generation
    
    # Create the model
    custom_model = {}
    
    # Function to calculate mean and std for a relationship
    def calculate_moments(up, down, num_ancs):
        # Mean is based on generational difference
        gen_diff = down - up
        mean = gen_diff * generational_multiple
        
        # Std dev depends on relationship structure
        if up == 0 or down == 0:  # Direct line (ancestors/descendants)
            std = parent_child_std * max(1, abs(gen_diff))
        elif num_ancs == 2 and up == 1 and down == 1:  # Full siblings
            std = sibling_std
        elif num_ancs == 1 and up == 1 and down == 1:  # Half siblings
            std = sibling_std * 1.2  # Slightly more variable
        else:  # Other relationships
            # More distant relationships have more variability
            std = parent_child_std * (up + down) * 0.8
        
        return mean, std
    
    # Generate a comprehensive set of relationships
    for up in range(5):  # 0 to 4 generations up
        for down in range(5):  # 0 to 4 generations down
            if up == 0 and down == 0:  # Skip self
                continue
                
            # For each feasible number of common ancestors
            for num_ancs in [1, 2]:
                # Skip impossible combinations (e.g., direct line with 2 common ancestors)
                if (up == 0 or down == 0) and num_ancs > 1:
                    continue
                
                mean, std = calculate_moments(up, down, num_ancs)
                rel_key = f"({up}, {down}, {num_ancs})"
                custom_model[rel_key] = [mean, std]
    
    # Add special cases
    custom_model["(0, 0, 0)"] = [0, 0]  # Self
    
    return custom_model

# Create a custom age model
custom_age_model = create_custom_age_model()

# Display some key relationships from the custom model
key_relationships = [
    "(0, 1, 1)",  # Parent-Child
    "(1, 1, 2)",  # Full Siblings
    "(1, 1, 1)",  # Half Siblings
    "(0, 2, 1)",  # Grandparent-Grandchild
    "(1, 2, 1)",  # Uncle/Aunt-Niece/Nephew
    "(2, 2, 2)",  # First Cousins
]

print("Custom Age Model - Key Relationships:")
for rel in key_relationships:
    if rel in custom_age_model:
        mean, std = custom_age_model[rel]
        rel_tuple = eval(rel)
        rel_name = describe_relationship(*rel_tuple)
        print(f"{rel_name} ({rel}): Mean = {mean}, Std Dev = {std}")

### Comparing Custom and Default Models

Let's compare our custom age model with Bonsai's default model to see the differences:

In [None]:
def compare_age_models(custom_model, default_model):
    """Compare custom and default age models"""
    # Prepare comparison data
    comparison_data = []
    
    # Get common keys
    common_keys = set(custom_model.keys()).intersection(set(default_model.keys()))
    
    for key in common_keys:
        # Get values from both models
        custom_mean, custom_std = custom_model[key]
        default_mean, default_std = default_model[key]
        
        # Calculate differences
        mean_diff = custom_mean - default_mean
        std_diff = custom_std - default_std
        
        # Get relationship name
        try:
            rel_tuple = eval(key)
            rel_name = describe_relationship(*rel_tuple)
        except:
            rel_name = key
        
        comparison_data.append({
            'Relationship': rel_name,
            'Tuple': key,
            'Custom Mean': custom_mean,
            'Default Mean': default_mean,
            'Mean Difference': mean_diff,
            'Custom Std': custom_std,
            'Default Std': default_std,
            'Std Difference': std_diff
        })
    
    # Convert to DataFrame and sort by absolute mean difference
    df = pd.DataFrame(comparison_data)
    df['Abs Mean Diff'] = df['Mean Difference'].abs()
    df = df.sort_values('Abs Mean Diff', ascending=False)
    
    return df

# Compare models if we have default data
if isinstance(age_moments_data, dict) and len(age_moments_data) > 0:
    comparison_df = compare_age_models(custom_age_model, age_moments_data)
    display(comparison_df[['Relationship', 'Tuple', 'Custom Mean', 'Default Mean', 'Mean Difference', 'Custom Std', 'Default Std', 'Std Difference']].head(10))
else:
    print("No default model data available for comparison.")

## Summary

In this lab, we've explored how Bonsai v3 incorporates age information to enhance relationship inference:

1. **Age Difference Distributions**: We examined the statistical distributions of age differences for various relationships, represented by normal distributions with characteristic means and standard deviations.

2. **Age-Based Likelihood Calculation**: We explored how Bonsai computes likelihood scores based on age differences using the actual code from the `likelihoods.py` module.

3. **Combining Evidence**: We saw how the `PwLogLike` class integrates genetic and age-based evidence to get more accurate relationship inferences.

4. **Relationship Disambiguation**: We demonstrated how age information can disambiguate between genetically similar relationships, such as half-siblings, grandparent-grandchild, and avuncular relationships.

5. **Age Moments Data Format**: We examined the structure of Bonsai's age moments data and how it's stored and retrieved.

6. **Custom Age Models**: We learned how to create custom age models with different parameters for specific populations or use cases.

The integration of demographic information with genetic evidence is a powerful approach that makes Bonsai v3 more accurate in real-world relationship inference and pedigree reconstruction.