Skip to content

Commit

Permalink
feat: graph based clumping
Browse files Browse the repository at this point in the history
  • Loading branch information
DSuveges committed Feb 10, 2023
1 parent 19ee061 commit 31cab8d
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 100 deletions.
31 changes: 1 addition & 30 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ numpy = "^1.23.2"
gcsfs = "^2022.8.2"
pytest = "^7.1.3"
hail = "^0.2.98"
graphframes = "^0.6"

[tool.poetry.dev-dependencies]
pre-commit = "^2.15.0"
Expand Down
126 changes: 125 additions & 1 deletion src/etl/gwas_ingest/clumping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,105 @@
from typing import TYPE_CHECKING

import pyspark.sql.functions as f
from graphframes import GraphFrame
from graphframes.lib import Pregel
from pyspark.sql import Window

from etl.common.spark_helpers import adding_quality_flag
from etl.gwas_ingest.pics import _neglog_p

if TYPE_CHECKING:
from pyspark.sql import DataFrame


def resolve_graph(df: DataFrame) -> DataFrame:
"""Graph resolver for clumping.
It takes a dataframe with a list of variants and their explained variants, and returns a dataframe
with a list of variants and their resolved roots
Args:
df (DataFrame): DataFrame
Returns:
A dataframe with the resolved roots.
"""
# Convert to vertices:
nodes = df.select(
"studyId",
"variantId",
# Generating node identifier column (has to be unique):
f.concat_ws("_", f.col("studyId"), f.col("variantId")).alias("id"),
# Generating the original root list. This is the original message which is propagated across nodes:
f.when(f.col("variantId") == f.col("explained"), f.col("variantId")).alias(
"origin_root"
),
).distinct()

# Convert to edges (more significant points to less significant):
edges = (
df.filter(f.col("variantId") != f.col("explained"))
.select(
f.concat_ws("_", f.col("studyId"), f.col("variantId")).alias("dst"),
f.concat_ws("_", f.col("studyId"), f.col("explained")).alias("src"),
f.lit("explains").alias("edgeType"),
)
.distinct()
)

# Building graph:
graph = GraphFrame(nodes, edges)

# Extracing nodes with edges - most of the
filtered_nodes = (
graph.outDegrees.join(graph.inDegrees, on="id", how="outer")
.drop("outDegree", "inDegree")
.join(nodes, on="id", how="inner")
.repartition("studyId", "variantId")
)

# Building graph:
graph = GraphFrame(filtered_nodes, edges)

# Pregel resolver:
resolved_nodes = (
graph.pregel.setMaxIter(5)
# New column for the resolved roots:
.withVertexColumn(
"message",
f.when(f.col("origin_root").isNotNull(), f.col("origin_root")),
f.when(Pregel.msg().isNotNull(), Pregel.msg()),
)
.withVertexColumn(
"resolved_roots",
# The value is initialized by the original root value:
f.when(
f.col("origin_root").isNotNull(), f.array(f.col("origin_root"))
).otherwise(f.array()),
# When new value arrives to the node, it gets merged with the existing list:
f.when(
Pregel.msg().isNotNull(),
f.array_union(f.split(Pregel.msg(), " "), f.col("resolved_roots")),
).otherwise(f.col("resolved_roots")),
)
# We need to reinforce the message in both direction:
.sendMsgToDst(Pregel.src("message"))
# Once the message is delivered it is updated with the existing list of roots at the node:
.aggMsgs(f.concat_ws(" ", f.collect_set(Pregel.msg())))
.run()
.orderBy("studyId", "id")
.persist()
)

# Joining back the dataset:
return df.join(
# The `resolved_roots` column will be null for nodes, with no connection.
resolved_nodes.select("resolved_roots", "studyId", "variantId"),
on=["studyId", "variantId"],
how="left",
)


def clumping(df: DataFrame) -> DataFrame:
"""Clump non-independent credible sets.
Expand All @@ -25,6 +116,7 @@ def clumping(df: DataFrame) -> DataFrame:
- removing overall R from non independent leads.
- Adding QC flag to non-independent leads pointing to the relevant lead.
"""
# GraphFrames needs this:
w = Window.partitionBy("studyId", "variantPair").orderBy(f.col("negLogPVal").desc())

# This dataframe contains all the resolved and independent leads. However not all linked signals are properly assigned to a more significant lead:
Expand Down Expand Up @@ -69,7 +161,39 @@ def clumping(df: DataFrame) -> DataFrame:
)
.withColumn("explained", f.explode("all_explained"))
.drop("reference_lead", "all_explained")
.persist()
.transform(resolve_graph)
# Generate QC notes for explained associations:
.withColumn(
"qualityControl",
adding_quality_flag(
f.col("qualityControl"),
~f.col("keep_lead"),
f.concat_ws(
" ",
f.lit("Association explained by:"),
f.concat_ws(", ", f.col("resolved_roots")),
),
),
)
# Remove tag information if lead is explained by other variant:
.withColumn(
"tagVariantId",
f.when(f.col("explained") == f.col("variantId"), f.col("tagVariantId")),
)
.withColumn(
"R_overall", f.when(f.col("tagVariantId").isNotNull(), f.col("R_overall"))
)
# Drop unused column:
.drop(
"variantPair",
"explained",
"negLogPVal",
"rank",
"keep_lead",
"resolved_roots",
)
.distinct()
.orderBy("studyId", "variantId")
)

# Test
Expand Down
26 changes: 15 additions & 11 deletions src/etl/gwas_ingest/pics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from etl.common.spark_helpers import _neglog_p, adding_quality_flag
from etl.gwas_ingest.clumping import clumping
from etl.gwas_ingest.pics import ld_annotation_by_locus_ancestry
from etl.json import data

if TYPE_CHECKING:
Expand Down Expand Up @@ -407,6 +408,9 @@ def pics_all_study_locus(
Returns:
DataFrame: _description_
"""
# GraphFrames needs this... terrible.
etl.spark.sparkContext.setCheckpointDir("<pwd_output>")

# Extracting ancestry information from study table, then map to gnomad population:
gnomad_mapped_studies = _get_study_gnomad_ancestries(
etl, studies.withColumnRenamed("id", "studyId")
Expand All @@ -433,17 +437,10 @@ def pics_all_study_locus(
.distinct()
)

# Number of distinct variants/population pairs to map:
etl.logger.info(f"Number of variant/ancestry pairs: {variant_population.count()}")
etl.logger.info(
f'Number of unique variants: {variant_population.select("variantId").distinct().count()}'
)

# LD information for all locus and ancestries
ld_r = etl.spark.read.parquet("gs://ot-team/dsuveges/ld_expanded_dataset")
# ld_r = ld_annotation_by_locus_ancestry(
# etl, variant_population, ld_populations, min_r2
# ).persist()
ld_r = ld_annotation_by_locus_ancestry(
etl, variant_population, ld_populations, min_r2
).persist()

# Joining association with linked variants (while keeping unresolved associations).
association_ancestry_ld = (
Expand Down Expand Up @@ -490,9 +487,16 @@ def pics_all_study_locus(
# Collapse the data by study, lead, tag
.drop("relativeSampleSize", "r", "gnomadPopulation").distinct()
# Clumping non-independent associations together:
.transform(clumping)
# .transform(clumping)
.persist()
)

# Dataset before clumping:
associations_ld_allancestries = (
associations_ld_allancestries
# Clumping non-independent associations together:
.transform(clumping)
)
pics_results = calculate_pics(associations_ld_allancestries, k)

return pics_results
Loading

0 comments on commit 31cab8d

Please sign in to comment.