In [1]:
import h5py
import os

import cvxpy as cp
import numpy as np
from scipy.stats import pearsonr, spearmanr
from sklearn.linear_model import RidgeCV
from tqdm.notebook import tqdm

In [2]:
DATASET_DIR = "random_split.384_bins"
train_h5_path = os.path.join(DATASET_DIR, "train.h5")
val_h5_path = os.path.join(DATASET_DIR, "val.h5")
test_h5_path = os.path.join(DATASET_DIR, "test.h5")

In [3]:
def load_h5_data(h5_path: str):
    features_per_gene = {}
    meta_features_per_gene = {}
    zscores_per_gene = {}
    with h5py.File(h5_path, "r") as f:
        genes = f["genes"][:].astype(str)
        for g in tqdm(genes):
            if f[g]["variants"].shape[0] == 0:
                continue
            features_per_gene[g] = f[g]["dosages"][:].astype(np.float32)
            meta_features_per_gene[g] = f[g]["meta_features"][:].astype(np.float32)
            zscores_per_gene[g] = f[g]["z_scores"][:].astype(np.float32)
    return (features_per_gene, meta_features_per_gene, zscores_per_gene)

In [4]:
train_features, train_meta_features, train_zscores = load_h5_data(train_h5_path)
val_features, _, val_zscores = load_h5_data(val_h5_path)
test_features, _, test_zscores = load_h5_data(test_h5_path)

  0%|          | 0/3259 [00:00<?, ?it/s]

  0%|          | 0/3259 [00:00<?, ?it/s]

  0%|          | 0/3259 [00:00<?, ?it/s]

In [9]:
def get_num_meta_features(meta_features_per_gene):
    n_meta_features = [
        meta_features_per_gene[g].shape[1] for g in meta_features_per_gene
    ]
    assert all(n == n_meta_features[0] for n in n_meta_features)
    return n_meta_features[0]


def compute_priors_per_gene(
    meta_features_per_gene: dict[str, np.ndarray], beta: np.ndarray
):
    priors_per_gene = {}
    for g in meta_features_per_gene:
        priors_per_gene[g] = meta_features_per_gene[g] @ beta
    return priors_per_gene


def update_weights_per_gene(
    features_per_gene: dict[str, np.ndarray],
    priors_per_gene: dict[str, np.ndarray],
    zscores_per_gene: dict[str, np.ndarray],
) -> dict[str, RidgeCV]:
    coefs_per_gene = {}
    intercepts_per_gene = {}
    for g in tqdm(features_per_gene):
        X = features_per_gene[g]
        P = np.diag(priors_per_gene[g])
        P_one_half = np.sqrt(P)
        X_tilde = X @ P_one_half
        y = zscores_per_gene[g]

        model = RidgeCV(alphas=np.logspace(-3, 3, 50))
        model.fit(X_tilde, y)

        coefs_per_gene[g] = P_one_half @ model.coef_
        intercepts_per_gene[g] = model.intercept_
    return coefs_per_gene, intercepts_per_gene


def update_beta(
    coefs_per_gene: dict[str, np.ndarray],
    meta_features_per_gene: dict[str, np.ndarray],
    C: float = 1.0,
    n_subsample: int = 10_000,
    epsilon: float = 1e-9,
):
    genes = list(coefs_per_gene.keys())
    F = np.vstack([meta_features_per_gene[g] for g in genes])
    W = np.concatenate([coefs_per_gene[g] for g in genes])
    assert F.shape[0] == W.shape[0]

    # Subsample for speed and tractability
    subsample_idxs = np.random.permutation(F.shape[0])[:n_subsample]
    F = F[subsample_idxs]
    W = W[subsample_idxs]
    assert F.shape[0] == W.shape[0]

    beta = cp.Variable(F.shape[1])
    z = F @ beta

    # objective: sum(W^2 / z) + C * sum(z)
    objective = cp.Minimize(
        cp.sum(cp.multiply(np.square(W), cp.inv_pos(z + epsilon))) + C * cp.sum(z)
    )
    constraints = [z >= 0]
    problem = cp.Problem(objective, constraints)
    problem.solve(verbose=True)
    return beta.value


def compute_test_performance_per_gene(
    features_per_gene: dict[str, np.ndarray],
    coefs_per_gene: dict[str, np.ndarray],
    intercepts_per_gene: dict[str, np.ndarray],
    zscores_per_gene: dict[str, np.ndarray],
):
    pearsons_per_gene = {}
    spearmans_per_gene = {}
    mse_per_gene = {}
    for g in tqdm(features_per_gene):
        X = features_per_gene[g]
        y = zscores_per_gene[g]
        y_pred = X @ coefs_per_gene[g] + intercepts_per_gene[g]
        pearsons_per_gene[g] = pearsonr(y, y_pred)[0]
        spearmans_per_gene[g] = spearmanr(y, y_pred)[0]
        mse_per_gene[g] = np.mean((y - y_pred) ** 2)
    return pearsons_per_gene, spearmans_per_gene, mse_per_gene

In [6]:
n_meta_features = get_num_meta_features(train_meta_features)
beta = np.ones((n_meta_features), dtype=np.float32) / n_meta_features
priors_per_gene = compute_priors_per_gene(train_meta_features, beta)

In [7]:
coefs_per_gene, intercepts_per_gene = update_weights_per_gene(
    train_features, priors_per_gene, train_zscores
)

  0%|          | 0/3252 [00:00<?, ?it/s]

In [10]:
(
    val_pearsons_per_gene,
    val_spearmans_per_gene,
    val_mses_per_gene,
) = compute_test_performance_per_gene(
    val_features, coefs_per_gene, intercepts_per_gene, val_zscores
)
print(
    f"Mean val Pearson correlation: {np.nanmean(list(val_pearsons_per_gene.values()))}"
)
print(
    f"Mean val Spearman correlation: {np.nanmean(list(val_spearmans_per_gene.values()))}"
)
print(f"Mean val MSE: {np.nanmean(list(val_mses_per_gene.values()))}")

  0%|          | 0/3252 [00:00<?, ?it/s]

  pearsons_per_gene[g] = pearsonr(y, y_pred)[0]
  spearmans_per_gene[g] = spearmanr(y, y_pred)[0]


Mean val Pearson correlation: 0.28209403820126533
Mean val Spearman correlation: 0.26095346754381826
Mean val MSE: 1.1276605298791342


In [11]:
(
    test_pearsons_per_gene,
    test_spearmans_per_gene,
    test_mses_per_gene,
) = compute_test_performance_per_gene(
    test_features, coefs_per_gene, intercepts_per_gene, test_zscores
)
print(
    f"Mean test Pearson correlation: {np.nanmean(list(test_pearsons_per_gene.values()))}"
)
print(
    f"Mean test Spearman correlation: {np.nanmean(list(test_spearmans_per_gene.values()))}"
)
print(f"Mean test MSE: {np.nanmean(list(test_mses_per_gene.values()))}")

  0%|          | 0/3252 [00:00<?, ?it/s]

Mean test Pearson correlation: 0.2834294585050532
Mean test Spearman correlation: 0.27159774893734484
Mean test MSE: 1.0031527366647355


In [9]:
beta2 = update_beta(coefs_per_gene, train_meta_features)

                                     CVXPY                                     
                                     v1.4.1                                    
(CVXPY) Apr 29 05:11:09 PM: Your problem has 5313 variables, 1 constraints, and 0 parameters.
(CVXPY) Apr 29 05:11:09 PM: It is compliant with the following grammars: DCP, DQCP
(CVXPY) Apr 29 05:11:09 PM: (If you need to solve this problem multiple times, but with different data, consider using parameters.)
(CVXPY) Apr 29 05:11:09 PM: CVXPY will first compile your problem; then, it will invoke a numerical solver to obtain a solution.
(CVXPY) Apr 29 05:11:09 PM: Your problem is compiled with the CPP canonicalization backend.
-------------------------------------------------------------------------------
                                  Compilation                                  
-------------------------------------------------------------------------------
(CVXPY) Apr 29 05:11:09 PM: Compiling problem (target solver=ECOS).
(C

    Your problem is being solved with the ECOS solver by default. Starting in 
    CVXPY 1.5.0, Clarabel will be used as the default solver instead. To continue 
    using ECOS, specify the ECOS solver explicitly using the ``solver=cp.ECOS`` 
    argument to the ``problem.solve`` method.
    


(CVXPY) Apr 29 05:12:10 PM: Applying reduction ECOS
(CVXPY) Apr 29 05:13:01 PM: Finished problem compilation (took 1.122e+02 seconds).
-------------------------------------------------------------------------------
                                Numerical solver                               
-------------------------------------------------------------------------------
(CVXPY) Apr 29 05:13:01 PM: Invoking solver ECOS  to obtain a solution.
