In [1]:
#| default_exp workflow

In [2]:
#|export
import jax
import jax.numpy as jnp

# Assuming 'chewc' is installed or in the python path
from chewc.structs import quick_haplo, add_trait


In [3]:
import jax
import jax.numpy as jnp
from chewc.structs import quick_haplo, add_trait
from chewc.pheno import calculate_phenotypes

def main():
    """
    An example script to demonstrate and verify the setup of a 
    ChewC simulation scenario.
    """
    print("--- Setting up and Verifying ChewC simulation scenario ---")

    # --- 1. Define Simulation Parameters (with increased scale) ---
    n_individuals = 2000  # Increased from 50 for better statistical accuracy
    n_chromosomes = 10
    n_loci_per_chr = 1000
    seed = 42
    print(f"\nIncreased scale: Simulating {n_individuals} individuals.")

    # --- 2. Create JAX PRNG keys ---
    key = jax.random.PRNGKey(seed)
    pop_key, trait_key, pheno_key = jax.random.split(key, 3)

    # --- 3. Generate Founder Population ---
    print(f"\nStep 1: Generating founder population...")
    founder_pop, genetic_map = quick_haplo(
        key=pop_key, n_ind=n_individuals, n_chr=n_chromosomes, seg_sites=n_loci_per_chr
    )
    print(f"  - Population generated. Genotype shape: {founder_pop.geno.shape}")

    # --- 4. Generate Correlated Trait Architecture ---
    print("\nStep 2: Generating a two-trait architecture...")
    target_means = jnp.array([100.0, 50.0])
    target_vars = jnp.array([10.0, 2.0])
    target_genetic_corr = -0.3

    cov = target_genetic_corr * jnp.sqrt(target_vars[0] * target_vars[1])
    sigma = jnp.array([[target_vars[0], cov], [cov, target_vars[1]]])

    trait_architecture = add_trait(
        key=trait_key,
        founder_pop=founder_pop,
        n_qtl_per_chr=50,
        mean=target_means,
        var=target_vars,
        sigma=sigma,
    )
    print(f"  - Trait architecture created.")

    # --- 5. Calculate Phenotypes for the Founder Population ---
    print("\nStep 3: Calculating phenotypes for the founder population...")
    heritabilities = jnp.array([0.4, 0.7]) 
    print(f"  - Target heritabilities (h²): {heritabilities}")

    phenotypes, tbvs = calculate_phenotypes(
        key=pheno_key,
        population=founder_pop,
        trait=trait_architecture,
        heritability=heritabilities
    )
    print(f"  - Phenotypes and TBVs calculated.")

    # --- 6. Verification ---
    print("\n--- Verification Checks ---")
    
    # Check 1: Realized Heritability (h² = V_a / V_p)
    realized_h2 = jnp.var(tbvs, axis=0) / jnp.var(phenotypes, axis=0)
    print("\n[Check 1: Heritability]")
    print(f"  - Target h²:   {heritabilities}")
    print(f"  - Realized h²: {realized_h2.round(3)}")

    # Check 2: Genetic Correlation
    realized_genetic_corr_matrix = jnp.corrcoef(tbvs, rowvar=False)
    print("\n[Check 2: Genetic Correlation]")
    print(f"  - Target:   {target_genetic_corr:.3f}")
    print(f"  - Realized: {realized_genetic_corr_matrix[0, 1]:.3f}")

    # Check 3: Phenotypic Correlation
    # The expected phenotypic correlation is r_p = r_g * h_1 * h_2
    h_1 = jnp.sqrt(heritabilities[0])
    h_2 = jnp.sqrt(heritabilities[1])
    expected_phenotypic_corr = target_genetic_corr * h_1 * h_2
    realized_phenotypic_corr_matrix = jnp.corrcoef(phenotypes, rowvar=False)
    print("\n[Check 3: Phenotypic Correlation]")
    print(f"  - Expected: {expected_phenotypic_corr:.3f}")
    print(f"  - Realized: {realized_phenotypic_corr_matrix[0, 1]:.3f}")

    print("\n--- Verification complete! ---")


if __name__ == "__main__":
    main()




--- Setting up and Verifying ChewC simulation scenario ---

Increased scale: Simulating 2000 individuals.

Step 1: Generating founder population...
  - Population generated. Genotype shape: (2000, 10, 2, 1000)

Step 2: Generating a two-trait architecture...
  - Trait architecture created.

Step 3: Calculating phenotypes for the founder population...
  - Target heritabilities (h²): [0.4 0.7]
  - Phenotypes and TBVs calculated.

--- Verification Checks ---

[Check 1: Heritability]
  - Target h²:   [0.4 0.7]
  - Realized h²: [0.39200002 0.666     ]

[Check 2: Genetic Correlation]
  - Target:   -0.300
  - Realized: -0.329

[Check 3: Phenotypic Correlation]
  - Expected: -0.159
  - Realized: -0.138

--- Verification complete! ---


In [4]:
#| hide
import nbdev; nbdev.nbdev_export()

Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
