In [1]:
"""Step to run Locus to Gene either for inference or for training."""
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

import pyspark.sql.functions as f

if TYPE_CHECKING:
    from pyspark.sql import DataFrame

from otg.common.session import ETLSession
from otg.config import LocusToGeneConfig

from otg.common.spark_helpers import get_record_with_maximum_value
from otg.dataset.study_locus import StudyLocus
from otg.dataset.study_locus_overlap import StudyLocusOverlap
from otg.dataset.v2g import V2G
from otg.method.locus_to_gene import LocusToGeneTrainer


In [1]:
from otg.common.session import ETLSession
from otg.config import LocusToGeneConfig

etl=ETLSession("local[*]", "ot_genetics_local", "overwrite")
cfg = LocusToGeneConfig(
    run_mode="train",
    study_locus_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/mock_study_locus",
    study_locus_overlap_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/mock_study_locus_overlap",
    variant_gene_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/mock_v2g",
    colocalisation_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/mock_colocalisation",
    study_index_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/mock_study_index",
    gold_standard_curation_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/curation.json",
    gene_interactions_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/interaction",
)

23/03/10 22:28:59 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [34]:
@dataclass
class LocusToGeneStep:
    """Locus to gene step."""

    cfg: LocusToGeneConfig

    def run(self: LocusToGeneStep) -> None:
        """Run Locus to Gene step."""
        print("hola")
        # if self.run_mode == "train":
        #     gold_standards = get_gold_standards(
        #         etl=self.etl,
        #         study_locus_path=self.study_locus_path,
        #         v2g_path=self.variant_gene_path,
        #         study_locus_overlap_path=self.study_locus_overlap_path,
        #         gold_standard_curation=self.gold_standard_curation_path,
        #         interactions_path=self.gene_interactions_path,
        #     )
            # print(gold_standards.printSchema())
            # gold_standards.write.parquet(
            #     "/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/staging/gs"
            # )
            # fm = L2GFeatureMatrix  # FIXME: debug credset
            # data = gold_standards.join(
            #     fm, on="studyLocusId", how="inner"
            # ).train_test_split(frac=0.1, seed=42)
            # # TODO: data normalization and standardisation of features

            # LocusToGeneTrainer.train(
            #     train_set=data["train"],
            #     test_set=data["test"],
            #     **self.hyperparameters,
            #     # TODO: Add push to hub, and push to W&B
            # )


def get_gold_standards(
    etl: ETLSession,
    gold_standard_curation: str,
    v2g_path: str,
    study_locus_path: str,
    study_locus_overlap_path: str,
    interactions_path: str,
) -> DataFrame:
    """Process gold standard curation to use as training data."""
    # FIXME: assign function to class - something is wrong instantiating the classes, used to work
    overlaps_df = StudyLocusOverlap.from_parquet(
        etl, study_locus_overlap_path
    ).df.select("left_studyLocusId", "right_studyLocusId")
    interactions = process_gene_interactions(etl, interactions_path)
    return (
        etl.spark.read.json(gold_standard_curation)
        .select(
            f.col("association_info.otg_id").alias("studyId"),
            f.col("gold_standard_info.gene_id").alias("geneId"),
            f.concat_ws(
                "_",
                f.col("sentinel_variant.locus_GRCh38.chromosome"),
                f.col("sentinel_variant.locus_GRCh38.position"),
                f.col("sentinel_variant.alleles.reference"),
                f.col("sentinel_variant.alleles.alternative"),
            ).alias("variantId"),
        )
        .filter(f.col("gold_standard_info.highest_confidence").isin(["High", "Medium"]))
        # Bring studyLocusId - TODO: what if I don't have one?
        .join(
            StudyLocus.from_parquet(etl, study_locus_path)._df.select(
                "studyId", "variantId", "studyLocusId"
            ),
            on=["studyId", "variantId"],
            how="inner",
        )
        # Assign Positive or Negative Status based on confidence
        .join(
            V2G.from_parquet(etl, v2g_path)._df.select(
                "variantId", "geneId", "distance"
            ),
            on=["variantId", "geneId"],
            how="inner",
        )
        .withColumn(
            "gsStatus",
            f.when(f.col("distance") <= 500_000, "Positive").otherwise("Negative"),
        )
        # Remove redundant loci
        .alias("left")
        .join(
            overlaps_df.alias("right"),
            (f.col("left.variantId") == f.col("right.left_studyLocusId"))
            | (f.col("left.variantId") == f.col("right.right_studyLocusId")),
            how="left",
        )
        .distinct()
        # Remove redundant genes
        .join(
            interactions.alias("interactions"),
            (f.col("left.geneId") == f.col("interactions.geneIdA"))
            | (f.col("left.geneId") == f.col("interactions.geneIdB")),
            how="left",
        )
        .withColumn("interacting", (f.col("score") > 0.7))
        # filter out genes where geneIdA has gsStatus Negative but geneIdA and gene IdB are interacting
        .filter(
            ~(
                (f.col("gsStatus") == "Negative")
                & (f.col("interacting"))
                & (
                    (f.col("left.geneId") == f.col("interactions.geneIdA"))
                    | (f.col("left.geneId") == f.col("interactions.geneIdB"))
                )
            )
        )
    )


def process_gene_interactions(etl: ETLSession, interactions_path: str) -> DataFrame:
    """Extract top scoring gene-gene interaction from the Platform."""
    # FIXME: assign function to class
    return get_record_with_maximum_value(
        etl.spark.read.parquet(interactions_path),
        ["targetA", "targetB"],
        "scoring",
    ).selectExpr(
        "targetA as geneIdA",
        "targetB as geneIdB",
        "scoring as score",
    )


In [26]:
get_gold_standards(
    etl=etl,
    gold_standard_curation=cfg.gold_standard_curation_path,
    v2g_path=cfg.variant_gene_path,
    study_locus_path=cfg.study_locus_path,
    study_locus_overlap_path=cfg.study_locus_overlap_path,
    interactions_path=cfg.gene_interactions_path,
).printSchema()

                                                                                

root
 |-- variantId: string (nullable = false)
 |-- geneId: string (nullable = true)
 |-- studyId: string (nullable = true)
 |-- studyLocusId: string (nullable = true)
 |-- distance: long (nullable = true)
 |-- gsStatus: string (nullable = false)
 |-- left_studyLocusId: string (nullable = true)
 |-- right_studyLocusId: string (nullable = true)
 |-- geneIdA: string (nullable = true)
 |-- geneIdB: string (nullable = true)
 |-- score: double (nullable = true)
 |-- interacting: boolean (nullable = true)



In [35]:
LocusToGeneStep(cfg=cfg).run()

hola


23/03/09 16:58:46 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 213110 ms exceeds timeout 120000 ms
23/03/09 16:58:46 WARN SparkContext: Killing executors is not supported by current scheduler.


In [4]:
from otg.common.schemas import parse_spark_schema
from otg.dataset.study_index import StudyIndex
from otg.dataset.study_locus import StudyLocus
from otg.dataset.colocalisation import Colocalisation

from dataclasses import field

from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import StructType

spark = etl.spark
from functools import partial
@dataclass
class L2GFeatureMatrix:
    """Dataset with features for Locus to Gene prediction."""

    _schema: StructType = parse_spark_schema("l2g_feature_matrix.json")
    _df: DataFrame = field(
        default_factory=partial(spark.createDataFrame, [], schema=_schema)
    )

    def get_distance_features(
        self: L2GFeatureMatrix, study_locus: StudyLocus, distances: V2G, etl: ETLSession
    ) -> None:
        """Get distance features."""
        distance = study_locus._get_tss_distance_features(distances, etl)
        return L2GFeatureMatrix(_df=self._df.unionByName(distance, allowMissingColumns=True) if isinstance(distance, DataFrame) else self._df)
    
    def get_coloc_features(
        self: L2GFeatureMatrix, colocalisation: Colocalisation, study_locus: StudyLocus, studies: StudyIndex
    ):
        """Get coloc features."""
        ss_coloc = colocalisation.get_max_llr_per_study_locus(study_locus, studies)
        return L2GFeatureMatrix(
            _df=self._df.unionByName(ss_coloc, allowMissingColumns=True)
        )


feature_matrix = (
    L2GFeatureMatrix()
    .get_distance_features(
        StudyLocus.from_parquet(etl, cfg.study_locus_path),
        V2G.from_parquet(etl, cfg.variant_gene_path),
        etl,
    )
    # .get_coloc_features(
    #     Colocalisation.from_parquet(etl, cfg.colocalisation_path),
    #     StudyLocus.from_parquet(etl, cfg.study_locus_path),
    #     StudyIndex.from_parquet(etl, cfg.study_index_path),

    # )
)

feature_matrix._df.printSchema()

                                                                                

root
 |-- studyLocusId: string (nullable = true)
 |-- feature: string (nullable = true)
 |-- geneId: string (nullable = true)
 |-- value: string (nullable = true)



In [6]:
Colocalisation.from_parquet(etl, cfg.colocalisation_path).get_max_llr_per_study_locus(
    StudyLocus.from_parquet(etl, cfg.study_locus_path),
    StudyIndex.from_parquet(etl, cfg.study_index_path),
)

                                                                                

AnalysisException: cannot resolve '`studyType`' given input columns: [neighbourhood_max.studyLocusId];
'Project ['studyType, studyLocusId#194, 'geneId, studyId#198 AS studyId_nbh#446, leadVariantId#266 AS leadVariantId_nbh#447, tagVariantId#268 AS tagVariantId_nbh#448, tagPValueConditioned#269 AS tagPValueConditioned_nbh#449, coloc_log2_h4_h3#180 AS coloc_log2_h4_h3_nbh#450]
+- Project [studyLocusId#194]
   +- Project [studyLocusId#194]
      +- Aggregate [studyLocusId#194], [studyLocusId#194, pivotfirst(studyType#228, first(neighbourhood_max.`coloc_log2_h4_h3`)#441, 0, 0) AS __pivot_first(neighbourhood_max.`coloc_log2_h4_h3`) AS `first(neighbourhood_max.``coloc_log2_h4_h3``)`#443]
         +- Aggregate [studyLocusId#194, studyType#228], [studyLocusId#194, studyType#228, first(coloc_log2_h4_h3#180, false) AS first(neighbourhood_max.`coloc_log2_h4_h3`)#441]
            +- SubqueryAlias neighbourhood_max
               +- Project [studyId#198, studyLocusId#194, leadVariantId#266, tagVariantId#268, tagPValueConditioned#269, coloc_log2_h4_h3#180, studyType#228, geneId#231]
                  +- Filter (row_number#377 = 1)
                     +- Project [studyId#198, studyLocusId#194, leadVariantId#266, tagVariantId#268, tagPValueConditioned#269, coloc_log2_h4_h3#180, studyType#228, geneId#231, row_number#377]
                        +- Project [studyId#198, studyLocusId#194, leadVariantId#266, tagVariantId#268, tagPValueConditioned#269, coloc_log2_h4_h3#180, studyType#228, geneId#231, row_number#377, row_number#377]
                           +- Window [row_number() windowspecdefinition(studyType#228, studyLocusId#194, geneId#231, coloc_log2_h4_h3#180 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS row_number#377], [studyType#228, studyLocusId#194, geneId#231], [coloc_log2_h4_h3#180 DESC NULLS LAST]
                              +- Project [studyId#198, studyLocusId#194, leadVariantId#266, tagVariantId#268, tagPValueConditioned#269, coloc_log2_h4_h3#180, studyType#228, geneId#231]
                                 +- Project [studyId#198, studyLocusId#194, leadVariantId#266, tagVariantId#268, tagPValueConditioned#269, coloc_log2_h4_h3#180, studyType#228, geneId#231]
                                    +- Join LeftOuter, (studyId#198 = studyId#226)
                                       :- Project [studyLocusId#194, studyId#198, leadVariantId#266, tagVariantId#268, tagPValueConditioned#269, coloc_log2_h4_h3#180]
                                       :  +- Join Inner, (studyLocusId#194 = studyLocusId#284)
                                       :     :- SubqueryAlias sentinel_study_locus
                                       :     :  +- Deduplicate [tagPValueConditioned#269, studyLocusId#194, tagVariantId#268, leadVariantId#266, studyId#198]
                                       :     :     +- Project [studyLocusId#194, studyId#198, leadVariantId#266, tagVariantId#268, tagPValueConditioned#269]
                                       :     :        +- Filter (leadVariantId#266 = tagVariantId#268)
                                       :     :           +- Project [studyLocusId#194, studyId#198, variantId#195 AS leadVariantId#266, credibleSetExploded#270, credibleSetExploded#270.tagVariantId AS tagVariantId#268, credibleSetExploded#270.tagPValueConditioned AS tagPValueConditioned#269]
                                       :     :              +- Generate explode(credibleSet#209), false, [credibleSetExploded#270]
                                       :     :                 +- Relation[studyLocusId#194,variantId#195,chromosome#196,position#197,studyId#198,beta#199,oddsRatio#200,oddsRatioConfidenceIntervalLower#201,oddsRatioConfidenceIntervalUpper#202,betaConfidenceIntervalLower#203,betaConfidenceIntervalUpper#204,pValueMantissa#205,pValueExponent#206L,qualityControls#207,finemappingMethod#208,credibleSet#209] parquet
                                       :     +- Project [right_studyLocusId#171 AS studyLocusId#284, coloc_log2_h4_h3#180]
                                       :        +- Relation[left_studyLocusId#170,right_studyLocusId#171,chromosome#172,colocalisationMethod#173,coloc_n_vars#174L,coloc_h0#175,coloc_h1#176,coloc_h2#177,coloc_h3#178,coloc_h4#179,coloc_log2_h4_h3#180,clpp#181] parquet
                                       +- Project [studyId#226, studyType#228, geneId#231]
                                          +- Relation[studyId#226,projectId#227,studyType#228,traitFromSource#229,traitFromSourceMappedIds#230,geneId#231,pubmedId#232,publicationTitle#233,publicationFirstAuthor#234,publicationDate#235,publicationJournal#236,backgroundTraitFromSourceMappedIds#237,initialSampleSize#238,nCases#239L,nControls#240L,nSamples#241L,discoverySamples#242,replicationSamples#243,summarystatsLocation#244,hasSumstats#245] parquet


In [24]:
from otg.common.spark_helpers import _convert_from_wide_to_long

coloc = Colocalisation.from_parquet(etl, cfg.colocalisation_path)
studies = StudyIndex.from_parquet(etl, cfg.study_index_path)

sentinel_study_locus = (
    StudyLocus.from_parquet(etl, cfg.study_locus_path).get_sentinels().alias("sentinel_study_locus")
    .join(
        coloc._df.selectExpr(
                    "right_studyLocusId as studyLocusId", "coloc_log2_h4_h3"
                ),
                on="studyLocusId",
                how="inner",
    )
    .join(
                # bring study metadata
                studies._df.select("studyId", "studyType", "geneId"),
                on="studyId",
                how="left",
            )
)

wide_local_max_df = (
    get_record_with_maximum_value(
            sentinel_study_locus,
            ["studyType", "studyLocusId", "geneId"],
            "coloc_log2_h4_h3",
    ).withColumnRenamed("coloc_log2_h4_h3", "coloc_llr_local_max")
    # .transform(
    #         lambda df: pivot_df(
    #             df, "studyType", "coloc_log2_h4_h3", ["studyLocusId", "geneId"]
    #         )
    #     )
)

neighbourhood_max = (
            get_record_with_maximum_value(
                sentinel_study_locus,
                ["studyType", "studyLocusId"],
                "coloc_log2_h4_h3",
            )
            .withColumnRenamed("coloc_log2_h4_h3", "coloc_llr_nbh_max")
        )

(
    _convert_from_wide_to_long(
            wide_local_max_df,
            id_vars=("studyLocusId", "geneId"),
            var_name="feature",
            value_name="value",
            spark=etl.spark,  # not great, but necessary to go from pandas to spark
        )
)
neighbourhood_max.printSchema()


root
 |-- studyId: string (nullable = true)
 |-- studyLocusId: string (nullable = true)
 |-- leadVariantId: string (nullable = true)
 |-- tagVariantId: string (nullable = true)
 |-- tagPValueConditioned: double (nullable = true)
 |-- coloc_log2_h4_h3: double (nullable = true)
 |-- studyType: string (nullable = true)
 |-- geneId: string (nullable = true)



In [12]:
coloc._df.printSchema()

root
 |-- left_studyLocusId: string (nullable = true)
 |-- right_studyLocusId: string (nullable = true)
 |-- chromosome: string (nullable = true)
 |-- colocalisationMethod: string (nullable = true)
 |-- coloc_n_vars: long (nullable = true)
 |-- coloc_h0: double (nullable = true)
 |-- coloc_h1: double (nullable = true)
 |-- coloc_h2: double (nullable = true)
 |-- coloc_h3: double (nullable = true)
 |-- coloc_h4: double (nullable = true)
 |-- coloc_log2_h4_h3: double (nullable = true)
 |-- clpp: double (nullable = true)



In [4]:
from enum import Enum
class CredibleInterval(Enum):
    """Credible interval enum.

    Interval within which an unobserved parameter value falls with a particular probability.

    Attributes:
        IS95 (str): 95% credible interval
        IS99 (str): 99% credible interval
    """

    IS95 = "is95CredibleSet"
    IS99 = "is99CredibleSet"

pdf = (
    StudyLocus.from_parquet(etl, cfg.study_locus_path)
    .credible_set(CredibleInterval.IS95.value)
    .select(
                "studyLocusId",
                "variantId",
                f.explode("credibleSet.tagVariantId").alias("tagVariantId"),
            )
    .join(
        V2G.from_parquet(etl, cfg.variant_gene_path).df.selectExpr(
                    "variantId as tagVariantId", "geneId", "distance"
                ),
                on="tagVariantId",
                how="inner"
            )
    .groupBy("studyLocusId", "variantId", "geneId")
    .agg(
        f.min("distance").alias("dist_tss_min"),
        f.mean("distance").alias("dist_tss_ave"),
    )
    .toPandas()
    .melt(id_vars=["studyLocusId", "geneId"], var_name="feature", value_name="value")
    
)

pdf

                                                                                

Unnamed: 0,studyLocusId,geneId,feature,value


In [10]:
from pandas import DataFrame as PandasDataFrame
import pyspark.sql.types as t

def _get_spark_schema_from_pandas_df(pdf: PandasDataFrame) -> t.StructType:
    """Returns the Spark schema based on a Pandas DataFrame."""
    return t.StructType([
        t.StructField(field, _get_spark_type(pdf[field].dtype), True)
        for field in pdf.columns
    ])

def _get_spark_type(pandas_type: str) -> t.DataType:
    """Returns the Spark type based on the Pandas type."""
    try:
        if pandas_type == "object":
            return t.StringType()
        elif pandas_type == "int64":
            return t.IntegerType()
        elif pandas_type == "float64":
            return t.FloatType()
    except Exception as e:
        raise ValueError(f"Unsupported type: {pandas_type}") from e
    
def _convert_from_wide_to_long(
    df: DataFrame,
    id_vars: list[str],
    var_name: str,
    value_name: str,
    spark: SparkSession,
) -> DataFrame:
    """Converts a dataframe from wide to long format using Pandas melt built-in function.

    Args:
        df (DataFrame): Dataframe to melt
        id_vars (list[str]): List of fixed columns to keep
        var_name (str): Name of the column containing the variable names
        value_name (str): Name of the column containing the values
        spark (SparkSession): Spark session

    Returns:
        DataFrame: Melted dataframe

    Examples:
    >>> df = spark.createDataFrame([("a", 1, 2)], ["id", "feature_1", "feature_2"])
    >>> _convert_from_wide_to_long(df, ["id"], "feature", "value", spark).show()
    +---+---------+-----+
    | id|  feature|value|
    +---+---------+-----+
    |  a|feature_1|    1|
    |  a|feature_2|    2|
    +---+---------+-----+
    <BLANKLINE>
    """
    pandas_df = df.toPandas().melt(
        id_vars=id_vars, var_name=var_name, value_name=value_name
    )
    schema = _get_spark_schema_from_pandas_df(pandas_df)
    return spark.createDataFrame(pandas_df, schema)



In [12]:
etl.spark.createDataFrame(pdf, _get_spark_schema_from_pandas_df(pdf))

DataFrame[studyLocusId: string, geneId: string, feature: string, value: string]

In [9]:
StudyLocus.from_parquet(etl, cfg.study_locus_path).df.schema

StructType(List(StructField(studyLocusId,StringType,true),StructField(variantId,StringType,true),StructField(chromosome,StringType,true),StructField(position,IntegerType,true),StructField(studyId,StringType,true),StructField(beta,DoubleType,true),StructField(oddsRatio,DoubleType,true),StructField(oddsRatioConfidenceIntervalLower,DoubleType,true),StructField(oddsRatioConfidenceIntervalUpper,DoubleType,true),StructField(betaConfidenceIntervalLower,DoubleType,true),StructField(betaConfidenceIntervalUpper,DoubleType,true),StructField(pValueMantissa,DoubleType,true),StructField(pValueExponent,LongType,true),StructField(qualityControls,ArrayType(StringType,true),true),StructField(finemappingMethod,StringType,true),StructField(credibleSet,ArrayType(StructType(List(StructField(is95CredibleSet,BooleanType,true),StructField(is99CredibleSet,BooleanType,true),StructField(logABF,DoubleType,true),StructField(posteriorProbability,DoubleType,true),StructField(tagVariantId,StringType,true),StructField(

In [None]:
# created a processed gs file ad hoc from the old one

gs_processed = spark.read.parquet("gs://genetics-portal-dev-staging/l2g/221107/gold_standards/featurematrix_w_goldstandards.training_only.221107.parquet")

gs_processed_slim = (
    gs_processed.filter(f.col("gs_confidence") != "Low")
    .select(
        f.col("study_id").alias("studyId"),
        f.concat_ws(
            "_", f.col("chrom"), f.col("pos"), f.col("ref"), f.col("alt")
        ).alias("variantId"),
        f.col("gene_id").alias("geneId"),
        f.col("gold_standard_status").alias("gsStatus"),
    )
)

# studyLocusId must be created separately
assocs = assocs = gs_processed_slim.select("studyId", "variantId").distinct().withColumn("studyLocusId", f.monotonically_increasing_id())
gs_processed_slim = gs_processed_slim.join(assocs, on=["studyId", "variantId"], how="inner")




In [None]:
# Distance feature matrix

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

from otg.common.schemas import parse_spark_schema
from otg.dataset.dataset import Dataset
from otg.dataset.study_locus import StudyLocus
from otg.dataset.v2g import V2G

if TYPE_CHECKING:
    from pyspark.sql.types import StructType

    from otg.common.session import ETLSession

@dataclass
class L2GFeature:
    """Property of a study locus pair."""

    study_id: str  # TODO: think about moving this to a trait id - so that we can extract the best study for that trait to train on
    locus_id: str
    gene_id: str
    feature_name: str
    feature_value: float


@dataclass
class L2GFeatureMatrix(Dataset):
    """Dataset with features for Locus to Gene prediction."""

    _schema: StructType = parse_spark_schema("l2g_feature_matrix.json")

    @classmethod
    def from_parquet(
        cls: type[L2GFeatureMatrix], etl: ETLSession, path: str
    ) -> Dataset:
        """Initialise L2GFeatureMatrix from parquet file.

        Args:
            etl (ETLSession): ETL session
            path (str): Path to parquet file

        Returns:
            Dataset: Locus to gene feature matrix
        """
        return super().from_parquet(etl, path), cls.schema
    
    def get_distance_features(
            self: L2GFeatureMatrix, study_locus: StudyLocus, distances: V2G
        ) -> L2GFeatureMatrix:
            """Get distance features."""
            distance = study_locus._get_tss_distance_features(distances)
            # return L2GFeatureMatrix(
            #     _df=self._df.unionByName(distance, allowMissingColumns=True)
            # )   
            self._df = self.df.unionByName(distance, allowMissingColumns=True)

    @classmethod
    def get_all_features(
            cls: type[L2GFeatureMatrix], study_locus: StudyLocus, distances: V2G
        ) -> L2GFeatureMatrix:
            """Get all features."""
            distance_features = L2GFeatureMatrix()
            return self


In [None]:
fm = L2GFeatureMatrix(_df=cfg.etl.spark.createDataFrame([], schema=_schema), path=None).get_distance_features(
    study_locus=StudyLocus.from_parquet(
        cfg.etl,
        "/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/processed_gs",
    ),
    distances=V2G.from_parquet(cfg.etl, cfg.variant_gene_path),
)


AttributeError: 'NoneType' object has no attribute 'unionByName'

23/02/27 17:57:59 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 547751 ms exceeds timeout 120000 ms
23/02/27 17:57:59 WARN SparkContext: Killing executors is not supported by current scheduler.


In [None]:
L2GFeatureMatrix.__dict__

mappingproxy({'__module__': '__main__',
              '__annotations__': {'schema': 'StructType'},
              '__doc__': 'Dataset with features for Locus to Gene prediction.',
              'schema': StructType(List(StructField(studyLocusId,StringType,false),StructField(feature,StringType,true),StructField(geneId,StringType,true),StructField(value,DoubleType,true))),
              'from_parquet': <classmethod at 0x12c3d28e0>,
              'get_distance_features': <function __main__.L2GFeatureMatrix.get_distance_features(self: 'L2GFeatureMatrix', study_locus: 'StudyLocus', distances: 'V2G') -> 'L2GFeatureMatrix'>,
              '__dict__': <attribute '__dict__' of 'L2GFeatureMatrix' objects>,
              '__weakref__': <attribute '__weakref__' of 'L2GFeatureMatrix' objects>,
              '__dataclass_params__': _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False),
              '__dataclass_fields__': {'schema': Field(name='schema',type='Struct

In [57]:
def validate_schema(expected_schema, observed_schema) -> None:
    """Validate DataFrame schema against expected class schema.

    Raises:
        ValueError: DataFrame schema is not valid
    """
    # Do not look at the nullable flag
    from collections import namedtuple

    schema_tuple = namedtuple("schema", ["name", "type"])
    expected_schema = [schema_tuple(e.name, e.dataType) for e in expected_schema.fields]
    observed_schema = [schema_tuple(e.name, e.dataType) for e in observed_schema.fields]

    # Unexpected fields in the observed schema
    missing_struct_fields = [x for x in observed_schema if x not in expected_schema]
    error_message = f"The {missing_struct_fields} StructFields are not included in DataFrame schema: {expected_schema}"
    if missing_struct_fields:
        raise ValueError(error_message)

    # Required fields not in dataset
    required_fields = [x for x in expected_schema if not x.nullable]
    missing_required_fields = [
        x for x in required_fields if x not in observed_schema
    ]
    error_message = f"The {missing_required_fields} StructFields are required but missing from the DataFrame schema: {expected_schema}"
    if missing_required_fields:
        raise ValueError(error_message)

In [60]:
v2g = etl.spark.read.parquet(cfg.variant_gene_path)
observed_schema = v2g.schema
expected_schema = V2G.schema


In [65]:
observed_schema.fields[0].name

'geneId'

In [56]:
missing_struct_fields = [x.name for x in observed_schema if x not in expected_schema]

missing_struct_fields

[]

In [14]:
from otg.dataset.l2g_feature_matrix import L2GFeatureMatrix

fm = L2GFeatureMatrix.generate_features(
    study_locus_path=cfg.study_locus_path,
    study_index_path=cfg.study_index_path,
    variant_gene_path=cfg.variant_gene_path,
    colocalisation_path=cfg.colocalisation_path,
    etl=etl,
)



NameError: name 'StudyLocus' is not defined

In [12]:
from otg.dataset.study_locus import StudyLocus

study_locus = StudyLocus.from_parquet(etl, cfg.study_locus_path)

In [9]:
from otg.method.l2g_utils.feature_factory import (
    ColocalisationFactory,
    StudyLocusFactory,
)
from functools import reduce
from otg.common.spark_helpers import _convert_from_long_to_wide
from dataclasses import dataclass
from otg.dataset.dataset import Dataset
from otg.dataset.study_locus import StudyLocus
from otg.dataset.study_index import StudyIndex
from otg.dataset.colocalisation import Colocalisation
from otg.dataset.v2g import V2G
from pyspark.sql import DataFrame

from typing import Type

@dataclass
class L2GFeatureMatrix(Dataset):
    """Dataset with features for Locus to Gene prediction."""

    _df: DataFrame
    _schema =  ""
    
    @classmethod
    def generate_features(
        cls: Type[L2GFeatureMatrix],
        study_locus_path: str,
        study_index_path: str,
        variant_gene_path: str,
        colocalisation_path: str,
        etl: ETLSession,
    ) -> L2GFeatureMatrix:
        """Generate features from the OTG datasets."""
        # Load datasets
        study_locus = StudyLocus.from_parquet(etl, study_locus_path)
        studies = StudyIndex.from_parquet(etl, study_index_path)
        distances = V2G.from_parquet(etl, variant_gene_path)
        coloc = Colocalisation.from_parquet(etl, colocalisation_path)

        # Extract features
        coloc_features = ColocalisationFactory._get_max_llr_per_study_locus(
            study_locus, studies, coloc
        )
        distance_features = StudyLocusFactory._get_tss_distance_features(distances)

        fm = reduce(
            lambda x, y: x.unionByName(y),
            [coloc_features._df, distance_features._df],
        )

        return L2GFeatureMatrix(_df=_convert_from_long_to_wide(fm))
    




NameError: name 'L2GFeatureMatrix' is not defined

In [27]:
fm = L2GFeatureMatrix.generate_features(
    study_locus_path=cfg.study_locus_path,
    study_index_path=cfg.study_index_path,
    variant_gene_path=cfg.variant_gene_path,
    colocalisation_path=cfg.colocalisation_path,
    etl=etl,
)



TypeError: _get_max_llr_per_study_locus() takes 3 positional arguments but 4 were given

In [20]:
from otg.common.schemas import parse_spark_schema
from pyspark.sql.types import StructType

@dataclass
class Colocalisation(Dataset):
    """Colocalisation results for pairs of overlapping study-locus."""

    _schema: StructType = parse_spark_schema("colocalisation.json")

    @classmethod
    def from_parquet(
        cls: Type[Colocalisation], etl: ETLSession, path: str
    ) -> Colocalisation:
        """Initialise Colocalisation dataset from parquet file.

        Args:
            etl (ETLSession): ETL session
            path (str): Path to parquet file

        Returns:
            Colocalisation: Colocalisation results
        """
        return super().from_parquet(etl, path, cls._schema)
    
Colocalisation.from_parquet(etl, cfg.colocalisation_path)



Colocalisation(path='/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/mock_colocalisation', _schema=StructType(List(StructField(left_studyLocusId,StringType,false),StructField(right_studyLocusId,StringType,false),StructField(chromosome,StringType,false),StructField(colocalisationMethod,StringType,false),StructField(coloc_n_vars,LongType,false),StructField(coloc_h0,DoubleType,true),StructField(coloc_h1,DoubleType,true),StructField(coloc_h2,DoubleType,true),StructField(coloc_h3,DoubleType,true),StructField(coloc_h4,DoubleType,true),StructField(coloc_log2_h4_h3,DoubleType,true),StructField(clpp,DoubleType,true))), _df=DataFrame[left_studyLocusId: string, right_studyLocusId: string, chromosome: string, colocalisationMethod: string, coloc_n_vars: bigint, coloc_h0: double, coloc_h1: double, coloc_h2: double, coloc_h3: double, coloc_h4: double, coloc_log2_h4_h3: double, clpp: double])

In [14]:
cfg.colocalisation_path

'/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/mock_colocalisation'

In [18]:
sl = StudyLocus.from_parquet(etl, cfg.study_locus_path)

sl

StudyLocus(path='/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/mock_study_locus', _schema=StructType(List(StructField(studyLocusId,StringType,false),StructField(variantId,StringType,false),StructField(chromosome,StringType,true),StructField(position,IntegerType,true),StructField(studyId,StringType,false),StructField(beta,DoubleType,true),StructField(oddsRatio,DoubleType,true),StructField(oddsRatioConfidenceIntervalLower,DoubleType,true),StructField(oddsRatioConfidenceIntervalUpper,DoubleType,true),StructField(betaConfidenceIntervalLower,DoubleType,true),StructField(betaConfidenceIntervalUpper,DoubleType,true),StructField(pValueMantissa,DoubleType,true),StructField(pValueExponent,LongType,true),StructField(qualityControls,ArrayType(StringType,true),true),StructField(finemappingMethod,StringType,true),StructField(credibleSet,ArrayType(StructType(List(StructField(is95CredibleSet,BooleanType,true),StructField(is99CredibleSet,BooleanType,true),StructField(logABF,DoubleT

In [28]:
ColocalisationFactory._get_max_llr_per_study_locus()

TypeError: _get_max_llr_per_study_locus() missing 2 required positional arguments: 'study_locus' and 'studies'