In [1]:
#| default_exp workflow

In [2]:
import msprime
import sys

print(f"msprime version being used: {msprime.__version__}")
print(f"Python executable: {sys.executable}")

msprime version being used: 1.3.4
Python executable: /bin/python


In [3]:
#|export

import jax
from chewc.config import SimConfig
from chewc.structs import SimParam
from chewc.popgen import quick_haplo
from chewc.state import init_state_from_founders
import jax.numpy as jnp


# 1. Define static simulation configuration
config = SimConfig(
    n_chr=10,
    ploidy=2,
    max_pop_size=1000,
    n_loci_per_chr=100,
    n_generations=50,
    n_select=50,
    population_size=101  # invalid value, not multiple
)

# 2. Create a master random key
key = jax.random.PRNGKey(42)
key, founder_key = jax.random.split(key)

# 3. Generate the founder population (runs on host)
founder_pop, genetic_map = quick_haplo(
    key=founder_key,
    n_ind=100,
    n_chr=config.n_chr,
    n_loci_per_chr=config.n_loci_per_chr,
    max_pop_size=config.max_pop_size,
    ploidy=config.ploidy
)

# 4. Define simulation parameters (e.g., genetic maps, trait info)
sp = SimParam(
    gen_map=genetic_map,
    ploidy=config.ploidy,
    # traits, var_e, etc. would be defined here
)

# 5. Initialize the dynamic simulation state (the carry for lax.scan)
# This uses your existing `init_state_from_founders` function.
key, state_key = jax.random.split(key)
initial_state = init_state_from_founders(
    key=state_key,
    founder_pop=founder_pop,
    sp=sp,
    config=config
)

print(f"Initial write position: {initial_state.write_pos}")
print(f"Next available ID: {initial_state.next_id}")
print(f"Founder population active: {jnp.sum(initial_state.is_active)}")

# `initial_state` is now ready to be passed to your JIT'd `generation_step`
# within a `lax.scan`.



Initial write position: 101
Next available ID: 100
Founder population active: 100




In [8]:
# | export

import jax
import jax.numpy as jnp
from chewc.config import SimConfig
from chewc.structs import SimParam
from chewc.popgen import quick_haplo
from chewc.state import init_state_from_founders
# Import the new trait module functions
from chewc.trait import add_trait_a, set_pheno_h2


# 1. Define static simulation configuration
config = SimConfig(
    n_chr=10,
    ploidy=2,
    max_pop_size=1000,
    n_loci_per_chr=100,
    n_generations=50,
    n_select=50,
    population_size=100
)

# 2. Create and split master random key
key = jax.random.PRNGKey(42)
key, founder_key, trait_key, pheno_key, state_key = jax.random.split(key, 5)

# 3. Generate the founder population
founder_pop, genetic_map = quick_haplo(
    key=founder_key,
    n_ind=100,
    n_chr=config.n_chr,
    n_loci_per_chr=config.n_loci_per_chr,
    max_pop_size=config.max_pop_size,
    ploidy=config.ploidy
)

# 4. Define initial simulation parameters
sp = SimParam(
    gen_map=genetic_map,
    ploidy=config.ploidy,
)

# 5. Add an additive trait to the simulation parameters
print("--- Defining Trait ---")
sp = add_trait_a(
    key=trait_key,
    founder_pop=founder_pop,
    sim_param=sp,
    n_qtl_per_chr=10,
    mean=jnp.array([0.0]),
    var=jnp.array([1.0]) # Target *genetic* variance
)
print(f"Trait added with {sp.traits.n_loci} QTLs.")

# 6. Set phenotypes using a target narrow-sense heritability (h2)
print("\n--- Setting Initial Phenotypes with h2 ---")
# Let's target a narrow-sense heritability of 0.4
h2 = 0.4
founder_pop_with_pheno = set_pheno_h2(
    key=pheno_key,
    pop=founder_pop,
    sp=sp,
    h2=h2
)

# 7. Initialize the dynamic simulation state
print("\n--- Initializing Simulation State ---")
initial_state = init_state_from_founders(
    key=state_key,
    founder_pop=founder_pop_with_pheno,
    sp=sp,
    config=config
)

print(f"Initial write position: {initial_state.write_pos}")
print(f"Next available ID: {initial_state.next_id}")
print(f"Founder population active: {jnp.sum(initial_state.is_active)}")

# 8. Verification
print("\n--- Verification ---")
active_mask = initial_state.is_active
var_a = jnp.var(initial_state.bv[active_mask])
var_p = jnp.var(initial_state.pheno[active_mask])
realized_h2 = var_a / var_p

print(f"Target narrow-sense heritability (h2): {h2:.4f}")
print(f"Realized additive variance (VarA) in founders: {var_a:.4f}")
print(f"Realized phenotypic variance (VarP) in founders: {var_p:.4f}")
print(f"Realized narrow-sense heritability (h2) in founders: {realized_h2:.4f}")

# `initial_state` is now ready for the simulation loop.

--- Defining Trait ---
Trait added with 100 QTLs.

--- Setting Initial Phenotypes with h2 ---

--- Initializing Simulation State ---
Initial write position: 100
Next available ID: 100
Founder population active: 100

--- Verification ---
Target narrow-sense heritability (h2): 0.4000
Realized additive variance (VarA) in founders: 1.0000
Realized phenotypic variance (VarP) in founders: 2.6278
Realized narrow-sense heritability (h2) in founders: 0.3805


In [7]:
#| hide
import nbdev; nbdev.nbdev_export()

SyntaxError: closing parenthesis ']' does not match opening parenthesis '(' (<unknown>, line 49)