In [None]:
import os
from pyspark.conf import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import create_map, lit, col, explode, collect_list, array, struct, when, current_timestamp
from pyspark.sql.types import IntegerType, FloatType
from delta import *
import re

warehouse_diretory_path = '[YOUR_WAREHOUSE]'

conf = SparkConf()
conf.setAll(
    [
        ('spark.master', 'local[*]'), 
        ('spark.driver.host', 'localhost'),
        ('spark.app.name', 'Transform TCGA Firebrowse miRNA Level 3 Sample To miRNA Patient Sample'),
        ('spark.ui.showConsoleProgress', 'true'),
        ('spark.sql.execution.arrow.pyspark.enabled', 'true'),
        ('spark.driver.memory', '4g'),
        ('spark.executor.memory', '4g'),        
        ('spark.sql.extensions','io.delta.sql.DeltaSparkSessionExtension'),
        ('spark.sql.catalog.spark_catalog', 'org.apache.spark.sql.delta.catalog.DeltaCatalog'),
        ('spark.sql.warehouse.dir', warehouse_diretory_path),
        ('spark.driver.extraJavaOptions', f'-Dderby.system.home={warehouse_diretory_path}')
    ])

spark = SparkSession.builder.config(conf=conf).enableHiveSupport().getOrCreate()

In [None]:
patient_id_pattern = re.compile(r'(.{12})', re.IGNORECASE)
tcga_barcode_pattern = re.compile(r'(.{15})', re.IGNORECASE)

condition_pattern = re.compile(r'.{13}(.{2}).*', re.IGNORECASE)
condition_group_index = 1

biological_database_name = 'biological_database'

disease = 'Breast Invasive Carcinoma'

bronze_tcga_mirna_sample_table_name = 'bronze_tcga_firebrowse_mirna_illumina_hiseq_level3_sample'

silver_patient_mirna_sample_table_name = 'silver_patient_mirna_sample'
silver_patient_mirna_sample_view_name = 'silver_patient_mirna_sample_view'

tcga_sample_type_codes = {
                            '01': 'Primary Solid Tumor', 
                            '02': 'Recurrent Solid Tumor',
                            '03': 'Primary Blood Derived Cancer - Peripheral Blood',
                            '04': 'Recurrent Blood Derived Cancer - Bone Marrow',
                            '05': 'Additional - New Primary',
                            '06': 'Metastatic', 
                            '07': 'Primary Blood Derived Cancer - Bone Marrow',
                            '08': 'Human Tumor Original Cells',
                            '09': 'Primary Blood Derived Cancer',
                            '10': 'Blood Derived Normal', 
                            '11': 'Solid Tissue Normal',
                            '12': 'Buccal Cell Normal',
                            '13': 'EBV Immortalized Normal',
                            '14': 'Bone Marrow Normal',
                            '15': 'sample type',
                            '16': 'sample type'
                        }

tcga_sample_type_control_codes = ['10', '11', '12', '13', '14', '15', '16']
tcga_sample_type_case_codes = ['01', '02', '03', '04', '05', '06', '07', '08', '09']

sample_types = create_map([lit(description) for code in tcga_sample_type_codes.items() for description in code])

In [None]:
spark.sql(f'USE {biological_database_name};')

In [None]:
silver_mirna_sample_df = spark.sql(f"""SELECT * FROM {bronze_tcga_mirna_sample_table_name} WHERE disease = '{disease}'""")
silver_mirna_sample_df = silver_mirna_sample_df.drop('disease', 'timestamp')

In [None]:
mirna_id_column_name = list(filter(lambda c: 'TCGA' not in c, silver_mirna_sample_df.columns))[0]
mirna_sample_columns = silver_mirna_sample_df.where(f"{mirna_id_column_name} = 'miRNA_ID'").first().asDict()

In [None]:
column_name = list(mirna_sample_columns.keys())[0]
column_value = mirna_sample_columns[column_name]

silver_mirna_sample_df = silver_mirna_sample_df.filter(~col(column_name).contains(column_value))

In [None]:
patient_ids = list(mirna_sample_columns.keys())[1:]

columns = list(set(map(lambda x: tcga_barcode_pattern.search(x).group(), list(mirna_sample_columns.keys())[1:])))
columns = [list(filter(lambda x: x.startswith(c), patient_ids)) for c in columns]

In [None]:
silver_mirna_sample_df = silver_mirna_sample_df.select(array([struct(col(mirna_id_column_name).alias('mirna_id'),
                                                                     lit(patient_id_pattern.search(c[0]).group()).alias('patient_id'),
                                                                     lit(condition_pattern.search(c[0]).group(condition_group_index)).alias('condition'),
                                                                     col(c[0]).alias(mirna_sample_columns[c[0]]),
                                                                     col(c[1]).alias(mirna_sample_columns[c[1]]),
                                                                     col(c[2]).alias(mirna_sample_columns[c[2]])) for c in columns]).alias('levels')) \
                                                .select(explode('levels').alias('levels')) \
                                                .select('levels.*') \
                                                .selectExpr('patient_id', 'condition', 'mirna_id', 'CAST(read_count AS int) read_count', 'CAST(reads_per_million_miRNA_mapped AS double) reads_per_million') \
                                                .selectExpr('patient_id', 'condition', 'struct(mirna_id, read_count, reads_per_million) AS mirna_expression_levels') \
                                                .groupBy('patient_id', 'condition') \
                                                .agg(collect_list(col('mirna_expression_levels')).alias('mirna_expression_levels')) \
                                                .withColumn('sample_type', sample_types[col('condition')]) \
                                                .withColumn('condition', when(col('condition').isin(tcga_sample_type_control_codes), 'control')
                                                                        .when(col('condition').isin(tcga_sample_type_case_codes), 'case')
                                                                        .otherwise('')) \
                                                .withColumn('disease', lit(disease)) \
                                                .withColumn('timestamp', current_timestamp())
                                               

In [None]:
if not [t.name for t in spark.catalog.listTables(biological_database_name) if t.name == silver_patient_mirna_sample_table_name] or \
   (any([t.name for t in spark.catalog.listTables(biological_database_name) if t.name == silver_patient_mirna_sample_table_name]) and
    spark.sql(f"SELECT disease FROM {silver_patient_mirna_sample_table_name} WHERE disease = '{disease}';").count() == 0):  

    silver_mirna_sample_df.write \
        .format('delta') \
        .mode('overwrite') \
        .option('overwriteSchema', 'true') \
        .option('partitionOverwriteMode', 'dynamic') \
        .partitionBy('disease', 'condition', 'sample_type') \
        .saveAsTable(silver_patient_mirna_sample_table_name)
else:    
    silver_mirna_sample_df.createOrReplaceTempView(silver_patient_mirna_sample_view_name)
    
    spark.sql(f"""MERGE INTO {silver_patient_mirna_sample_table_name} AS target
                  USING {silver_patient_mirna_sample_view_name} AS source
                  ON source.disease = target.disease
                     AND source.patient_id = target.patient_id
                     AND source.condition = target.condition 
                     AND source.sample_type = target.sample_type               
                  WHEN MATCHED THEN 
                    UPDATE SET *
                  WHEN NOT MATCHED THEN INSERT * """)

    spark.sql(f"""MERGE INTO {silver_patient_mirna_sample_table_name} AS target
                  USING (SELECT 
                           disease,
                           patient_id,
                           condition,
                           sample_type
                         FROM (SELECT 
                                 target.disease, 
                                 target.patient_id, 
                                 target.condition,
                                 target.sample_type,                               
                                 source.disease AS source_disease
                               FROM {silver_patient_mirna_sample_table_name} AS target
                               LEFT JOIN {silver_patient_mirna_sample_view_name} AS source
                                 ON target.disease = source.disease
                                    AND target.patient_id = source.patient_id
                                    AND target.condition = source.condition
                                    AND target.sample_type = source.sample_type)
                         WHERE ISNULL(source_disease)) AS source
                  ON source.disease = target.disease
                     AND source.patient_id = target.patient_id
                     AND source.condition = target.condition
                     AND source.sample_type = target.sample_type
                  WHEN MATCHED THEN                    
                    DELETE """)
                     
    spark.catalog.dropTempView(silver_patient_mirna_sample_view_name)

In [None]:
spark.stop()