In [None]:
# Generate new gold standards from chembl:
from pyspark.sql import Window
from pyspark.sql.types import StructType, StructField, StringType
from pyspark.sql import SparkSession

from pyspark.sql import Window
from pyspark.sql.functions import row_number, desc, col

from pyspark.sql.functions import explode
import sys
from gentropy.common.session import Session
from pyspark.sql import functions as f

from pyspark import SparkConf
from pyspark.sql import SparkSession
app_name = "example_app"
CREDENTIALS = "/Users/xg1/.config/gcloud/service_account_credentials.json" 


GCS_CONNECTOR_CONF = {
    "spark.hadoop.fs.gs.impl": "com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem",
    "spark.jars": "https://storage.googleapis.com/hadoop-lib/gcs/gcs-connector-hadoop3-latest.jar",
    "spark.hadoop.google.cloud.auth.service.account.enable": "true",
    "spark.hadoop.google.cloud.auth.service.account.json.keyfile": CREDENTIALS,
}

extended_spark_conf = {
    "spark.driver.memory": "12g",
    "spark.kryoserializer.buffer.max": "500m",
    "spark.driver.maxResultSize": "3g",
}

# Combine both configurations
combined_conf = {**GCS_CONNECTOR_CONF, **extended_spark_conf}

# Initialize SparkConf with the combined configuration
spark_config = SparkConf().setAll(combined_conf.items())

# Create the Spark session
session = SparkSession.builder.config(conf=spark_config).appName(app_name).getOrCreate()
#spark_config = SparkConf().setAll(GCS_CONNECTOR_CONF.items())
#session = SparkSession.builder.config(conf=spark_config).appName("example").getOrCreate()

def TA_OncoOrNot (diseases):
    ### create a dataframe asigning TA code, names and Oncology/Other 
    taDf = session.createDataFrame(
        data=[
            ("MONDO_0045024", "cell proliferation disorder", "Oncology"),
            ("EFO_0005741", "infectious disease", "Other"),
            ("OTAR_0000014", "pregnancy or perinatal disease", "Other"),
            ("EFO_0005932", "animal disease", "Other"),
            ("MONDO_0024458", "disease of visual system", "Other"),
            ("EFO_0000319", "cardiovascular disease", "Other"),
            ("EFO_0009605", "pancreas disease", "Other"),
            ("EFO_0010282", "gastrointestinal disease", "Other"),
            ("OTAR_0000017", "reproductive system or breast disease", "Other"),
            ("EFO_0010285", "integumentary system disease", "Other"),
            ("EFO_0001379", "endocrine system disease", "Other"),
            ("OTAR_0000010", "respiratory or thoracic disease", "Other"),
            ("EFO_0009690", "urinary system disease", "Other"),
            ("OTAR_0000006", "musculoskeletal or connective tissue disease", "Other"),
            ("MONDO_0021205", "disease of ear", "Other"),
            ("EFO_0000540", "immune system disease", "Other"),
            ("EFO_0005803", "hematologic disease", "Other"),
            ("EFO_0000618", "nervous system disease", "Other"),
            ("MONDO_0002025", "psychiatric disorder", "Other"),
            ("MONDO_0024297", "nutritional or metabolic disease", "Other"),
            ("OTAR_0000018", "genetic, familial or congenital disease", "Other"),
            ("OTAR_0000009", "injury, poisoning or other complication", "Other"),
            ("EFO_0000651", "phenotype", "Other"),
            ("EFO_0001444", "measurement", "Other"),
            ("GO_0008150", "biological process", "Other"),
        ],
        schema=StructType(
            [
                StructField("taId", StringType(), True),
                StructField("taLabel", StringType(), True),
                StructField("taLabelSimple", StringType(), True),
            ]
        ),
    ).withColumn("taRank", f.monotonically_increasing_id())

    ### window over disease to take Oncology VS non oncology
    wByDisease = Window.partitionBy("diseaseId")
  ### explode therapy areas of diseases and joining the dataframe, categorise them between Onco or Others
    return (
        diseases.withColumn("taId", f.explode("therapeuticAreas"))
        .select(f.col("id").alias("diseaseId"), "taId", "parents")
        .join(taDf, on="taId", how="left")
        .withColumn("minRank", f.min("taRank").over(wByDisease))
        .filter(f.col("taRank") == f.col("minRank"))
        .drop("taRank", "minRank")
    )

diseases=session.read.parquet("/Users/xg1/Downloads/otg_releases/diseases.parquet")
TA_diseases=TA_OncoOrNot(diseases).select("diseaseId", "taLabelSimple").persist()

windowSpec = Window.partitionBy("diseaseFromSourceMappedId", "targetId")

In [None]:
sys.path.append("../../gentropy/src/")
release_path="../../otg_releases"
release_ver="2406"


chembl_evidence=session.read.parquet(
    "/Users/xg1/Downloads/platform_2409_evidence/evidence/sourceId\=chembl").select(
        "targetId", "drugId", "clinicalPhase", 
        "diseaseFromSourceMappedId", 
        "diseaseFromSource", 
        "clinicalStatus").withColumn(
            "maxClinicalPhase", f.max("clinicalPhase").over(windowSpec)).filter(
                f.col("clinicalPhase") == f.col("maxClinicalPhase")).drop("clinicalPhase", "drugId").distinct()

chembl_evidence_noOncology=chembl_evidence.join(TA_diseases.withColumnRenamed(
    "diseaseId", "diseaseFromSourceMappedId"), on="diseaseFromSourceMappedId", how="inner").filter(
    (f.col("taLabelSimple") != "Oncology") | f.col("taLabelSimple").isNull())

chembl_evidence_noOncology.withColumnRenamed(
    "diseaseFromSourceMappedId", "efo_terms").withColumnRenamed(
        "targetId", "geneId").filter(f.col("maxClinicalPhase") >= 3).filter(f.col("clinicalStatus") == "Completed").select("efo_terms", "geneId").distinct().count()

In [None]:
GS=session.read.json(
    f"{release_path}/{release_ver}/locus_to_gene_gold_standard.json")

selected_df = GS.filter(f.col("gold_standard_info.highest_confidence").isin(["High", "Medium"])).select(
    "gold_standard_info.gene_id",
    "trait_info.ontology",
    "trait_info.reported_trait_name"
).withColumnRenamed(
    "ontology", "efo_terms").withColumnRenamed(
        "gene_id", "geneId").withColumnRenamed(
            "reported_trait_name", "diseaseFromSource").select(
                "efo_terms", "geneId", "diseaseFromSource").withColumn(
                    "GS_source", f.lit("old_otg_gs"))

chembl_GS=chembl_evidence_noOncology.withColumnRenamed(
    "diseaseFromSourceMappedId", "efo_terms").withColumnRenamed(
        "targetId", "geneId").filter(f.col("maxClinicalPhase") >= 3).filter(f.col("clinicalStatus") == "Completed").select("efo_terms", "geneId", "diseaseFromSource").distinct().withColumn(
                    "GS_source", f.lit("chembl_p3_p4_2409"))

expanded_df = selected_df.withColumn("efo_terms", f.explode("efo_terms"))
new_egl=expanded_df.unionByName(chembl_GS)

new_egl.write.mode("overwrite").parquet("/Users/xg1/Downloads/Effector_gene_list.parquet")

In [None]:
# Checkpoint for selecting Goldstandard positives from the effector gene list.
new_egl=session.read.parquet("/Users/xg1/Downloads/Effector_gene_list.parquet")
fm=session.read.parquet("gs://ot_orchestration/releases/24.10_freeze4/locus_to_gene_feature_matrix")
study_index=session.read.parquet("gs://ot_orchestration/releases/24.10_freeze4/study_index")
credible_set=session.read.parquet("gs://ot_orchestration/releases/24.10_freeze4/credible_set")

study_to_credible_set=credible_set.select("studyLocusId", "studyId").join(
    study_index.filter(f.col("studyType") == "gwas").select("studyId", "traitFromSource", "traitFromSourceMappedIds").withColumn("traitFromSourceMappedId", f.explode("traitFromSourceMappedIds")).drop("traitFromSourceMappedIds"), on="studyId", how="inner").persist()

study_to_credible_set_fm=study_to_credible_set.join(fm, on="studyLocusId", how="inner").persist()

GS_fm=study_to_credible_set_fm.join(new_egl.withColumnRenamed("efo_terms", "traitFromSourceMappedId"), on=["traitFromSourceMappedId", "geneId"], how="inner").persist()

# Define the prioritised features
Prioritised_features = ["eQtlColocClppMaximum", "eQtlColocH4Maximum", "pQtlColocClppMaximum", 
                        "pQtlColocH4Maximum", "sQtlColocClppMaximum", "sQtlColocH4Maximum", 
                        "vepMaximum", "vepMean"]

null_or_zero_count_per_row = sum(f.when((f.col(c).isNull()) | (f.col(c) == 0), 0).otherwise(1) for c in Prioritised_features)
GS_fm_with_count = GS_fm.withColumn("null_or_zero_count", null_or_zero_count_per_row)

window_spec = Window.partitionBy("traitFromSourceMappedId", "geneId").orderBy(f.col("null_or_zero_count").desc())
ranked_df = GS_fm_with_count.withColumn("row_number", f.row_number().over(window_spec))
filtered_df = ranked_df.filter(f.col("row_number") == 1).drop("row_number")

gs_source_counts = filtered_df.groupBy("null_or_zero_count").count()
gs_source_counts.orderBy(f.col("null_or_zero_count")).show(50)

In [None]:
filtered_df.filter(f.col("null_or_zero_count") > 0).write.parquet("/Users/xg1/Downloads/feature_matrix_gsp.parquet")
filtered_df.filter(f.col("null_or_zero_count") > 0).select("traitFromSourceMappedId", "geneId", "studyLocusId", "studyId", "traitFromSource", "GS_source").write.parquet("/Users/xg1/Downloads/EGL_GSP.parquet")