In [105]:
!gcloud auth application-default login

Your browser has been opened to visit:

    https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=764086051850-6qr4p6gpi6hn506pt8ejuq83di341hur.apps.googleusercontent.com&redirect_uri=http%3A%2F%2Flocalhost%3A8085%2F&scope=openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fsqlservice.login&state=V5XYMb6s7E1XMYY93rCgLAUApvyMGK&access_type=offline&code_challenge=DaC748rWYZFDN0wGZqKZk3ms2La2e9mvl8gG_YO8F9M&code_challenge_method=S256


Credentials saved to file: [/Users/yt4/.config/gcloud/application_default_credentials.json]

These credentials will be used by any library that requests Application Default Credentials (ADC).

Quota project "open-targets-genetics-dev" was added to ADC which can be used by Google client libraries for billing and quota. Note that some services may still bill the project owning the resource.


In [1]:
import os

import hail as hl
import numpy as np
import pyspark.sql.functions as f
from pyspark.sql import DataFrame

from gentropy.common.session import Session
from gentropy.dataset.study_index import StudyIndex
from gentropy.dataset.summary_statistics import SummaryStatistics
from gentropy.dataset.study_locus import StudyLocus
from gentropy.susie_finemapper import SusieFineMapperStep
#from gentropy.method.drug_enrichment_from_evid import chemblDrugEnrichment

"""Common utilities for the project."""

import os
from pathlib import Path
from gentropy.common.session import Session
import logging


def get_gcs_credentials() -> str:
    """Get the credentials for google cloud storage."""
    app_default_credentials = os.path.join(
        os.getenv("HOME", "."), ".config/gcloud/application_default_credentials.json"
    )

    service_account_credentials = os.path.join(
        os.getenv("HOME", "."), ".config/gcloud/service_account_credentials.json"
    )

    if Path(app_default_credentials).exists():
        return app_default_credentials
    else:
        raise FileNotFoundError("No GCS credentials found.")


def get_gcs_hadoop_connector_jar() -> str:
    """Get the google cloud storage hadoop connector for spark.

    This function will return the url to download the hadoop jar.
    """

    return (
        "https://storage.googleapis.com/hadoop-lib/gcs/gcs-connector-hadoop3-latest.jar"
    )


def gcs_conf(
    credentials_path=None, project="open-targets-genetics-dev"
) -> dict[str, str]:
    """Get the spark configuration with hadoop connector for google cloud storage."""
    credentials_path = credentials_path or get_gcs_credentials()
    return {
        "spark.driver.memory": "12g",
        "spark.kryoserializer.buffer.max": "500m",
        "spark.driver.maxResultSize":"2g",
        "spark.hadoop.fs.gs.impl": "com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem",
        "spark.jars": get_gcs_hadoop_connector_jar(),
        "spark.hadoop.google.cloud.auth.service.account.enable": "true",
        "spark.hadoop.fs.gs.project.id": project,
        "spark.hadoop.google.cloud.auth.service.account.json.keyfile": credentials_path,
        "spark.hadoop.fs.gs.requester.pays.mode": "AUTO",
    }


class GentropySession(Session):
    def __init__(self, *args, **kwargs):
        if "extended_spark_conf" in kwargs:
            kwargs["extended_spark_conf"].update(gcs_conf())
        else:
            kwargs["extended_spark_conf"] = gcs_conf()
        super().__init__(*args, **kwargs)

    @property
    def conf(self):
        logging.warning(
            "To change the config restart the session and use the `extended_spark_conf` parameter."
        )
        return self.spark.sparkContext.getConf().getAll()

session= GentropySession()


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/06/25 14:49:43 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/06/25 14:49:44 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [2]:
path_to_release_folder="gs://open-targets-data-releases/25.06/"


si=StudyIndex.from_parquet(session,path_to_release_folder+"output/study/")
sl=StudyLocus.from_parquet(session,path_to_release_folder+"output/credible_set/")

                                                                                

# Combining with FM

In [3]:
fm=session.spark.read.parquet(path_to_release_folder+"intermediate/l2g_feature_matrix/")

In [4]:
cs=sl.df.select("studyLocusId","studyId")
result_df = fm.join(cs, on="studyLocusId", how="left")

In [5]:
si_df=si.df.select("studyId","diseaseIds").dropDuplicates(["studyId"])
result_df = result_df.join(si_df, on="studyId", how="left")

In [6]:
result_df.count()

                                                                                

31538858

In [7]:
fm.count()

                                                                                

31538858

In [8]:
#sgl=session.spark.read.parquet("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/EGL_all.parquet")
sgl=session.spark.read.parquet("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/20250625_EGL_2506_0.95_otg_chembl.parquet")
#sgl=session.spark.read.parquet("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/EGL_2506_lit_0.7_otg_chembl.parquet")

In [9]:
sgl.count()

                                                                                

42288

In [10]:
sgl = sgl.select("targetId","diseaseId").withColumnRenamed("targetId", "geneId_sgl")

In [11]:
joined_df = result_df.join(sgl, (f.array_contains(result_df["diseaseIds"], sgl["diseaseId"]))&(result_df["geneId"]==sgl["geneId_sgl"]),how="left")

In [12]:
df=joined_df.filter(f.col("diseaseId").isNotNull())
df=df.drop_duplicates(["studyLocusId","geneId","diseaseIds"])

In [13]:
df=df.cache()
df.count()

25/06/25 14:09:21 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

21799

In [14]:
df.drop_duplicates(["geneId","diseaseIds"]).count()

                                                                                

3090

In [15]:
study_locus_ids = df.select("studyLocusId")
study_locus_id_list = [row.studyLocusId for row in study_locus_ids.collect()]

In [16]:
len(study_locus_id_list)

21799

In [17]:
result_df.printSchema()

root
 |-- studyId: string (nullable = true)
 |-- studyLocusId: string (nullable = true)
 |-- geneId: string (nullable = true)
 |-- credibleSetConfidence: float (nullable = true)
 |-- distanceFootprintMean: float (nullable = true)
 |-- distanceFootprintMeanNeighbourhood: float (nullable = true)
 |-- distanceSentinelFootprint: float (nullable = true)
 |-- distanceSentinelFootprintNeighbourhood: float (nullable = true)
 |-- distanceSentinelTss: float (nullable = true)
 |-- distanceSentinelTssNeighbourhood: float (nullable = true)
 |-- distanceTssMean: float (nullable = true)
 |-- distanceTssMeanNeighbourhood: float (nullable = true)
 |-- eQtlColocClppMaximum: float (nullable = true)
 |-- eQtlColocClppMaximumNeighbourhood: float (nullable = true)
 |-- eQtlColocH4Maximum: float (nullable = true)
 |-- eQtlColocH4MaximumNeighbourhood: float (nullable = true)
 |-- geneCount500kb: double (nullable = true)
 |-- pQtlColocClppMaximum: float (nullable = true)
 |-- pQtlColocClppMaximumNeighbourhood:

In [18]:
fm1=result_df.filter(f.col("studyLocusId").isin(study_locus_id_list))

In [19]:
fm1.count()

25/06/25 14:11:14 WARN DAGScheduler: Broadcasting large task binary with size 1598.2 KiB
25/06/25 14:11:14 WARN DAGScheduler: Broadcasting large task binary with size 1598.1 KiB
                                                                                

892206

In [20]:
df.count()

21799

In [21]:
fm1 = fm1.join(
    df.select("studyLocusId", "geneId").withColumn("GSP", f.lit(1)),
    on=["studyLocusId", "geneId"],
    how="left"
).withColumn(
    "GSP",
    f.when(f.col("GSP").isNotNull(), 1).otherwise(0)
).cache()

# Show the result
fm1.show(1)

25/06/25 14:12:32 WARN DAGScheduler: Broadcasting large task binary with size 1617.4 KiB
25/06/25 14:12:32 WARN DAGScheduler: Broadcasting large task binary with size 1598.8 KiB
25/06/25 14:12:32 WARN DAGScheduler: Broadcasting large task binary with size 1019.7 KiB
                                                                                

+--------------------+---------------+------------+---------------------+---------------------+----------------------------------+-------------------------+--------------------------------------+-------------------+--------------------------------+---------------+----------------------------+--------------------+---------------------------------+------------------+-------------------------------+--------------+--------------------+---------------------------------+------------------+-------------------------------+---------------------+--------------------+---------------------------------+------------------+-------------------------------+----------+-----------------------+-------+--------------------+---------------+-------------+---+
|        studyLocusId|         geneId|     studyId|credibleSetConfidence|distanceFootprintMean|distanceFootprintMeanNeighbourhood|distanceSentinelFootprint|distanceSentinelFootprintNeighbourhood|distanceSentinelTss|distanceSentinelTssNeighbourhood|dista

In [22]:
fm1.count()

892206

In [23]:
fm1.filter(f.col("GSP") == 1).count()

21799

In [24]:
fm1.filter((f.col("GSP") == 1)&(f.col("isProteinCoding")==1)).count()

21778

In [25]:
fm1.write.mode("overwrite").parquet("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/not_filltered_FM_with_GSP.parquet")

                                                                                

# Making training set - step 1

In [3]:
fm=session.spark.read.parquet("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/not_filltered_FM_with_GSP.parquet")

In [4]:
fm.printSchema()

root
 |-- studyLocusId: string (nullable = true)
 |-- geneId: string (nullable = true)
 |-- studyId: string (nullable = true)
 |-- credibleSetConfidence: float (nullable = true)
 |-- distanceFootprintMean: float (nullable = true)
 |-- distanceFootprintMeanNeighbourhood: float (nullable = true)
 |-- distanceSentinelFootprint: float (nullable = true)
 |-- distanceSentinelFootprintNeighbourhood: float (nullable = true)
 |-- distanceSentinelTss: float (nullable = true)
 |-- distanceSentinelTssNeighbourhood: float (nullable = true)
 |-- distanceTssMean: float (nullable = true)
 |-- distanceTssMeanNeighbourhood: float (nullable = true)
 |-- eQtlColocClppMaximum: float (nullable = true)
 |-- eQtlColocClppMaximumNeighbourhood: float (nullable = true)
 |-- eQtlColocH4Maximum: float (nullable = true)
 |-- eQtlColocH4MaximumNeighbourhood: float (nullable = true)
 |-- geneCount500kb: double (nullable = true)
 |-- pQtlColocClppMaximum: float (nullable = true)
 |-- pQtlColocClppMaximumNeighbourhood:

In [5]:
fm.count()

                                                                                

892206

In [6]:
fm.filter(f.col("GSP") == 1).count()

                                                                                

21799

# Patch FM - SKIP

In [31]:
fm.count()

                                                                                

991408

In [32]:
from gentropy.dataset.target_index import TargetIndex
target_index_path=path_to_release_folder+"output/target/"
target_index = TargetIndex.from_parquet(
                session, target_index_path, recursiveFileLookup=True
            )

                                                                                

In [33]:
target_index_df = target_index.df.select("id", "biotype").withColumnRenamed(
    "id", "geneId"
)

target_index_df = target_index_df.withColumn(
    "isProteinCoding",
    f.when(f.col("biotype") == "protein_coding", 1).otherwise(0),
).drop("biotype")

fm= fm.drop("isProteinCoding").join(
    target_index_df, on="geneId", how="inner"
).cache()
fm.count()

                                                                                

991408

# Only qulified studies and measurments - SKIP

In [6]:
si.df.filter(f.col("studyType")=="gwas").count()

                                                                                

96404

In [7]:
qsi=session.spark.read.parquet("gs://genetics-portal-dev-analysis/yt4/20250403_for_gentropy_paper/qualified_studies_with_oncology")
qsi_measurements=session.spark.read.parquet("gs://genetics-portal-dev-analysis/yt4/20250403_for_gentropy_paper/qualified_studies_measurements")

In [8]:
qsi.count()

                                                                                

7300

In [9]:
qsi_measurements.count()

60451

In [10]:
qsi=qsi.union(qsi_measurements).distinct().cache()
qsi.count()

                                                                                

67751

In [11]:
fm=fm.join(qsi, on="studyId", how="inner").cache()
fm.count()

25/05/12 22:56:34 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

525524

In [12]:
fm.filter(f.col("GSP") == 1).count()

12520

# Replication criteria

In [7]:
repl_cs=session.spark.read.parquet("gs://genetics-portal-dev-analysis/yt4/20250403_for_gentropy_paper/list_of_gwas_replicated_CSs.parquet")

In [8]:
repl_cs.count()

263705

In [9]:
fm=fm.join(repl_cs, on="studyLocusId", how="inner").cache()
fm.count()

25/06/25 14:50:09 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

524141

In [10]:
fm.filter(f.col("GSP") == 1).count()

12883

# Max two GSP per CS

In [11]:
grouped = fm.filter(f.col("GSP")==1).groupBy("studyLocusId").agg(f.count("*").alias("count"))
filtered = grouped.filter(f.col("count")>2).select("studyLocusId").cache()
filtered.count()

                                                                                

142

In [12]:
grouped = fm.filter(f.col("GSP")==1).groupBy("studyLocusId").agg(f.count("*").alias("count"))
filtered = grouped.filter(f.col("count")>0).filter(f.col("count")<=2).select("studyLocusId").cache()
filtered.count()

                                                                                

12011

In [13]:
filtered.show(1)

+--------------------+
|        studyLocusId|
+--------------------+
|235e8ce166619f33e...|
+--------------------+
only showing top 1 row



In [14]:
fm=fm.join(filtered, on="studyLocusId", how="inner").cache()
fm.count()

                                                                                

517925

In [15]:
fm.filter((f.col("GSP")==1)).count()

12414

# String filter

In [16]:
inter=session.spark.read.parquet(path_to_release_folder+"/output/interaction/")
inter.show(1)

[Stage 50:>                                                         (0 + 1) / 1]

+--------------+---------------+---------------+------------------+---------------+---------------+------------------+--------------------+--------------------+-----+-------+
|sourceDatabase|        targetA|           intA|intABiologicalRole|        targetB|           intB|intBBiologicalRole|            speciesA|            speciesB|count|scoring|
+--------------+---------------+---------------+------------------+---------------+---------------+------------------+--------------------+--------------------+-----+-------+
|        string|ENSG00000004059|ENSP00000000233|  unspecified role|ENSG00000162222|ENSP00000325266|  unspecified role|{human, Homo sapi...|{human, Homo sapi...|    2|   0.18|
+--------------+---------------+---------------+------------------+---------------+---------------+------------------+--------------------+--------------------+-----+-------+
only showing top 1 row



                                                                                

In [17]:
inter.groupBy("sourceDatabase").agg(f.count("*").alias("count")).show()



+--------------+--------+
|sourceDatabase|   count|
+--------------+--------+
|        string|13217122|
|        intact| 1215084|
|        signor|   36648|
|      reactome|   55172|
+--------------+--------+



                                                                                

In [18]:
inter=inter.filter((f.col("sourceDatabase")=="string")&(f.col("scoring")>=0.8)&(~(f.col("targetA")==f.col("targetB")))).select("targetA","targetB").distinct().cache()
#inter=inter.filter(~(f.col("targetA")==f.col("targetB"))).select("targetA","targetB").distinct().cache()
inter.count()

                                                                                

314292

In [19]:
inter.show(1)

+---------------+---------------+
|        targetA|        targetB|
+---------------+---------------+
|ENSG00000067646|ENSG00000092377|
+---------------+---------------+
only showing top 1 row



In [20]:
fm1=fm.filter(f.col("GSP")==1).select("geneId","studyLocusId").cache()
fm1.count()

12414

In [21]:
fm1.select("geneId").distinct().count()

394

In [22]:
result = fm1.join(
    inter,
    fm1["geneId"] == inter["targetA"],
    how="left"
).select(
    "geneId",
    "studyLocusId",
    "targetB"
).cache()

result.count()

                                                                                

423729

In [23]:
result.filter(f.col("targetB").isNull()).count()

120

In [24]:
result.filter(f.col("targetB")==f.col("geneId")).count()

0

In [25]:
result.select(f.col("geneId")).distinct().count()

394

In [26]:
result=result.filter(f.col("targetB").isNotNull()).select("targetB","studyLocusId").withColumnRenamed("targetB","geneId").cache()
result.count()

423609

In [27]:
fm_s=fm.filter(f.col("GSP")==0).select("geneId","studyLocusId").cache()
fm_s.count()

505511

In [28]:
fm_s=fm_s.union(result.select("geneId","studyLocusId")).cache()
fm_s.count()

                                                                                

929120

In [29]:
grouped = fm_s.groupBy("geneId", "studyLocusId").agg(f.count("*").alias("count"))
filtered = grouped.filter(f.col("count") >= 2).select("geneId", "studyLocusId").cache()
filtered.count()

                                                                                

8657

In [30]:
fm.filter(f.col("GSP")==1).count()

12414

In [31]:
fm.count()

517925

In [32]:
fm_no_inter=fm.join(filtered, on=["geneId", "studyLocusId"], how="anti").cache()
fm_no_inter.count()

                                                                                

513018

In [33]:
fm_no_inter.filter(f.col("GSP")==1).count()

12414

In [34]:
fm_no_inter.filter(f.col("GSP")==1).select(f.col("geneId")).distinct().count()

394

In [35]:
fm_no_inter.filter(f.col("GSP")==1).select("geneId","diseaseIds").distinct().count()

1400

In [36]:
fm_no_inter.write.mode("overwrite").parquet("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/preset_no_interaction.parquet")

                                                                                

# Distance criteria

In [37]:
fm_no_inter=session.spark.read.parquet("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/preset_no_interaction.parquet")

In [38]:
fm_no_inter.printSchema()

root
 |-- geneId: string (nullable = true)
 |-- studyLocusId: string (nullable = true)
 |-- studyId: string (nullable = true)
 |-- credibleSetConfidence: float (nullable = true)
 |-- distanceFootprintMean: float (nullable = true)
 |-- distanceFootprintMeanNeighbourhood: float (nullable = true)
 |-- distanceSentinelFootprint: float (nullable = true)
 |-- distanceSentinelFootprintNeighbourhood: float (nullable = true)
 |-- distanceSentinelTss: float (nullable = true)
 |-- distanceSentinelTssNeighbourhood: float (nullable = true)
 |-- distanceTssMean: float (nullable = true)
 |-- distanceTssMeanNeighbourhood: float (nullable = true)
 |-- eQtlColocClppMaximum: float (nullable = true)
 |-- eQtlColocClppMaximumNeighbourhood: float (nullable = true)
 |-- eQtlColocH4Maximum: float (nullable = true)
 |-- eQtlColocH4MaximumNeighbourhood: float (nullable = true)
 |-- geneCount500kb: double (nullable = true)
 |-- pQtlColocClppMaximum: float (nullable = true)
 |-- pQtlColocClppMaximumNeighbourhood:

In [39]:
fm_no_inter.filter(((f.col("GSP")==1)&(f.col("distanceSentinelFootprint")==0))).count()

                                                                                

222

In [40]:
fm_0=fm_no_inter.filter(~((f.col("GSP")==1)&(f.col("distanceSentinelFootprint")==0))).cache()
fm_0.count()

                                                                                

512796

In [41]:
fm_0.filter(f.col("GSP")==1).select("geneId","diseaseIds").distinct().count()

1377

In [42]:
fm_0.filter(f.col("GSP")==1).count()

12192

In [43]:
fm_0.filter(f.col("GSP")==0).count()

500604

In [44]:
fm_0.filter((f.col("GSP")==1)&(f.col("isProteinCoding")==1)).count()

12192

In [45]:
fm_0.filter((f.col("GSP")==1)&(f.col("isProteinCoding")==0)).count()

0

In [46]:
fm_0.write.mode("overwrite").parquet("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/preset_distance_filtered_all_genes.parquet")

                                                                                

# Protein-coding only genes

In [47]:
fm=session.spark.read.parquet("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/preset_distance_filtered_all_genes.parquet")
#fm=session.spark.read.parquet("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/preset_no_interaction.parquet")

In [48]:
fm_0=fm.filter(f.col("isProteinCoding")==1).cache()
fm_0.count()

                                                                                

189240

In [49]:
grouped = fm_0.filter(f.col("GSP")==1).groupBy("studyLocusId").agg(f.count("*").alias("count"))
filtered = grouped.filter(f.col("count")>0).filter(f.col("count")<=2).select("studyLocusId").cache()
filtered.count()

11834

In [50]:
filtered.show(1)

+--------------------+
|        studyLocusId|
+--------------------+
|235e8ce166619f33e...|
+--------------------+
only showing top 1 row



In [51]:
fm_0=fm_0.join(filtered, on="studyLocusId", how="inner").cache()
fm_0.count()

184791

In [52]:
fm_0.write.mode("overwrite").parquet("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/prefinal_training_set_protein_coding_only_genes.parquet")

                                                                                

# Patching - removing duplicates

In [53]:
fm=session.spark.read.parquet("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/prefinal_training_set_protein_coding_only_genes.parquet")
fm.count()

184791

In [54]:
fm=fm.join(sl.df.select("studyLocusId","variantId"), on="studyLocusId", how="left").cache()
fm.count()

                                                                                

184791

In [55]:
n1 = fm.count()
tmp1 = fm.filter(f.col("GSP") == 1)
colnms = ["geneId", "diseaseIds", "variantId", "vepMaximum", "vepMean"]
colnms_to_round = [
    "eQtlColocClppMaximum", "pQtlColocClppMaximum", "sQtlColocClppMaximum",
    "eQtlColocH4Maximum", "pQtlColocH4Maximum", "sQtlColocH4Maximum"
]

                                                                                

In [56]:
# Step 4: Round specified columns to 2 decimal places
for col in colnms_to_round:
    tmp1 = tmp1.withColumn(col, f.round(f.col(col), 2))

In [57]:
# Step 5: Remove duplicates based on specified columns
tmp1 = tmp1.dropDuplicates(colnms + colnms_to_round)

# Step 6: Get unique studyLocusId values to keep
cs_to_keep = tmp1.select("studyLocusId").distinct()

# Step 7: Filter the original DataFrame to keep only rows with studyLocusId in cs_to_keep
fm_0 = fm.join(cs_to_keep, on="studyLocusId", how="inner")

In [58]:
# Step 8: Get the new row count
n2 = fm_0.count()

# Step 9: Print the reduction percentage
reduction_percentage = round(((n1 - n2) / n1) * 100, 2)
print(f"Reduced by: {reduction_percentage}%")

                                                                                

Reduced by: 28.04%


In [59]:
fm_0.count()

                                                                                

132970

In [60]:
fm_0.filter(f.col("GSP")==1).count()

                                                                                

8520

In [61]:
fm_0.filter(f.col("GSP")==1).select("geneId","diseaseIds").distinct().count()

                                                                                

1377

In [62]:
fm_0.filter(f.col("GSP")==1).select("geneId").distinct().count()

                                                                                

390

In [63]:
fm_0.filter(f.col("GSP")==1).count()

                                                                                

8520

In [64]:
fm_0.filter(f.col("GSP")==0).count()

                                                                                

124450

In [65]:
fm_0.write.parquet("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/250624_training_set_full_fm.parquet",mode="overwrite")

                                                                                

# Convert to JSON

In [66]:
fm_0=session.spark.read.parquet("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/250624_training_set_full_fm.parquet")

In [67]:
x=fm_0.select("studyLocusId","geneId","GSP","diseaseIds","variantId","studyId")

In [68]:
x.count()

                                                                                

132970

In [69]:
x.show(1)

+--------------------+---------------+---+--------------------+---------------+------------+
|        studyLocusId|         geneId|GSP|          diseaseIds|      variantId|     studyId|
+--------------------+---------------+---+--------------------+---------------+------------+
|08ef835a25f0bf2c8...|ENSG00000130173|  0|[EFO_0004611, EFO...|19_11079858_G_A|GCST90091598|
+--------------------+---------------+---+--------------------+---------------+------------+
only showing top 1 row



In [70]:
result_df = x.withColumn(
    "goldStandardSet",
    f.when(f.col("GSP") == 1, "positive").otherwise("negative")
).drop("GSP").cache()
result_df.count()

                                                                                

132970

In [71]:
result_df.show(3)

+--------------------+---------------+--------------------+---------------+------------+---------------+
|        studyLocusId|         geneId|          diseaseIds|      variantId|     studyId|goldStandardSet|
+--------------------+---------------+--------------------+---------------+------------+---------------+
|08ef835a25f0bf2c8...|ENSG00000130173|[EFO_0004611, EFO...|19_11079858_G_A|GCST90091598|       negative|
|08ef835a25f0bf2c8...|ENSG00000127616|[EFO_0004611, EFO...|19_11079858_G_A|GCST90091598|       negative|
|08ef835a25f0bf2c8...|ENSG00000130159|[EFO_0004611, EFO...|19_11079858_G_A|GCST90091598|       negative|
+--------------------+---------------+--------------------+---------------+------------+---------------+
only showing top 3 rows



In [73]:
result_df.groupBy("goldStandardSet").count().show()

+---------------+------+
|goldStandardSet| count|
+---------------+------+
|       positive|  8520|
|       negative|124450|
+---------------+------+



In [74]:
result_df.write.json("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/20250625_gentropy_paper_v1.json",mode="overwrite")

                                                                                

In [75]:
result_df.write.parquet("../gentropy_paper/data/20250625_gentropy_paper_v1",mode="overwrite")

# Deduplicating futher - SKIP

In [71]:
train_set=session.spark.read.json("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/250616_traing_set.json")

                                                                                

In [72]:
from pyspark.sql import Window
tmp1=train_set.filter(f.col("goldStandardSet")=="positive")
tmp1=tmp1.select("studyLocusId", "geneId", "diseaseIds", "variantId")
tmp1.count()

                                                                                

8732

In [73]:
window_spec = Window.partitionBy("geneId", "diseaseIds", "variantId").orderBy(f.lit(1))
tmp1 = tmp1.withColumn("row_number", f.row_number().over(window_spec))

tmp1 = tmp1.filter(f.col("row_number") == 1).drop("row_number")

tmp1=tmp1.select("studyLocusId").distinct().cache()
tmp1.count()

                                                                                

4661

In [74]:
train_set=train_set.join(tmp1, on="studyLocusId", how="inner").cache()
train_set.count()

                                                                                

55623

In [75]:
train_set.groupBy("goldStandardSet").count().show()

+---------------+-----+
|goldStandardSet|count|
+---------------+-----+
|       positive| 4661|
|       negative|50962|
+---------------+-----+



In [76]:
train_set.write.json("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/250616_traing_set_dedupl.json",mode="overwrite")

                                                                                

25/06/16 17:46:48 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 878176 ms exceeds timeout 120000 ms
25/06/16 17:46:48 WARN SparkContext: Killing executors is not supported by current scheduler.
25/06/16 17:46:57 ERROR Inbox: Ignoring error
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:56)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:310)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRefByURI(RpcEnv.scala:102)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRef(RpcEnv.scala:110)
	at org.apache.spark.util.RpcUtils$.makeDriverRef(RpcUtils.scala:36)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.driverEndpoint$lzycompute(BlockManagerMasterEndpoint.scala:124)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.org$apache$spark$storage$BlockManagerMasterEndpoint$$

# Remove literature form traing set - SKIP

In [74]:
train_set=session.spark.read.json("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/2506_traing_set.json")

                                                                                

In [75]:
new_egl=session.spark.read.parquet("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/EGL_all_no_literature_0_95.parquet")

In [76]:
new_egl.show(1)

+-----------+---------------+
|  diseaseId|       targetId|
+-----------+---------------+
|EFO_0000289|ENSG00000115159|
+-----------+---------------+
only showing top 1 row



In [77]:
train_set.show(1)

+-------------+---------------+---------------+------------+--------------------+---------------+
|   diseaseIds|         geneId|goldStandardSet|     studyId|        studyLocusId|      variantId|
+-------------+---------------+---------------+------------+--------------------+---------------+
|[EFO_0009767]|ENSG00000021826|       positive|GCST90128889|028243d6514e41ef0...|2_210743655_T_G|
+-------------+---------------+---------------+------------+--------------------+---------------+
only showing top 1 row



In [78]:
train_set_p=train_set.filter(f.col("goldStandardSet")=="positive").cache()
train_set_p.count()

                                                                                

7448

In [79]:
expanded_df = train_set_p.withColumn("diseaseId", f.explode(f.col("diseaseIds"))).cache()
expanded_df.count()

11110

In [80]:
expanded_df.show(1)

+-------------+---------------+---------------+------------+--------------------+---------------+-----------+
|   diseaseIds|         geneId|goldStandardSet|     studyId|        studyLocusId|      variantId|  diseaseId|
+-------------+---------------+---------------+------------+--------------------+---------------+-----------+
|[EFO_0009767]|ENSG00000021826|       positive|GCST90128889|028243d6514e41ef0...|2_210743655_T_G|EFO_0009767|
+-------------+---------------+---------------+------------+--------------------+---------------+-----------+
only showing top 1 row



In [81]:
filtered_df = expanded_df.join(
    new_egl.withColumnRenamed("targetId", "geneId"),
    on=["diseaseId", "geneId"],
    how="inner"
)
filtered_df.count()

                                                                                

6875

In [82]:
slids=filtered_df.select("studyLocusId").distinct().cache()
slids.count()

                                                                                

6809

In [83]:
train_set.select("studyLocusId").distinct().count()

                                                                                

7448

In [84]:
train_set=train_set.join(slids, on="studyLocusId", how="inner").cache()
train_set.count()

                                                                                

75460

In [85]:
train_set.filter(f.col("goldStandardSet")=="positive").select("geneId","diseaseIds").distinct().count()

1121

In [86]:
train_set.write.json("gs://genetics-portal-dev-analysis/yt4/2506_release/training_set/2506_traing_set_no_literature_0_95.json",mode="overwrite")

                                                                                

In [8]:
sorted_df = train_set.withColumn("sorted_array", f.array_sort(f.col("diseaseIds"))).withColumn("first_element", f.element_at(f.col("diseaseIds"), 1))


In [5]:
train_set.filter(f.col("goldStandardSet")=="positive").select("geneId","diseaseIds").distinct().count()

                                                                                

1408

In [9]:
sorted_df.filter(f.col("goldStandardSet")=="positive").select("geneId","sorted_array").distinct().count()

                                                                                

1408

In [10]:
sorted_df.filter(f.col("goldStandardSet")=="positive").select("geneId","first_element").distinct().count()

                                                                                

824

In [12]:
sorted_df.filter(f.col("goldStandardSet")=="positive").select("geneId","sorted_array").distinct().show(5,truncate=False)



+---------------+--------------------------+
|geneId         |sorted_array              |
+---------------+--------------------------+
|ENSG00000066336|[MONDO_0004975]           |
|ENSG00000130164|[EFO_0004530, EFO_0004611]|
|ENSG00000101670|[EFO_0004612, EFO_0008591]|
|ENSG00000084674|[EFO_0004611, EFO_0020946]|
|ENSG00000165029|[EFO_0004612, EFO_0008589]|
+---------------+--------------------------+
only showing top 5 rows



                                                                                

In [13]:
sorted_df.filter(f.col("goldStandardSet")=="positive").show(10,truncate=False)

+--------------------------+---------------+---------------+--------------------------------------+--------------------------------+----------------+--------------------------+-------------+
|diseaseIds                |geneId         |goldStandardSet|studyId                               |studyLocusId                    |variantId       |sorted_array              |first_element|
+--------------------------+---------------+---------------+--------------------------------------+--------------------------------+----------------+--------------------------+-------------+
|[EFO_0009767]             |ENSG00000021826|positive       |GCST90128889                          |028243d6514e41ef0c9e571a10dcadee|2_210743655_T_G |[EFO_0009767]             |EFO_0009767  |
|[EFO_0004612, EFO_0010351]|ENSG00000101670|positive       |GCST90092946                          |0b022fe83fc2b9304f9835d401aa777b|18_49571205_T_C |[EFO_0004612, EFO_0010351]|EFO_0004612  |
|[EFO_0004611]             |ENSG00000169174|p

In [4]:
training_set=session.spark.read.json("gs://genetics-portal-dev-analysis/yt4/20241024_EGL_playground/training_set/patched_training_2503-testrun-1_all_string_extended_EGL_variants.json")

                                                                                

In [5]:
training_set.groupBy("goldStandardSet").count().show()



+---------------+-----+
|goldStandardSet|count|
+---------------+-----+
|       positive| 6616|
|       negative|98632|
+---------------+-----+



                                                                                

14.908101571946796

In [7]:
training_set.write.parquet("../gentropy_paper/data/2503_training_set",mode="overwrite")

                                                                                