# Lab 19: Caching Mechanisms for Computational Efficiency

## Overview

In this lab, we'll explore the caching mechanisms used in Bonsai v3 to improve computational efficiency. By storing and reusing previously computed results, these mechanisms significantly reduce redundant calculations and improve performance for large-scale pedigree reconstruction.

In [None]:
# Standard imports
import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from IPython.display import display, HTML, Markdown
import inspect
import importlib
import copy
import random
import math
import time
import hashlib
import json
import functools
from collections import defaultdict, OrderedDict

sys.path.append(os.path.dirname(os.getcwd()))

# Cross-compatibility setup
from scripts_support.lab_cross_compatibility import setup_environment, is_jupyterlite, save_results, save_plot

# Set up environment-specific paths
DATA_DIR, RESULTS_DIR = setup_environment()

# Set visualization styles
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_context("notebook")

In [None]:
# Setup Bonsai module paths
if not is_jupyterlite():
    # In local environment, add the utils directory to system path
    utils_dir = os.getenv('PROJECT_UTILS_DIR', os.path.join(os.path.dirname(DATA_DIR), 'utils'))
    bonsaitree_dir = os.path.join(utils_dir, 'bonsaitree')
    
    # Add to path if it exists and isn't already there
    if os.path.exists(bonsaitree_dir) and bonsaitree_dir not in sys.path:
        sys.path.append(bonsaitree_dir)
        print(f"Added {bonsaitree_dir} to sys.path")
else:
    # In JupyterLite, use a simplified approach
    print("⚠️ Running in JupyterLite: Some Bonsai functionality may be limited.")
    print("This notebook is primarily designed for local execution where the Bonsai codebase is available.")

In [None]:
# Helper functions for exploring modules
def display_module_classes(module_name):
    """Display classes and their docstrings from a module"""
    try:
        # Import the module
        module = importlib.import_module(module_name)
        
        # Find all classes
        classes = inspect.getmembers(module, inspect.isclass)
        
        # Filter classes defined in this module (not imported)
        classes = [(name, cls) for name, cls in classes if cls.__module__ == module_name]
        
        # Print info for each class
        for name, cls in classes:
            print(f"\n## {name}")
            
            # Get docstring
            doc = inspect.getdoc(cls)
            if doc:
                print(f"Docstring: {doc}")
            else:
                print("No docstring available")
            
            # Get methods
            methods = inspect.getmembers(cls, inspect.isfunction)
            if methods:
                print("\nMethods:")
                for method_name, method in methods:
                    if not method_name.startswith('_'):  # Skip private methods
                        print(f"- {method_name}")
    except ImportError as e:
        print(f"Error importing module {module_name}: {e}")
    except Exception as e:
        print(f"Error processing module {module_name}: {e}")

def display_module_functions(module_name):
    """Display functions and their docstrings from a module"""
    try:
        # Import the module
        module = importlib.import_module(module_name)
        
        # Find all functions
        functions = inspect.getmembers(module, inspect.isfunction)
        
        # Filter functions defined in this module (not imported)
        functions = [(name, func) for name, func in functions if func.__module__ == module_name]
        
        # Print info for each function
        for name, func in functions:
            if name.startswith('_'):  # Skip private functions
                continue
                
            print(f"\n## {name}")
            
            # Get signature
            sig = inspect.signature(func)
            print(f"Signature: {name}{sig}")
            
            # Get docstring
            doc = inspect.getdoc(func)
            if doc:
                print(f"Docstring: {doc}")
            else:
                print("No docstring available")
    except ImportError as e:
        print(f"Error importing module {module_name}: {e}")
    except Exception as e:
        print(f"Error processing module {module_name}: {e}")

def view_source(obj):
    """Display the source code of an object (function or class)"""
    try:
        source = inspect.getsource(obj)
        display(Markdown(f"```python\n{source}\n```"))
    except Exception as e:
        print(f"Error retrieving source: {e}")

## Check Bonsai Installation

Let's verify that the Bonsai v3 module is available for import:

In [None]:
try:
    from utils.bonsaitree.bonsaitree import v3
    print("✅ Successfully imported Bonsai v3 module")
except ImportError as e:
    print(f"❌ Failed to import Bonsai v3 module: {e}")
    print("This lab requires access to the Bonsai v3 codebase.")
    print("Make sure you've properly set up your environment with the Bonsai repository.")

## Lab 19: Caching Mechanisms for Computational Efficiency

In this lab, we'll explore the sophisticated caching mechanisms used in Bonsai v3 to improve computational efficiency. Pedigree reconstruction involves many repetitive calculations, such as:

1. Computing likelihoods for the same relationship configuration multiple times
2. Finding ancestors or descendants of the same individual repeatedly
3. Evaluating the same IBD segments in different contexts

By implementing effective caching strategies, Bonsai v3 can avoid these redundant calculations and significantly improve performance. We'll focus on several key caching mechanisms:

1. **Memoization**: Storing the results of expensive function calls and returning the cached result when the same inputs occur again
2. **LRU Cache**: Using Least Recently Used (LRU) caching to maintain a fixed-size cache of the most recently accessed items
3. **Persistent Caching**: Storing computation results to disk for reuse across different runs
4. **Hierarchical Caching**: Using multi-level caching strategies for different types of calculations

We'll implement simplified versions of these mechanisms to understand how they work and why they're important for large-scale pedigree reconstruction.

## Part 1: Memoization

Memoization is a caching technique where the results of expensive function calls are stored so that the same computation isn't performed repeatedly for the same inputs. This is particularly useful in Bonsai v3, where many functions are called repeatedly with the same parameters during pedigree reconstruction.

In [ ]:
# Import Bonsai caching modules if available
if not is_jupyterlite():
    try:
        from utils.bonsaitree.bonsaitree.v3.caching import memoize
        
        # Display the source code if available
        print("Source code for memoize:")
        view_source(memoize)
    except (ImportError, AttributeError) as e:
        print(f"Could not import function: {e}")
else:
    print("Cannot display source code in JupyterLite environment.")

### 1.1 Basic Memoization

Let's implement a basic memoization decorator that caches function results based on input arguments:

In [ ]:
def memoize(func):
    """
    A simple memoization decorator that caches function results.
    
    Args:
        func: The function to memoize
        
    Returns:
        A wrapper function that implements memoization
    """
    cache = {}
    
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        # Create a key from the arguments
        key = str(args) + str(sorted(kwargs.items()))
        
        # Return cached result if available
        if key in cache:
            return cache[key]
        
        # Compute and cache the result
        result = func(*args, **kwargs)
        cache[key] = result
        return result
    
    # Add a method to clear the cache
    def clear_cache():
        cache.clear()
    
    wrapper.clear_cache = clear_cache
    
    # Add a method to get cache info
    def cache_info():
        return {"cache_hits": 0, "cache_misses": 0, "cache_size": len(cache)}
    
    wrapper.cache_info = cache_info
    
    return wrapper

# Example of an expensive function that could benefit from memoization
def calculate_fibonacci(n):
    """
    Calculate the nth Fibonacci number recursively.
    This is intentionally inefficient to demonstrate memoization.
    """
    if n <= 1:
        return n
    return calculate_fibonacci(n-1) + calculate_fibonacci(n-2)

# Create a memoized version
@memoize
def calculate_fibonacci_memoized(n):
    """
    Calculate the nth Fibonacci number recursively, with memoization.
    """
    if n <= 1:
        return n
    return calculate_fibonacci_memoized(n-1) + calculate_fibonacci_memoized(n-2)

# Let's see how memoization improves performance
def benchmark_fibonacci(n):
    """Benchmark regular vs. memoized Fibonacci calculation"""
    # Regular version
    start_time = time.time()
    result = calculate_fibonacci(n)
    regular_time = time.time() - start_time
    
    # Memoized version
    calculate_fibonacci_memoized.clear_cache()  # Start with an empty cache
    start_time = time.time()
    memoized_result = calculate_fibonacci_memoized(n)
    memoized_time = time.time() - start_time
    
    # Check results match
    assert result == memoized_result, "Results don't match!"
    
    # Calculate speedup
    speedup = regular_time / memoized_time if memoized_time > 0 else float('inf')
    
    return {
        "n": n,
        "result": result,
        "regular_time": regular_time,
        "memoized_time": memoized_time,
        "speedup": speedup
    }

# Benchmark for different values of n
results = []
for n in range(20, 36, 5):
    print(f"Benchmarking n={n}...")
    results.append(benchmark_fibonacci(n))

# Display results in a table
print("\nBenchmark Results:")
print(f"{'n':<5} | {'Result':<15} | {'Regular Time':<15} | {'Memoized Time':<15} | {'Speedup':<10}")
print("-" * 65)

for result in results:
    print(f"{result['n']:<5} | {result['result']:<15} | {result['regular_time']:.6f} s | {result['memoized_time']:.6f} s | {result['speedup']:.2f}x")

# Plot the results
plt.figure(figsize=(12, 6))
plt.title("Memoization Performance Improvement")

# Extract data for plotting
ns = [r["n"] for r in results]
regular_times = [r["regular_time"] for r in results]
memoized_times = [r["memoized_time"] for r in results]

# Create a logarithmic scale plot to handle large differences
plt.semilogy(ns, regular_times, 'o-', label='Regular', linewidth=2, markersize=8)
plt.semilogy(ns, memoized_times, 'o-', label='Memoized', linewidth=2, markersize=8)

plt.xlabel("Fibonacci Number (n)")
plt.ylabel("Computation Time (seconds, log scale)")
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend()
plt.tight_layout()
plt.show()

### 1.2 Improved Memoization with Metrics

Let's improve our memoization decorator to track cache hits and misses, which will help us understand how effective the cache is:

In [ ]:
def memoize_with_metrics(func):
    """
    An enhanced memoization decorator that tracks cache metrics.
    
    Args:
        func: The function to memoize
        
    Returns:
        A wrapper function that implements memoization with metrics
    """
    cache = {}
    hits = 0
    misses = 0
    
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        nonlocal hits, misses
        
        # Create a key from the arguments
        key = str(args) + str(sorted(kwargs.items()))
        
        # Return cached result if available
        if key in cache:
            hits += 1
            return cache[key]
        
        # Compute and cache the result
        misses += 1
        result = func(*args, **kwargs)
        cache[key] = result
        return result
    
    # Add a method to clear the cache
    def clear_cache():
        nonlocal hits, misses
        cache.clear()
        hits = 0
        misses = 0
    
    wrapper.clear_cache = clear_cache
    
    # Add a method to get cache info
    def cache_info():
        return {
            "cache_hits": hits,
            "cache_misses": misses,
            "cache_size": len(cache),
            "hit_ratio": hits / (hits + misses) if hits + misses > 0 else 0
        }
    
    wrapper.cache_info = cache_info
    
    return wrapper

# Example function - calculating binomial coefficients
def calculate_binomial(n, k):
    """Calculate binomial coefficient C(n,k) using a recursive formula"""
    if k == 0 or k == n:
        return 1
    return calculate_binomial(n-1, k-1) + calculate_binomial(n-1, k)

@memoize_with_metrics
def calculate_binomial_memoized(n, k):
    """Calculate binomial coefficient C(n,k) using a recursive formula with memoization"""
    if k == 0 or k == n:
        return 1
    return calculate_binomial_memoized(n-1, k-1) + calculate_binomial_memoized(n-1, k)

# Let's calculate some binomial coefficients and monitor cache performance
def demonstrate_binomial_memoization():
    # Clear cache to start fresh
    calculate_binomial_memoized.clear_cache()
    
    # Calculate C(20,10)
    start_time = time.time()
    result = calculate_binomial_memoized(20, 10)
    elapsed_time = time.time() - start_time
    
    print(f"C(20,10) = {result}, calculated in {elapsed_time:.6f} seconds")
    
    # Check cache metrics
    metrics = calculate_binomial_memoized.cache_info()
    print(f"Cache hits: {metrics['cache_hits']}")
    print(f"Cache misses: {metrics['cache_misses']}")
    print(f"Cache size: {metrics['cache_size']}")
    print(f"Hit ratio: {metrics['hit_ratio']:.2%}")
    
    # Calculate C(20,11) - should reuse many cached results
    start_time = time.time()
    result = calculate_binomial_memoized(20, 11)
    elapsed_time = time.time() - start_time
    
    print(f"\nC(20,11) = {result}, calculated in {elapsed_time:.6f} seconds")
    
    # Check cache metrics again
    metrics = calculate_binomial_memoized.cache_info()
    print(f"Cache hits: {metrics['cache_hits']}")
    print(f"Cache misses: {metrics['cache_misses']}")
    print(f"Cache size: {metrics['cache_size']}")
    print(f"Hit ratio: {metrics['hit_ratio']:.2%}")
    
    # Calculate hit ratio improvement for the second calculation
    first_hits = metrics['cache_hits'] - metrics['cache_misses']
    second_metrics = {}
    second_metrics['cache_hits'] = metrics['cache_hits'] - first_hits
    second_metrics['cache_misses'] = metrics['cache_misses'] - metrics['cache_misses']
    
    return metrics

# Let's run the demonstration
metrics = demonstrate_binomial_memoization()

# Visualize cache performance
labels = ['Hits', 'Misses']
sizes = [metrics['cache_hits'], metrics['cache_misses']]
colors = ['#66b3ff', '#ff9999']
explode = (0.1, 0)  # explode the 1st slice (Hits)

plt.figure(figsize=(8, 8))
plt.pie(sizes, explode=explode, labels=labels, colors=colors, autopct='%1.1f%%',
        shadow=True, startangle=90)
plt.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle
plt.title('Cache Performance Metrics')
plt.tight_layout()
plt.show()

### 1.3 Application to Pedigree Reconstruction

Let's see how memoization can be applied to pedigree reconstruction functions in Bonsai v3, such as finding ancestors or descendants in a pedigree:

In [ ]:
# Non-memoized version of get_ancestors function
def get_ancestors(iid, up_dct, max_degree=None):
    """
    Get all ancestors of an individual in a pedigree.
    
    Args:
        iid: ID of the individual
        up_dct: Up-node dictionary representing the pedigree
        max_degree: Maximum generational distance to consider
        
    Returns:
        Set of ancestor IDs
    """
    ancestors = set()
    current_degree = 0
    current_gen = {iid}
    
    while current_gen and (max_degree is None or current_degree < max_degree):
        next_gen = set()
        for current_id in current_gen:
            # Get parents from the up_dct
            if current_id in up_dct:
                for parent_id in up_dct[current_id]:
                    if parent_id not in ancestors:
                        next_gen.add(parent_id)
                        ancestors.add(parent_id)
        
        current_gen = next_gen
        current_degree += 1
    
    return ancestors

# Memoized version
@memoize_with_metrics
def get_ancestors_memoized(iid, up_dct, max_degree=None):
    """
    Get all ancestors of an individual in a pedigree (memoized version).
    
    Args:
        iid: ID of the individual
        up_dct: Up-node dictionary representing the pedigree
        max_degree: Maximum generational distance to consider
        
    Returns:
        Set of ancestor IDs
    """
    # We need to convert up_dct to a hashable representation for the memoization key
    # In this simplified version, we'll use a frozenset of (id, parent_id) tuples
    up_dct_hashable = frozenset(
        (id_val, parent_id) 
        for id_val, parents in up_dct.items() 
        for parent_id in parents
    )
    
    ancestors = set()
    current_degree = 0
    current_gen = {iid}
    
    while current_gen and (max_degree is None or current_degree < max_degree):
        next_gen = set()
        for current_id in current_gen:
            # Get parents from the up_dct
            if current_id in up_dct:
                for parent_id in up_dct[current_id]:
                    if parent_id not in ancestors:
                        next_gen.add(parent_id)
                        ancestors.add(parent_id)
        
        current_gen = next_gen
        current_degree += 1
    
    return ancestors

# Generate a large test pedigree
def generate_test_pedigree(num_generations=5, branching_factor=2):
    """
    Generate a test pedigree with a specified number of generations.
    
    Args:
        num_generations: Number of generations in the pedigree
        branching_factor: Number of children per individual
        
    Returns:
        up_dct: Up-node dictionary representing the pedigree
        id_to_gen: Dict mapping IDs to their generation
    """
    up_dct = {}
    id_to_gen = {}
    next_id = 1
    
    # Create the first generation (founders)
    founders = []
    for _ in range(branching_factor ** (num_generations - 1)):
        id_to_gen[next_id] = 0  # Generation 0
        up_dct[next_id] = {}
        founders.append(next_id)
        next_id += 1
    
    # Create subsequent generations
    for gen in range(1, num_generations):
        parents = [id_val for id_val, g in id_to_gen.items() if g == gen - 1]
        
        for parent_id in parents:
            for _ in range(branching_factor):
                id_to_gen[next_id] = gen
                up_dct[next_id] = {parent_id: 1}
                next_id += 1
    
    return up_dct, id_to_gen

# Benchmark ancestor calculation with and without memoization
def benchmark_ancestor_calculation(up_dct, id_to_gen, num_queries=100):
    """
    Benchmark ancestor calculation with and without memoization.
    
    Args:
        up_dct: Up-node dictionary representing the pedigree
        id_to_gen: Dict mapping IDs to their generation
        num_queries: Number of queries to perform
        
    Returns:
        Dict with benchmark results
    """
    # Clear cache to start fresh
    get_ancestors_memoized.clear_cache()
    
    # Get all IDs in the pedigree
    all_ids = list(up_dct.keys())
    
    # Generate random queries (ID, max_degree)
    queries = []
    for _ in range(num_queries):
        iid = random.choice(all_ids)
        max_degree = random.randint(1, 5)
        queries.append((iid, max_degree))
    
    # Time the non-memoized version
    start_time = time.time()
    regular_results = []
    
    for iid, max_degree in queries:
        ancestors = get_ancestors(iid, up_dct, max_degree)
        regular_results.append(ancestors)
    
    regular_time = time.time() - start_time
    
    # Time the memoized version
    start_time = time.time()
    memoized_results = []
    
    for iid, max_degree in queries:
        ancestors = get_ancestors_memoized(iid, up_dct, max_degree)
        memoized_results.append(ancestors)
    
    memoized_time = time.time() - start_time
    
    # Verify results match
    for i, (regular, memoized) in enumerate(zip(regular_results, memoized_results)):
        if regular != memoized:
            print(f"Warning: Results don't match for query {i}!")
    
    # Get cache metrics
    metrics = get_ancestors_memoized.cache_info()
    
    # Calculate speedup
    speedup = regular_time / memoized_time if memoized_time > 0 else float('inf')
    
    return {
        "queries": num_queries,
        "regular_time": regular_time,
        "memoized_time": memoized_time,
        "speedup": speedup,
        "cache_hits": metrics["cache_hits"],
        "cache_misses": metrics["cache_misses"],
        "hit_ratio": metrics["hit_ratio"]
    }

# Generate test pedigree
print("Generating test pedigree...")
up_dct, id_to_gen = generate_test_pedigree(num_generations=6, branching_factor=2)
print(f"Pedigree has {len(up_dct)} individuals across {max(id_to_gen.values()) + 1} generations")

# Run benchmark
print("\nRunning benchmark...")
results = benchmark_ancestor_calculation(up_dct, id_to_gen, num_queries=100)

# Display results
print("\nBenchmark Results:")
print(f"Number of queries: {results['queries']}")
print(f"Non-memoized time: {results['regular_time']:.6f} seconds")
print(f"Memoized time: {results['memoized_time']:.6f} seconds")
print(f"Speedup: {results['speedup']:.2f}x")
print(f"Cache hits: {results['cache_hits']}")
print(f"Cache misses: {results['cache_misses']}")
print(f"Hit ratio: {results['hit_ratio']:.2%}")

# Create a bar chart comparing performance
plt.figure(figsize=(10, 6))
plt.bar(['Non-memoized', 'Memoized'], 
        [results['regular_time'], results['memoized_time']], 
        color=['#ff9999', '#66b3ff'])
plt.title('Ancestor Calculation Performance')
plt.ylabel('Time (seconds)')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

## Part 2: Least Recently Used (LRU) Cache

While memoization is effective, it can lead to unbounded memory usage as the cache grows over time. Least Recently Used (LRU) caching addresses this by maintaining a fixed-size cache of the most recently used items, discarding the least recently used items when the cache is full.

In [ ]:
# Import Bonsai LRU caching modules if available
if not is_jupyterlite():
    try:
        from utils.bonsaitree.bonsaitree.v3.caching import lru_cache
        
        # Display the source code if available
        print("Source code for lru_cache:")
        view_source(lru_cache)
    except (ImportError, AttributeError) as e:
        print(f"Could not import function: {e}")
else:
    print("Cannot display source code in JupyterLite environment.")

### 2.1 Implementing a Simple LRU Cache

Let's implement a simple LRU cache using an OrderedDict, which maintains insertion order and allows us to move recently accessed items to the end:

In [ ]:
class LRUCache:
    """
    A simple implementation of an LRU (Least Recently Used) cache.
    
    This cache has a maximum size and evicts the least recently used items
    when that size is exceeded.
    
    Args:
        max_size: The maximum number of items to store in the cache
    """
    def __init__(self, max_size=128):
        self.cache = OrderedDict()
        self.max_size = max_size
        self.hits = 0
        self.misses = 0
    
    def get(self, key):
        """
        Get an item from the cache, marking it as recently used.
        
        Args:
            key: The key to look up
            
        Returns:
            The cached value, or None if not found
        """
        if key in self.cache:
            # Move the key to the end (most recently used)
            value = self.cache.pop(key)
            self.cache[key] = value
            self.hits += 1
            return value
        self.misses += 1
        return None
    
    def put(self, key, value):
        """
        Add an item to the cache, evicting the least recently used item if necessary.
        
        Args:
            key: The key to store
            value: The value to store
            
        Returns:
            None
        """
        if key in self.cache:
            # Remove the existing entry
            self.cache.pop(key)
        elif len(self.cache) >= self.max_size:
            # Remove the least recently used item (first item in OrderedDict)
            self.cache.popitem(last=False)
        
        # Add the new item
        self.cache[key] = value
    
    def clear(self):
        """Clear the cache and reset statistics."""
        self.cache.clear()
        self.hits = 0
        self.misses = 0
    
    def info(self):
        """Return statistics about the cache."""
        total_accesses = self.hits + self.misses
        hit_ratio = self.hits / total_accesses if total_accesses > 0 else 0
        return {
            "size": len(self.cache),
            "max_size": self.max_size,
            "hits": self.hits,
            "misses": self.misses,
            "hit_ratio": hit_ratio
        }

### 2.2 Creating an LRU Cache Decorator

Now let's create a decorator that applies an LRU cache to functions, similar to Python's `functools.lru_cache`:

In [ ]:
def lru_cache(maxsize=128):
    """
    Decorator to wrap a function with an LRU cache.
    
    Args:
        maxsize: Maximum number of entries to keep in the cache
        
    Returns:
        Decorator function
    """
    def decorator(func):
        cache = LRUCache(max_size=maxsize)
        
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            # Create a key from the arguments
            key = str(args) + str(sorted(kwargs.items()))
            
            # Try to get from cache
            result = cache.get(key)
            if result is not None:
                return result
            
            # Compute and store the result
            result = func(*args, **kwargs)
            cache.put(key, result)
            return result
        
        # Add methods to access cache information
        def cache_info():
            return cache.info()
        
        def cache_clear():
            cache.clear()
        
        wrapper.cache_info = cache_info
        wrapper.cache_clear = cache_clear
        
        return wrapper
    
    return decorator

### 2.3 Benchmarking LRU Cache Performance

Let's apply our LRU cache to the Fibonacci function and benchmark its performance with different cache sizes:

In [ ]:
# Create a Fibonacci function with our LRU cache
@lru_cache(maxsize=64)
def fib_lru(n):
    """Calculate the nth Fibonacci number using LRU cache."""
    if n <= 1:
        return n
    return fib_lru(n-1) + fib_lru(n-2)

# Benchmark different cache sizes
def benchmark_lru_cache_sizes(n=35):
    """
    Benchmark LRU cache with different cache sizes.
    
    Args:
        n: Fibonacci number to calculate
        
    Returns:
        Dict with benchmark results
    """
    results = []
    
    # Define a function to create and test a fib function with a specific cache size
    def test_cache_size(cache_size):
        # Create a new function with the specified cache size
        @lru_cache(maxsize=cache_size)
        def fib_test(n):
            if n <= 1:
                return n
            return fib_test(n-1) + fib_test(n-2)
        
        # Calculate the fibonacci number and measure time
        start_time = time.time()
        result = fib_test(n)
        elapsed_time = time.time() - start_time
        
        # Get cache info
        info = fib_test.cache_info()
        
        return {
            "cache_size": cache_size,
            "time": elapsed_time,
            "hits": info["hits"],
            "misses": info["misses"],
            "hit_ratio": info["hit_ratio"]
        }
    
    # Test different cache sizes
    for cache_size in [8, 16, 32, 64, 128, 256]:
        print(f"Testing cache size: {cache_size}")
        results.append(test_cache_size(cache_size))
    
    return results

# Run the benchmark
benchmark_results = benchmark_lru_cache_sizes(n=35)

# Create a visualization of the results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Plot computation time vs cache size
cache_sizes = [r["cache_size"] for r in benchmark_results]
times = [r["time"] for r in benchmark_results]
hit_ratios = [r["hit_ratio"] * 100 for r in benchmark_results]

ax1.plot(cache_sizes, times, 'o-', linewidth=2, markersize=8)
ax1.set_xlabel("Cache Size")
ax1.set_ylabel("Computation Time (seconds)")
ax1.set_title("Computation Time vs Cache Size")
ax1.grid(True, linestyle='--', alpha=0.7)

# Plot hit ratio vs cache size
ax2.plot(cache_sizes, hit_ratios, 'o-', linewidth=2, markersize=8, color='green')
ax2.set_xlabel("Cache Size")
ax2.set_ylabel("Cache Hit Ratio (%)")
ax2.set_title("Cache Hit Ratio vs Cache Size")
ax2.grid(True, linestyle='--', alpha=0.7)

plt.tight_layout()
plt.show()

# Print the detailed results
print("\nDetailed Results:")
print(f"{'Cache Size':<10} | {'Time (s)':<10} | {'Hits':<6} | {'Misses':<6} | {'Hit Ratio':<8}")
print("-" * 50)
for result in benchmark_results:
    print(f"{result['cache_size']:<10} | {result['time']:.6f} | {result['hits']:<6} | {result['misses']:<6} | {result['hit_ratio']:.2%}")

### 2.4 Application to Pedigree Reconstruction

Let's apply LRU caching to a common pedigree reconstruction function: finding the lowest common ancestor (LCA) of two individuals. This function is frequently called during pedigree inference, and would benefit from caching:

In [ ]:
# Function to find the lowest common ancestor of two individuals
def find_lowest_common_ancestor(iid1, iid2, up_dct):
    """
    Find the lowest common ancestor of two individuals in a pedigree.
    
    Args:
        iid1: ID of the first individual
        iid2: ID of the second individual
        up_dct: Up-node dictionary representing the pedigree
        
    Returns:
        ID of the lowest common ancestor, or None if none exists
    """
    # Get all ancestors of the first individual
    ancestors1 = set()
    queue = [iid1]
    while queue:
        current = queue.pop(0)
        if current in up_dct:
            for parent in up_dct[current]:
                if parent not in ancestors1:
                    ancestors1.add(parent)
                    queue.append(parent)
    
    # Check if the second individual is in the ancestors of the first
    if iid2 in ancestors1:
        return iid2
    
    # Traverse up from the second individual, checking for common ancestors
    visited = set()
    queue = [iid2]
    while queue:
        current = queue.pop(0)
        if current in ancestors1:
            return current
        
        visited.add(current)
        if current in up_dct:
            for parent in up_dct[current]:
                if parent not in visited:
                    queue.append(parent)
    
    # No common ancestor found
    return None

# Create an LRU-cached version
@lru_cache(maxsize=1000)
def find_lowest_common_ancestor_lru(iid1, iid2, up_dct_hashable):
    """
    Find the lowest common ancestor of two individuals in a pedigree (with LRU caching).
    
    Args:
        iid1: ID of the first individual
        iid2: ID of the second individual
        up_dct_hashable: Hashable representation of the up-node dictionary
        
    Returns:
        ID of the lowest common ancestor, or None if none exists
    """
    # Convert the hashable representation back to a dictionary
    up_dct = {}
    for id_val, parent_id in up_dct_hashable:
        if id_val not in up_dct:
            up_dct[id_val] = set()
        up_dct[id_val].add(parent_id)
    
    # Get all ancestors of the first individual
    ancestors1 = set()
    queue = [iid1]
    while queue:
        current = queue.pop(0)
        if current in up_dct:
            for parent in up_dct[current]:
                if parent not in ancestors1:
                    ancestors1.add(parent)
                    queue.append(parent)
    
    # Check if the second individual is in the ancestors of the first
    if iid2 in ancestors1:
        return iid2
    
    # Traverse up from the second individual, checking for common ancestors
    visited = set()
    queue = [iid2]
    while queue:
        current = queue.pop(0)
        if current in ancestors1:
            return current
        
        visited.add(current)
        if current in up_dct:
            for parent in up_dct[current]:
                if parent not in visited:
                    queue.append(parent)
    
    # No common ancestor found
    return None

# Generate a larger test pedigree for benchmarking
def generate_complex_pedigree(individuals=100, avg_children=2, consanguinity_rate=0.1):
    """
    Generate a more complex pedigree with consanguinity.
    
    Args:
        individuals: Total number of individuals to generate
        avg_children: Average number of children per pair
        consanguinity_rate: Probability of consanguineous mating
        
    Returns:
        up_dct: Up-node dictionary representing the pedigree
    """
    up_dct = {}
    available_mates = []
    next_id = 1
    
    # Create founders (10% of total)
    founders = []
    for _ in range(max(2, int(individuals * 0.1))):
        up_dct[next_id] = {}
        founders.append(next_id)
        available_mates.append(next_id)
        next_id += 1
    
    # Create the rest of the pedigree
    while next_id <= individuals:
        # Choose parents
        if random.random() < consanguinity_rate and len(available_mates) > 3:
            # Consanguineous mating - choose related individuals
            parent1 = random.choice(available_mates)
            # Find relatives of parent1
            relatives = []
            for iid in available_mates:
                if iid != parent1:
                    lca = find_lowest_common_ancestor(parent1, iid, up_dct)
                    if lca is not None:
                        relatives.append(iid)
            
            if relatives:
                parent2 = random.choice(relatives)
            else:
                # Fall back to random mating if no relatives found
                candidates = [iid for iid in available_mates if iid != parent1]
                parent2 = random.choice(candidates) if candidates else parent1
        else:
            # Random mating
            parent1 = random.choice(available_mates)
            candidates = [iid for iid in available_mates if iid != parent1]
            parent2 = random.choice(candidates) if candidates else parent1
        
        # Create children
        num_children = max(1, int(random.normalvariate(avg_children, 1)))
        for _ in range(num_children):
            if next_id > individuals:
                break
                
            up_dct[next_id] = {parent1, parent2}
            available_mates.append(next_id)
            next_id += 1
    
    return up_dct

# Benchmark the LCA function with and without LRU caching
def benchmark_lca(up_dct, num_queries=1000):
    """
    Benchmark the lowest common ancestor function with and without LRU caching.
    
    Args:
        up_dct: Up-node dictionary representing the pedigree
        num_queries: Number of LCA queries to perform
        
    Returns:
        Dict with benchmark results
    """
    # Create a hashable representation of the up_dct for the LRU cache
    up_dct_hashable = frozenset(
        (id_val, parent_id) 
        for id_val, parents in up_dct.items() 
        for parent_id in parents
    )
    
    # Clear the LRU cache
    find_lowest_common_ancestor_lru.cache_clear()
    
    # Generate random pairs of individuals
    all_ids = list(up_dct.keys())
    pairs = []
    for _ in range(num_queries):
        iid1 = random.choice(all_ids)
        iid2 = random.choice(all_ids)
        if iid1 != iid2:
            pairs.append((iid1, iid2))
    
    # Benchmark the regular version
    start_time = time.time()
    regular_results = []
    for iid1, iid2 in pairs:
        result = find_lowest_common_ancestor(iid1, iid2, up_dct)
        regular_results.append(result)
    regular_time = time.time() - start_time
    
    # Benchmark the LRU-cached version
    start_time = time.time()
    lru_results = []
    for iid1, iid2 in pairs:
        result = find_lowest_common_ancestor_lru(iid1, iid2, up_dct_hashable)
        lru_results.append(result)
    lru_time = time.time() - start_time
    
    # Verify results match
    for i, (reg, lru) in enumerate(zip(regular_results, lru_results)):
        if reg != lru:
            print(f"Warning: Results don't match for pair {i}!")
    
    # Get cache info
    cache_info = find_lowest_common_ancestor_lru.cache_info()
    
    # Calculate speedup
    speedup = regular_time / lru_time if lru_time > 0 else float('inf')
    
    return {
        "queries": num_queries,
        "regular_time": regular_time,
        "lru_time": lru_time,
        "speedup": speedup,
        "hits": cache_info["hits"],
        "misses": cache_info["misses"],
        "hit_ratio": cache_info["hit_ratio"]
    }

# Generate a pedigree and run the benchmark
print("Generating test pedigree...")
pedigree = generate_complex_pedigree(individuals=200, consanguinity_rate=0.2)
print(f"Pedigree has {len(pedigree)} individuals")

# Run benchmark
print("\nRunning LCA benchmark...")
lca_results = benchmark_lca(pedigree, num_queries=1000)

# Display results
print("\nLCA Benchmark Results:")
print(f"Number of queries: {lca_results['queries']}")
print(f"Regular time: {lca_results['regular_time']:.6f} seconds")
print(f"LRU-cached time: {lca_results['lru_time']:.6f} seconds")
print(f"Speedup: {lca_results['speedup']:.2f}x")
print(f"Cache hits: {lca_results['hits']}")
print(f"Cache misses: {lca_results['misses']}")
print(f"Hit ratio: {lca_results['hit_ratio']:.2%}")

# Create visual comparison
plt.figure(figsize=(12, 6))
labels = ['Regular', 'LRU-Cached']
times = [lca_results['regular_time'], lca_results['lru_time']]
colors = ['#ff9999', '#66b3ff']

bar_plot = plt.bar(labels, times, color=colors)
plt.ylabel('Time (seconds)')
plt.title('LCA Calculation Performance Comparison')
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Add text labels
for bar, time_val in zip(bar_plot, times):
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
             f'{time_val:.4f}s',
             ha='center', va='bottom')

# Add speedup annotation
plt.annotate(f'Speedup: {lca_results["speedup"]:.2f}x',
             xy=(0.5, max(times) * 0.5),
             xytext=(0.5, max(times) * 0.7),
             ha='center',
             bbox=dict(boxstyle="round,pad=0.3", fc="#d5f5e3", ec="black", alpha=0.8))

plt.tight_layout()
plt.show()

## Part 3: Persistent Caching

While memoization and LRU caching are effective for a single run, they don't persist between runs. For long-running or multi-stage pedigree reconstruction, it can be valuable to have persistent caching that saves computed results to disk.

### 3.1 Implementing a Simple Disk Cache

Let's implement a simple persistent cache that stores results in JSON files:

In [ ]:
class DiskCache:
    """
    A persistent cache that stores results on disk.
    
    Args:
        cache_dir: Directory to store cache files
        encoder: Function to encode objects for storage (default: JSON serialization)
        decoder: Function to decode stored objects (default: JSON deserialization)
    """
    def __init__(self, cache_dir=None, encoder=None, decoder=None):
        self.cache_dir = cache_dir or os.path.join(RESULTS_DIR, 'cache')
        self.encoder = encoder or json.dumps
        self.decoder = decoder or json.loads
        self.hits = 0
        self.misses = 0
        
        # Create the cache directory if it doesn't exist
        os.makedirs(self.cache_dir, exist_ok=True)
    
    def _get_cache_path(self, key):
        """Convert a key to a file path"""
        # Use a hash to create a safe filename
        key_hash = hashlib.md5(str(key).encode()).hexdigest()
        return os.path.join(self.cache_dir, f"{key_hash}.json")
    
    def get(self, key):
        """
        Get a value from the cache.
        
        Args:
            key: Cache key
            
        Returns:
            Cached value or None if not found
        """
        cache_path = self._get_cache_path(key)
        
        if os.path.exists(cache_path):
            try:
                with open(cache_path, 'r') as f:
                    self.hits += 1
                    return self.decoder(f.read())
            except (json.JSONDecodeError, IOError) as e:
                print(f"Error reading cache file: {e}")
        
        self.misses += 1
        return None
    
    def put(self, key, value):
        """
        Store a value in the cache.
        
        Args:
            key: Cache key
            value: Value to store
        """
        cache_path = self._get_cache_path(key)
        
        try:
            with open(cache_path, 'w') as f:
                f.write(self.encoder(value))
            return True
        except (TypeError, IOError) as e:
            print(f"Error writing to cache file: {e}")
            return False
    
    def clear(self):
        """Clear all cached items."""
        for filename in os.listdir(self.cache_dir):
            if filename.endswith('.json'):
                os.remove(os.path.join(self.cache_dir, filename))
        self.hits = 0
        self.misses = 0
    
    def info(self):
        """Return cache statistics."""
        cache_files = [f for f in os.listdir(self.cache_dir) if f.endswith('.json')]
        total_size = sum(os.path.getsize(os.path.join(self.cache_dir, f)) for f in cache_files)
        
        total_accesses = self.hits + self.misses
        hit_ratio = self.hits / total_accesses if total_accesses > 0 else 0
        
        return {
            "items": len(cache_files),
            "size_bytes": total_size,
            "hits": self.hits,
            "misses": self.misses,
            "hit_ratio": hit_ratio
        }

# Create a disk cache decorator
def disk_cache(cache_dir=None):
    """
    Create a decorator for persistent disk caching.
    
    Args:
        cache_dir: Directory to store cache files
        
    Returns:
        Decorator function
    """
    cache = DiskCache(cache_dir=cache_dir)
    
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            # Create a key from the function name and arguments
            key = {
                "func": func.__name__,
                "args": args,
                "kwargs": sorted(kwargs.items())
            }
            
            # Try to get from cache
            result = cache.get(key)
            if result is not None:
                return result
            
            # Compute and store the result
            result = func(*args, **kwargs)
            cache.put(key, result)
            return result
        
        # Add methods to access cache information
        def cache_info():
            return cache.info()
        
        def cache_clear():
            cache.clear()
        
        wrapper.cache_info = cache_info
        wrapper.cache_clear = cache_clear
        
        return wrapper
    
    return decorator

### 3.2 Using Persistent Caching in Pedigree Reconstruction

Let's apply our disk cache to a function that simulates a time-consuming pedigree reconstruction operation, such as inferring relationships from IBD segments:

In [ ]:
# Create a cache directory for our examples
example_cache_dir = os.path.join(RESULTS_DIR, 'example_cache')
os.makedirs(example_cache_dir, exist_ok=True)

# Define an expensive function that simulates relationship inference
def infer_relationship(ibd_segments, iid1, iid2):
    """
    Infer the relationship between two individuals based on IBD segments.
    This is a simplified simulation of a computationally expensive operation.
    
    Args:
        ibd_segments: List of (start, end, cM) tuples representing IBD segments
        iid1: ID of the first individual
        iid2: ID of the second individual
        
    Returns:
        Dict with relationship information
    """
    # Simulate computation time
    time.sleep(0.5)  # This would be much longer in reality
    
    # Calculate total cM shared
    total_cm = sum(segment[2] for segment in ibd_segments)
    
    # Infer relationship based on total cM (simplified)
    relationship = None
    confidence = None
    
    if total_cm > 3000:
        relationship = "parent-child"
        confidence = 0.99
    elif total_cm > 2000:
        relationship = "full-sibling"
        confidence = 0.95
    elif total_cm > 1000:
        relationship = "half-sibling/grandparent"
        confidence = 0.9
    elif total_cm > 500:
        relationship = "1st cousin"
        confidence = 0.85
    elif total_cm > 250:
        relationship = "2nd cousin"
        confidence = 0.8
    elif total_cm > 100:
        relationship = "3rd cousin"
        confidence = 0.7
    else:
        relationship = "distant"
        confidence = 0.5
    
    return {
        "relationship": relationship,
        "confidence": confidence,
        "total_cm": total_cm,
        "segments": len(ibd_segments)
    }

# Create a version with disk caching
@disk_cache(cache_dir=example_cache_dir)
def infer_relationship_cached(ibd_segments, iid1, iid2):
    """Cached version of the relationship inference function."""
    # Convert IBD segments to a hashable representation for the cache key
    segments_hashable = tuple((start, end, cM) for start, end, cM in ibd_segments)
    
    # Call the original function
    return infer_relationship(segments_hashable, iid1, iid2)

# Generate some random IBD segments
def generate_random_ibd_segments(num_segments=5, relationship="distant"):
    """Generate random IBD segments consistent with a given relationship."""
    segments = []
    
    # Set cM range based on relationship
    if relationship == "parent-child":
        cm_range = (100, 200)
    elif relationship == "full-sibling":
        cm_range = (75, 150)
    elif relationship == "half-sibling":
        cm_range = (50, 100)
    elif relationship == "1st cousin":
        cm_range = (20, 50)
    elif relationship == "2nd cousin":
        cm_range = (10, 30)
    else:  # distant
        cm_range = (5, 15)
    
    # Generate segments
    for _ in range(num_segments):
        start = random.randint(1, 250000000)
        end = start + random.randint(5000000, 20000000)
        cM = random.uniform(cm_range[0], cm_range[1])
        segments.append((start, end, cM))
    
    return segments

# Benchmark the cached vs non-cached version
def benchmark_persistent_cache(num_pairs=5, num_repeats=3):
    """
    Benchmark persistent caching for relationship inference.
    
    Args:
        num_pairs: Number of distinct individual pairs to test
        num_repeats: Number of times to repeat each inference
        
    Returns:
        Dict with benchmark results
    """
    # Clear the cache
    infer_relationship_cached.cache_clear()
    
    # Generate test data
    test_data = []
    for i in range(num_pairs):
        relationship = random.choice(["parent-child", "full-sibling", "half-sibling", "1st cousin", "2nd cousin", "distant"])
        segments = generate_random_ibd_segments(num_segments=random.randint(3, 10), relationship=relationship)
        iid1 = f"ind_{i*2}"
        iid2 = f"ind_{i*2+1}"
        test_data.append((segments, iid1, iid2))
    
    # Test without caching
    start_time = time.time()
    non_cached_results = []
    
    for _ in range(num_repeats):
        for segments, iid1, iid2 in test_data:
            result = infer_relationship(segments, iid1, iid2)
            non_cached_results.append(result)
    
    non_cached_time = time.time() - start_time
    
    # Test with caching
    start_time = time.time()
    cached_results = []
    
    for _ in range(num_repeats):
        for segments, iid1, iid2 in test_data:
            result = infer_relationship_cached(segments, iid1, iid2)
            cached_results.append(result)
    
    cached_time = time.time() - start_time
    
    # Calculate speedup
    speedup = non_cached_time / cached_time if cached_time > 0 else float('inf')
    
    # Get cache info
    cache_info = infer_relationship_cached.cache_info()
    
    return {
        "num_pairs": num_pairs,
        "num_repeats": num_repeats,
        "non_cached_time": non_cached_time,
        "cached_time": cached_time,
        "speedup": speedup,
        "cache_info": cache_info
    }

# Run the benchmark and visualize results
benchmark_results = benchmark_persistent_cache(num_pairs=5, num_repeats=3)

print("Persistent Cache Benchmark Results:")
print(f"Number of pairs: {benchmark_results['num_pairs']}")
print(f"Number of repeats: {benchmark_results['num_repeats']}")
print(f"Non-cached time: {benchmark_results['non_cached_time']:.3f} seconds")
print(f"Cached time: {benchmark_results['cached_time']:.3f} seconds")
print(f"Speedup: {benchmark_results['speedup']:.2f}x")
print(f"Cache hits: {benchmark_results['cache_info']['hits']}")
print(f"Cache misses: {benchmark_results['cache_info']['misses']}")
print(f"Cache hit ratio: {benchmark_results['cache_info']['hit_ratio']:.2%}")
print(f"Cache items: {benchmark_results['cache_info']['items']}")
print(f"Cache size: {benchmark_results['cache_info']['size_bytes'] / 1024:.2f} KB")

# Create a bar chart comparing performance
plt.figure(figsize=(10, 6))
labels = ['Without Cache', 'With Disk Cache']
times = [benchmark_results['non_cached_time'], benchmark_results['cached_time']]
colors = ['#ff9999', '#66b3ff']

plt.bar(labels, times, color=colors)
plt.ylabel('Time (seconds)')
plt.title('Persistent Caching Performance')
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Add speed-up annotation
plt.annotate(f"Speedup: {benchmark_results['speedup']:.2f}x",
             xy=(0.5, max(times) * 0.7),
             ha='center',
             bbox=dict(boxstyle="round,pad=0.3", fc="#d5f5e3", ec="black", alpha=0.8))

# Create a small pie chart showing cache hits vs misses
pie_ax = plt.axes([0.7, 0.5, 0.2, 0.2])  # Position the pie chart inside the main figure
hits = benchmark_results['cache_info']['hits']
misses = benchmark_results['cache_info']['misses']
pie_ax.pie([hits, misses], labels=['Hits', 'Misses'], colors=['#66b3ff', '#ff9999'],
          autopct='%1.1f%%', startangle=90)
pie_ax.set_title('Cache Hits vs Misses')

plt.tight_layout()
plt.show()

## Part 4: Hierarchical Caching

For the most sophisticated caching strategy, Bonsai v3 uses hierarchical caching, combining multiple caching mechanisms for different types of operations. Let's implement a simple hierarchical cache:

In [ ]:
class HierarchicalCache:
    """
    A hierarchical cache that combines multiple cache levels.
    
    This cache checks multiple levels in order, starting with the fastest.
    When a value is found, it's stored in all faster caches.
    
    Args:
        caches: List of cache objects in order from fastest to slowest
    """
    def __init__(self, caches):
        self.caches = caches
    
    def get(self, key):
        """
        Get a value from the cache, checking each level in order.
        
        Args:
            key: Cache key
            
        Returns:
            Cached value or None if not found
        """
        # Check each cache level
        for i, cache in enumerate(self.caches):
            value = cache.get(key)
            if value is not None:
                # Store the value in all faster caches
                for j in range(i):
                    self.caches[j].put(key, value)
                return value
        
        return None
    
    def put(self, key, value):
        """
        Store a value in all cache levels.
        
        Args:
            key: Cache key
            value: Value to store
        """
        for cache in self.caches:
            cache.put(key, value)
    
    def clear(self):
        """Clear all cache levels."""
        for cache in self.caches:
            cache.clear()
    
    def info(self):
        """Return information about all cache levels."""
        return {f"level_{i}": cache.info() for i, cache in enumerate(self.caches)}

# Create a hierarchical cache decorator combining in-memory LRU and disk caching
def hierarchical_cache(memory_size=128, cache_dir=None):
    """
    Create a decorator for hierarchical caching.
    
    Args:
        memory_size: Size of the in-memory LRU cache
        cache_dir: Directory for the disk cache
        
    Returns:
        Decorator function
    """
    memory_cache = LRUCache(max_size=memory_size)
    disk_cache_obj = DiskCache(cache_dir=cache_dir)
    hierarchical_cache_obj = HierarchicalCache([memory_cache, disk_cache_obj])
    
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            # Create a key from the function name and arguments
            key = {
                "func": func.__name__,
                "args": args,
                "kwargs": sorted(kwargs.items())
            }
            
            # Try to get from cache
            result = hierarchical_cache_obj.get(key)
            if result is not None:
                return result
            
            # Compute and store the result
            result = func(*args, **kwargs)
            hierarchical_cache_obj.put(key, result)
            return result
        
        # Add methods to access cache information
        def cache_info():
            return hierarchical_cache_obj.info()
        
        def cache_clear():
            hierarchical_cache_obj.clear()
        
        wrapper.cache_info = cache_info
        wrapper.cache_clear = cache_clear
        
        return wrapper
    
    return decorator

### 4.1 Benchmarking Hierarchical Caching

Let's benchmark our hierarchical cache to see how it compares to single-level caching approaches:

In [ ]:
# Create a directory for hierarchical cache testing
hierarchical_cache_dir = os.path.join(RESULTS_DIR, 'hierarchical_cache')
os.makedirs(hierarchical_cache_dir, exist_ok=True)

# Create different cached versions of our relationship inference function
@lru_cache(maxsize=50)
def infer_relationship_lru(ibd_segments, iid1, iid2):
    """LRU-cached version of relationship inference."""
    segments_hashable = tuple((start, end, cM) for start, end, cM in ibd_segments)
    return infer_relationship(segments_hashable, iid1, iid2)

@disk_cache(cache_dir=hierarchical_cache_dir)
def infer_relationship_disk(ibd_segments, iid1, iid2):
    """Disk-cached version of relationship inference."""
    segments_hashable = tuple((start, end, cM) for start, end, cM in ibd_segments)
    return infer_relationship(segments_hashable, iid1, iid2)

@hierarchical_cache(memory_size=50, cache_dir=hierarchical_cache_dir)
def infer_relationship_hierarchical(ibd_segments, iid1, iid2):
    """Hierarchically-cached version of relationship inference."""
    segments_hashable = tuple((start, end, cM) for start, end, cM in ibd_segments)
    return infer_relationship(segments_hashable, iid1, iid2)

# Create a benchmark function
def benchmark_caching_strategies(num_pairs=20, num_repeats=3):
    """
    Benchmark different caching strategies.
    
    Args:
        num_pairs: Number of distinct individual pairs to test
        num_repeats: Number of times to repeat each inference
        
    Returns:
        Dict with benchmark results
    """
    # Clear all caches
    infer_relationship_lru.cache_clear()
    infer_relationship_disk.cache_clear()
    infer_relationship_hierarchical.cache_clear()
    
    # Generate test data
    test_data = []
    for i in range(num_pairs):
        relationship = random.choice(["parent-child", "full-sibling", "half-sibling", "1st cousin", "2nd cousin", "distant"])
        segments = generate_random_ibd_segments(num_segments=random.randint(3, 10), relationship=relationship)
        iid1 = f"ind_{i*2}"
        iid2 = f"ind_{i*2+1}"
        test_data.append((segments, iid1, iid2))
    
    # Function to test a specific caching strategy
    def test_strategy(inference_func, name):
        start_time = time.time()
        results = []
        
        for _ in range(num_repeats):
            for segments, iid1, iid2 in test_data:
                result = inference_func(segments, iid1, iid2)
                results.append(result)
        
        elapsed_time = time.time() - start_time
        
        return {
            "name": name,
            "time": elapsed_time
        }
    
    # Test each strategy
    no_cache_result = test_strategy(infer_relationship, "No Cache")
    lru_result = test_strategy(infer_relationship_lru, "LRU Cache")
    disk_result = test_strategy(infer_relationship_disk, "Disk Cache")
    hierarchical_result = test_strategy(infer_relationship_hierarchical, "Hierarchical Cache")
    
    # Calculate speedups
    baseline_time = no_cache_result["time"]
    lru_speedup = baseline_time / lru_result["time"] if lru_result["time"] > 0 else float('inf')
    disk_speedup = baseline_time / disk_result["time"] if disk_result["time"] > 0 else float('inf')
    hierarchical_speedup = baseline_time / hierarchical_result["time"] if hierarchical_result["time"] > 0 else float('inf')
    
    return {
        "num_pairs": num_pairs,
        "num_repeats": num_repeats,
        "no_cache": no_cache_result,
        "lru_cache": lru_result,
        "disk_cache": disk_result,
        "hierarchical_cache": hierarchical_result,
        "lru_speedup": lru_speedup,
        "disk_speedup": disk_speedup,
        "hierarchical_speedup": hierarchical_speedup
    }

# Run the benchmark
print("Running caching strategies benchmark...")
benchmark_results = benchmark_caching_strategies(num_pairs=20, num_repeats=3)

# Display results
print("\nCaching Strategies Benchmark Results:")
print(f"Number of pairs: {benchmark_results['num_pairs']}")
print(f"Number of repeats: {benchmark_results['num_repeats']}")
print("\nExecution times:")
print(f"No Cache: {benchmark_results['no_cache']['time']:.3f} seconds")
print(f"LRU Cache: {benchmark_results['lru_cache']['time']:.3f} seconds (speedup: {benchmark_results['lru_speedup']:.2f}x)")
print(f"Disk Cache: {benchmark_results['disk_cache']['time']:.3f} seconds (speedup: {benchmark_results['disk_speedup']:.2f}x)")
print(f"Hierarchical Cache: {benchmark_results['hierarchical_cache']['time']:.3f} seconds (speedup: {benchmark_results['hierarchical_speedup']:.2f}x)")

# Create a visualization of the results
plt.figure(figsize=(12, 8))

# Bar chart of execution times
labels = ['No Cache', 'LRU Cache', 'Disk Cache', 'Hierarchical Cache']
times = [
    benchmark_results['no_cache']['time'],
    benchmark_results['lru_cache']['time'],
    benchmark_results['disk_cache']['time'],
    benchmark_results['hierarchical_cache']['time']
]
colors = ['#ff9999', '#66b3ff', '#99ff99', '#ffcc99']

bar_plot = plt.bar(labels, times, color=colors)
plt.ylabel('Time (seconds)')
plt.title('Caching Strategy Performance Comparison')
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Add speedup annotations
for i, (bar, time_val, name) in enumerate(zip(bar_plot, times, labels)):
    if i > 0:  # Skip "No Cache"
        speedup = getattr(benchmark_results, f"{name.lower().replace(' ', '_')}_speedup")
        plt.text(bar.get_x() + bar.get_width()/2., time_val + 0.1,
                 f'{speedup:.2f}x',
                 ha='center', va='bottom',
                 bbox=dict(boxstyle="round,pad=0.2", fc="#d5f5e3", ec="black", alpha=0.8))

plt.tight_layout()
plt.show()

# Create a secondary plot showing relative performance
plt.figure(figsize=(10, 6))
baseline_time = benchmark_results['no_cache']['time']
relative_times = [t / baseline_time * 100 for t in times]

plt.bar(labels, relative_times, color=colors)
plt.ylabel('Relative Time (%)')
plt.title('Relative Performance of Caching Strategies')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.axhline(y=100, color='r', linestyle='--')
plt.ylim(0, 110)

for i, (rel_time, label) in enumerate(zip(relative_times, labels)):
    plt.text(i, rel_time + 2, f'{rel_time:.1f}%', ha='center')

plt.tight_layout()
plt.show()

## Summary

In this lab, we explored the caching mechanisms used in Bonsai v3 to improve computational efficiency:

1. **Memoization**: Simple in-memory caching that stores function results for reuse. It provides significant speedups for recursive and repetitive calculations but can lead to unbounded memory usage.

2. **LRU Cache**: Fixed-size cache that evicts least recently used items when full. It balances memory usage with performance, providing good caching benefits even with limited memory.

3. **Persistent Caching**: Stores computation results to disk, allowing them to persist between runs. This is especially valuable for long-running pedigree reconstruction tasks that may be split across multiple sessions.

4. **Hierarchical Caching**: Combines multiple caching strategies for optimal performance. Fast in-memory caches provide quick access to recently used items, while slower persistent caches maintain a larger history.

By implementing these caching strategies effectively, Bonsai v3 can dramatically improve performance in pedigree reconstruction tasks, especially for large pedigrees with many relationship calculations.

### Key Takeaways

- Caching is particularly effective for recursive operations like finding ancestors or calculating relationships
- Different caching strategies have different strengths and weaknesses:
  - Memoization: Simple but potentially memory-hungry
  - LRU Cache: Memory-efficient but limited capacity
  - Disk Cache: Persistent but slower access
  - Hierarchical Cache: Best overall performance but most complex
- The right caching strategy depends on the specific task and available resources
- Bonsai v3 uses all these strategies in different contexts to optimize performance

### Related Concepts

In the next lab, we'll explore error handling and data validation mechanisms in Bonsai v3, which ensure robust performance even with imperfect data inputs.