In [18]:
import sys
sys.path.append('../src')
import numpy as np
import pandas as pd
from vcf_processor import VCFProcessor
from variant_clustering import VariantClusterer
from deconvolution_models import DeconvolutionModels
import pyro
import pyro.distributions as dist  # For Dirichlet distribution
from pyro.infer import SVI, Trace_ELBO, Predictive  # NEW: For variational inference and sampling
from pyro.optim import Adam  # NEW: For optimizer (was implicit)
import torch
from pathlib import Path
import json
from scipy.optimize import minimize


In [19]:
# List of VCF paths (from context)
vcf_paths = [
    "../../../data/vault_pipeline_output/SRR12676841_output/snp/all_snp_from_perfect_umi.vcf",
    "../../../data/vault_pipeline_output/SRR12676842_output/snp/all_snp_from_perfect_umi.vcf",
    "../../../data/vault_pipeline_output/SRR12676843_output/snp/all_snp_from_perfect_umi.vcf",
    "../../../data/vault_pipeline_output/SRR12676844_output/snp/all_snp_from_perfect_umi.vcf",
    "../../../data/vault_pipeline_output/SRR12676845_output/snp/all_snp_from_perfect_umi.vcf",
    "../../../data/vault_pipeline_output/SRR12676846_output/snp/all_snp_from_perfect_umi.vcf"
]
# Extract cluster proportions from each VCF
historical_proportions = []
for vcf_path in vcf_paths:
    processor = VCFProcessor(vcf_path)
    variants_df = processor.read_vcf()
    filtered_variants = processor.apply_filters(include_indels=True, min_vaf=0.15)
    variant_matrix = processor.get_variant_matrix(continuous=True, include_indels=True)
    clusterer = VariantClusterer(method='spectral', n_clusters=20)
    _, cluster_labels = clusterer.fit_predict(variant_matrix)
    # Ignore -1 labels (noise/zeros)
    valid_labels = cluster_labels[cluster_labels >= 0]
    if len(np.unique(valid_labels)) < 2:
        print(f"Skipping {vcf_path}: insufficient clusters after filtering.")
        continue
    props = np.bincount(valid_labels, minlength=20) / len(valid_labels)
    historical_proportions.append(props)
historical_proportions = np.array(historical_proportions)
# Check if we have data
if len(historical_proportions) == 0:
    raise ValueError("No valid proportions extracted from any VCF. Check data or filtering.")
# Normalize proportions to ensure they sum to 1 (add epsilon for stability)
historical_proportions = historical_proportions / (historical_proportions.sum(axis=1, keepdims=True) + 1e-10)
# Fit Dirichlet prior (concentration params favor sparsity)
def dirichlet_neg_log_lik(alpha, data):
    alpha_t = torch.tensor(alpha)
    data_t = torch.tensor(data)
    return -torch.sum(dist.Dirichlet(alpha_t).log_prob(data_t))  # Use torch.sum for tensor compatibility
res = minimize(dirichlet_neg_log_lik, np.ones(20), args=(historical_proportions,), bounds=[(0.01, None)]*20)
dirichlet_alpha_prior = res.x
print("Extracted Dirichlet prior:", dirichlet_alpha_prior)
# Save prior
np.save('dirichlet_prior.npy', dirichlet_alpha_prior)


Reading VCF file: ../../../data/vault_pipeline_output/SRR12676841_output/snp/all_snp_from_perfect_umi.vcf
Loaded 82236 variants from VCF file
Unique UMIs: 3658
After PASS filter: 82236 variants (0 removed)
INDELs included (encoded as continuous features)
After VAF >= 0.15: 81379 variants (857 removed)
After depth >= 5: 4127 variants (77252 removed)
After alt reads >= 1: 4127 variants (0 removed)
After frequency filter (min 18 occurrences): 125 variants (4002 removed)

Final: 125 variants passed all filters (0.2%)
Unique UMIs remaining: 113
Encoding 125 INDELs as 3 features each (presence, length, type)
Created variant matrix: (16674, 113)
Total variants/features: 250
Mean value per UMI: 0.00
Sparsity: 0.999867
Applied PCA: Reduced to 50 components
Reading VCF file: ../../../data/vault_pipeline_output/SRR12676842_output/snp/all_snp_from_perfect_umi.vcf
Loaded 141140 variants from VCF file
Unique UMIs: 5368
After PASS filter: 141140 variants (0 removed)
INDELs included (encoded as contin



Skipping ../../../data/vault_pipeline_output/SRR12676844_output/snp/all_snp_from_perfect_umi.vcf: insufficient clusters after filtering.
Reading VCF file: ../../../data/vault_pipeline_output/SRR12676845_output/snp/all_snp_from_perfect_umi.vcf
Loaded 90480 variants from VCF file
Unique UMIs: 3456
After PASS filter: 90480 variants (0 removed)
INDELs included (encoded as continuous features)
After VAF >= 0.15: 85833 variants (4647 removed)
After depth >= 5: 11564 variants (74269 removed)
After alt reads >= 1: 11564 variants (0 removed)
After frequency filter (min 17 occurrences): 2499 variants (9065 removed)

Final: 2499 variants passed all filters (2.8%)
Unique UMIs remaining: 922
Encoding 2499 INDELs as 3 features each (presence, length, type)
Created variant matrix: (23796, 922)
Total variants/features: 5055
Mean value per UMI: 0.00
Sparsity: 0.999770
Applied PCA: Reduced to 50 components
Reading VCF file: ../../../data/vault_pipeline_output/SRR12676846_output/snp/all_snp_from_perfect_

In [20]:
# Load simulated/real data (example from previous notebook; adjust path if needed)
data_loader = np.load('../sc_mito_vars/real_data/imigseq_SRR12676843/numpy/vcf_processed_data.npz')
K_true = data_loader['K_true']  # Assume (n_positions x n_clusters); transpose if needed
C_observed = data_loader['C_observed']
# Load extracted prior
dirichlet_alpha_prior = np.load('dirichlet_prior.npy')
# Adjust shapes if necessary (ensure K is (n_clusters x n_positions) for matmul)
if K_true.shape[1] != len(dirichlet_alpha_prior):
    print("Warning: Prior dim mismatch; resizing prior.")
    dirichlet_alpha_prior = np.resize(dirichlet_alpha_prior, K_true.shape[1])
K_true = K_true.T  # Transpose to (n_clusters x n_positions)
# Deconvolve with Bayesian model (revised for stable sampling)
n_cells, n_positions = C_observed.shape
n_clusters = K_true.shape[0]  # Now n_clusters first
K_torch = torch.tensor(K_true, dtype=torch.float32)
C_torch = torch.tensor(C_observed, dtype=torch.float32)
alpha_prior = torch.tensor(dirichlet_alpha_prior, dtype=torch.float32)
def model(c_obs):
    p = pyro.sample("p", dist.Dirichlet(alpha_prior))
    with pyro.plate("data", n_positions):
        pyro.sample("obs", dist.Normal((p @ K_torch), 0.01), obs=c_obs)  # Adjusted matmul
def guide(c_obs):
    alpha_q = pyro.param("alpha_q", alpha_prior.clone(), constraint=dist.constraints.positive)
    pyro.sample("p", dist.Dirichlet(alpha_q))
p_estimated = np.zeros((n_cells, n_clusters))
p_uncertainty = np.zeros((n_cells, n_clusters))
for i in range(n_cells):
    svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())
    losses = []
    for step in range(1000):  # Increased steps for convergence
        loss = svi.step(C_torch[i])
        losses.append(loss)
        if step % 100 == 0:
            print(f"Cell {i}, Step {step}: Loss = {loss:.4f}")
    # Check convergence (simple threshold on final loss)
    if losses[-1] > 1e4:  # Arbitrary high loss threshold; adjust based on data
        print(f"Warning: High loss for cell {i}; using uniform fallback.")
        p_estimated[i] = np.ones(n_clusters) / n_clusters
        p_uncertainty[i] = np.ones(n_clusters) * 0.1  # Arbitrary uncertainty
        continue
    # Use Predictive for stable posterior sampling
    predictive = Predictive(model=guide, num_samples=1000, return_sites=("p",))
    samples = predictive(C_torch[i])
    p_samples = samples["p"].detach().numpy().squeeze()  # (1000 x n_clusters)
    p_estimated[i] = p_samples.mean(axis=0)
    p_uncertainty[i] = p_samples.std(axis=0)
print("Estimated proportions shape:", p_estimated.shape)
print("Uncertainty shape:", p_uncertainty.shape)


Cell 0, Step 0: Loss = -55885.8023
Cell 0, Step 100: Loss = -55896.7790
Cell 0, Step 200: Loss = -55895.5110
Cell 0, Step 300: Loss = -55897.7501
Cell 0, Step 400: Loss = -55899.6079
Cell 0, Step 500: Loss = -55898.4413
Cell 0, Step 600: Loss = -55899.4442
Cell 0, Step 700: Loss = -55898.7081
Cell 0, Step 800: Loss = -55898.5875
Cell 0, Step 900: Loss = -55898.5439
Cell 1, Step 0: Loss = -55863.6576
Cell 1, Step 100: Loss = -55891.7153
Cell 1, Step 200: Loss = -55898.4871
Cell 1, Step 300: Loss = -55895.7755
Cell 1, Step 400: Loss = -55895.1461
Cell 1, Step 500: Loss = -55897.2187
Cell 1, Step 600: Loss = -55895.5055
Cell 1, Step 700: Loss = -55896.8284
Cell 1, Step 800: Loss = -55895.1640
Cell 1, Step 900: Loss = -55895.3329
Cell 2, Step 0: Loss = -55846.2933
Cell 2, Step 100: Loss = -55857.5576
Cell 2, Step 200: Loss = -55860.5034
Cell 2, Step 300: Loss = -55860.0893
Cell 2, Step 400: Loss = -55861.0317
Cell 2, Step 500: Loss = -55863.7628
Cell 2, Step 600: Loss = -55861.8989
Cell 2,

In [21]:
# Assume ground truth P_cells_true from data
P_cells_true = data_loader['P_cells_true']
from sklearn.metrics import r2_score
r2 = r2_score(P_cells_true.flatten(), p_estimated.flatten())
print(f"R² with priors: {r2:.4f}")
# Uncertainty analysis
print("Mean uncertainty:", p_uncertainty.mean())


R² with priors: 0.9653
Mean uncertainty: 0.007448166846646927
