In [1]:
import jax
import jax.numpy as jnp

# Assuming these are the locations of your modules
from chewc.population import msprime_pop
from chewc.sp import SimParam
from chewc.trait import add_trait_a, _calculate_gvs_vectorized_alternative
from chewc.phenotype import set_pheno
import numpy as np

# JAX random key setup
key = jax.random.PRNGKey(42)
key, pop_key, trait_key = jax.random.split(key, 3)

# 1. Generate the founder population and its genetic map together.
founder_pop, genetic_map = msprime_pop(
    key=pop_key, n_ind=500, n_loci_per_chr=500, n_chr=3, ploidy=2
)

# 2. Use the founder population and its map to configure the simulation's rules.
SP = SimParam.from_founder_pop(
    founder_pop=founder_pop,
    gen_map=genetic_map,
    sexes="no"
)

# 3. Add two additive traits with a negative genetic correlation.
# Define trait means, variances, and the correlation matrix
trait_means = jnp.array([0.0, 0.0])
trait_vars = jnp.array([1.0, 1.0])
neg_cor_matrix = jnp.array([[1.0, 0.6],
                           [0.6, 1.0]])

# Call the function to add the traits
SP = add_trait_a(
    key=trait_key,
    founder_pop=founder_pop,
    sim_param=SP,
    n_qtl_per_chr=500,
    mean=trait_means,
    var=trait_vars,
    cor_a=neg_cor_matrix
)




In [2]:
founder_pop

Population(nInd=500, nTraits=0, has_ebv=No)

In [8]:

# Split the JAX key for two separate phenotyping operations
key = jax.random.PRNGKey(42) # A fresh key for this block
key, h2_key, varE_key = jax.random.split(key, 3)

# --- Example 1: Phenotyping with Heritability (h2) ---

# Define the narrow-sense heritability for each trait
heritabilities = jnp.array([0.3, 0.6])

# Call the JIT-compiled set_pheno function using the h2 argument
pop_with_h2 = set_pheno(
    key=h2_key,
    pop=founder_pop,
    traits=SP.traits,
    ploidy=SP.ploidy,
    h2=heritabilities
)

# --- Verification (Optional) ---
# Calculate genetic values for the new traits in the founder population
bvs,gvs = _calculate_gvs_vectorized_alternative(founder_pop, SP.traits, SP.ploidy)

print(np.mean(gvs[:,1])) #9.822846e-08
print(np.var(gvs[:,1])) #1.0000001


print(np.mean(pop_with_h2.gv[:,1])) #1.9693801
print(np.var(pop_with_h2.gv[:,1])) #1.0000001


4.0054324e-08
1.0
4.0054324e-08
1.0


In [7]:
gvs[0].shape

(500, 2)

In [None]:
SP.traits

In [None]:
# --- Correct Verification ---

# 1. Inspect the Breeding Values (BVs) stored in the population object.
# The mean of the BVs should also be ~0.
print("Mean of stored BVs:", np.mean(pop_with_h2.bv[:,1]))

# 2. Correctly calculate the Genetic Values (GVs) from the final object.
# GV = BV + Intercept
final_gvs = pop_with_h2.bv + SP.traits.intercept

print("Mean of calculated GVs from final pop:", np.mean(final_gvs[:,1]))
print("Variance of calculated GVs from final pop:", np.var(final_gvs[:,1]))

# --- Compare with the initial calculation ---

# Re-run the initial calculation for a direct comparison
# (Note the unpacking of the tuple now)
initial_bvs, initial_gvs = _calculate_gvs_vectorized_alternative(founder_pop, SP.traits, SP.ploidy)

print("\n--- Comparison ---")
print("Initial GV Mean:", np.mean(initial_gvs[:,1]))
print("Final GV Mean:  ", np.mean(final_gvs[:,1]))

In [None]:
import jax
import jax.numpy as jnp

# Assuming the previous setup code has been run to create founder_pop and SP
# founder_pop, SP, ...

from chewc.phenotype import set_pheno # Import the function

# Split the JAX key for two separate phenotyping operations
key = jax.random.PRNGKey(42) # A fresh key for this block
key, h2_key, varE_key = jax.random.split(key, 3)

# --- Example 1: Phenotyping with Heritability (h2) ---

# Define the narrow-sense heritability for each trait
heritabilities = jnp.array([0.3, 0.6])

# Call the JIT-compiled set_pheno function using the h2 argument
pop_with_h2 = set_pheno(
    key=h2_key,
    pop=founder_pop,
    traits=SP.traits,
    ploidy=SP.ploidy,
    h2=heritabilities
)

# Verification for the h2-based population
var_g_h2 = jnp.var(pop_with_h2.bv, axis=0)
var_p_h2 = jnp.var(pop_with_h2.pheno, axis=0)
realized_h2 = var_g_h2 / var_p_h2

# --- Print h2 Results ---
print("--- Example 1: Phenotyping with Heritability (h2) ---")
print(f"\nTarget Heritabilities: {heritabilities}")
print(f"Realized Heritabilities: {realized_h2}")
print(f"\nGenetic Variance (from pop): {var_g_h2}")
print(f"Phenotypic Variance (from pop): {var_p_h2}")
print("\n--- Phenotype Summary (first 5 individuals) ---")
for i in range(5):
    print(f"Ind {i+1}: BV={pop_with_h2.bv[i]}, Pheno={pop_with_h2.pheno[i]}")


print("\n" + "="*60 + "\n") # Separator


# --- Example 2: Phenotyping with Environmental Variance (varE) ---

# Define the absolute environmental variance for each trait
environmental_variances = jnp.array([23.0, 3.0])

# Call the JIT-compiled set_pheno function using the varE argument
pop_with_varE = set_pheno(
    key=varE_key,
    pop=founder_pop,
    traits=SP.traits,
    ploidy=SP.ploidy,
    varE=environmental_variances
)

# Verification for the varE-based population
var_g_varE = jnp.var(pop_with_varE.bv, axis=0)
var_p_varE = jnp.var(pop_with_varE.pheno, axis=0)
implied_h2 = var_g_varE / var_p_varE

# --- Print varE Results ---
print("--- Example 2: Phenotyping with Environmental Variance (varE) ---")
print(f"\nTarget Environmental Variances: {environmental_variances}")
print(f"Implied Heritabilities: {implied_h2}") # Note: this is a result, not a target
print(f"\nGenetic Variance (from pop): {var_g_varE}")
print(f"Phenotypic Variance (from pop): {var_p_varE}")
print("\n--- Phenotype Summary (first 5 individuals) ---")
for i in range(5):
    print(f"Ind {i+1}: BV={pop_with_varE.bv[i]}, Pheno={pop_with_varE.pheno[i]}")

In [None]:
founder_pop

In [None]:
pop_with_h2

In [None]:
from chewc.cross import make_cross

In [None]:
import jax
import jax.numpy as jnp

# Assuming previous code has been run to create:
# - SP: The SimParam object
# - pop_with_h2: The phenotyped founder population
# - set_pheno: The phenotyping function

# Import the crossing function
from chewc.cross import make_cross

# Split the key for crossing and subsequent phenotyping
key = jax.random.PRNGKey(91)
key, cross_key, progeny_pheno_key = jax.random.split(key, 3)

# --- 1. Define the Unique Crosses ---
# First, define the 5 unique parent pairings as before.
top_5_trait1 = jnp.argsort(pop_with_h2.pheno[:, 0])[-5:]
top_5_trait2 = jnp.argsort(pop_with_h2.pheno[:, 1])[-5:]
unique_cross_plan = jnp.column_stack([top_5_trait1, top_5_trait2])

# --- 2. Create the Final, Repeated Crossing Plan ---
# Use jnp.repeat to duplicate each unique cross 10 times.
# The axis=0 ensures we repeat the rows.
cross_plan = jnp.repeat(unique_cross_plan, repeats=100, axis=0)


# --- 3. Execute the 50 Crosses ---
# The rest of the code works exactly the same, but now on the larger cross plan.
# The next available public ID for the progeny is still 100.
progeny_pop = make_cross(
    key=cross_key,
    pop=pop_with_h2,
    cross_plan=cross_plan,
    sp=SP,
    next_id_start=100
)


# --- 4. Phenotype the New Progeny Generation ---
progeny_pop_phenotyped = set_pheno(
    key=progeny_pheno_key,
    pop=progeny_pop,
    traits=SP.traits,
    ploidy=SP.ploidy,
    h2=jnp.array([0.99, 0.99]) # Using the same heritabilities
)


# --- 5. Print and Inspect the Results ---
print("--- Crossing Plan ---")
print(f"Created {unique_cross_plan.shape[0]} unique crosses and repeated each 10 times, "
      f"for a total of {cross_plan.shape[0]} progeny.")
print("\nUnique Parent Pairings (mother iid, father iid):")
print(unique_cross_plan)


print("\n--- Progeny Population ---")
print(progeny_pop_phenotyped)
# This will now show nInd=50

print("\n--- Progeny Pedigree (showing full-sibs from first cross) ---")
# Displaying the first 12 progeny to show they share the same parents
for i in range(12):
    print(f"Ind {progeny_pop_phenotyped.id[i]:<4}: "
          f"Mother={progeny_pop_phenotyped.mother[i]:<4}, "
          f"Father={progeny_pop_phenotyped.father[i]:<4}")

print("\n--- Progeny Phenotypes (first 5 full-sibs) ---")
# Note that even with the same parents, progeny have different BVs and Phenos
# due to the randomness of meiosis and environmental noise.
for i in range(5):
    print(f"Ind {progeny_pop_phenotyped.id[i]}: "
          f"BV={progeny_pop_phenotyped.bv[i]}, "
          f"Pheno={progeny_pop_phenotyped.pheno[i]}")

In [None]:
progeny_pop_phenotyped.bv

In [None]:
pop_with_h2.geno

In [None]:
pop_with_h2.bv

In [None]:
pop_with_h2.gv.shape

In [None]:
np.mean(pop_with_h2.gv[:,0])

In [None]:
np.mean(pop_with_h2.bv[:,0])

In [None]:
import matplotlib.pyplot as plt
import numpy as np # JAX arrays can be used directly, but aliasing is good practice

# --- Create Histograms to Visualize Selection Response ---

# Create a figure with 2 rows and 2 columns of subplots
fig, axes = plt.subplots(2, 2, figsize=(12, 10), sharex=True)
fig.suptitle('Distribution of Breeding Values Before and After Selection', fontsize=16)

# Define bins to be the same for both generations for a fair comparison
bins = 20

# --- Parent Population (pop_with_h2) ---

# Trait 1 - Parents
axes[0, 0].hist(np.asarray(pop_with_h2.bv[:, 0]), bins=bins, color='skyblue', edgecolor='black')
axes[0, 0].set_title('Parent Population - Trait 1')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].axvline(pop_with_h2.bv[:, 0].mean(), color='r', linestyle='--', linewidth=2, label=f'Mean BV')
axes[0, 0].legend()


# Trait 2 - Parents
axes[0, 1].hist(np.asarray(pop_with_h2.bv[:, 1]), bins=bins, color='salmon', edgecolor='black')
axes[0, 1].set_title('Parent Population - Trait 2')
axes[0, 1].axvline(pop_with_h2.bv[:, 1].mean(), color='r', linestyle='--', linewidth=2, label=f'Mean BV')
axes[0, 1].legend()

# --- Progeny Population (progeny_pop_phenotyped) ---

# Trait 1 - Progeny
axes[1, 0].hist(np.asarray(progeny_pop_phenotyped.bv[:, 0]), bins=bins, color='skyblue', edgecolor='black')
axes[1, 0].set_title('Progeny Population - Trait 1')
axes[1, 0].set_xlabel('Breeding Value')
axes[1, 0].set_ylabel('Frequency')
axes[1, 0].axvline(progeny_pop_phenotyped.bv[:, 0].mean(), color='r', linestyle='--', linewidth=2, label=f'Mean BV')
axes[1, 0].legend()

# Trait 2 - Progeny
axes[1, 1].hist(np.asarray(progeny_pop_phenotyped.bv[:, 1]), bins=bins, color='salmon', edgecolor='black')
axes[1, 1].set_title('Progeny Population - Trait 2')
axes[1, 1].set_xlabel('Breeding Value')
axes[1, 1].axvline(progeny_pop_phenotyped.bv[:, 1].mean(), color='r', linestyle='--', linewidth=2, label=f'Mean BV')
axes[1, 1].legend()


# Display the plots
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to make room for suptitle
plt.show()

In [None]:
progeny_pop_phenotyped.geno.shape

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

# --- 1. Prepare the Genotype Data ---
geno_data = progeny_pop_phenotyped.geno
n_individuals = geno_data.shape[0] # This is now 500

# Sum over the ploidy axis and reshape to a 2D matrix
dosage_matrix = jnp.sum(geno_data, axis=2).reshape(n_individuals, -1)
dosage_matrix_np = np.asarray(dosage_matrix)

# --- 2. Standardize the Data ---
scaler = StandardScaler()
scaled_dosage = scaler.fit_transform(dosage_matrix_np)

# --- 3. Perform PCA ---
pca = PCA(n_components=2)
principal_components = pca.fit_transform(scaled_dosage)

# --- 4. Prepare Data for Plotting (THE FIX IS HERE) ---
# Assuming 5 unique crosses were made to produce the 500 progeny,
# this means each cross was repeated 100 times (500 / 5 = 100).
n_families = 5
repeats_per_family = n_individuals // n_families

family_ids = np.repeat(np.arange(n_families), repeats=repeats_per_family)

# --- 5. Create the Visualization ---
plt.figure(figsize=(10, 8))
scatter = plt.scatter(
    principal_components[:, 0],
    principal_components[:, 1],
    c=family_ids,
    cmap='viridis',
    alpha=0.8
)

# Add labels and title
plt.title('PCA of Progeny Genotypes', fontsize=16)
plt.xlabel(f'Principal Component 1 ({pca.explained_variance_ratio_[0]:.1%} Variance)', fontsize=12)
plt.ylabel(f'Principal Component 2 ({pca.explained_variance_ratio_[1]:.1%} Variance)', fontsize=12)
plt.grid(True, linestyle='--', alpha=0.6)

# Add a legend to identify the families
plt.legend(handles=scatter.legend_elements()[0], labels=[f'Family {i+1}' for i in range(n_families)])

# Show the plot
plt.show()