In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.window import Window
import json
import os
import copy
import configparser
import datetime

In [2]:
def load_cdm_tables(spark, property_ini_file_path):
    
    """Load all cdm domain tables and return as an array of dataframes
    
    Keyword arguments:
    spark -- the spark session object
    property_ini_file_path -- the property file that contains the jdbc connection parameters
    
    The function is used to connect to an OMOP instance and return the domain tables.
    
    """
    
    #Parse the properties
    config = configparser.ConfigParser()
    config.read(property_ini_file_path)
    properties = config.defaults()
    base_url = properties["base_url"]
    
    visit_occurrence = spark.read \
        .jdbc(base_url, "dbo.visit_occurrence", properties=properties)

    #Load condition_occurrence
    condition_occurrence = spark.read \
        .jdbc(base_url, "dbo.condition_occurrence", properties=properties)

    #Load drug_exposure
    drug_exposure = spark.read \
        .jdbc(base_url, "dbo.drug_exposure", properties=properties)

    #Load procedure_occurrence
    procedure_occurrence = spark.read \
        .jdbc(base_url, "dbo.procedure_occurrence", properties=properties)

    #Load measurement
    measurement = spark.read \
        .jdbc(base_url, "dbo.measurement", properties=properties)

    #Load observation
    observation = spark.read \
        .jdbc(base_url, "dbo.observation", properties=properties)
        
    return (visit_occurrence, condition_occurrence, drug_exposure, procedure_occurrence, measurement, observation)

In [3]:
def join_domain_to_visit(domain_tables, visit_occurrence):
    
    """Join domain tables to visit_occurrence
    
    Keyword arguments:
    domain_tables -- the array containing the OMOOP domain tabls except visit_occurrence
    visit_occurrence -- the OMOP visit_occurrence table
    
    The function is to join each domain table to visit_occurrence to limit the records that have a valid visit_occurrence_id.
    In addition, the output columns of the domain table is converted to the same standard format as the following 
    (person_id, time_window, standard_concept_id, domain). In this case, co-occurrence is defined as those concept ids that have co-occurred within the same visit. 
    
    """
    
    joined_domain_tables = []
    
    for domain_table in domain_tables:
        #extract the domain concept_id from the table fields. E.g. condition_concept_id from condition_occurrence
        concept_id_field = [f.name for f in domain_table.schema.fields if "concept_id" in f.name][0]
        #extract the name of the table
        table_domain_field = concept_id_field.replace("_concept_id", "")
        #limit the domain records to those which have a visit_occurrence_id
        joined_domain_table = domain_table \
            .join(v, domain_table["visit_occurrence_id"] == v["visit_occurrence_id"])
        #standardize the output columns
        joined_domain_tables.append(
            joined_domain_table \
                .select(domain_table["person_id"], 
                    domain_table["visit_occurrence_id"].alias("time_window"), 
                    domain_table[concept_id_field].alias("standard_concept_id"), 
                    lit(table_domain_field).alias("domain"))
        )
        
    return joined_domain_tables

In [4]:
def join_domain_lifetime(domain_tables):
    
    """Standardize the format of OMOP domain tables except visit_occurrence
    
    Keyword arguments:
    domain_tables -- the array containing the OMOOP domain tabls except visit_occurrence
    
    The the output columns of the domain table is converted to the same standard format as the following 
    (person_id, time_window, standard_concept_id, domain). 
    In this case, co-occurrence is defined as those concept ids that have co-occurred within the lifetime of a patient.
    
    """
    
    joined_domain_tables = []
    
    for domain_table in domain_tables:
        #extract the domain concept_id from the table fields. E.g. condition_concept_id from condition_occurrence
        concept_id_field = [f.name for f in domain_table.schema.fields if "concept_id" in f.name][0]
        #extract the name of the table
        table_domain_field = concept_id_field.replace("_concept_id", "")
        #assign every record to the same time_window
        domain_table = domain_table.withColumn("time_window", lit(1))
        #standardize the output columns
        joined_domain_tables.append(
            domain_table \
                .select(domain_table["person_id"], 
                    domain_table["time_window"], 
                    domain_table[concept_id_field].alias("standard_concept_id"), 
                    lit(table_domain_field).alias("domain"))
        )
        
    return joined_domain_tables

In [5]:
def join_domain_time_window(domain_tables, span, start_year, end_year):
    
    """Standardize the format of OMOP domain tables using a time frame
    
    Keyword arguments:
    domain_tables -- the array containing the OMOOP domain tabls except visit_occurrence
    span -- the span of the time window
    start_year -- the start year for the sliding time window
    end_year -- the end year for the sliding time window
    
    The the output columns of the domain table is converted to the same standard format as the following 
    (person_id, time_window, standard_concept_id, domain). 
    In this case, co-occurrence is defined as those concept ids that have co-occurred 
    within the same time window of a patient.
    
    """
    
    max_time_period = int((end_year - start_year) / span) + 1
    
    joined_domain_tables = []
    
    for domain_table in domain_tables:
        #extract the domain concept_id from the table fields. E.g. condition_concept_id from condition_occurrence
        concept_id_field = [f.name for f in domain_table.schema.fields if "concept_id" in f.name][0]
        
        #extract the domain start_date column
        date_field = [f.name for f in domain_table.schema.fields if "date" in f.name][0] 
        
        #extract the name of the table
        table_domain_field = concept_id_field.replace("_concept_id", "")
        
        domain_table = domain_table.withColumn("year", substring(domain_table[date_field], 0, 4).cast("integer"))
        
        domain_table = domain_table.withColumn("time_window", when(domain_table["year"] < start_year, -1) \
                                                .when(domain_table["year"] > end_year, max_time_period) \
                                                .otherwise((domain_table["year"] - start_year) / span).cast("integer"))
        
        #standardize the output columns
        joined_domain_tables.append(
            domain_table \
                .select(domain_table["person_id"], 
                    domain_table["time_window"], 
                    domain_table[concept_id_field].alias("standard_concept_id"), 
                    lit(table_domain_field).alias("domain"))
        )
        
    return joined_domain_tables

In [6]:
def create_cooccurrence_matrix(domain_tables,
                               patient_time_period_concept_output,
                               concept_occurrence_output,
                               cooccurrence_matrix_output):
    
    """Create the co-occurrence matrix across all domains
    
    Keyword arguments:
    domain_tables -- the array containing the OMOOP domain tabls except visit_occurrence
    concept_occurrence_output -- the path for writing the concept occurence matrix
    cooccurrence_matrix_output -- the path for writing the co-occurrence matrix
    
    This function is union all the domain tables and calculate the co-occurrence for the same person within the defined time window
    
    """
    
    #Union domain tables
    patient_visit_concept = None
    
    for domain_table in domain_tables:
        if patient_visit_concept == None:    
            patient_visit_concept = domain_table
        else:
            patient_visit_concept = patient_visit_concept.union(domain_table)
    
    #person_id, time_window, and concept_id from all domains
    patient_visit_concept = patient_visit_concept.select("person_id", "time_window", "standard_concept_id") \
        .distinct() \
        .where(col("standard_concept_id") != 0)
    
    patient_visit_concept.write.mode("overwrite") \
        .parquet(patient_time_period_concept_output)
    
    #Create the concept occurrence matrix
    concept_occurrence_matrix = patient_visit_concept \
        .groupBy("standard_concept_id").count() \
        .withColumn("id", dense_rank().over(Window.orderBy("standard_concept_id"))) \
        .select("id","standard_concept_id", "count")
        
    concept_occurrence_matrix.write.mode("overwrite") \
        .parquet(concept_occurrence_output)
    
    #Add ranks to dataframe to avoid the symetric pairs generated by the self-join operation. 
    #Make two copies of the patient_visit_concept dataframe for self-join
    pvc_1 = patient_visit_concept.rdd.toDF(patient_visit_concept.schema)
    pvc_2 = patient_visit_concept.rdd.toDF(patient_visit_concept.schema)
    
    #Create the cooccurrence matrix via a self-join where the concept_ids are NOT the same
    cooccurrence_matrix = pvc_1 \
        .join(pvc_2, (pvc_1["person_id"] == pvc_2["person_id"]) \
            & (pvc_1["time_window"] == pvc_2["time_window"])) \
        .where(pvc_1["standard_concept_id"] != pvc_2["standard_concept_id"]) \
        .select(pvc_1["person_id"].alias("person_id"),
                pvc_1["standard_concept_id"].alias("standard_concept_id_1"), 
                pvc_2["standard_concept_id"].alias("standard_concept_id_2")) \
        .groupBy("standard_concept_id_1", "standard_concept_id_2").count() \
    
    #Join the cooccurrence matrix to the concept occurrence table to normalize the cooccurrence frequency
    cooccurrence_matrix = cooccurrence_matrix \
        .join(concept_occurrence_matrix, 
                        cooccurrence_matrix["standard_concept_id_1"] == concept_occurrence_matrix["standard_concept_id"]) \
        .select(cooccurrence_matrix["standard_concept_id_1"],
                cooccurrence_matrix["standard_concept_id_2"],
                cooccurrence_matrix["count"],
                concept_occurrence_matrix["count"].alias("standard_concept_id_1_count"),
        concept_occurrence_matrix["id"].alias("id_1")
               ) \
        .join(concept_occurrence_matrix, 
                        cooccurrence_matrix["standard_concept_id_2"] == concept_occurrence_matrix["standard_concept_id"]) \
        .select(cooccurrence_matrix["standard_concept_id_1"],
                cooccurrence_matrix["standard_concept_id_2"],
                cooccurrence_matrix["count"],
                col("standard_concept_id_1_count"),
        col("id_1"),
                concept_occurrence_matrix["count"].alias("standard_concept_id_2_count"),
        concept_occurrence_matrix["id"].alias("id_2")
               ) \
        .withColumn("normalized_count", col("count") * 2 / (col("standard_concept_id_1_count") + col("standard_concept_id_2_count"))) \
    .select("id_1", "id_2", "normalized_count", "standard_concept_id_1", "standard_concept_id_2", "count")
    
    #Save the cooccurrence matrix
    cooccurrence_matrix.write.mode("overwrite").parquet(cooccurrence_matrix_output)
    
    return (cooccurrence_matrix, concept_occurrence_matrix)

In [None]:
if __name__ == "__main__":

    spark = SparkSession.builder.appName("Phenotype Cooccurrence").getOrCreate()
    spark.sparkContext.setLogLevel("WARN")

    v, c, d, p, m, o = load_cdm_tables(spark, "omop_database_properties.ini")
    
    domain_tables = [c, d, p, m, o]
    
    
    #Create the cooccurrence_matrix based on visits
    create_cooccurrence_matrix(join_domain_to_visit(domain_tables, v), 
                               "concept_occurrence_visit", "cooccurrence_matrix_visit")
    
    #Create the cooccurrence matrix based on a 5-year time window
    start_year = 1985
    end_year = current_year = datetime.datetime.now().year
    time_window_span = 5
    create_cooccurrence_matrix(join_domain_time_window(domain_tables, time_window_span, start_year, end_year),
                                "concept_occurrence_5", "cooccurrence_matrix_5")
    
    #Create the cooccurrence_matrix based for the lifetime events
    create_cooccurrence_matrix(join_domain_lifetime(domain_tables, v), 
                               "concept_occurrence_lifetime", "cooccurrence_matrix_lifetime")