diff --git a/impc_etl/jobs/load/impc_kg/mouse_gene_mapper.py b/impc_etl/jobs/load/impc_kg/mouse_gene_mapper.py index a626f2c4..dfeba536 100644 --- a/impc_etl/jobs/load/impc_kg/mouse_gene_mapper.py +++ b/impc_etl/jobs/load/impc_kg/mouse_gene_mapper.py @@ -1,131 +1,120 @@ -import luigi -from impc_etl.jobs.extract.allele_ref_extractor import ExtractAlleleRef -from impc_etl.jobs.extract.gene_ref_extractor import ExtractGeneRef -from impc_etl.jobs.load.impc_bulk_api.impc_api_mapper import to_camel_case -from luigi.contrib.spark import PySparkTask -from pyspark import SparkContext -from pyspark.sql import SparkSession -from pyspark.sql.functions import col, collect_set - -from impc_etl.jobs.load.impc_kg.impc_kg_helper import add_unique_id, map_unique_ids -from impc_etl.jobs.load.solr.gene_mapper import GeneLoader -from impc_etl.workflow.config import ImpcConfig - - -class ImpcKgMouseGeneMapper(PySparkTask): - """ - PySpark Task class to parse GenTar Product report data. - """ - - #: Name of the Spark task - name: str = "ImpcKgMouseGeneMapper" - - #: Path of the output directory where the new parquet file will be generated. - output_path: luigi.Parameter = luigi.Parameter() - - def requires(self): - return [GeneLoader(), ExtractGeneRef(), ExtractAlleleRef()] - - def output(self): - """ - Returns the full parquet path as an output for the Luigi Task - (e.g. impc/dr15.2/parquet/product_report_parquet) - """ - return ImpcConfig().get_target(f"{self.output_path}/impc_kg/mouse_gene_json") - - def app_options(self): - """ - Generates the options pass to the PySpark job - """ - return [ - self.input()[0].path, - self.input()[1].path, - self.input()[2].path, - self.output().path, - ] - - def main(self, sc: SparkContext, *args): +""" +Module to generate the mouse gene data as JSON for the KG. +""" +import logging +import textwrap + +from airflow.sdk import Variable, asset + +from impc_etl.utils.airflow import create_input_asset, create_output_asset +from impc_etl.utils.spark import with_spark_session + +task_logger = logging.getLogger("airflow.task") +dr_tag = Variable.get("data_release_tag") + +gene_ref_parquet_path_asset = create_input_asset("output/gene_ref_parquet") +allele_ref_parquet_path_asset = create_input_asset("output/allele_ref_parquet") +gene_data_include_parquet_path_asset = create_input_asset("output/gene_data_include_parquet") + +mouse_gene_output_asset = create_output_asset("/impc_kg/mouse_gene_json") + +@asset.multi( + schedule=[gene_ref_parquet_path_asset, + allele_ref_parquet_path_asset, + gene_data_include_parquet_path_asset, + ], + outlets=[mouse_gene_output_asset], + dag_id=f"{dr_tag}_impc_kg_mouse_gene_mapper", + description=textwrap.dedent( """ - Takes in a SparkContext and the list of arguments generated by `app_options` and executes the PySpark job. + PySpark task to create the Knowledge Graph JSON files for + mouse gene data from the gene_ref_parquet, allele_ref_parquet and the output of the gene_mapper. """ - spark = SparkSession(sc) - - # Parsing app options - gene_parquet_path = args[0] - gene_ref_parquet_path = args[1] - allele_ref_parquet_path = args[2] - output_path = args[3] - - gene_df = spark.read.parquet(gene_parquet_path) - gene_ref_df = spark.read.parquet(gene_ref_parquet_path) - allele_ref_df = spark.read.parquet(allele_ref_parquet_path) - - gene_df = add_unique_id( - gene_df, - "mouse_gene_id", - ["mgi_accession_id"], + ), + tags=["impc_kg"], +) +@with_spark_session +def impc_kg_mouse_gene_mapper(): + + from impc_etl.jobs.load.impc_web_api.impc_web_api_helper import to_camel_case + from impc_etl.jobs.load.impc_kg.impc_kg_helper import add_unique_id, map_unique_ids + + from pyspark.sql import SparkSession + from pyspark.sql.functions import collect_set, col + + spark = SparkSession.builder.getOrCreate() + + gene_df = spark.read.parquet(gene_data_include_parquet_path_asset.uri) + gene_ref_df = spark.read.parquet(gene_ref_parquet_path_asset.uri) + allele_ref_df = spark.read.parquet(allele_ref_parquet_path_asset.uri) + + gene_df = add_unique_id( + gene_df, + "mouse_gene_id", + ["mgi_accession_id"], + ) + + allele_ref_df = allele_ref_df.select( + "mgi_marker_acc_id", "mgi_allele_acc_id" + ).distinct() + allele_ref_df = allele_ref_df.withColumnRenamed( + "mgi_marker_acc_id", "mgi_accession_id" + ) + + allele_ref_df = allele_ref_df.groupBy("mgi_accession_id").agg( + collect_set("mgi_allele_acc_id").alias("mouse_allele_acc_ids") + ) + + gene_df = gene_df.join(allele_ref_df, "mgi_accession_id", "left_outer") + + gene_df = map_unique_ids(gene_df, "mouse_alleles", "mouse_allele_acc_ids") + + gene_ref_df = gene_ref_df.select("mgi_gene_acc_id", "human_gene_acc_id") + + gene_df = gene_df.join( + gene_ref_df, col("mgi_accession_id") == col("mgi_gene_acc_id") + ).drop("mgi_gene_acc_id") + + gene_df = map_unique_ids( + gene_df, "human_gene_orthologues", "human_gene_acc_id" + ) + + mouse_gene_col_map = { + "marker_name": "name", + "marker_symbol": "symbol", + "mgi_accession_id": "mgiGeneAccessionId", + "marker_synonym": "synonyms", + "ensembl_gene_id": "ensembl_acc_id", + } + + output_cols = [ + "marker_name", + "marker_symbol", + "mgi_accession_id", + "marker_synonym", + "seq_region_start", + "seq_region_end", + "chr_strand", + "seq_region_id", + "entrezgene_id", + "ensembl_gene_id", + "ccds_id", + "ncbi_id", + "human_gene_orthologues", + "mouse_alleles", + ] + output_df = gene_df.select(*output_cols).distinct() + for col_name in output_df.columns: + output_df = output_df.withColumnRenamed( + col_name, + ( + to_camel_case(col_name) + if col_name not in mouse_gene_col_map + else to_camel_case(mouse_gene_col_map[col_name]) + ), ) + output_df.distinct().coalesce(1).write.json( + mouse_gene_output_asset.uri, mode="overwrite", compression="gzip" + ) - allele_ref_df = allele_ref_df.select( - "mgi_marker_acc_id", "mgi_allele_acc_id" - ).distinct() - allele_ref_df = allele_ref_df.withColumnRenamed( - "mgi_marker_acc_id", "mgi_accession_id" - ) - - allele_ref_df = allele_ref_df.groupBy("mgi_accession_id").agg( - collect_set("mgi_allele_acc_id").alias("mouse_allele_acc_ids") - ) - - gene_df = gene_df.join(allele_ref_df, "mgi_accession_id", "left_outer") - - gene_df = map_unique_ids(gene_df, "mouse_alleles", "mouse_allele_acc_ids") - - gene_ref_df = gene_ref_df.select("mgi_gene_acc_id", "human_gene_acc_id") - - gene_df = gene_df.join( - gene_ref_df, col("mgi_accession_id") == col("mgi_gene_acc_id") - ).drop("mgi_gene_acc_id") - - gene_df = map_unique_ids( - gene_df, "human_gene_orthologues", "human_gene_acc_id" - ) - - mouse_gene_col_map = { - "marker_name": "name", - "marker_symbol": "symbol", - "mgi_accession_id": "mgiGeneAccessionId", - "marker_synonym": "synonyms", - "ensembl_gene_id": "ensembl_acc_id", - } - - output_cols = [ - "marker_name", - "marker_symbol", - "mgi_accession_id", - "marker_synonym", - "seq_region_start", - "seq_region_end", - "chr_strand", - "seq_region_id", - "entrezgene_id", - "ensembl_gene_id", - "ccds_id", - "ncbi_id", - "human_gene_orthologues", - "mouse_alleles", - ] - output_df = gene_df.select(*output_cols).distinct() - for col_name in output_df.columns: - output_df = output_df.withColumnRenamed( - col_name, - ( - to_camel_case(col_name) - if col_name not in mouse_gene_col_map - else to_camel_case(mouse_gene_col_map[col_name]) - ), - ) - output_df.distinct().coalesce(1).write.json( - output_path, mode="overwrite", compression="gzip" - )