# Rare variant analysis

1. Calculate number of variants that have MAF > 0.01 in 
   1. GWAS credible sets
   2. all credible sets 

In [8]:
import polars as pl
from collections.abc import Callable


In [None]:
# data generated by 01 and 02 notebooks
dataset = pl.read_parquet("rescaled-beta.parquet")


In [36]:
class VariantFrequencyClass:
    def __init__(self, name, condition: Callable[[pl.Expr], pl.Expr]) -> None:
        self.name = name
        self.condition = condition

    def from_maf(self, maf: pl.Expr) -> pl.Expr:
        """Extract variant type from MAF."""
        condition = self.condition(maf)
        return pl.when(condition).then(True).otherwise(False).alias(self.name)


def variant_maf_classification(maf: pl.Expr, variant_types: list[VariantFrequencyClass] | None = None) -> pl.Expr:
    """Classify variants based on the maf thresholds."""
    if not variant_types:
        variant_types = [
            VariantFrequencyClass("common", lambda maf: (maf > 0.05)),
            VariantFrequencyClass("lowFrequency", lambda maf: ((maf <= 0.05) & (maf >= 0.01))),
            VariantFrequencyClass("rare", lambda maf: (maf < 0.01)),
        ]
    maf = pl.col("majorPopulationMAF")
    distributions = [vt.from_maf(maf) for vt in variant_types]
    return pl.struct(*distributions).alias("variantMAFClassification")


### Calculate class balance in all lead variants

In [63]:
df = (
    dataset.filter(pl.col("majorPopulationMAF").is_not_null())
    .filter(pl.col("majorPopulationMAF") != 0.0)
    .select(
        pl.col("variantId"),
        variant_maf_classification(pl.col("majorPopulationMAF")),
    )
    .select(
        pl.col("variantId"),
        pl.col("variantMAFClassification").struct.unnest(),
    )
    .unpivot(
        on=["common", "lowFrequency", "rare"],
        index="variantId",
        variable_name="class",
        value_name="belongToClass",
    )
    .filter(pl.col("belongToClass"))
    .select("variantId", "class")
    .group_by("class")
    .agg(pl.len().alias("count"))
    .with_columns(percent=pl.col("count") / pl.col("count").sum() * 100)
    .sort("class")
)
df


class,count,percent
str,u32,f64
"""common""",2255546,86.234165
"""lowFrequency""",239447,9.154552
"""rare""",120613,4.611283


### Calculate class balance in GWAS conly lead variants

In [64]:
df = (
    dataset.filter(pl.col("majorPopulationMAF").is_not_null())
    .filter(pl.col("majorPopulationMAF") != 0.0)
    .filter(pl.col("studyType") == "gwas")
    .select(
        pl.col("variantId"),
        variant_maf_classification(pl.col("majorPopulationMAF")),
    )
    .select(
        pl.col("variantId"),
        pl.col("variantMAFClassification").struct.unnest(),
    )
    .unpivot(
        on=["common", "lowFrequency", "rare"],
        index="variantId",
        variable_name="class",
        value_name="belongToClass",
    )
    .filter(pl.col("belongToClass"))
    .select("variantId", "class")
    .group_by("class")
    .agg(pl.len().alias("count"))
    .with_columns(percent=pl.col("count") / pl.col("count").sum() * 100)
    .sort("class")
)
df


class,count,percent
str,u32,f64
"""common""",474703,83.005417
"""lowFrequency""",57009,9.968456
"""rare""",40182,7.026127
