# Variance Explained

The goal of this notebook is to add a variance explained calculated as
`chi2.isf(pvalue, df=1)`


In [6]:
import polars as pl
from scipy.stats import chi2


### Load the dataset from previous notebook


In [7]:
lead_variant_maf_dataset = pl.read_parquet("../../data/lead-maf-vep/*.parquet")
print(lead_variant_maf_dataset.shape)
print(lead_variant_maf_dataset.columns)


(2622098, 19)
['variantId', 'studyId', 'studyLocusId', 'beta', 'zScore', 'pValueMantissa', 'pValueExponent', 'standardError', 'finemappingMethod', 'studyType', 'credibleSetSize', 'nSamples', 'nControls', 'nCases', 'majorPopulation', 'allelefrequencies', 'vepEffect', 'majorPopulationAF', 'majorPopulationMAF']


### Calculate variance explained

The variance explained follows the simplified formula

${variance\;explained}=\chi^2 / n $

* The $\chi^2$ is calculated as **Inverse survival function** by using `scipy.stats.isf` function from lead variant $pValue$ (depicted as `pValueMantissa` and `pValueExponent`). 
* The $n$ parameter is the number of samples derived from GWAS study description.

* In case where the `pValueExponent < 300` to avoid floating point errors we estimate $\chi^2$ statistic with $-log_{10}(pValue)$
* The $variance\;explained$ can be only calculated where the $n > 0$

In [9]:
def variance_explained(p_value_mantissa: pl.Expr, p_value_exponent: pl.Expr, n_samples: pl.Expr) -> pl.Expr:
    """Estimate the variance explained by the lead variant in a dataset.

    # NOTE! Calculate variance explained requires removal of the studies that have nSamples = 0
    """
    p_value = (p_value_mantissa.cast(pl.Float64()) * pl.lit(10).pow(p_value_exponent.cast(pl.Float64()))).alias(
        "pValue"
    )
    neglog_pval = -1 * p_value_mantissa.log10() + p_value_exponent
    neglog_approximation_intercept = -5.367
    neglog_approximation_coeff = 4.596
    chi2_stat = (
        pl.when(p_value_exponent < -300)
        .then(neglog_pval * neglog_approximation_coeff + neglog_approximation_intercept)
        .otherwise(p_value.map_elements(lambda x: chi2.isf(x, df=1), pl.Float64()))
    ).alias("chi2Stat")
    variance_exp = (chi2_stat / n_samples).alias("varianceExplained")

    return pl.struct(chi2_stat, p_value, variance_exp).alias("leadVariantStats")


In [10]:
df = lead_variant_maf_dataset.filter(pl.col("nSamples").is_not_null() | pl.col("nSamples") > 0).select(
    "*",
    variance_explained(
        pl.col("pValueMantissa"),
        pl.col("pValueExponent"),
        pl.col("nSamples"),
    ),
)


In [11]:
df.write_parquet("../../data/variance-explained.parquet")


In [12]:
lead_variant_maf_dataset.shape[0] - df.shape[0]


1050

Exactly 1050 samples have no `nSamples` disallowing us to calculate the varianceExplained


## Rescale marginal effect size


Rescaling of marginal effect size to the original value from the standardised marginal effect size is done via two formulas depending on trait being **quantitative** or **binary**


Estimation of the trait type is done on the basis of availability of reported `nCases` and `nControls` fields in the study description. 
* In case both fields are non empty and non zero we assume *binary trait*
* In case cases are zero or are not reported we assume *quantitative trait*

In both cases we estimate the marginal effect size $estimated\;\beta$ with following formula
$$estimated\;\beta = zscore \cdot se$$

Where 
* $zscore = \frac{\beta}{|{\beta}|} \cdot \sqrt{\chi^2}$
* $se$ depends on the trait type
* $\beta$ - *standardised beta reported from in the summary statistics* 

In case when $\beta$ was not reported we assumed the $\frac{\beta}{|{\beta}|}$ to be equal to 1

#### Binary trait marginal effect size estimation

$$se = \frac{1}{\sqrt{(varG \cdot prev \cdot (1 - prev))}}$$
* $varG = 2 \cdot f \cdot (1 - f)$
* $f$ - *Minor Allele Frequency of lead variant*
* $prev = \frac{nCases}{nCases + nControls}$

#### Quantative trait marginal effect size estimation

$$se = \frac{1}{\sqrt{varG}}$$
* $varG = 2 \cdot f \cdot (1 - f)$
* $f$ - *Minor Allele Frequency of lead variant*


The $\chi^2$ was esteimated as described in `variance Explained` calculation.

In [13]:
def rescale_beta(
    beta: pl.Expr,
    n_cases: pl.Expr,
    n_controls: pl.Expr,
    n_samples: pl.Expr,
    p_value_mantissa: pl.Expr,
    p_value_exponent: pl.Expr,
    maf: pl.Expr,
) -> pl.Expr:
    """Rescale beta to be between 0 and 1."""
    neglog_approximation_intercept = -5.367
    neglog_approximation_coeff = 4.596
    trait_class = (
        pl.when(n_cases.is_null())
        .then(pl.lit("quantitative"))
        .when((n_cases == 0) | (n_controls == 0))
        .then(pl.lit("quantitative"))
        .otherwise(pl.lit("binary"))
    )
    p_value = (p_value_mantissa.cast(pl.Float64()) * pl.lit(10).pow(p_value_exponent.cast(pl.Float64()))).alias(
        "pValue"
    )
    neglog_pval = -1 * p_value_mantissa.log10() + p_value_exponent
    n_samples = pl.when(trait_class == "quantitative").then(n_samples).otherwise(n_cases + n_controls)
    # Calculate the chi2 value - the calculation has to be approximated to the -neglog pval in case when exponent is very low
    # otherwise the chi2 will be infinity.
    chi2_stat = (
        pl.when(p_value_exponent < -300)
        .then(neglog_pval * neglog_approximation_coeff + neglog_approximation_intercept)
        .otherwise(p_value.map_elements(lambda x: chi2.isf(x, df=1), pl.Float64()))
    )
    # In case beta is positive or not reported we use 1 as a sign
    effect_direction = pl.when(beta < 0).then(pl.lit(-1)).otherwise(pl.lit(1))
    z_score = effect_direction * chi2_stat.sqrt()
    var_g = 2 * maf * (1 - maf)
    prev = n_cases / (n_cases + n_controls)
    se = (
        pl.when(trait_class == "quantitative")
        .then((1 / (var_g * n_samples)).sqrt())
        .otherwise((1 / (var_g * n_samples * prev * (1 - prev))).sqrt())
    )
    new_beta = z_score * se

    return pl.struct(
        new_beta.alias("estimatedBeta"),
        trait_class.alias("traitClass"),
        chi2_stat.alias("chi2Stat"),
        se.alias("estimatedSE"),
        var_g.alias("varG"),
        prev.alias("prev"),
        n_samples.alias("nSamples"),
    ).alias("rescaledStatistics")


In [14]:
df2 = df.select(
    "*",
    rescale_beta(
        pl.col("beta"),
        pl.col("nCases"),
        pl.col("nControls"),
        pl.col("nSamples"),
        pl.col("pValueMantissa"),
        pl.col("pValueExponent"),
        pl.col("majorPopulationMAF"),
    ),
)


In [15]:
df3 = df2.filter(pl.col("majorPopulationMAF") == 0.0).filter(
    pl.col("rescaledStatistics").struct.field("estimatedBeta").is_infinite()
)
df3.shape[0]


3778

We have 3778 lead variants that have major population MAF = 0.0


In [16]:
# Write the dataset
df2.write_parquet("../../data/rescaled-betas.parquet")


#### Sanity checks

In [17]:
# binary estimatedBeta distribtution
df2.select(pl.col("rescaledStatistics").struct.unnest()).filter(pl.col("traitClass") == "binary").filter(
    pl.col("estimatedBeta").is_finite()
).select(pl.col("estimatedBeta").abs()).describe()


statistic,estimatedBeta
str,f64
"""count""",69725.0
"""null_count""",0.0
"""mean""",0.907295
"""std""",6.406072
"""min""",0.010357
"""25%""",0.048874
"""50%""",0.085318
"""75%""",0.189065
"""max""",820.178296


In [18]:
# quantitative estimatedBeta distribution
df2.select(pl.col("rescaledStatistics").struct.unnest()).filter(pl.col("traitClass") == "quantitative").filter(
    pl.col("estimatedBeta").is_finite()
).select(pl.col("estimatedBeta").abs()).describe()


statistic,estimatedBeta
str,f64
"""count""",2539980.0
"""null_count""",0.0
"""mean""",1.066917
"""std""",2.991953
"""min""",0.003516
"""25%""",0.413165
"""50%""",0.693346
"""75%""",1.06016
"""max""",576.10404


In [None]:
df2.select(pl.col("rescaledStatistics").struct.unnest())


estimatedBeta,traitClass,chi2Stat,estimatedSE,varG,prev,nSamples
f64,str,f64,f64,f64,f64,i32
0.364678,"""quantitative""",31.912211,0.064555,0.413011,,581
0.372868,"""quantitative""",30.548044,0.067463,0.413011,,532
-1.364617,"""quantitative""",26.680755,0.264187,0.075014,,191
0.503056,"""quantitative""",27.462141,0.095995,0.478053,,227
-0.515738,"""quantitative""",20.263376,0.114571,0.461711,,165
…,…,…,…,…,…,…
0.724962,"""quantitative""",77.700681,0.082244,0.320695,,461
-0.088602,"""binary""",34.951044,0.014987,0.389375,0.030189,390539
0.079953,"""binary""",49.220051,0.011396,0.427172,0.115048,177039
0.056312,"""binary""",73.512517,0.006568,0.398685,0.457826,234253
