In [None]:
import pyspark as ps
import polars as pl

from pyspark.sql import functions as F

In [None]:
spark = ps.sql.SparkSession.builder.getOrCreate()

In [None]:
kgml = pl.read_parquet('gs://mtrx-us-central1-hub-dev-storage/kedro/data/releases/v0.11.2/datasets/integration/int/kgml_xdtd_ground_truth/edges.norm/')

In [None]:
off_label = pl.read_parquet('gs://mtrx-us-central1-hub-dev-storage/kedro/data/releases/v0.11.2/datasets/integration/int/off_label/edges.norm/')

In [None]:
db_gt = pl.read_parquet('gs://mtrx-us-central1-hub-prod-storage/kedro/data/releases/v0.11.11-prod/datasets/integration/int/drugbank_ground_truth/edges.norm/')

In [None]:
db_gt

In [None]:
db_gt.filter(pl.col("y") == 1).select("subject", "object", "y")

In [None]:
kgml_df = spark.createDataFrame(kgml.to_pandas())
kgml_df.show()

In [None]:
db_gt = spark.createDataFrame(db_gt.to_pandas())

In [None]:
from functools import reduce

In [None]:
datasets_to_include = {
    'kgml_ground_truth': 
    {"positives": True, "negatives": False}, 
    'db_ground_truth': 
    {"positives": True, "negatives": True}
    }

In [None]:
all_datasets = {
    "kgml_ground_truth": kgml_df,
    "db_ground_truth": db_gt
}

In [None]:
included_dataset_list = []
for dataset_name, dataset in all_datasets.items():
    y_values_required = [
        y for y, pair_type in zip([0,1], ["negatives", "positives"]) if datasets_to_include[dataset_name][pair_type]
        ] 
    if len(y_values_required) > 0:
        included_dataset_list.append(
            dataset
            .filter(F.col("y").isin(y_values_required))
            .select(F.col("subject").alias("drug_id"), F.col("object").alias("disease_id"))
        )
reduce(lambda df1, df2: df1.union(df2), included_dataset_list).distinct().show()

In [None]:
y_values_required = {
    name : [
        y for y, pair_type in zip([0,1], ["negatives", "positives"]) 
        if datasets_to_include[name][pair_type]
    ] 
    for name in all_datasets.keys()
}
dataframes_to_concatenate = [
        df
        .filter(F.col("y").isin(y_values_required[name]))
        .select(F.col("subject").alias("drug_id"), F.col("object").alias("disease_id"))
        for name, df in all_datasets.items()
        if len(y_values_required[name]) > 0
]
reduce(lambda df1, df2: df1.union(df2), dataframes_to_concatenate).distinct().show()

In [None]:
y_values = {dataset_name : []
}


In [None]:
for dataset

In [None]:
datasets_to_include = [
    dataset
    .filter(F.col("y").isin(y_values[dataset_name]))
    .select(F.col("subject").alias("drug_id"), F.col("object").alias("disease_id"))
    for dataset_name, dataset in all_datasets.items()
    if len(y_values[dataset_name]) > 0
]

In [None]:
reduce(lambda df1, df2: df1.union(df2), [kgml_df, db_gt])

## Drug expansion

In [None]:
drugs_pl = pl.read_parquet("gs://mtrx-us-central1-hub-dev-storage/kedro/data/releases/v0.13.0/datasets/integration/int/drug_list/nodes.norm/")

In [None]:
drugs_pl

In [None]:
# Load drugs list with ATC codes
drugs_list_with_atc_new_raw = pl.read_csv('gs://mtrx-us-central1-wg2-modeling-dev-storage/known_entities/data/atc_codes.csv')

# Prepare drugs list with new ATC codes
atc_new_col = [list(eval(x)) if isinstance(x, str) else x for x in drugs_list_with_atc_new_raw['atc_name'].to_list()]
drugs_list_with_atc_new = drugs_list_with_atc_new_raw.select(pl.col("translator_id").alias("drug_id"), pl.col("name").alias("drug_name")).with_columns(atc_codes = pl.Series(atc_new_col))

# Explode multiple ATC codes into multiple rows
drugs_list_with_atc_new = drugs_list_with_atc_new.explode("atc_codes").rename({"atc_codes": "atc_level_5"})

# Fill in lower levels by removing letters from the end of the code
drugs_list_with_atc_new = drugs_list_with_atc_new.with_columns(
    atc_level_1 = pl.col("atc_level_5").str.slice(0, 1),
    atc_level_2 = pl.col("atc_level_5").str.slice(0, 3),
    atc_level_3 = pl.col("atc_level_5").str.slice(0, 4),
    atc_level_4 = pl.col("atc_level_5").str.slice(0, 5)
)

In [None]:
drugs_list_with_atc_new_raw.select("id").unique()

In [None]:
drugs_list_with_atc_new.group_by("atc_level_5").count().sort("count", descending=True)

Conclusion: Don't do drug expansion 

### Diseases

In [None]:
import owlready2

In [None]:
class OntologyMONDO:
    """A class to get ancestor and descendent IDs from the MONDO ontology."""

    def __init__(self, owl_url: str = 'https://purl.obolibrary.org/obo/mondo.owl'):
        """Initialize the ontology with the given URL.
        
        Args: 
            owl_url: A URL to download the MONDO ontology in OWL format. 
        """
        self.ont = owlready2.get_ontology(owl_url).load()

    @staticmethod
    def _get_ids_from_owl_things(owl_things: list[owlready2.entity.ThingClass]) -> list[str]:
        return [thing.id[0] for thing in owl_things if hasattr(thing, "id")]

    def get_related_ids(self, mondo_id: str) -> list[str]:
        mondo_class = self.ont.search_one(id=mondo_id)
        if mondo_class is None:
            return {
                'ancestors': [],
                'descendants': []
            }
        return {
            'ancestors': self._get_ids_from_owl_things(mondo_class.is_a),
            'descendants': self._get_ids_from_owl_things(mondo_class.subclasses())
        }

    def get_equivalent_mondo_ids(self, mondo_id: str) -> list[str]:
        related_ids = self.get_related_ids(mondo_id)
        return list(set(related_ids['ancestors'] + related_ids['descendants']))
        

In [None]:
off_label = pl.read_parquet('gs://mtrx-us-central1-hub-dev-storage/kedro/data/releases/v0.11.2/datasets/integration/int/off_label/edges.norm/')

In [None]:
diseases = off_label.select("object").unique()

In [None]:
# Test 
ont = OntologyMONDO(owl_url = 'https://purl.obolibrary.org/obo/mondo.owl')
equiv =[ont.get_equivalent_mondo_ids(disease_id) for disease_id in diseases["object"].to_list()]

In [None]:
class OntologyTest(OntologyMONDO):
    """A class to override the OntologyMONDO in the test environment."""
    def __init__(self):
        pass

    def get_related_ids(self, mondo_id: str) -> list[str]:
        return {
            'ancestors': [mondo_id + '_ancestor'],
            'descendants': [mondo_id + '_descendant']
        }


In [None]:
# Test 
id = "RTX:123"
ont = OntologyTest()
ont.get_equivalent_mondo_ids(id)


In [None]:
# off_label = spark.createDataFrame(off_label.select(pl.col("subject").alias("drug_id"), pl.col("object").alias("disease_id")).to_pandas())

In [None]:
# Convert function to udf
from pyspark.sql.functions import col, udf
from pyspark.sql.types import ArrayType, StringType

In [None]:
# UDF returning all equivalent MONDO IDs including input ID itself
equivalent_mondo_ids_udf = udf(lambda id: ont.get_equivalent_mondo_ids(id) + [id], ArrayType(StringType()))

equivalent_diseases = (
    off_label
    .select("disease_id")
    .distinct()
    .withColumn("equivalent_disease_id", F.explode(equivalent_mondo_ids_udf(col("disease_id"))))
)

out = (
    off_label
    .join(equivalent_diseases, on="disease_id", how="left")
    .select("drug_id", F.col("equivalent_disease_id").alias("disease_id"))
)


