In [None]:
import pandas as pd
import os
import subprocess
from datetime import datetime
from matplotlib import pyplot
import matplotlib.pyplot as plt
import seaborn as sns

# get the bucket name
my_bucket = os.getenv('WORKSPACE_BUCKET')

# Get the BigQuery curated dataset for the current workspace context.
CDR = os.environ['WORKSPACE_CDR']

from google.cloud import bigquery
# Instantiate a BigQuery client
client = bigquery.Client()
#!pip install upsetplot #if necessary

from tqdm import tqdm

pd.set_option('display.max_columns', 7000)
pd.set_option('display.max_row', 7000)

In [None]:
def getTable(table_name, folder):
    my_bucket = os.getenv('WORKSPACE_BUCKET')

    # copy csv file from the bucket to the current working space
    os.system(f"gsutil cp '{my_bucket}/data/{folder}/{table_name}' .")

    print(f'[INFO] {table_name} is successfully downloaded into your working space')
    # save dataframe in a csv file in the same workspace as the notebook
    table_read = pd.read_csv(table_name, sep="\t")
    return table_read
def getFile(table_name, folder):
    my_bucket = os.getenv('WORKSPACE_BUCKET')

    # copy csv file from the bucket to the current working space
    os.system(f"gsutil cp '{my_bucket}/data/{folder}/{table_name}' .")

    print(f'[INFO] {table_name} is successfully downloaded into your working space')
def describe_tables(**tables):
    # Initialize a dictionary to store results
    descriptions = {}
    
    
    i=0
    # Loop through each table with its name
    for table_name, table in tables.items():
        
        table_unique = table.drop_duplicates("person_id", keep = "first")
        
        desc = table_unique["AgeAtVisit_Years"].describe()

        # Calculate male and female counts and mf_ratio
        male_count = table_unique[table_unique["sex_at_birth_source_value"].str.contains("SexAtBirth_Male", na=False)].shape[0]
        female_count = table_unique[table_unique["sex_at_birth_source_value"].str.contains("SexAtBirth_Female", na=False)].shape[0]
        mf_ratio = male_count / (male_count + female_count) if (male_count + female_count) > 0 else None

        # Add the calculated values to the description as new rows
        desc["male_count"] = male_count
        desc["female_count"] = female_count
        desc["mf_ratio"] = mf_ratio
        
        #calculate solution annotated rate
        

        # Store the updated description in the dictionary
        descriptions[table_name] = desc
        i+=1

    # Convert the dictionary of descriptions to a DataFrame
    result_df = pd.DataFrame(descriptions).T  # Transpose to have table names as rows
    result_df.index.name = "Table_Name"  # Set a meaningful index name
    result_df = result_df.reset_index()

    return result_df
def annoConsentDate(table):
    #annotate with consent dates

    # Get the list of person IDs from the first query.
    person_ids = table['person_id'].unique().tolist()
    person_ids_query = ','.join(map(str, person_ids))

    # Run the consent query using the person IDs from the first query.
    consent_query = f'''
    SELECT DISTINCT person_id, MAX(observation_date) AS primary_consent_date
    FROM `{CDR}.concept`
    JOIN `{CDR}.concept_ancestor` ON concept_id = ancestor_concept_id
    JOIN `{CDR}.observation` ON descendant_concept_id = observation_source_concept_id
    WHERE concept_name = 'Consent PII' 
      AND concept_class_id = 'Module'
      AND person_id IN ({person_ids_query})
    GROUP BY person_id
    '''
    consent_dates = pd.read_gbq(consent_query, progress_bar_type="tqdm_notebook")

    # Merge the results on person_id to incorporate the consent date.
    final_result = table.merge(consent_dates, on='person_id', how='left')
    
    return final_result
def describe_solution_tables(**tables):
    # Initialize a dictionary to store results
    descriptions = {}
    
    
    i=0
    # Loop through each table with its name
    for table_name, table in tables.items():
        
        table_unique = table.drop_duplicates("person_id", keep = "first")
        
        #table_unique = calcAgeAtConsent(table_unique)
        
        desc = table_unique["AgeAtConsent_Years"].describe()

        # Calculate male and female counts and mf_ratio
        male_count = table_unique[table_unique["sex_at_birth_source_value"].str.contains("SexAtBirth_Male", na=False)].shape[0]
        female_count = table_unique[table_unique["sex_at_birth_source_value"].str.contains("SexAtBirth_Female", na=False)].shape[0]
        mf_ratio = male_count / (male_count + female_count) if (male_count + female_count) > 0 else None

        # Add the calculated values to the description as new rows
        desc["male_count"] = male_count
        desc["female_count"] = female_count
        desc["mf_ratio"] = mf_ratio
        
        #calculate solution annotated rate
        

        # Store the updated description in the dictionary
        descriptions[table_name] = desc
        i+=1

    # Convert the dictionary of descriptions to a DataFrame
    result_df = pd.DataFrame(descriptions).T  # Transpose to have table names as rows
    result_df.index.name = "Table_Name"  # Set a meaningful index name
    result_df = result_df.reset_index()

    return result_df
def calcAgeToday(solution_person_table):
    # Create a copy of the input DataFrame to avoid modifying the original
    solution_person_table = solution_person_table.copy()

    # Ensure 'birth_datetime' is in datetime format and make it timezone-naive
    solution_person_table["birth_datetime"] = pd.to_datetime(
        solution_person_table["birth_datetime"], errors="coerce"
    ).dt.tz_localize(None)

    # Check for invalid or missing dates
    if solution_person_table["birth_datetime"].isnull().any():
        print("Warning: Missing or invalid birth_datetime values detected.")
        print(solution_person_table[solution_person_table["birth_datetime"].isnull()])

    # Drop rows with invalid or missing 'birth_datetime'
    solution_person_table = solution_person_table.dropna(subset=["birth_datetime"])

    # Get today's date as timezone-naive
    today = datetime.now()

    # Calculate the difference (result is a timedelta Series)
    solution_person_table["AgeToday"] = today - solution_person_table["birth_datetime"]

    # Ensure 'AgeToday' is a timedelta type
    if not pd.api.types.is_timedelta64_dtype(solution_person_table["AgeToday"]):
        raise ValueError("AgeToday is not a valid timedelta64 dtype.")

    # Convert the timedelta to years
    solution_person_table["AgeToday_Years"] = (
        solution_person_table["AgeToday"].dt.days / 365.25
    )

    return solution_person_table
def calcAgeAtConsent(ICD_person_table):
    # Calculate the difference (result is a timedelta Series)
    ICD_person_table["AgeAtConsent"] = (
        ICD_person_table["primary_consent_date"] - ICD_person_table["birth_datetime"]
    )

    # Convert the timedelta to years
    ICD_person_table["AgeAtConsent_Years"] = (
        ICD_person_table["AgeAtConsent"].dt.days / 365.25
    )
    
    return ICD_person_table
def calc_solved_rate(summary_df2):

    concept_solved_n = summary_df2[~summary_df2['Table_Name'].str.contains("concept|female", case=False)]["count"].sum()
    summary_df2 = summary_df2.reset_index()
    #print(summary_df2)
    #print(concept_solved_n)

    concept_n = summary_df2.loc[0,"count"]
    #print(concept_n)
    concept_solved_rate = concept_solved_n / concept_n

    summary_df2.loc[0,"solved_rate"] = concept_solved_rate

    subset_names = ["AD", "XLmale", "XLfemale", "AR", "ADAR"]
    k=1
    for i in subset_names:
        i_reg = f"^{i}_"
        i_tablename = f"{i}_solutions"

        #print(f"Table Name: {i_tablename}")

        # Access and print the DataFrame dynamically
        if i_tablename in globals():
            tbl = globals()[i_tablename]
            tot_sol = tbl["person_id"].nunique()
            sol_anno = summary_df2[summary_df2["Table_Name"].str.contains(i_reg, case=False)]["count"].sum()
            sol_anno_rate = sol_anno / tot_sol
            summary_df2.loc[k,"solved_rate"] = sol_anno_rate
        else:
            print(f"{i_tablename} does not exist.")
        k+=1

    return summary_df2 
def calc_solved_rate_within(summary_df2):

    concept_solved_n = summary_df2[~summary_df2['Table_Name'].str.contains("concept|female", case=False)]["count"].sum()
    summary_df2 = summary_df2.reset_index()
    #print(summary_df2)
    #print(concept_solved_n)

    concept_n = summary_df2.loc[0,"count"]
    #print(concept_n)
    concept_solved_rate = concept_solved_n / concept_n

    summary_df2.loc[0,"solved_rate"] = concept_solved_rate

    subset_names = ["AD", "XLmale", "XLfemale", "AR", "ADAR"]
    k=1
    for i in subset_names:
        i_reg = f"^{i}_"
        i_tablename = f"{i}_solutions"

        #print(f"Table Name: {i_tablename}")

        # Access and print the DataFrame dynamically
        if i_tablename in globals():
            tbl = globals()[i_tablename]
            tot_sol = tbl["person_id"].nunique()
            sol_anno = summary_df2[summary_df2["Table_Name"].str.contains(i_reg, case=False)]["count"].sum()
            sol_anno_rate = sol_anno / concept_n
            summary_df2.loc[k,"solved_rate"] = sol_anno_rate
        else:
            print(f"{i_tablename} does not exist.")
        k+=1

    return summary_df2 
def generate_merges(base_name, concept_person_table, solution_dict):
    """
    Dynamically generate merged tables, add them to the global namespace, and return a list of table names.

    Parameters:
    - base_name (str): The base name to replace "retdegen" in the generated table names.
    - concept_person_table (pd.DataFrame): The concept-person DataFrame to merge on.
    - solution_dict (dict): A dictionary where keys are solution table names (e.g., "AD_solutions")
                            and values are the corresponding DataFrames.

    Returns:
    - list: A list of dynamically generated table names.
    """
    # Initialize the list to store table names
    table_names = []
    
    # Remove conflicting global variables
    #for key in list(globals().keys()):
    #    if key.endswith("_in_retdegen") or key == "retdegen_concept_person":
    #        print(f"Removing stale variable: {key}")
    #        del globals()[key]

    for solution_name, solution_table in solution_dict.items():
        # Construct the target table name dynamically
        target_table_name = f"{solution_name.split('_')[0]}_in_{base_name}"
        table_names.append(target_table_name)  # Add the table name to the list
        

        # Perform the merge
        merged_table = pd.merge(
            solution_table[['variant_id', 'person_id', 'allele_count', 'GeneID_EG', 'GeneID_Symbol']],
            concept_person_table,
            on="person_id",
            how="left"
        )

        # Assign the result to the global namespace
        globals()[target_table_name] = merged_table
        print(f"Generated table: {target_table_name}")

    # Add the concept-person table to the list of table names
    concept_table_name = f"{base_name}_concept_person"
    table_names.insert(0, concept_table_name)  # Ensure it appears first in the list
    globals()[concept_table_name] = concept_person_table  # Assign it to the global namespace
    #print(f"Generated table: {concept_table_name}")

    return table_names
def process_concept_and_solution_tables(concept_tables, solution_tables):
    """
    Loops through each concept_table and solution_table, generates merged tables,
    calculates summaries, and combines all summaries into one big table.

    Parameters:
    - concept_tables (dict): Dictionary of concept tables (key: table name, value: DataFrame).
    - solution_tables (dict): Dictionary of solution tables (key: table name, value: DataFrame).

    Returns:
    - pd.DataFrame: A single merged summary table containing results for all concept and solution tables.
    """
    all_summaries = []

    # Loop through each concept table
    for concept_name, concept_table in concept_tables.items():
        base_name = concept_name.split("_")[0]  # Extract base name (e.g., "retdys")

        # Generate merged tables for this concept table
        merged_table_names = generate_merges(base_name, concept_table, solution_tables)

        # Retrieve the generated tables from the global namespace
        merged_tables = {name: globals()[name] for name in merged_table_names}

        # Describe the merged tables
        summary_df = describe_tables(**merged_tables)

        # Calculate solved rates
        solved_summary = calc_solved_rate(summary_df)

        # Add a column to identify the concept
        solved_summary["concept"] = concept_name

        # Append the summary to the list
        all_summaries.append(solved_summary)

    # Combine all summaries into one big table
    final_summary_table = pd.concat(all_summaries, axis=0)

    return final_summary_table
def process_concept_and_solution_tables_within(concept_tables, solution_tables):
    """
    Loops through each concept_table and solution_table, generates merged tables,
    calculates summaries, and combines all summaries into one big table.

    Parameters:
    - concept_tables (dict): Dictionary of concept tables (key: table name, value: DataFrame).
    - solution_tables (dict): Dictionary of solution tables (key: table name, value: DataFrame).

    Returns:
    - pd.DataFrame: A single merged summary table containing results for all concept and solution tables.
    """
    all_summaries = []

    # Loop through each concept table
    for concept_name, concept_table in concept_tables.items():
        base_name = concept_name.split("_")[0]  # Extract base name (e.g., "retdys")

        # Generate merged tables for this concept table
        merged_table_names = generate_merges(base_name, concept_table, solution_tables)

        # Retrieve the generated tables from the global namespace
        merged_tables = {name: globals()[name] for name in merged_table_names}

        # Describe the merged tables
        summary_df = describe_tables(**merged_tables)

        # Calculate solved rates
        solved_summary = calc_solved_rate_within(summary_df)

        # Add a column to identify the concept
        solved_summary["concept"] = concept_name

        # Append the summary to the list
        all_summaries.append(solved_summary)

    # Combine all summaries into one big table
    final_summary_table = pd.concat(all_summaries, axis=0)

    return final_summary_table
def process_concept_solution_gene_summaries(concept_tables, solution_tables):
    """
    Loops through each concept_table and solution_table, generates merged tables,
    calculates gene-level summaries, and combines all summaries into one big table.

    Parameters:
    - concept_tables (dict): Dictionary of concept tables (key: table name, value: DataFrame).
    - solution_tables (dict): Dictionary of solution tables (key: table name, value: DataFrame).

    Returns:
    - pd.DataFrame: A single merged summary table containing results for all concepts, solutions, and genes.
    """
    all_gene_summaries = []

    # Loop through each concept table
    for concept_name, concept_table in concept_tables.items():
        base_name = concept_name.split("_")[0]  # Extract base name (e.g., "retdys")

        # Generate merged tables for this concept table
        merged_table_names = generate_merges(base_name, concept_table, solution_tables)

        # Retrieve the generated tables from the global namespace
        merged_tables = {name: globals()[name] for name in merged_table_names}

        # Loop through each merged table
        for table_name, merged_table in merged_tables.items():
            # Ensure 'GeneID_Symbol' exists in the merged table
            if "GeneID_Symbol" not in merged_table.columns:
                print(f"Warning: 'GeneID_Symbol' column not found in {table_name}. Skipping...")
                continue

            # Group by GeneID_Symbol and describe each group
            gene_groups = merged_table.groupby("GeneID_Symbol")


            for gene, group in gene_groups:
                # Generate summaries for the current gene
                
                # find all NaN entries in condition i.e. unannotated
                nan_mask = group.condition_source_value.isnull() # or df.isna()
                
                summary = {
                    "concept": concept_name,
                    "table": table_name,
                    "gene": gene,
                    "variant_count":  group[~nan_mask].variant_id.nunique(),
                    "person_count": group[~nan_mask].person_id.nunique()
                    #"allele_count_mean": group["allele_count"].mean(),
                    #"allele_count_std": group["allele_count"].std(),
                    #"allele_count_min": group["allele_count"].min(),
                    #"allele_count_max": group["allele_count"].max(),
                }
                # Append the summary to the list
                all_gene_summaries.append(summary)

    # Convert the list of summaries into a DataFrame
    gene_summary_table = pd.DataFrame(all_gene_summaries)
    
    gene_summary_table['solution_table'] = gene_summary_table['table'].str.split('_').str[0]


    return gene_summary_table
def process_solution_gene_summaries(solution_tables):
    """
    Generates gene-level summaries directly from solution tables.

    Parameters:
    - solution_tables (dict): Dictionary of solution tables (key: table name, value: DataFrame).

    Returns:
    - pd.DataFrame: A DataFrame containing gene-level summaries for all solution tables.
    """
    all_gene_summaries = []

    # Loop through each solution table
    for table_name, solution_table in solution_tables.items():
        # Ensure 'GeneID_Symbol' exists in the solution table
        if "GeneID_Symbol" not in solution_table.columns:
            print(f"Warning: 'GeneID_Symbol' column not found in {table_name}. Skipping...")
            continue

        # Group by GeneID_Symbol and describe each group
        gene_groups = solution_table.groupby("GeneID_Symbol")

        for gene, group in gene_groups:
            # Generate summaries for the current gene
            summary = {
                "table": table_name,
                "gene": gene,
                "variant_count": group["variant_id"].nunique(),
                "person_count": group["person_id"].nunique(),
                "allele_count_mean": group["allele_count"].mean(),
                "allele_count_std": group["allele_count"].std(),
                "allele_count_min": group["allele_count"].min(),
                "allele_count_max": group["allele_count"].max(),
            }
            # Append the summary to the list
            all_gene_summaries.append(summary)

    # Convert the list of summaries into a DataFrame
    gene_summary_table = pd.DataFrame(all_gene_summaries)

    return gene_summary_table
def calculate_annotation_rates(solution_gene_summary, concept_gene_summary):
    """
    Calculates annotation rates at the gene and variant levels by comparing solution summaries
    and concept-intersection summaries.

    Parameters:
    - solution_gene_summary (pd.DataFrame): Gene-level summaries for solution tables 
                                             (output from process_solution_gene_summaries).
    - concept_gene_summary (pd.DataFrame): Gene-level summaries for intersection tables 
                                           (output from process_concept_solution_gene_summaries).

    Returns:
    - pd.DataFrame: A DataFrame containing annotation rates at the gene and variant levels.
    """
    # Merge the solution and concept summaries on the common columns
    merged = pd.merge(
        concept_gene_summary,
        solution_gene_summary,
        on=["table", "gene"],
        suffixes=("_concept", "_solution"),
        how="inner"
    )

    # Calculate annotation rates
    merged["person_annotation_rate"] = (
        merged["person_count_concept"] / merged["person_count_solution"]
    )
    merged["variant_annotation_rate"] = (
        merged["variant_count_concept"] / merged["variant_count_solution"]
    )

    # Handle divisions by zero or missing values
    merged["person_annotation_rate"] = merged["person_annotation_rate"].fillna(0).replace([float('inf')], 0)
    merged["variant_annotation_rate"] = merged["variant_annotation_rate"].fillna(0).replace([float('inf')], 0)

    # Return the annotated DataFrame
    return merged
def calculate_all_annotation_rates(solution_gene_summary, concept_gene_summary):
    """
    Calculates annotation rates at the gene and variant levels for each category
    by comparing solution and concept summaries.

    Parameters:
    - solution_gene_summary (pd.DataFrame): Gene-level summaries for solution tables.
    - concept_gene_summary (pd.DataFrame): Gene-level summaries for intersection tables.

    Returns:
    - pd.DataFrame: A DataFrame containing annotation rates at the gene and variant levels for each category.
    """
    all_annotation_rates = []
    
    # Make necessary columns
    concept_gene_summary["category"] = concept_gene_summary["table"].str.extract(r"in_([a-zA-Z]+)").fillna("")
    concept_gene_summary["category_inheritance"] = concept_gene_summary["table"].str.extract(r"([a-zA-Z]+)_in").fillna("")
    solution_gene_summary["category"] = solution_gene_summary["table"].str.extract(r"([a-zA-Z]+)_solutions").fillna("")


    # Ensure both summaries have the necessary columns
    required_columns = {"table", "gene", "person_count", "variant_count", "category"}
    if not required_columns.issubset(solution_gene_summary.columns):
        raise ValueError("Solution gene summary is missing required columns.")
    if not required_columns.issubset(concept_gene_summary.columns):
        raise ValueError("Concept gene summary is missing required columns.")

    # Process each category separately
    for con_category in concept_gene_summary["category"].unique():
        # Filter concept summaries for the current category
        concept_subset = concept_gene_summary[concept_gene_summary["category"] == con_category]
        
        for sol_category in solution_gene_summary["category"].unique():
            # Filter solution summaries for the current category
            sol_category_q = f"^{sol_category}"
            consol_category_q = f"^{sol_category}$"
            solution_subset = solution_gene_summary[solution_gene_summary["category"].str.contains(sol_category_q)]
            concept_subset_inh = concept_subset[concept_subset["category_inheritance"].str.contains(consol_category_q)]

            if concept_subset_inh.empty or solution_subset.empty:
                print(f"Skipping category '{con_category}' / '{sol_category}' because one of the subsets is empty.")
                continue

            # Debugging: Check subset details
            print(f"Processing category '{con_category}' / '{sol_category}'")
            #print("Concept Subset:")
            #print(concept_subset_inh.head())

            # Merge the summaries for this category
            merged = pd.merge(
                concept_subset_inh,
                solution_subset[['variant_count', 'person_count', 'gene']],
                on=["gene"],
                suffixes=("_concept", "_solution"),
                how="inner"
            )

            # Calculate annotation rates
            merged["person_annotation_rate"] = (
                merged["person_count_concept"] / merged["person_count_solution"]
            )
            merged["variant_annotation_rate"] = (
                merged["variant_count_concept"] / merged["variant_count_solution"]
            )

            # Handle divisions by zero or missing values
            merged["person_annotation_rate"] = merged["person_annotation_rate"].fillna(0).replace([float('inf')], 0)
            merged["variant_annotation_rate"] = merged["variant_annotation_rate"].fillna(0).replace([float('inf')], 0)

            # Append to the results list
            all_annotation_rates.append(merged)

    # Combine all results into a single DataFrame
    if all_annotation_rates:
        final_annotation_rates = pd.concat(all_annotation_rates, axis=0, ignore_index=True)

    else:
        final_annotation_rates = pd.DataFrame()  # Return an empty DataFrame if no data
    
    return final_annotation_rates
def process_solution_variant_summaries(solution_tables):
    """
    Generates variant-level summaries directly from solution tables.

    Parameters:
    - solution_tables (dict): Dictionary of solution tables (key: table name, value: DataFrame).

    Returns:
    - pd.DataFrame: A DataFrame containing variant-level summaries for all solution tables.
    """
    all_variant_summaries = []

    # Loop through each solution table
    for table_name, solution_table in solution_tables.items():
        # Ensure 'variant_id' exists in the solution table
        if "variant_id" not in solution_table.columns:
            print(f"Warning: 'variant_id' column not found in {table_name}. Skipping...")
            continue

        # Group by variant_id and describe each group
        variant_groups = solution_table.groupby("variant_id")

        for variant, group in variant_groups:
            # Generate summaries for the current variant
            summary = {
                "table": table_name,
                "variant_id": variant,
                "gene": group["GeneID_Symbol"].iloc[0] if "GeneID_Symbol" in group.columns else None,
                "person_count": group["person_id"].nunique(),
                "allele_count_mean": group["allele_count"].mean(),
                "allele_count_std": group["allele_count"].std(),
                "allele_count_min": group["allele_count"].min(),
                "allele_count_max": group["allele_count"].max(),
            }
            # Append the summary to the list
            all_variant_summaries.append(summary)

    # Convert the list of summaries into a DataFrame
    variant_summary_table = pd.DataFrame(all_variant_summaries)

    return variant_summary_table
def process_concept_solution_variant_summaries(concept_tables, solution_tables):
    """
    Loops through each concept_table and solution_table, generates merged tables,
    calculates variant-level summaries, and combines all summaries into one big table.

    Parameters:
    - concept_tables (dict): Dictionary of concept tables (key: table name, value: DataFrame).
    - solution_tables (dict): Dictionary of solution tables (key: table name, value: DataFrame).

    Returns:
    - pd.DataFrame: A single merged summary table containing results for all concepts, solutions, and variants.
    """
    all_variant_summaries = []

    # Loop through each concept table
    for concept_name, concept_table in concept_tables.items():
        base_name = concept_name.split("_")[0]  # Extract base name (e.g., "retdys")

        # Generate merged tables for this concept table
        merged_table_names = generate_merges(base_name, concept_table, solution_tables)

        # Retrieve the generated tables from the global namespace
        merged_tables = {name: globals()[name] for name in merged_table_names}

        # Loop through each merged table
        for table_name, merged_table in merged_tables.items():
            # Ensure 'variant_id' exists in the merged table
            if "variant_id" not in merged_table.columns:
                print(f"Warning: 'variant_id' column not found in {table_name}. Skipping...")
                continue

            # Group by variant_id and describe each group
            variant_groups = merged_table.groupby("variant_id")

            for variant, group in variant_groups:
                # Generate summaries for the current variant
                
                # find all NaN entries in condition i.e. unannotated
                nan_mask = group.condition_source_value.isnull() # or df.isna()
                
                summary = {
                    "concept": concept_name,
                    "table": table_name,
                    "variant_id": variant,
                    "gene": group["GeneID_Symbol"].iloc[0] if "GeneID_Symbol" in group.columns else None,
                    "person_count": group[~nan_mask].person_id.nunique()
                    #"allele_count_mean": group["allele_count"].mean(),
                    #"allele_count_std": group["allele_count"].std(),
                    #"allele_count_min": group["allele_count"].min(),
                    #"allele_count_max": group["allele_count"].max(),
                }
                # Append the summary to the list
                all_variant_summaries.append(summary)

    # Convert the list of summaries into a DataFrame
    variant_summary_table = pd.DataFrame(all_variant_summaries)

    return variant_summary_table
def calculate_all_variant_annotation_rates(solution_variant_summary, concept_variant_summary):
    """
    Calculates annotation rates at the variant level for each category
    by comparing solution and concept summaries.

    Parameters:
    - solution_variant_summary (pd.DataFrame): Variant-level summaries for solution tables.
    - concept_variant_summary (pd.DataFrame): Variant-level summaries for intersection tables.

    Returns:
    - pd.DataFrame: A DataFrame containing annotation rates at the variant level for each category.
    """
    all_annotation_rates = []

    # Add necessary columns for categorization
    concept_variant_summary["category"] = concept_variant_summary["table"].str.extract(r"in_([a-zA-Z]+)").fillna("")
    concept_variant_summary["category_inheritance"] = concept_variant_summary["table"].str.extract(r"([a-zA-Z]+)_in").fillna("")
    solution_variant_summary["category"] = solution_variant_summary["table"].str.extract(r"([a-zA-Z]+)_solutions").fillna("")

    # Ensure both summaries have the required columns
    required_columns = {"table", "variant_id", "person_count", "allele_count_mean", "category"}
    if not required_columns.issubset(solution_variant_summary.columns):
        raise ValueError("Solution variant summary is missing required columns.")
    if not required_columns.issubset(concept_variant_summary.columns):
        raise ValueError("Concept variant summary is missing required columns.")

    # Process each category separately
    for con_category in concept_variant_summary["category"].unique():
        # Filter concept summaries for the current category
        concept_subset = concept_variant_summary[concept_variant_summary["category"] == con_category]

        for sol_category in solution_variant_summary["category"].unique():
            # Filter solution summaries for the current category
            sol_category_q = f"^{sol_category}"
            consol_category_q = f"^{sol_category}$"
            solution_subset = solution_variant_summary[solution_variant_summary["category"].str.contains(sol_category_q)]
            concept_subset_inh = concept_subset[concept_subset["category_inheritance"].str.contains(consol_category_q)]

            if concept_subset_inh.empty or solution_subset.empty:
                print(f"Skipping category '{con_category}' / '{sol_category}' because one of the subsets is empty.")
                continue

            # Debugging: Check subset details
            print(f"Processing category '{con_category}' / '{sol_category}'")

            # Merge the summaries for this category
            merged = pd.merge(
                concept_subset_inh,
                solution_subset[['variant_id', 'person_count', 'allele_count_mean']],
                on=["variant_id"],
                suffixes=("_concept", "_solution"),
                how="inner"
            )

            # Calculate annotation rates
            merged["person_annotation_rate"] = (
                merged["person_count_concept"] / merged["person_count_solution"]
            )
            merged["allele_annotation_rate"] = (
                merged["allele_count_mean_concept"] / merged["allele_count_mean_solution"]
            )

            # Handle divisions by zero or missing values
            merged["person_annotation_rate"] = merged["person_annotation_rate"].fillna(0).replace([float('inf')], 0)
            merged["allele_annotation_rate"] = merged["allele_annotation_rate"].fillna(0).replace([float('inf')], 0)

            # Append to the results list
            all_annotation_rates.append(merged)

    # Combine all results into a single DataFrame
    if all_annotation_rates:
        final_annotation_rates = pd.concat(all_annotation_rates, axis=0, ignore_index=True)
    else:
        final_annotation_rates = pd.DataFrame()  # Return an empty DataFrame if no data

    return final_annotation_rates
def saveToBucket(df, df_filename, data_folder):
    df.to_csv(df_filename, sep = "\t", index=False)

    # get the bucket name
    my_bucket = os.getenv('WORKSPACE_BUCKET')

    # copy csv file to the bucket
    args = ["gsutil", "cp", f"./{df_filename}", f"{my_bucket}/data/{data_folder}/"]
    output = subprocess.run(args, capture_output=True)

    # print output from gsutil
    output.stderr
def concatenate_variant_columns(df, chrom_col='CHROM', pos_col='POS', ref_col='REF', alt_col='ALT', new_col='Variant'):
    """
    Concatenates CHROM, POS, REF, and ALT columns into a single Variant column.

    Parameters:
    - df (pd.DataFrame): The input DataFrame containing the columns to concatenate.
    - chrom_col (str): Name of the chromosome column. Default is 'CHROM'.
    - pos_col (str): Name of the position column. Default is 'POS'.
    - ref_col (str): Name of the reference allele column. Default is 'REF'.
    - alt_col (str): Name of the alternate allele column. Default is 'ALT'.
    - new_col (str): Name of the new concatenated column. Default is 'Variant'.

    Returns:
    - pd.DataFrame: The DataFrame with the new concatenated column.
    """
    # Ensure the required columns exist
    required_cols = [chrom_col, pos_col, ref_col, alt_col]
    for col in required_cols:
        if col not in df.columns:
            raise ValueError(f"Column '{col}' not found in the DataFrame.")

    # Convert all columns to string to ensure proper concatenation
    df[new_col] = df[chrom_col].astype(str) + '-' + \
                 df[pos_col].astype(str) + '-' + \
                 df[ref_col].astype(str) + '-' + \
                 df[alt_col].astype(str)

    return df
def split_variant_column(df, variant_col='Variant', chrom_col='CHROM', pos_col='POS', ref_col='REF', alt_col='ALT', sep='-'):
    """
    Splits a Variant column into CHROM, POS, REF, and ALT columns.

    Parameters:
    - df (pd.DataFrame): The input DataFrame containing the Variant column.
    - variant_col (str): Name of the concatenated Variant column. Default is 'Variant'.
    - chrom_col (str): Name of the chromosome column to create. Default is 'CHROM'.
    - pos_col (str): Name of the position column to create. Default is 'POS'.
    - ref_col (str): Name of the reference allele column to create. Default is 'REF'.
    - alt_col (str): Name of the alternate allele column to create. Default is 'ALT'.
    - sep (str): Separator used in the Variant column. Default is '-'.

    Returns:
    - pd.DataFrame: The DataFrame with the split columns.
    """
    if variant_col not in df.columns:
        raise ValueError(f"Column '{variant_col}' not found in the DataFrame.")

    # Split the Variant column into multiple columns
    split_cols = df[variant_col].str.split(sep, expand=True)

    if split_cols.shape[1] != 4:
        raise ValueError(f"Expected 4 components in the Variant column separated by '{sep}', but got {split_cols.shape[1]}.")

    # Assign the split columns to respective new columns
    df[chrom_col] = split_cols[0]
    df[pos_col] = split_cols[1]
    df[ref_col] = split_cols[2]
    df[alt_col] = split_cols[3]

    return df
def annotate_variants(
    anno_vars_table: pd.DataFrame,
    clinvar_query_table: pd.DataFrame,
    VAT_query_table: pd.DataFrame,
    remove_list: list = None,
    filter_canonical: bool = True,
    drop_original_variant_columns: bool = True
) -> pd.DataFrame:
    """
    Annotate a table of genetic variants with ClinVar and VAT annotations.

    Parameters:
    - anno_vars_table (pd.DataFrame): DataFrame containing variants to annotate. Must include 'variant_id'.
    - clinvar_query_table (pd.DataFrame): DataFrame containing ClinVar annotations.
    - VAT_query_table (pd.DataFrame): DataFrame containing VAT annotations.
    - remove_list (list, optional): List of transcript consequences to exclude. Defaults to None.
    - filter_canonical (bool, optional): If True, filter VAT to canonical transcripts. Defaults to True.
    - drop_original_variant_columns (bool, optional): If True, drop 'VARIANT_forQuery' and 'vid' after merging. Defaults to True.

    Returns:
    - pd.DataFrame: Annotated DataFrame with ClinVar and VAT information.
    """
    
    # Validate input DataFrames
    required_clinvar_cols = {'VARIANT_forQuery', 'CLNHGVS', 'CLNREVSTAT', 
                             'CLNSIG', 'MC', 'Type', 'Name', 'RCVaccession'}
    if not required_clinvar_cols.issubset(clinvar_query_table.columns):
        missing = required_clinvar_cols - set(clinvar_query_table.columns)
        raise ValueError(f"ClinVar query table is missing columns: {missing}")
    
    required_VAT_cols = {'vid', 'gvs_all_af', 'transcript_source', 'transcript', 
                        'dna_change_in_transcript', 'aa_change', 'consequence',
                        'variant_type', 'exon_number', 'intron_number', 'gene_id',
                        'revel', 'splice_ai_acceptor_gain_score', 
                        'splice_ai_acceptor_loss_score',
                        'splice_ai_donor_gain_score', 
                        'splice_ai_donor_loss_score',
                        'omim_phenotypes_id', 'omim_phenotypes_name',
                        'clinvar_classification', 'entrezgene', 'is_canonical_transcript'}
    if not required_VAT_cols.issubset(VAT_query_table.columns):
        missing = required_VAT_cols - set(VAT_query_table.columns)
        raise ValueError(f"VAT query table is missing columns: {missing}")
    
    # Validate anno_vars_table
    if 'variant_id' not in anno_vars_table.columns:
        raise ValueError("anno_vars_table must contain a 'variant_id' column.")
    
    # Step 1: Merge with ClinVar Query Table
    # Select necessary columns from ClinVar
    clinvar_cols_to_merge = ['VARIANT_forQuery', 'CLNHGVS', 'CLNREVSTAT', 'ID',
                             'CLNSIG', 'MC', 'Type', 'Name', 'RCVaccession']
    
    # Perform the merge
    merged_df = anno_vars_table.merge(
        clinvar_query_table[clinvar_cols_to_merge],
        how='left',
        left_on='variant_id',
        right_on='VARIANT_forQuery'
    )
    
    # Drop 'VARIANT_forQuery' if required
    if drop_original_variant_columns:
        merged_df.drop(columns=['VARIANT_forQuery'], inplace=True)
    
    # Step 2: Filter VAT Query Table
    VAT_query_table_filtered = VAT_query_table.copy()
    
    if filter_canonical:
        # Filter to canonical transcripts
        if 'is_canonical_transcript' not in VAT_query_table_filtered.columns:
            raise ValueError("VAT_query_table must contain 'is_canonical_transcript' column for filtering.")
        VAT_query_table_filtered = VAT_query_table_filtered[VAT_query_table_filtered["is_canonical_transcript"] == True]
    elif remove_list is not None:
        # Remove specified transcript consequences
        VAT_query_table_filtered = VAT_query_table_filtered[~VAT_query_table_filtered["consequence"].isin(remove_list)]
    
    # Step 3: Merge with VAT Query Table
    # Select necessary columns from VAT
    VAT_cols_to_merge = [
        'vid', 'gvs_all_af', 'transcript_source', 'transcript', 
        'dna_change_in_transcript', 'aa_change', 'consequence',
        'variant_type', 'exon_number', 'intron_number', 'gene_id',
        'revel', 'splice_ai_acceptor_gain_score', 
        'splice_ai_acceptor_loss_score',
        'splice_ai_donor_gain_score', 
        'splice_ai_donor_loss_score',
        'omim_phenotypes_id', 'omim_phenotypes_name',
        'clinvar_classification', 'entrezgene'
    ]
    
    merged_df_vat = merged_df.merge(
        VAT_query_table_filtered[VAT_cols_to_merge],
        how='left',
        left_on='variant_id',
        right_on='vid'
    )
    
    # Drop 'vid' if required
    if drop_original_variant_columns:
        merged_df_vat.drop(columns=['vid'], inplace=True)
    
    # Optional: Handle multiple VAT annotations per variant_id
    # Depending on the data structure, you might want to aggregate or handle duplicates
    
    return merged_df_vat
def processConsentDate(tbl):
    #gets consent date for each participant in any table with person_id and calculates age at consent
    tbl = annoConsentDate(tbl)
    tbl['birth_datetime'] = pd.to_datetime(tbl['birth_datetime'])
    tbl['primary_consent_date'] = pd.to_datetime(tbl['primary_consent_date'])
    tbl['primary_consent_date'] = tbl['primary_consent_date'].dt.tz_localize('UTC')
    tbl = calcAgeAtConsent(tbl)
    return tbl

In [None]:
#get concept_person tables
retdys_concept_person = getTable(table_name = "retchordys_ICD_person_age_youngest.tsv", folder = "personID_concept")
retdegen_concept_person = getTable(table_name = "retdegen_ICD_person_age_youngest.tsv", folder = "personID_concept")
screenretdys_concept_person = getTable(table_name = "ScreenRetDysICD_person_age_youngest.tsv",  folder = "personID_concept")

AMDspecific_concept_person = getTable(table_name = "AMD_specific_ICD_person_age_youngest.tsv",  folder = "personID_concept")
AMDexudative_concept_person = getTable(table_name = "AMD_exudative_ICD_person_age_youngest.tsv",  folder = "personID_concept")
AMDnonexudative_concept_person = getTable(table_name = "AMD_nonexudative_ICD_person_age_youngest.tsv",  folder = "personID_concept")

CME_concept_person = getTable(table_name = "CME_ICD_person_age_youngest.tsv",  folder = "personID_concept")

myopia_concept_person = getTable(table_name = "myopia_ICD_person_age_youngest.tsv",  folder = "personID_concept")
hypermetropia_concept_person = getTable(table_name = "hypermetropia_ICD_person_age_youngest.tsv",  folder = "personID_concept")
pucker_concept_person = getTable(table_name = "pucker_ICD_person_age_youngest.tsv",  folder = "personID_concept")

#get concept icd codes
retchordys_ICD = getTable("retchordys_ICD.tsv", "personID_concept")
retdegen_ICD = getTable(table_name = "retdegen_ICD.tsv", folder = "personID_concept")
screenretdys_ICD = getTable(table_name = "ScreenRetDysICD.tsv",  folder = "personID_concept")

AMD_ICD = getTable("AMD_specific_ICD_codes.tsv", "Concept_Sets_ICDcodes")
CME_ICD = getTable("CME_ICD_codes.tsv", "Concept_Sets_ICDcodes")


pucker_ICD = getTable("pucker_ICD.tsv", "personID_concept")
myopia_ICD = getTable("myopia_ICD_codes.tsv", "Concept_Sets_ICDcodes")
hypermetropia_ICD = getTable("hypermetropia_ICD_codes.tsv", "Concept_Sets_ICDcodes")

In [None]:
#Upset plots of ICD codes in each code set

import matplotlib.pyplot as plt
from upsetplot import from_contents

concepts = from_contents(
    {#"ALLICD": list(set(conceptids_list_ALLICD)), 
     "IRD": set(retchordys_ICD.ICD_conceptID.to_list()), 
     "Retinopathy": set(retdegen_ICD.ICD_conceptID.to_list()),
     "Screening Set" : set(screenretdys_ICD.ICD_conceptID.to_list())
     #"AMD_All": AMDspecific_concept_person_pid,
     #"AMD_Ex": AMDexudative_concept_person_pid,
     #"AMD_NonEx" : AMDnonexudative_concept_person_pid,
     #"CME": CME_concept_person_pid,
     #"Pucker": pucker_concept_person_pid,
     #"RSS": concept_ids_list_Retinoschisis_ICD,
     #"Myopia": myopia_concept_person_pid,
     #"Hypermetropia": hypermetropia_concept_person_pid,
    }
)
#concepts
from upsetplot import UpSet
ax_dict = UpSet(concepts, subset_size="count", show_counts = True, min_subset_size=None).plot()
plt.savefig("upsetplot_nested_sets_codes.png", dpi=300, bbox_inches="tight")

#alter as necessary for desired upset plot

In [None]:
#Upset plots of person_IDs in each code set


from upsetplot import from_contents

concepts = from_contents(
    {#"ALLICD": list(set(conceptids_list_ALLICD)), 
     "IRD": retdys_concept_person_pid, 
     "Retinopathy": retdegen_concept_person_pid,
     "Screening Set" : screenretdys_concept_person_pid
     #"AMD_All": AMDspecific_concept_person_pid,
     #"AMD_Ex": AMDexudative_concept_person_pid,
     #"AMD_NonEx" : AMDnonexudative_concept_person_pid,
     #"CME": CME_concept_person_pid,
     #"Pucker": pucker_concept_person_pid,
     #"RSS": concept_ids_list_Retinoschisis_ICD,
     #"Myopia": myopia_concept_person_pid,
     #"Hypermetropia": hypermetropia_concept_person_pid,
    }
)
#concepts

ax_dict = UpSet(concepts, subset_size="count", show_counts = True, min_subset_size=None).plot()
plt.savefig("upsetplot_nested_pID.png", dpi=300, bbox_inches="tight")

#alter as necessary for desired upset plot