# 41. Python version of heritability analysis

a5: Try heritability calculations (almost) from scratch using `statsmodels` in this version.
Result: Kernel crashed after computing kinship matrix

In a6, we'll investigate why it crashes during heritability calculations. Also re-added code to skip windows with only one SNP.

In [1]:
import os
import numpy as np
import pandas as pd
import time
from pgenlib import PgenReader
import statsmodels.api as sm
from statsmodels.regression.mixed_linear_model import MixedLM

# -----------------------------
# Parameters (Adjust as needed)
# -----------------------------
window_sizes = [10000]                       # Window sizes in base pairs

window_size = window_sizes[0]                # Single window size

chunk_start = 1                              # Start index for CpG sites (1-based)
chunk_end = 50                               # End index for CpG sites (1-based)
benchmark = True                             # Whether to measure timing

# -----------------------------
# Paths (Adjust these paths according to your data)
# -----------------------------
df_csv_path = "/dcs04/lieber/statsgen/mnagle/mwas/CpGWAS/scripts/09.5-OUT_matched_SNP_meth_cov_chunked_JHPCE.csv"
output_dir = "./41-OUT_heritability_a1"

# -----------------------------
# Initialize Benchmarking
# -----------------------------
if benchmark:
    start_time_total = time.time()

# -----------------------------
# Create Output Directory
# -----------------------------
os.makedirs(output_dir, exist_ok=True)
os.chdir(output_dir)
print(f"Output directory set to: {output_dir}")

# -----------------------------
# Read the Metadata DataFrame
# -----------------------------
try:
    df = pd.read_csv(df_csv_path)
    print(f"Metadata loaded from '{df_csv_path}'.")
except Exception as e:
    print(f"Error reading metadata CSV '{df_csv_path}': {e}")
    exit(1)

# -----------------------------
# Select the Row for Processing
# -----------------------------
df_row = 0  # Adjust as needed
if df.empty:
    print("Metadata DataFrame is empty. Exiting.")
    exit(1)

# Extract paths from the data frame
gwas_dir = os.path.dirname(df.loc[df_row, 'SNP_data'])
methylation_file = df.loc[df_row, 'modified_methylation_data']

# Adjust methylation file paths
methylation_file = methylation_file.replace(
    "/dcs04/lieber/statsgen/shizhong/michael/mwas/pheno/",
    "/dcs04/lieber/statsgen/mnagle/mwas/pheno/"
).replace("rda", "csv").replace("rds", "csv")

print(f"Genotype Directory: {gwas_dir}")
print(f"Methylation File: {methylation_file}")

# -----------------------------
# Load Methylation Data
# -----------------------------
try:
    # Methylation data has 'sample_id' as the first column and CpG positions as other columns
    methylation_df = pd.read_csv(methylation_file)
    print(f"Methylation data loaded from '{methylation_file}'.")
except Exception as e:
    print(f"Error reading methylation file '{methylation_file}': {e}")
    exit(1)

# Ensure 'sample_id' is treated as a string
if 'sample_id' not in methylation_df.columns:
    print(f"'sample_id' column not found in methylation data. Exiting.")
    exit(1)

methylation_df['sample_id'] = methylation_df['sample_id'].astype(str)
print("'sample_id' column confirmed and converted to string.")

# Extract CpG columns (all columns except 'sample_id')
cpg_columns = methylation_df.columns.drop('sample_id')

# Extract numeric CpG positions from column names (e.g., 'pos_1069461' -> 1069461)
try:
    cpg_positions = [int(col.split('_')[1]) for col in cpg_columns]
    print("CpG positions extracted from column names.")
except IndexError as e:
    print(f"Error parsing CpG positions in column names: {e}")
    exit(1)
except ValueError as e:
    print(f"Non-integer CpG position found in column names: {e}")
    exit(1)

# Create a mapping from column names to positions
cpg_col_to_pos = dict(zip(cpg_columns, cpg_positions))

# Select the CpG positions for the specified chunk
selected_cpg_cols = cpg_columns[chunk_start - 1:chunk_end]
selected_cpg_positions = [cpg_col_to_pos[col] for col in selected_cpg_cols]

print(f"Selected CpG Columns: {selected_cpg_cols.tolist()}")
print(f"Selected CpG Positions: {selected_cpg_positions}")

# -----------------------------
# Iterate Over Selected CpG Sites
# -----------------------------
for idx, (cpg_col, cpg_pos) in enumerate(zip(selected_cpg_cols, selected_cpg_positions), start=1):
    print(f"\nProcessing CpG site {idx}: {cpg_col} at position {cpg_pos}")

    # -----------------------------
    # Extract Methylation Data for the Selected CpG Site
    # -----------------------------
    pheno_df = methylation_df[['sample_id', cpg_col]].dropna()
    y = pheno_df[cpg_col].values
    sample_ids = pheno_df['sample_id'].values
    n_samples = len(sample_ids)

    print(f"Number of samples with non-missing methylation data: {n_samples}")

    if n_samples == 0:
        print("No samples with non-missing methylation data. Skipping this CpG site.")
        continue

    # -----------------------------
    # Define Genomic Window
    # -----------------------------
    p1 = max(cpg_pos - window_size, 0)
    p2 = cpg_pos + window_size

    print(f"Genomic window: {p1} - {p2} bp")

    # -----------------------------
    # Load Genotype Data for the Specified Chromosome
    # -----------------------------
    pgen_prefix = os.path.join(gwas_dir, f"libd_chr{df.loc[df_row, 'Chr']}")
    pgen_file = f"{pgen_prefix}.pgen"
    pvar_file = f"{pgen_prefix}.pvar"
    psam_file = f"{pgen_prefix}.psam"

    # Check if all necessary PLINK 2 files exist
    if not all(os.path.exists(f) for f in [pgen_file, pvar_file, psam_file]):
        print("One or more PLINK 2 files are missing. Skipping this CpG site.")
        continue

    print("All necessary PLINK 2 files found.")

    # -----------------------------
    # Read Sample IDs from .psam File
    # -----------------------------
    try:
        psam_df = pd.read_csv(psam_file, sep='\t')
        if '#IID' not in psam_df.columns:
            print(f"'#IID' column not found in .psam file '{psam_file}'. Skipping this CpG site.")
            continue
        geno_sample_ids = psam_df['#IID'].astype(str).values
        print("Genotype sample IDs loaded from .psam file.")
    except Exception as e:
        print(f"Error reading .psam file '{psam_file}': {e}. Skipping this CpG site.")
        continue

    # Create a mapping from sample ID to index in genotype data
    sample_id_to_index = {sid: idx for idx, sid in enumerate(geno_sample_ids)}

    # Get genotype indices for samples present in methylation data
    geno_indices = [sample_id_to_index[sid] for sid in sample_ids if sid in sample_id_to_index]

    if not geno_indices:
        print("No matching samples between genotype and methylation data. Skipping this CpG site.")
        continue

    print(f"Number of matching samples: {len(geno_indices)}")

    # -----------------------------
    # Read SNP Positions from .pvar File
    # -----------------------------
    try:
        pvar_df = pd.read_csv(pvar_file, sep='\t', comment='#',
                              names=['CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT'])
        print("SNP positions loaded from .pvar file.")
    except Exception as e:
        print(f"Error reading .pvar file '{pvar_file}': {e}. Skipping this CpG site.")
        continue

    # Subset SNPs within the genomic window
    snps_in_window = pvar_df[(pvar_df['POS'] >= p1) & (pvar_df['POS'] <= p2)]

    if snps_in_window.empty:
        print("No SNPs found within the genomic window. Skipping this CpG site.")
        continue

    print(f"Number of SNPs within the window: {len(snps_in_window)}")

    # Skip if there's only one SNP
    if len(snps_in_window) < 2:
        print("Only one SNP in window; skipping heritability estimation.")
        continue
    
    # Get variant indices (0-based)
    variant_indices = snps_in_window.index.values

    # -----------------------------
    # Initialize PgenReader with sample_subset
    # -----------------------------
    try:
        if benchmark:
            start_time_geno = time.time()

        pgr = PgenReader(pgen_file.encode('utf-8'), sample_subset=np.array(sorted(geno_indices), dtype=np.uint32))
        print("PgenReader initialized.")

    except Exception as e:
        print(f"Error initializing PgenReader: {e}. Skipping this CpG site.")
        continue

    # -----------------------------
    # Allocate buffer: rows=variants (SNPs), cols=samples
    # -----------------------------
    try:
        geno_buffer = np.empty((len(variant_indices), n_samples), dtype=np.int32)
    except Exception as e:
        print(f"Error allocating geno_buffer: {e}. Skipping this CpG site.")
        continue

    # -----------------------------
    # Read Genotype Data Using PgenReader
    # -----------------------------
    try:
        for var_idx, variant_idx in enumerate(variant_indices):
            # Read genotype for the current variant
            pgr.read(variant_idx, geno_buffer[var_idx, :], allele_idx=1)

        print("Genotype data successfully read and stored in buffer.")

        # -----------------------------
        # Benchmarking: Genotype Reading Time
        # -----------------------------
        if benchmark:
            geno_time = time.time() - start_time_geno
            print(f"Genotype reading time: {geno_time:.2f} seconds")

    except Exception as e:
        print(f"Error reading genotype data: {e}. Skipping this CpG site.")
        continue

    # -----------------------------
    # Check for Missing Data and Impute
    # -----------------------------
    if np.any(geno_buffer == -9):
        print("Missing genotype data detected. Imputing missing values with mean genotype.")
        for var in range(geno_buffer.shape[0]):
            missing = geno_buffer[var, :] == -9
            if np.any(missing):
                non_missing = geno_buffer[var, :] != -9
                mean_geno = np.mean(geno_buffer[var, non_missing])
                geno_buffer[var, missing] = mean_geno
                print(f"  Imputed missing values for SNP {var + 1} with mean genotype {mean_geno:.2f}.")
        if np.isnan(geno_buffer).any():
            nan_indices = np.argwhere(np.isnan(geno_buffer))
            print(f"NaNs found at positions: {nan_indices}")
            print("Exiting due to NaN values in geno_buffer.")
            exit(1)

    # -----------------------------
    # Standardize Genotypes (Samples × SNPs)
    # -----------------------------
    print("Standardizing genotype data.")
    M = geno_buffer.astype(float).T  # Shape: (Samples, SNPs)
    mu = np.mean(M, axis=0, keepdims=True)      # Shape: (1, SNPs)
    sigma = np.std(M, axis=0, ddof=1, keepdims=True)
    sigma[sigma == 0] = 1
    S = (M - mu) / sigma  # Standardize
    print("Genotype data standardized.")

    # -----------------------------
    # Compute Kinship Matrix using GEMMA Method
    # -----------------------------
    print("Computing kinship matrix.")
    try:
        K = np.dot(S, S.T) / S.shape[1]  # Shape: (Samples, Samples)
        print("Kinship matrix computed.")
    except Exception as e:
        print(f"Error computing kinship matrix: {e}")
        exit(1)

    if benchmark:
        kinship_time = time.time() - start_time_total
        print(f"Kinship computation time: {kinship_time:.2f} seconds")

    # -----------------------------
    # Normalize Kinship Matrix
    # -----------------------------
    try:
        mean_diag = np.mean(np.diagonal(K))
        if mean_diag == 0:
            print("Mean of the diagonal of the kinship matrix is zero. Cannot normalize. Skipping this CpG site.")
            continue
        K_normalized = K / mean_diag
        print("Kinship matrix normalized.")
    except Exception as e:
        print(f"Kinship normalization failed: {e}")
        continue

    # -----------------------------
    # Investigate potential issues post-normalization
    # -----------------------------
    try:
        print(f"K_normalized shape: {K_normalized.shape}")
        print(f"K_normalized diagonal: {np.diagonal(K_normalized)}")
        print(f"K_normalized summary stats: mean={np.mean(K_normalized)}, std={np.std(K_normalized)}")

        # Debugging output for K_normalized values
        if np.any(np.isnan(K_normalized)):
            print("NaN values found in K_normalized matrix.")
            exit(1)
        if np.any(np.isinf(K_normalized)):
            print("Infinite values found in K_normalized matrix.")
            exit(1)

    except Exception as e:
        print(f"Error during post-normalization investigation: {e}")
        exit(1)

    # -----------------------------
    # Estimate Heritability Using MixedLM in Statsmodels
    # -----------------------------
    try:
        exog = np.ones((n_samples, 1))  # Intercept
        exog_re = K_normalized
        model = MixedLM(endog=y, exog=exog, groups=np.arange(n_samples), exog_re=exog_re)
        result = model.fit()
        v_g = result.cov_re.iloc[0, 0]  # Genetic variance component
        v_e = result.scale  # Environmental variance component
        h2 = v_g / (v_g + v_e)
        print(f"Estimated heritability (h2): {h2:.4f}")
    except Exception as e:
        print(f"Heritability estimation failed: {e}")
        continue

Output directory set to: ./41-OUT_heritability_a1
Metadata loaded from '/dcs04/lieber/statsgen/mnagle/mwas/CpGWAS/scripts/09.5-OUT_matched_SNP_meth_cov_chunked_JHPCE.csv'.
Genotype Directory: /dcs04/lieber/statsgen/shizhong/michael/mwas/gwas
Methylation File: /dcs04/lieber/statsgen/mnagle/mwas/pheno/caud/out/chr1_AA_8982-28981.csv
Methylation data loaded from '/dcs04/lieber/statsgen/mnagle/mwas/pheno/caud/out/chr1_AA_8982-28981.csv'.
'sample_id' column confirmed and converted to string.
CpG positions extracted from column names.
Selected CpG Columns: ['pos_1069461', 'pos_1069467', 'pos_1069470', 'pos_1069477', 'pos_1069484', 'pos_1069498', 'pos_1069506', 'pos_1069516', 'pos_1069530', 'pos_1069533', 'pos_1069539', 'pos_1069544', 'pos_1069569', 'pos_1069573', 'pos_1069591', 'pos_1069599', 'pos_1069601', 'pos_1069603', 'pos_1069613', 'pos_1069626', 'pos_1069629', 'pos_1069635', 'pos_1069637', 'pos_1069645', 'pos_1069651', 'pos_1069653', 'pos_1069669', 'pos_1069682', 'pos_1069691', 'pos_10


KeyboardInterrupt



In [2]:
exog = np.ones((n_samples, 1))  # Intercept

In [3]:
exog_re = K_normalized

In [None]:
model = MixedLM(endog=y, exog=exog, groups=np.arange(n_samples), exog_re=exog_re)
result = model.fit()
v_g = result.cov_re.iloc[0, 0]  # Genetic variance component
v_e = result.scale  # Environmental variance component
h2 = v_g / (v_g + v_e)
print(f"Estimated heritability (h2): {h2:.4f}")