# Scaling to Real eQTL Data

In real-world applications (e.g., eQTL studies with millions of tests across
dozens of tissues), it is impractical to load and fit all tests at once.
This notebook demonstrates the recommended workflow using two subsets:

- **Strong signals** — top eQTLs per gene, used for learning covariance patterns
- **Random subset** — an unbiased sample of all tests, used for fitting the model

In [Urbut et al. 2019](https://doi.org/10.1038/s41588-018-0268-8), the strong
set contained ~16k tests and the random set ~20k tests.

pymash's **workflow module** (`select_training_effects`, `fit_mash_prior`,
`apply_mash_prior`) handles the two-stage train/apply pattern, including
subsetting, fitting with `outputlevel=1`, and applying with `fixg=True`.

## Strategy

1. Estimate null correlations from the random subset
2. Learn data-driven covariances from the strong subset
3. Fit the model (mixture proportions) on the random subset via `fit_mash_prior`
4. Compute posteriors for any subset using `apply_mash_prior`

In [None]:
import numpy as np
import pymash as mash

## Simulate Data

We simulate a larger dataset (40k tests) to mimic a real scenario.

In [2]:
sim = mash.simple_sims(nsamp=10000, ncond=5, err_sd=1.0, seed=1)
print(f"Total tests: {sim['Bhat'].shape[0]}")

Total tests: 40000


### Identify Strong and Random Subsets

In [None]:
# Find strong signals using a quick 1-by-1 analysis
full_data = mash.mash_set_data(sim["Bhat"], sim["Shat"])
m1 = mash.mash_1by1(full_data)
strong_idx = mash.get_significant_results(m1, thresh=0.05)
print(f"Strong signals: {len(strong_idx)}")

# Select a random training subset using the workflow module.
# For real eQTL data with millions of tests, use n_train=20000-50000.
random_idx = mash.select_training_effects(
    full_data, n_train=5000, method="random", seed=42,
)
print(f"Random subset: {len(random_idx)}")

## Step 1: Estimate Null Correlations

Estimate the residual correlation structure among conditions from the random
subset. This captures correlations due to confounders, not true effects.

**Use the random subset** (not the strong subset, which may lack null tests).

In [None]:
data_temp = mash.mash_set_data(
    sim["Bhat"][random_idx], sim["Shat"][random_idx]
)
Vhat = mash.estimate_null_correlation_simple(data_temp)
print("Estimated null correlation matrix:")
print(np.array2string(Vhat, precision=3))

### Create Data Objects with Estimated Correlations

Create the full dataset (for fitting and posterior computation) and a
strong-signal subset (for covariance learning only).

In [None]:
data_all = mash.mash_set_data(
    sim["Bhat"], sim["Shat"], V=Vhat
)
data_strong = mash.mash_set_data(
    sim["Bhat"][strong_idx], sim["Shat"][strong_idx], V=Vhat
)
print(f"Full data: {data_all.n_effects} tests x {data_all.n_conditions} conditions")
print(f"Strong data: {data_strong.n_effects} tests x {data_strong.n_conditions} conditions")

## Step 2: Learn Data-Driven Covariances from Strong Signals

In [None]:
U_pca = mash.cov_pca(data_strong, npc=5)
U_ed = mash.cov_ed(data_strong, U_pca)
print(f"Data-driven covariances: {list(U_ed.keys())}")

## Step 3: Fit the Model on the Random Subset

Use `fit_mash_prior` to fit mixture weights on the random training subset.
It handles subsetting internally — pass the full `data_all` and the
`random_idx` from `select_training_effects`. It fits with `outputlevel=1`
(mixture weights only, no posteriors).

In [None]:
U_c = mash.cov_canonical(data_all)
U_all = {**U_ed, **U_c}

fitted_g, train_idx, train_result = mash.fit_mash_prior(
    data_all, U_all, train_indices=random_idx,
)
print(f"Log-likelihood: {train_result.loglik:.2f}")
print(f"Training subset: {len(train_idx)} tests")

## Step 4: Compute Posteriors on the Strong Subset

Use `apply_mash_prior` to apply the learned model to the strong signals.
This calls `mash(..., g=fitted_g, fixg=True)` internally.

You could also apply to the full dataset or any other subset of interest.

In [None]:
m_strong = mash.apply_mash_prior(data_strong, fitted_g)

sig = mash.get_significant_results(m_strong, thresh=0.05)
print(f"Significant effects in strong set: {len(sig)}")
print(f"\nlfsr (first 5 effects):")
print(mash.get_lfsr(m_strong)[:5].round(3))

## Loading Real Data from CSV

In practice, you would load your Bhat and Shat matrices from files.
For example, starting from CSV files where rows are tests and columns
are conditions:

```python
import pandas as pd

bhat_df = pd.read_csv("bhat.csv", index_col=0)
shat_df = pd.read_csv("shat.csv", index_col=0)

data = mash.mash_set_data(bhat_df.values, shat_df.values)
```

The key requirement is that `Bhat` and `Shat` are aligned matrices with
the same shape: J tests (rows) by R conditions (columns).

## One-Shot Alternative: `mash_train_apply`

The entire train/apply workflow (steps 3–4) can be combined into a single
call. This is convenient when you don't need separate control over
each stage:

```python
result = mash.mash_train_apply(
    data_all, U_all,
    n_train=5000,
    select_method="random",
)
pm   = mash.get_pm(result.apply_result)
lfsr = mash.get_lfsr(result.apply_result)
print(f"Trained on {len(result.train_indices)} tests")
```