In [23]:
# GBLUP Study Setup - Population Evolution and Matrix Calculations
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

# Import chewc components
from chewc.population import Population, quick_haplo, combine_populations, calc_g_matrix, calc_ibd_matrix
from chewc.sp import SimParam  
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno, set_bv
from chewc.select import TruncationSelection
from chewc.cross import make_cross

# Fixed pedigree matrix function
def calc_a_matrix_pedigree_fixed(pop: Population) -> jnp.ndarray:
    """Calculate pedigree-based A-matrix with proper ID mapping."""
    n_ind = pop.nInd
    A = np.zeros((n_ind, n_ind))
    
    # Create mapping from public ID to internal index
    id_to_iid = {int(pub_id): int(iid) for pub_id, iid in zip(pop.id, pop.iid)}
    mother_ids = np.array(pop.mother)
    father_ids = np.array(pop.father)
    
    # Henderson's algorithm
    for i in range(n_ind):
        A[i, i] = 1.0
        
        dam_pub_id = mother_ids[i] if mother_ids[i] >= 0 else None
        sire_pub_id = father_ids[i] if father_ids[i] >= 0 else None
        
        dam_idx = id_to_iid.get(dam_pub_id) if dam_pub_id is not None else None
        sire_idx = id_to_iid.get(sire_pub_id) if sire_pub_id is not None else None
        
        if dam_idx is not None and sire_idx is not None:
            A[i, i] = 1.0 + 0.5 * A[dam_idx, sire_idx]
            for j in range(i):
                A[i, j] = A[j, i] = 0.5 * (A[j, dam_idx] + A[j, sire_idx])
        elif dam_idx is not None:
            for j in range(i):
                A[i, j] = A[j, i] = 0.5 * A[j, dam_idx]
        elif sire_idx is not None:
            for j in range(i):
                A[i, j] = A[j, i] = 0.5 * A[j, sire_idx]
    
    return jnp.array(A)

print("=== GBLUP Study Setup ===")

# Set random seed for reproducibility
key = jax.random.PRNGKey(42)

# ====================================================================
# 1. CREATE FOUNDER POPULATION AND SETUP
# ====================================================================
key, founder_key = jax.random.split(key)

founder_pop, gen_map = quick_haplo(
    key=founder_key, n_ind=100, n_chr=10, n_loci_per_chr=1000, 
    ploidy=2, inbred=False, chr_len_cm=100.0
)

sp = SimParam.from_founder_pop(founder_pop, gen_map)
print(f"✓ Created {founder_pop.nInd} founders with {founder_pop.geno.shape[1] * founder_pop.geno.shape[3]} SNPs")

# ====================================================================
# 2. ADD TRAIT AND SET PHENOTYPES
# ====================================================================
key, trait_key, pheno_key = jax.random.split(key, 3)

sp = add_trait_a(
    key=trait_key, founder_pop=founder_pop, sim_param=sp,
    n_qtl_per_chr=20, mean=jnp.array([100.0]), var=jnp.array([400.0])
)

founder_pop = set_bv(founder_pop, sp.traits, sp.ploidy)
founder_pop = set_pheno(key=pheno_key, pop=founder_pop, traits=sp.traits, 
                       ploidy=sp.ploidy, h2=jnp.array([0.3]))

print(f"✓ Added trait with {sp.traits.n_loci} QTLs, h² = {jnp.var(founder_pop.bv) / jnp.var(founder_pop.pheno):.3f}")

# ====================================================================
# 3. EVOLUTION THROUGH 3 GENERATIONS
# ====================================================================
selector = TruncationSelection()
n_parents = int(founder_pop.nInd * 0.2)  # Top 20%

# Initialize tracking
generations = []
current_pop = founder_pop
current_id = founder_pop.nInd
all_populations = [founder_pop]  # Store all generations for pedigree

print(f"✓ Evolving through 3 generations with {n_parents} parents per generation...")

for gen in range(4):  # Gen 0 = founders, Gen 1-3 = offspring
    # Set breeding values and phenotypes (fix for offspring with zeros)
    if current_pop.bv is None:
        current_pop = set_bv(current_pop, sp.traits, sp.ploidy)
    
    if jnp.any(jnp.isnan(current_pop.pheno)) or jnp.allclose(current_pop.pheno, 0.0):
        key, pheno_key = jax.random.split(key)
        current_pop = set_pheno(key=pheno_key, pop=current_pop, traits=sp.traits, 
                               ploidy=sp.ploidy, h2=jnp.array([0.3]))
    
    # Record statistics
    gen_stats = {
        'generation': gen,
        'bv_mean': float(jnp.mean(current_pop.bv)),
        'bv_std': float(jnp.std(current_pop.bv)),
        'h2_realized': float(jnp.var(current_pop.bv) / jnp.var(current_pop.pheno))
    }
    generations.append(gen_stats)
    
    if gen == 3:  # Stop after recording final generation
        break
        
    # Selection and mating
    key, select_key, mating_key, father_key, cross_key = jax.random.split(key, 5)
    
    parent_indices = selector.select_parents(key=select_key, pop=current_pop, sp=sp, n_select=n_parents)
    
    # Random mating
    n_crosses = current_pop.nInd
    mother_indices = jax.random.choice(mating_key, len(parent_indices), shape=(n_crosses,), replace=True)
    father_indices = jax.random.choice(father_key, len(parent_indices), shape=(n_crosses,), replace=True)
    
    cross_plan = jnp.column_stack([parent_indices[mother_indices], parent_indices[father_indices]])
    
    # Create offspring
    offspring_pop = make_cross(key=cross_key, pop=current_pop, cross_plan=cross_plan, 
                              sp=sp, next_id_start=current_id)
    
    current_id += offspring_pop.nInd
    current_pop = offspring_pop
    all_populations.append(offspring_pop)  # Store for pedigree

print(f"✓ Evolution complete. BV progress: {generations[0]['bv_mean']:.1f} → {generations[-1]['bv_mean']:.1f}")

# ====================================================================
# 4. CREATE MULTI-GENERATIONAL POPULATION
# ====================================================================
multi_gen_pop = all_populations[0]
for pop in all_populations[1:]:
    multi_gen_pop = combine_populations(multi_gen_pop, pop)

print(f"✓ Multi-generational population: {multi_gen_pop.nInd} individuals ({len(all_populations)} generations)")

# ====================================================================
# 5. CALCULATE RELATIONSHIP MATRICES
# ====================================================================
final_pop = current_pop

# Ensure final population has proper phenotypes
if jnp.allclose(final_pop.pheno, 0.0):
    key, final_pheno_key = jax.random.split(key)
    final_pop = set_pheno(key=final_pheno_key, pop=final_pop, traits=sp.traits, 
                         ploidy=sp.ploidy, h2=jnp.array([0.3]))

final_pop = set_bv(final_pop, sp.traits, sp.ploidy)

# Calculate matrices
G_matrix = calc_g_matrix(final_pop.dosage)
A_ibd_matrix = calc_ibd_matrix(final_pop.ibd)

# Pedigree matrix using multi-generational data
A_ped_full = calc_a_matrix_pedigree_fixed(multi_gen_pop)
# Extract final generation subset
final_gen_start = multi_gen_pop.nInd - final_pop.nInd
final_gen_indices = jnp.arange(final_gen_start, multi_gen_pop.nInd)
A_ped_matrix = A_ped_full[jnp.ix_(final_gen_indices, final_gen_indices)]

print(f"✓ Calculated relationship matrices (all {final_pop.nInd}×{final_pop.nInd})")

# ====================================================================
# 6. SUMMARY AND VALIDATION
# ====================================================================
print("\n" + "="*50)
print("GBLUP DATA READY")
print("="*50)

# Matrix correlations
g_flat = G_matrix[jnp.triu_indices_from(G_matrix, k=1)]
a_ibd_flat = A_ibd_matrix[jnp.triu_indices_from(A_ibd_matrix, k=1)]  
a_ped_flat = A_ped_matrix[jnp.triu_indices_from(A_ped_matrix, k=1)]

print(f"Population: {final_pop.nInd} individuals, {final_pop.dosage.shape[1]} SNPs, {sp.traits.n_loci} QTLs")
print(f"Final h²: {jnp.var(final_pop.bv) / jnp.var(final_pop.pheno):.3f}")

print(f"\nMatrix correlations:")
print(f"  G vs A_IBD: {jnp.corrcoef(g_flat, a_ibd_flat)[0,1]:.3f}")
print(f"  G vs A_pedigree: {jnp.corrcoef(g_flat, a_ped_flat)[0,1]:.3f}")
print(f"  A_IBD vs A_pedigree: {jnp.corrcoef(a_ibd_flat, a_ped_flat)[0,1]:.3f}")

print(f"\nValidation checks:")
print(f"  ✓ Phenotypes have variance: {jnp.var(final_pop.pheno):.1f}")
print(f"  ✓ A_pedigree has off-diagonal values: {jnp.max(a_ped_flat):.3f}")
print(f"  ✓ All matrices positive definite: G({jnp.min(jnp.linalg.eigvals(G_matrix)):.3f}), A_IBD({jnp.min(jnp.linalg.eigvals(A_ibd_matrix)):.3f}), A_ped({jnp.min(jnp.linalg.eigvals(A_ped_matrix)):.3f})")

print(f"\nKey variables for GBLUP:")
print(f"  • final_pop.pheno: phenotypes ({final_pop.pheno.shape})")
print(f"  • final_pop.bv: true breeding values (for validation)")
print(f"  • final_pop.dosage: SNP dosages ({final_pop.dosage.shape})")
print(f"  • G_matrix: genomic relationships")
print(f"  • A_ibd_matrix: IBD relationships") 
print(f"  • A_ped_matrix: pedigree relationships")

print(f"\nReady for GBLUP implementation! 🧬")

=== GBLUP Study Setup ===
✓ Created 100 founders with 10000 SNPs
✓ Added trait with 200 QTLs, h² = 0.239
✓ Evolving through 3 generations with 20 parents per generation...
✓ Evolution complete. BV progress: 56.5 → 123.5
✓ Multi-generational population: 400 individuals (4 generations)
✓ Calculated relationship matrices (all 100×100)

GBLUP DATA READY
Population: 100 individuals, 10000 SNPs, 200 QTLs
Final h²: 0.357

Matrix correlations:
  G vs A_IBD: 0.701
  G vs A_pedigree: 0.800
  A_IBD vs A_pedigree: 0.624

Validation checks:
  ✓ Phenotypes have variance: 841.4
  ✓ A_pedigree has off-diagonal values: 1.125
  ✓ All matrices positive definite: G(-0.000+0.000j), A_IBD(0.046+0.000j), A_ped(0.437+0.000j)

Key variables for GBLUP:
  • final_pop.pheno: phenotypes ((100, 1))
  • final_pop.bv: true breeding values (for validation)
  • final_pop.dosage: SNP dosages ((100, 10000))
  • G_matrix: genomic relationships
  • A_ibd_matrix: IBD relationships
  • A_ped_matrix: pedigree relationships

R

In [24]:
final_pop.pheno

Array([[181.09564 ],
       [169.60098 ],
       [176.32455 ],
       [142.99876 ],
       [186.4361  ],
       [185.85002 ],
       [141.3191  ],
       [139.87769 ],
       [170.47449 ],
       [222.23183 ],
       [179.41925 ],
       [182.82259 ],
       [184.19508 ],
       [160.18872 ],
       [181.11237 ],
       [158.17108 ],
       [177.44794 ],
       [ 97.55769 ],
       [ 84.4271  ],
       [154.82501 ],
       [133.37837 ],
       [192.93526 ],
       [168.21289 ],
       [168.08492 ],
       [146.95811 ],
       [177.08153 ],
       [154.12373 ],
       [163.94667 ],
       [178.45831 ],
       [131.25273 ],
       [172.10278 ],
       [201.88728 ],
       [160.7001  ],
       [157.35904 ],
       [129.8335  ],
       [166.1706  ],
       [115.80762 ],
       [207.86388 ],
       [143.863   ],
       [195.9968  ],
       [157.68723 ],
       [152.97589 ],
       [159.01793 ],
       [155.43083 ],
       [207.76442 ],
       [179.6551  ],
       [173.97577 ],
       [173.0

In [25]:
G_matrix.shape

(100, 100)

In [26]:
A_ibd_matrix

Array([[0.5       , 0.06995   , 0.0477    , ..., 0.046725  , 0.19864999,
        0.048925  ],
       [0.06995   , 0.5       , 0.07045   , ..., 0.023875  , 0.070975  ,
        0.04875   ],
       [0.0477    , 0.07045   , 0.5       , ..., 0.04875   , 0.124925  ,
        0.09575   ],
       ...,
       [0.046725  , 0.023875  , 0.04875   , ..., 0.5       , 0.024125  ,
        0.02405   ],
       [0.19864999, 0.070975  , 0.124925  , ..., 0.024125  , 0.5       ,
        0.025325  ],
       [0.048925  , 0.04875   , 0.09575   , ..., 0.02405   , 0.025325  ,
        0.5       ]], dtype=float32)

In [27]:
A_ped_matrix

Array([[1.09375 , 0.078125, 0.15625 , ..., 0.078125, 0.390625, 0.09375 ],
       [0.078125, 1.0625  , 0.1875  , ..., 0.09375 , 0.109375, 0.171875],
       [0.15625 , 0.1875  , 1.0625  , ..., 0.0625  , 0.375   , 0.1875  ],
       ...,
       [0.078125, 0.09375 , 0.0625  , ..., 1.      , 0.09375 , 0.0625  ],
       [0.390625, 0.109375, 0.375   , ..., 0.09375 , 1.125   , 0.109375],
       [0.09375 , 0.171875, 0.1875  , ..., 0.0625  , 0.109375, 1.03125 ]],      dtype=float32)