In [None]:
import pandas as pd
import os
import subprocess
from datetime import datetime
from matplotlib import pyplot
import matplotlib.pyplot as plt
import seaborn as sns

# get the bucket name
my_bucket = os.getenv('WORKSPACE_BUCKET')

# Get the BigQuery curated dataset for the current workspace context.
CDR = os.environ['WORKSPACE_CDR']

from google.cloud import bigquery
# Instantiate a BigQuery client
client = bigquery.Client()
#!pip install upsetplot #if necessary

from tqdm import tqdm

pd.set_option('display.max_columns', 7000)
pd.set_option('display.max_row', 7000)

In [None]:
def getTable(table_name, folder):
    my_bucket = os.getenv('WORKSPACE_BUCKET')

    # copy csv file from the bucket to the current working space
    os.system(f"gsutil cp '{my_bucket}/data/{folder}/{table_name}' .")

    print(f'[INFO] {table_name} is successfully downloaded into your working space')
    # save dataframe in a csv file in the same workspace as the notebook
    table_read = pd.read_csv(table_name, sep="\t")
    return table_read
def getFile(table_name, folder):
    my_bucket = os.getenv('WORKSPACE_BUCKET')

    # copy csv file from the bucket to the current working space
    os.system(f"gsutil cp '{my_bucket}/data/{folder}/{table_name}' .")

    print(f'[INFO] {table_name} is successfully downloaded into your working space')
def describe_tables(**tables):
    # Initialize a dictionary to store results
    descriptions = {}
    
    
    i=0
    # Loop through each table with its name
    for table_name, table in tables.items():
        
        table_unique = table.drop_duplicates("person_id", keep = "first")
        
        desc = table_unique["AgeAtVisit_Years"].describe()

        # Calculate male and female counts and mf_ratio
        male_count = table_unique[table_unique["sex_at_birth_source_value"].str.contains("SexAtBirth_Male", na=False)].shape[0]
        female_count = table_unique[table_unique["sex_at_birth_source_value"].str.contains("SexAtBirth_Female", na=False)].shape[0]
        mf_ratio = male_count / (male_count + female_count) if (male_count + female_count) > 0 else None

        # Add the calculated values to the description as new rows
        desc["male_count"] = male_count
        desc["female_count"] = female_count
        desc["mf_ratio"] = mf_ratio
        
        #calculate solution annotated rate
        

        # Store the updated description in the dictionary
        descriptions[table_name] = desc
        i+=1

    # Convert the dictionary of descriptions to a DataFrame
    result_df = pd.DataFrame(descriptions).T  # Transpose to have table names as rows
    result_df.index.name = "Table_Name"  # Set a meaningful index name
    result_df = result_df.reset_index()

    return result_df
def annoConsentDate(table):
    #annotate with consent dates

    # Get the list of person IDs from the first query.
    person_ids = table['person_id'].unique().tolist()
    person_ids_query = ','.join(map(str, person_ids))

    # Run the consent query using the person IDs from the first query.
    consent_query = f'''
    SELECT DISTINCT person_id, MAX(observation_date) AS primary_consent_date
    FROM `{CDR}.concept`
    JOIN `{CDR}.concept_ancestor` ON concept_id = ancestor_concept_id
    JOIN `{CDR}.observation` ON descendant_concept_id = observation_source_concept_id
    WHERE concept_name = 'Consent PII' 
      AND concept_class_id = 'Module'
      AND person_id IN ({person_ids_query})
    GROUP BY person_id
    '''
    consent_dates = pd.read_gbq(consent_query, progress_bar_type="tqdm_notebook")

    # Merge the results on person_id to incorporate the consent date.
    final_result = table.merge(consent_dates, on='person_id', how='left')
    
    return final_result
def describe_solution_tables(**tables):
    # Initialize a dictionary to store results
    descriptions = {}
    
    
    i=0
    # Loop through each table with its name
    for table_name, table in tables.items():
        
        table_unique = table.drop_duplicates("person_id", keep = "first")
        
        #table_unique = calcAgeAtConsent(table_unique)
        
        desc = table_unique["AgeAtConsent_Years"].describe()

        # Calculate male and female counts and mf_ratio
        male_count = table_unique[table_unique["sex_at_birth_source_value"].str.contains("SexAtBirth_Male", na=False)].shape[0]
        female_count = table_unique[table_unique["sex_at_birth_source_value"].str.contains("SexAtBirth_Female", na=False)].shape[0]
        mf_ratio = male_count / (male_count + female_count) if (male_count + female_count) > 0 else None

        # Add the calculated values to the description as new rows
        desc["male_count"] = male_count
        desc["female_count"] = female_count
        desc["mf_ratio"] = mf_ratio
        
        #calculate solution annotated rate
        

        # Store the updated description in the dictionary
        descriptions[table_name] = desc
        i+=1

    # Convert the dictionary of descriptions to a DataFrame
    result_df = pd.DataFrame(descriptions).T  # Transpose to have table names as rows
    result_df.index.name = "Table_Name"  # Set a meaningful index name
    result_df = result_df.reset_index()

    return result_df
def calcAgeToday(solution_person_table):
    # Create a copy of the input DataFrame to avoid modifying the original
    solution_person_table = solution_person_table.copy()

    # Ensure 'birth_datetime' is in datetime format and make it timezone-naive
    solution_person_table["birth_datetime"] = pd.to_datetime(
        solution_person_table["birth_datetime"], errors="coerce"
    ).dt.tz_localize(None)

    # Check for invalid or missing dates
    if solution_person_table["birth_datetime"].isnull().any():
        print("Warning: Missing or invalid birth_datetime values detected.")
        print(solution_person_table[solution_person_table["birth_datetime"].isnull()])

    # Drop rows with invalid or missing 'birth_datetime'
    solution_person_table = solution_person_table.dropna(subset=["birth_datetime"])

    # Get today's date as timezone-naive
    today = datetime.now()

    # Calculate the difference (result is a timedelta Series)
    solution_person_table["AgeToday"] = today - solution_person_table["birth_datetime"]

    # Ensure 'AgeToday' is a timedelta type
    if not pd.api.types.is_timedelta64_dtype(solution_person_table["AgeToday"]):
        raise ValueError("AgeToday is not a valid timedelta64 dtype.")

    # Convert the timedelta to years
    solution_person_table["AgeToday_Years"] = (
        solution_person_table["AgeToday"].dt.days / 365.25
    )

    return solution_person_table
def calcAgeAtConsent(ICD_person_table):
    # Calculate the difference (result is a timedelta Series)
    ICD_person_table["AgeAtConsent"] = (
        ICD_person_table["primary_consent_date"] - ICD_person_table["birth_datetime"]
    )

    # Convert the timedelta to years
    ICD_person_table["AgeAtConsent_Years"] = (
        ICD_person_table["AgeAtConsent"].dt.days / 365.25
    )
    
    return ICD_person_table
def calc_solved_rate(summary_df2):

    concept_solved_n = summary_df2[~summary_df2['Table_Name'].str.contains("concept|female", case=False)]["count"].sum()
    summary_df2 = summary_df2.reset_index()
    #print(summary_df2)
    #print(concept_solved_n)

    concept_n = summary_df2.loc[0,"count"]
    #print(concept_n)
    concept_solved_rate = concept_solved_n / concept_n

    summary_df2.loc[0,"solved_rate"] = concept_solved_rate

    subset_names = ["AD", "XLmale", "XLfemale", "AR", "ADAR"]
    k=1
    for i in subset_names:
        i_reg = f"^{i}_"
        i_tablename = f"{i}_solutions"

        #print(f"Table Name: {i_tablename}")

        # Access and print the DataFrame dynamically
        if i_tablename in globals():
            tbl = globals()[i_tablename]
            tot_sol = tbl["person_id"].nunique()
            sol_anno = summary_df2[summary_df2["Table_Name"].str.contains(i_reg, case=False)]["count"].sum()
            sol_anno_rate = sol_anno / tot_sol
            summary_df2.loc[k,"solved_rate"] = sol_anno_rate
        else:
            print(f"{i_tablename} does not exist.")
        k+=1

    return summary_df2 
def calc_solved_rate_within(summary_df2):

    concept_solved_n = summary_df2[~summary_df2['Table_Name'].str.contains("concept|female", case=False)]["count"].sum()
    summary_df2 = summary_df2.reset_index()
    #print(summary_df2)
    #print(concept_solved_n)

    concept_n = summary_df2.loc[0,"count"]
    #print(concept_n)
    concept_solved_rate = concept_solved_n / concept_n

    summary_df2.loc[0,"solved_rate"] = concept_solved_rate

    subset_names = ["AD", "XLmale", "XLfemale", "AR", "ADAR"]
    k=1
    for i in subset_names:
        i_reg = f"^{i}_"
        i_tablename = f"{i}_solutions"

        #print(f"Table Name: {i_tablename}")

        # Access and print the DataFrame dynamically
        if i_tablename in globals():
            tbl = globals()[i_tablename]
            tot_sol = tbl["person_id"].nunique()
            sol_anno = summary_df2[summary_df2["Table_Name"].str.contains(i_reg, case=False)]["count"].sum()
            sol_anno_rate = sol_anno / concept_n
            summary_df2.loc[k,"solved_rate"] = sol_anno_rate
        else:
            print(f"{i_tablename} does not exist.")
        k+=1

    return summary_df2 
def generate_merges(base_name, concept_person_table, solution_dict):
    """
    Dynamically generate merged tables, add them to the global namespace, and return a list of table names.

    Parameters:
    - base_name (str): The base name to replace "retdegen" in the generated table names.
    - concept_person_table (pd.DataFrame): The concept-person DataFrame to merge on.
    - solution_dict (dict): A dictionary where keys are solution table names (e.g., "AD_solutions")
                            and values are the corresponding DataFrames.

    Returns:
    - list: A list of dynamically generated table names.
    """
    # Initialize the list to store table names
    table_names = []
    
    # Remove conflicting global variables
    #for key in list(globals().keys()):
    #    if key.endswith("_in_retdegen") or key == "retdegen_concept_person":
    #        print(f"Removing stale variable: {key}")
    #        del globals()[key]

    for solution_name, solution_table in solution_dict.items():
        # Construct the target table name dynamically
        target_table_name = f"{solution_name.split('_')[0]}_in_{base_name}"
        table_names.append(target_table_name)  # Add the table name to the list
        

        # Perform the merge
        merged_table = pd.merge(
            solution_table[['variant_id', 'person_id', 'allele_count', 'GeneID_EG', 'GeneID_Symbol']],
            concept_person_table,
            on="person_id",
            how="left"
        )

        # Assign the result to the global namespace
        globals()[target_table_name] = merged_table
        print(f"Generated table: {target_table_name}")

    # Add the concept-person table to the list of table names
    concept_table_name = f"{base_name}_concept_person"
    table_names.insert(0, concept_table_name)  # Ensure it appears first in the list
    globals()[concept_table_name] = concept_person_table  # Assign it to the global namespace
    #print(f"Generated table: {concept_table_name}")

    return table_names
def process_concept_and_solution_tables(concept_tables, solution_tables):
    """
    Loops through each concept_table and solution_table, generates merged tables,
    calculates summaries, and combines all summaries into one big table.

    Parameters:
    - concept_tables (dict): Dictionary of concept tables (key: table name, value: DataFrame).
    - solution_tables (dict): Dictionary of solution tables (key: table name, value: DataFrame).

    Returns:
    - pd.DataFrame: A single merged summary table containing results for all concept and solution tables.
    """
    all_summaries = []

    # Loop through each concept table
    for concept_name, concept_table in concept_tables.items():
        base_name = concept_name.split("_")[0]  # Extract base name (e.g., "retdys")

        # Generate merged tables for this concept table
        merged_table_names = generate_merges(base_name, concept_table, solution_tables)

        # Retrieve the generated tables from the global namespace
        merged_tables = {name: globals()[name] for name in merged_table_names}

        # Describe the merged tables
        summary_df = describe_tables(**merged_tables)

        # Calculate solved rates
        solved_summary = calc_solved_rate(summary_df)

        # Add a column to identify the concept
        solved_summary["concept"] = concept_name

        # Append the summary to the list
        all_summaries.append(solved_summary)

    # Combine all summaries into one big table
    final_summary_table = pd.concat(all_summaries, axis=0)

    return final_summary_table
def process_concept_and_solution_tables_within(concept_tables, solution_tables):
    """
    Loops through each concept_table and solution_table, generates merged tables,
    calculates summaries, and combines all summaries into one big table.

    Parameters:
    - concept_tables (dict): Dictionary of concept tables (key: table name, value: DataFrame).
    - solution_tables (dict): Dictionary of solution tables (key: table name, value: DataFrame).

    Returns:
    - pd.DataFrame: A single merged summary table containing results for all concept and solution tables.
    """
    all_summaries = []

    # Loop through each concept table
    for concept_name, concept_table in concept_tables.items():
        base_name = concept_name.split("_")[0]  # Extract base name (e.g., "retdys")

        # Generate merged tables for this concept table
        merged_table_names = generate_merges(base_name, concept_table, solution_tables)

        # Retrieve the generated tables from the global namespace
        merged_tables = {name: globals()[name] for name in merged_table_names}

        # Describe the merged tables
        summary_df = describe_tables(**merged_tables)

        # Calculate solved rates
        solved_summary = calc_solved_rate_within(summary_df)

        # Add a column to identify the concept
        solved_summary["concept"] = concept_name

        # Append the summary to the list
        all_summaries.append(solved_summary)

    # Combine all summaries into one big table
    final_summary_table = pd.concat(all_summaries, axis=0)

    return final_summary_table
def process_concept_solution_gene_summaries(concept_tables, solution_tables):
    """
    Loops through each concept_table and solution_table, generates merged tables,
    calculates gene-level summaries, and combines all summaries into one big table.

    Parameters:
    - concept_tables (dict): Dictionary of concept tables (key: table name, value: DataFrame).
    - solution_tables (dict): Dictionary of solution tables (key: table name, value: DataFrame).

    Returns:
    - pd.DataFrame: A single merged summary table containing results for all concepts, solutions, and genes.
    """
    all_gene_summaries = []

    # Loop through each concept table
    for concept_name, concept_table in concept_tables.items():
        base_name = concept_name.split("_")[0]  # Extract base name (e.g., "retdys")

        # Generate merged tables for this concept table
        merged_table_names = generate_merges(base_name, concept_table, solution_tables)

        # Retrieve the generated tables from the global namespace
        merged_tables = {name: globals()[name] for name in merged_table_names}

        # Loop through each merged table
        for table_name, merged_table in merged_tables.items():
            # Ensure 'GeneID_Symbol' exists in the merged table
            if "GeneID_Symbol" not in merged_table.columns:
                print(f"Warning: 'GeneID_Symbol' column not found in {table_name}. Skipping...")
                continue

            # Group by GeneID_Symbol and describe each group
            gene_groups = merged_table.groupby("GeneID_Symbol")


            for gene, group in gene_groups:
                # Generate summaries for the current gene
                
                # find all NaN entries in condition i.e. unannotated
                nan_mask = group.condition_source_value.isnull() # or df.isna()
                
                summary = {
                    "concept": concept_name,
                    "table": table_name,
                    "gene": gene,
                    "variant_count":  group[~nan_mask].variant_id.nunique(),
                    "person_count": group[~nan_mask].person_id.nunique()
                    #"allele_count_mean": group["allele_count"].mean(),
                    #"allele_count_std": group["allele_count"].std(),
                    #"allele_count_min": group["allele_count"].min(),
                    #"allele_count_max": group["allele_count"].max(),
                }
                # Append the summary to the list
                all_gene_summaries.append(summary)

    # Convert the list of summaries into a DataFrame
    gene_summary_table = pd.DataFrame(all_gene_summaries)
    
    gene_summary_table['solution_table'] = gene_summary_table['table'].str.split('_').str[0]


    return gene_summary_table
def process_solution_gene_summaries(solution_tables):
    """
    Generates gene-level summaries directly from solution tables.

    Parameters:
    - solution_tables (dict): Dictionary of solution tables (key: table name, value: DataFrame).

    Returns:
    - pd.DataFrame: A DataFrame containing gene-level summaries for all solution tables.
    """
    all_gene_summaries = []

    # Loop through each solution table
    for table_name, solution_table in solution_tables.items():
        # Ensure 'GeneID_Symbol' exists in the solution table
        if "GeneID_Symbol" not in solution_table.columns:
            print(f"Warning: 'GeneID_Symbol' column not found in {table_name}. Skipping...")
            continue

        # Group by GeneID_Symbol and describe each group
        gene_groups = solution_table.groupby("GeneID_Symbol")

        for gene, group in gene_groups:
            # Generate summaries for the current gene
            summary = {
                "table": table_name,
                "gene": gene,
                "variant_count": group["variant_id"].nunique(),
                "person_count": group["person_id"].nunique(),
                "allele_count_mean": group["allele_count"].mean(),
                "allele_count_std": group["allele_count"].std(),
                "allele_count_min": group["allele_count"].min(),
                "allele_count_max": group["allele_count"].max(),
            }
            # Append the summary to the list
            all_gene_summaries.append(summary)

    # Convert the list of summaries into a DataFrame
    gene_summary_table = pd.DataFrame(all_gene_summaries)

    return gene_summary_table
def calculate_annotation_rates(solution_gene_summary, concept_gene_summary):
    """
    Calculates annotation rates at the gene and variant levels by comparing solution summaries
    and concept-intersection summaries.

    Parameters:
    - solution_gene_summary (pd.DataFrame): Gene-level summaries for solution tables 
                                             (output from process_solution_gene_summaries).
    - concept_gene_summary (pd.DataFrame): Gene-level summaries for intersection tables 
                                           (output from process_concept_solution_gene_summaries).

    Returns:
    - pd.DataFrame: A DataFrame containing annotation rates at the gene and variant levels.
    """
    # Merge the solution and concept summaries on the common columns
    merged = pd.merge(
        concept_gene_summary,
        solution_gene_summary,
        on=["table", "gene"],
        suffixes=("_concept", "_solution"),
        how="inner"
    )

    # Calculate annotation rates
    merged["person_annotation_rate"] = (
        merged["person_count_concept"] / merged["person_count_solution"]
    )
    merged["variant_annotation_rate"] = (
        merged["variant_count_concept"] / merged["variant_count_solution"]
    )

    # Handle divisions by zero or missing values
    merged["person_annotation_rate"] = merged["person_annotation_rate"].fillna(0).replace([float('inf')], 0)
    merged["variant_annotation_rate"] = merged["variant_annotation_rate"].fillna(0).replace([float('inf')], 0)

    # Return the annotated DataFrame
    return merged
def calculate_all_annotation_rates(solution_gene_summary, concept_gene_summary):
    """
    Calculates annotation rates at the gene and variant levels for each category
    by comparing solution and concept summaries.

    Parameters:
    - solution_gene_summary (pd.DataFrame): Gene-level summaries for solution tables.
    - concept_gene_summary (pd.DataFrame): Gene-level summaries for intersection tables.

    Returns:
    - pd.DataFrame: A DataFrame containing annotation rates at the gene and variant levels for each category.
    """
    all_annotation_rates = []
    
    # Make necessary columns
    concept_gene_summary["category"] = concept_gene_summary["table"].str.extract(r"in_([a-zA-Z]+)").fillna("")
    concept_gene_summary["category_inheritance"] = concept_gene_summary["table"].str.extract(r"([a-zA-Z]+)_in").fillna("")
    solution_gene_summary["category"] = solution_gene_summary["table"].str.extract(r"([a-zA-Z]+)_solutions").fillna("")


    # Ensure both summaries have the necessary columns
    required_columns = {"table", "gene", "person_count", "variant_count", "category"}
    if not required_columns.issubset(solution_gene_summary.columns):
        raise ValueError("Solution gene summary is missing required columns.")
    if not required_columns.issubset(concept_gene_summary.columns):
        raise ValueError("Concept gene summary is missing required columns.")

    # Process each category separately
    for con_category in concept_gene_summary["category"].unique():
        # Filter concept summaries for the current category
        concept_subset = concept_gene_summary[concept_gene_summary["category"] == con_category]
        
        for sol_category in solution_gene_summary["category"].unique():
            # Filter solution summaries for the current category
            sol_category_q = f"^{sol_category}"
            consol_category_q = f"^{sol_category}$"
            solution_subset = solution_gene_summary[solution_gene_summary["category"].str.contains(sol_category_q)]
            concept_subset_inh = concept_subset[concept_subset["category_inheritance"].str.contains(consol_category_q)]

            if concept_subset_inh.empty or solution_subset.empty:
                print(f"Skipping category '{con_category}' / '{sol_category}' because one of the subsets is empty.")
                continue

            # Debugging: Check subset details
            print(f"Processing category '{con_category}' / '{sol_category}'")
            #print("Concept Subset:")
            #print(concept_subset_inh.head())

            # Merge the summaries for this category
            merged = pd.merge(
                concept_subset_inh,
                solution_subset[['variant_count', 'person_count', 'gene']],
                on=["gene"],
                suffixes=("_concept", "_solution"),
                how="inner"
            )

            # Calculate annotation rates
            merged["person_annotation_rate"] = (
                merged["person_count_concept"] / merged["person_count_solution"]
            )
            merged["variant_annotation_rate"] = (
                merged["variant_count_concept"] / merged["variant_count_solution"]
            )

            # Handle divisions by zero or missing values
            merged["person_annotation_rate"] = merged["person_annotation_rate"].fillna(0).replace([float('inf')], 0)
            merged["variant_annotation_rate"] = merged["variant_annotation_rate"].fillna(0).replace([float('inf')], 0)

            # Append to the results list
            all_annotation_rates.append(merged)

    # Combine all results into a single DataFrame
    if all_annotation_rates:
        final_annotation_rates = pd.concat(all_annotation_rates, axis=0, ignore_index=True)

    else:
        final_annotation_rates = pd.DataFrame()  # Return an empty DataFrame if no data
    
    return final_annotation_rates
def process_solution_variant_summaries(solution_tables):
    """
    Generates variant-level summaries directly from solution tables.

    Parameters:
    - solution_tables (dict): Dictionary of solution tables (key: table name, value: DataFrame).

    Returns:
    - pd.DataFrame: A DataFrame containing variant-level summaries for all solution tables.
    """
    all_variant_summaries = []

    # Loop through each solution table
    for table_name, solution_table in solution_tables.items():
        # Ensure 'variant_id' exists in the solution table
        if "variant_id" not in solution_table.columns:
            print(f"Warning: 'variant_id' column not found in {table_name}. Skipping...")
            continue

        # Group by variant_id and describe each group
        variant_groups = solution_table.groupby("variant_id")

        for variant, group in variant_groups:
            # Generate summaries for the current variant
            summary = {
                "table": table_name,
                "variant_id": variant,
                "gene": group["GeneID_Symbol"].iloc[0] if "GeneID_Symbol" in group.columns else None,
                "person_count": group["person_id"].nunique(),
                "allele_count_mean": group["allele_count"].mean(),
                "allele_count_std": group["allele_count"].std(),
                "allele_count_min": group["allele_count"].min(),
                "allele_count_max": group["allele_count"].max(),
            }
            # Append the summary to the list
            all_variant_summaries.append(summary)

    # Convert the list of summaries into a DataFrame
    variant_summary_table = pd.DataFrame(all_variant_summaries)

    return variant_summary_table
def process_concept_solution_variant_summaries(concept_tables, solution_tables):
    """
    Loops through each concept_table and solution_table, generates merged tables,
    calculates variant-level summaries, and combines all summaries into one big table.

    Parameters:
    - concept_tables (dict): Dictionary of concept tables (key: table name, value: DataFrame).
    - solution_tables (dict): Dictionary of solution tables (key: table name, value: DataFrame).

    Returns:
    - pd.DataFrame: A single merged summary table containing results for all concepts, solutions, and variants.
    """
    all_variant_summaries = []

    # Loop through each concept table
    for concept_name, concept_table in concept_tables.items():
        base_name = concept_name.split("_")[0]  # Extract base name (e.g., "retdys")

        # Generate merged tables for this concept table
        merged_table_names = generate_merges(base_name, concept_table, solution_tables)

        # Retrieve the generated tables from the global namespace
        merged_tables = {name: globals()[name] for name in merged_table_names}

        # Loop through each merged table
        for table_name, merged_table in merged_tables.items():
            # Ensure 'variant_id' exists in the merged table
            if "variant_id" not in merged_table.columns:
                print(f"Warning: 'variant_id' column not found in {table_name}. Skipping...")
                continue

            # Group by variant_id and describe each group
            variant_groups = merged_table.groupby("variant_id")

            for variant, group in variant_groups:
                # Generate summaries for the current variant
                
                # find all NaN entries in condition i.e. unannotated
                nan_mask = group.condition_source_value.isnull() # or df.isna()
                
                summary = {
                    "concept": concept_name,
                    "table": table_name,
                    "variant_id": variant,
                    "gene": group["GeneID_Symbol"].iloc[0] if "GeneID_Symbol" in group.columns else None,
                    "person_count": group[~nan_mask].person_id.nunique()
                    #"allele_count_mean": group["allele_count"].mean(),
                    #"allele_count_std": group["allele_count"].std(),
                    #"allele_count_min": group["allele_count"].min(),
                    #"allele_count_max": group["allele_count"].max(),
                }
                # Append the summary to the list
                all_variant_summaries.append(summary)

    # Convert the list of summaries into a DataFrame
    variant_summary_table = pd.DataFrame(all_variant_summaries)

    return variant_summary_table
def calculate_all_variant_annotation_rates(solution_variant_summary, concept_variant_summary):
    """
    Calculates annotation rates at the variant level for each category
    by comparing solution and concept summaries.

    Parameters:
    - solution_variant_summary (pd.DataFrame): Variant-level summaries for solution tables.
    - concept_variant_summary (pd.DataFrame): Variant-level summaries for intersection tables.

    Returns:
    - pd.DataFrame: A DataFrame containing annotation rates at the variant level for each category.
    """
    all_annotation_rates = []

    # Add necessary columns for categorization
    concept_variant_summary["category"] = concept_variant_summary["table"].str.extract(r"in_([a-zA-Z]+)").fillna("")
    concept_variant_summary["category_inheritance"] = concept_variant_summary["table"].str.extract(r"([a-zA-Z]+)_in").fillna("")
    solution_variant_summary["category"] = solution_variant_summary["table"].str.extract(r"([a-zA-Z]+)_solutions").fillna("")

    # Ensure both summaries have the required columns
    required_columns = {"table", "variant_id", "person_count", "allele_count_mean", "category"}
    if not required_columns.issubset(solution_variant_summary.columns):
        raise ValueError("Solution variant summary is missing required columns.")
    if not required_columns.issubset(concept_variant_summary.columns):
        raise ValueError("Concept variant summary is missing required columns.")

    # Process each category separately
    for con_category in concept_variant_summary["category"].unique():
        # Filter concept summaries for the current category
        concept_subset = concept_variant_summary[concept_variant_summary["category"] == con_category]

        for sol_category in solution_variant_summary["category"].unique():
            # Filter solution summaries for the current category
            sol_category_q = f"^{sol_category}"
            consol_category_q = f"^{sol_category}$"
            solution_subset = solution_variant_summary[solution_variant_summary["category"].str.contains(sol_category_q)]
            concept_subset_inh = concept_subset[concept_subset["category_inheritance"].str.contains(consol_category_q)]

            if concept_subset_inh.empty or solution_subset.empty:
                print(f"Skipping category '{con_category}' / '{sol_category}' because one of the subsets is empty.")
                continue

            # Debugging: Check subset details
            print(f"Processing category '{con_category}' / '{sol_category}'")

            # Merge the summaries for this category
            merged = pd.merge(
                concept_subset_inh,
                solution_subset[['variant_id', 'person_count', 'allele_count_mean']],
                on=["variant_id"],
                suffixes=("_concept", "_solution"),
                how="inner"
            )

            # Calculate annotation rates
            merged["person_annotation_rate"] = (
                merged["person_count_concept"] / merged["person_count_solution"]
            )
            merged["allele_annotation_rate"] = (
                merged["allele_count_mean_concept"] / merged["allele_count_mean_solution"]
            )

            # Handle divisions by zero or missing values
            merged["person_annotation_rate"] = merged["person_annotation_rate"].fillna(0).replace([float('inf')], 0)
            merged["allele_annotation_rate"] = merged["allele_annotation_rate"].fillna(0).replace([float('inf')], 0)

            # Append to the results list
            all_annotation_rates.append(merged)

    # Combine all results into a single DataFrame
    if all_annotation_rates:
        final_annotation_rates = pd.concat(all_annotation_rates, axis=0, ignore_index=True)
    else:
        final_annotation_rates = pd.DataFrame()  # Return an empty DataFrame if no data

    return final_annotation_rates
def saveToBucket(df, df_filename, data_folder):
    df.to_csv(df_filename, sep = "\t", index=False)

    # get the bucket name
    my_bucket = os.getenv('WORKSPACE_BUCKET')

    # copy csv file to the bucket
    args = ["gsutil", "cp", f"./{df_filename}", f"{my_bucket}/data/{data_folder}/"]
    output = subprocess.run(args, capture_output=True)

    # print output from gsutil
    output.stderr
def concatenate_variant_columns(df, chrom_col='CHROM', pos_col='POS', ref_col='REF', alt_col='ALT', new_col='Variant'):
    """
    Concatenates CHROM, POS, REF, and ALT columns into a single Variant column.

    Parameters:
    - df (pd.DataFrame): The input DataFrame containing the columns to concatenate.
    - chrom_col (str): Name of the chromosome column. Default is 'CHROM'.
    - pos_col (str): Name of the position column. Default is 'POS'.
    - ref_col (str): Name of the reference allele column. Default is 'REF'.
    - alt_col (str): Name of the alternate allele column. Default is 'ALT'.
    - new_col (str): Name of the new concatenated column. Default is 'Variant'.

    Returns:
    - pd.DataFrame: The DataFrame with the new concatenated column.
    """
    # Ensure the required columns exist
    required_cols = [chrom_col, pos_col, ref_col, alt_col]
    for col in required_cols:
        if col not in df.columns:
            raise ValueError(f"Column '{col}' not found in the DataFrame.")

    # Convert all columns to string to ensure proper concatenation
    df[new_col] = df[chrom_col].astype(str) + '-' + \
                 df[pos_col].astype(str) + '-' + \
                 df[ref_col].astype(str) + '-' + \
                 df[alt_col].astype(str)

    return df
def split_variant_column(df, variant_col='Variant', chrom_col='CHROM', pos_col='POS', ref_col='REF', alt_col='ALT', sep='-'):
    """
    Splits a Variant column into CHROM, POS, REF, and ALT columns.

    Parameters:
    - df (pd.DataFrame): The input DataFrame containing the Variant column.
    - variant_col (str): Name of the concatenated Variant column. Default is 'Variant'.
    - chrom_col (str): Name of the chromosome column to create. Default is 'CHROM'.
    - pos_col (str): Name of the position column to create. Default is 'POS'.
    - ref_col (str): Name of the reference allele column to create. Default is 'REF'.
    - alt_col (str): Name of the alternate allele column to create. Default is 'ALT'.
    - sep (str): Separator used in the Variant column. Default is '-'.

    Returns:
    - pd.DataFrame: The DataFrame with the split columns.
    """
    if variant_col not in df.columns:
        raise ValueError(f"Column '{variant_col}' not found in the DataFrame.")

    # Split the Variant column into multiple columns
    split_cols = df[variant_col].str.split(sep, expand=True)

    if split_cols.shape[1] != 4:
        raise ValueError(f"Expected 4 components in the Variant column separated by '{sep}', but got {split_cols.shape[1]}.")

    # Assign the split columns to respective new columns
    df[chrom_col] = split_cols[0]
    df[pos_col] = split_cols[1]
    df[ref_col] = split_cols[2]
    df[alt_col] = split_cols[3]

    return df
def annotate_variants(
    anno_vars_table: pd.DataFrame,
    clinvar_query_table: pd.DataFrame,
    VAT_query_table: pd.DataFrame,
    remove_list: list = None,
    filter_canonical: bool = True,
    drop_original_variant_columns: bool = True
) -> pd.DataFrame:
    """
    Annotate a table of genetic variants with ClinVar and VAT annotations.

    Parameters:
    - anno_vars_table (pd.DataFrame): DataFrame containing variants to annotate. Must include 'variant_id'.
    - clinvar_query_table (pd.DataFrame): DataFrame containing ClinVar annotations.
    - VAT_query_table (pd.DataFrame): DataFrame containing VAT annotations.
    - remove_list (list, optional): List of transcript consequences to exclude. Defaults to None.
    - filter_canonical (bool, optional): If True, filter VAT to canonical transcripts. Defaults to True.
    - drop_original_variant_columns (bool, optional): If True, drop 'VARIANT_forQuery' and 'vid' after merging. Defaults to True.

    Returns:
    - pd.DataFrame: Annotated DataFrame with ClinVar and VAT information.
    """
    
    # Validate input DataFrames
    required_clinvar_cols = {'VARIANT_forQuery', 'CLNHGVS', 'CLNREVSTAT', 
                             'CLNSIG', 'MC', 'Type', 'Name', 'RCVaccession'}
    if not required_clinvar_cols.issubset(clinvar_query_table.columns):
        missing = required_clinvar_cols - set(clinvar_query_table.columns)
        raise ValueError(f"ClinVar query table is missing columns: {missing}")
    
    required_VAT_cols = {'vid', 'gvs_all_af', 'transcript_source', 'transcript', 
                        'dna_change_in_transcript', 'aa_change', 'consequence',
                        'variant_type', 'exon_number', 'intron_number', 'gene_id',
                        'revel', 'splice_ai_acceptor_gain_score', 
                        'splice_ai_acceptor_loss_score',
                        'splice_ai_donor_gain_score', 
                        'splice_ai_donor_loss_score',
                        'omim_phenotypes_id', 'omim_phenotypes_name',
                        'clinvar_classification', 'entrezgene', 'is_canonical_transcript'}
    if not required_VAT_cols.issubset(VAT_query_table.columns):
        missing = required_VAT_cols - set(VAT_query_table.columns)
        raise ValueError(f"VAT query table is missing columns: {missing}")
    
    # Validate anno_vars_table
    if 'variant_id' not in anno_vars_table.columns:
        raise ValueError("anno_vars_table must contain a 'variant_id' column.")
    
    # Step 1: Merge with ClinVar Query Table
    # Select necessary columns from ClinVar
    clinvar_cols_to_merge = ['VARIANT_forQuery', 'CLNHGVS', 'CLNREVSTAT', 'ID',
                             'CLNSIG', 'MC', 'Type', 'Name', 'RCVaccession']
    
    # Perform the merge
    merged_df = anno_vars_table.merge(
        clinvar_query_table[clinvar_cols_to_merge],
        how='left',
        left_on='variant_id',
        right_on='VARIANT_forQuery'
    )
    
    # Drop 'VARIANT_forQuery' if required
    if drop_original_variant_columns:
        merged_df.drop(columns=['VARIANT_forQuery'], inplace=True)
    
    # Step 2: Filter VAT Query Table
    VAT_query_table_filtered = VAT_query_table.copy()
    
    if filter_canonical:
        # Filter to canonical transcripts
        if 'is_canonical_transcript' not in VAT_query_table_filtered.columns:
            raise ValueError("VAT_query_table must contain 'is_canonical_transcript' column for filtering.")
        VAT_query_table_filtered = VAT_query_table_filtered[VAT_query_table_filtered["is_canonical_transcript"] == True]
    elif remove_list is not None:
        # Remove specified transcript consequences
        VAT_query_table_filtered = VAT_query_table_filtered[~VAT_query_table_filtered["consequence"].isin(remove_list)]
    
    # Step 3: Merge with VAT Query Table
    # Select necessary columns from VAT
    VAT_cols_to_merge = [
        'vid', 'gvs_all_af', 'transcript_source', 'transcript', 
        'dna_change_in_transcript', 'aa_change', 'consequence',
        'variant_type', 'exon_number', 'intron_number', 'gene_id',
        'revel', 'splice_ai_acceptor_gain_score', 
        'splice_ai_acceptor_loss_score',
        'splice_ai_donor_gain_score', 
        'splice_ai_donor_loss_score',
        'omim_phenotypes_id', 'omim_phenotypes_name',
        'clinvar_classification', 'entrezgene'
    ]
    
    merged_df_vat = merged_df.merge(
        VAT_query_table_filtered[VAT_cols_to_merge],
        how='left',
        left_on='variant_id',
        right_on='vid'
    )
    
    # Drop 'vid' if required
    if drop_original_variant_columns:
        merged_df_vat.drop(columns=['vid'], inplace=True)
    
    # Optional: Handle multiple VAT annotations per variant_id
    # Depending on the data structure, you might want to aggregate or handle duplicates
    
    return merged_df_vat
def processConsentDate(tbl):
    #gets consent date for each participant in any table with person_id and calculates age at consent
    tbl = annoConsentDate(tbl)
    tbl['birth_datetime'] = pd.to_datetime(tbl['birth_datetime'])
    tbl['primary_consent_date'] = pd.to_datetime(tbl['primary_consent_date'])
    tbl['primary_consent_date'] = tbl['primary_consent_date'].dt.tz_localize('UTC')
    tbl = calcAgeAtConsent(tbl)
    return tbl
def procRace(df):
    df.loc[df['race_source_value'].isin(["AoUDRC_NoneIndicated",
                                   "PMI_Skip",
                                   "PMI_PreferNotToAnswer"]), 'race_source_value'] = "Unknown"


    df.loc[df['race_source_value'].isin(["WhatRaceEthnicity_GeneralizedMultPopulations",
                                   "WhatRaceEthnicity_RaceEthnicityNoneOfThese"]), 'race_source_value'] = "Other"
    
    return(df)

def makecontable_bygene(gene, 
                        anno_rates_table,
                        concept = "screenretdys_concept_person", 
                        pop_size = 317964
                        ):
    #make a 2x2 contingency table for each gene within each table"
    #concept = "retdegen_concept_person"
    con_table = get_global_var(concept)
    #sol_table_name = "full_table"
    #sol_table = get_global_var(sol_table_name)
    #anno_rates_table = get_global_var(anno_rates_table_name)


    con_sol_table = anno_rates_table[anno_rates_table["concept"]==concept]
    #genes = full_table["gene"].unique()
    #gene = genes[2]
   

    Sp = con_sol_table[con_sol_table.gene == gene]["person_count_solution"].item()
    Cp = con_table.shape[0]

    SpCp = con_sol_table[con_sol_table.gene == gene]["person_count_concept"].item()
    SpCn = Sp - SpCp

    Cn = pop_size - Cp
    Sn = pop_size - Sp
    SnCp = Cp - SpCp
    SnCn = Sn - SnCp

    table = np.array([[SpCp, SpCn],
                    [SnCp, SnCn]])
    
    return([table, pop_size, Cp, Sp, SpCp])

def loop_through_fisher_bygene(anno_rates_table, concept = "screenretdys_concept_person", alpha = 0.05):
    #produce fisher tests for specific concepts by gene
    results = []

    # make fisher test table 
    # concept_array = anno_rates_table["concept"].unique()
    anno_rates_table = anno_rates_table.sort_values(by = "person_count_solution", ascending = False)
    gene_array = anno_rates_table[anno_rates_table.concept == concept].gene.unique()
    

    
    # alpha_corrected = alpha / concept_array.shape[0] if you want corrected 95CI
    # better to adjust p-values directly
    
    n_tests = anno_rates_table.shape[0] #for bonf correction for all of the tests run
    n_tests = gene_array.shape[0] #for bonf correction for all of the tests run


    for gene in gene_array:
        table_stats = makecontable_bygene(gene = gene, 
                                          concept = concept, 
                                          anno_rates_table = anno_rates_table)
        fisher_table = table_stats[0]
        pop_size = table_stats[1]
        concept_size = table_stats[2]
        num_IRDgenotype = table_stats[3]
        num_annotated = table_stats[4]
        concept_prevalence = concept_size/pop_size
        DAR = num_annotated / num_IRDgenotype

        # Fisher's exact test (returns odds ratio and p-value)
        odds_ratio, p_value = fisher_exact(fisher_table)
        p_value = p_value*n_tests

        # Use Table2x2 from statsmodels to calculate the 95% confidence interval for the odds ratio
        table2x2 = Table2x2(fisher_table)
        conf_int = table2x2.oddsratio_confint(alpha=alpha)
        lower_ci, upper_ci = conf_int[0], conf_int[1]
        
        # calculate concept prevlaence and DAR
        concept_prevalence = table_stats[2]/table_stats[1]

        # Append results: you can also store the p_value or table if needed
        results.append({
            "Test": gene,
            "Concept": concept,
            "Odds_Ratio": odds_ratio,
            "Lower_CI": lower_ci,
            "Upper_CI": upper_ci,
            "p_value": p_value,
            "Pop_size" : pop_size,
            "num_IRDgenotype" : num_IRDgenotype,
            "Concept_size" : concept_size,
            "Num_annotated" : num_annotated,
            "Concept_prevalence" : concept_prevalence,
            "DAR": DAR
        })

    # Create a DataFrame from the results
    df_results = pd.DataFrame(results)
    return(df_results)


def get_global_var(var_name):
    # Retrieve the variable from the global namespace using the variable name (string)
    return globals().get(var_name)

def makecontable(anno_rates_table, concept, pop_size = 317964, sol_table_name = "ALL_IRD"):
    #make a 2x2 contingency table for given solution & concept"
    con_table = get_global_var(concept)
    sol_table = get_global_var(sol_table_name)
    con_sol_table = anno_rates_table[anno_rates_table["concept"]==concept]

    Sp = sol_table.shape[0]
    Cp = con_table.shape[0]

    SpCp = sum(con_sol_table.person_count_concept)
    SpCn = Sp - SpCp

    Cn = pop_size - Cp
    Sn = pop_size - Sp
    SnCp = Cp - SpCp
    SnCn = Sn - SnCp

    table = np.array([[SpCp, SpCn],
                    [SnCp, SnCn]])
    
    return([table, pop_size, Cp, Sp, SpCp])

##test = makecontable(concept = "retdegen_concept_person", anno_rates_table = annotation_rates_solution_gene)
##print(test[3])
def loop_through_fisher(anno_rates_table, alpha = 0.05):
    results = []

    # make fisher test table 
    concept_array = anno_rates_table["concept"].unique()
    # alpha_corrected = alpha / concept_array.shape[0] if you want corrected 95CI
    # better to adjust p-values directly
    
    n_tests = concept_array.shape[0]

    for concept in concept_array:
        table_stats = makecontable(concept = concept, anno_rates_table = anno_rates_table)
        fisher_table = table_stats[0]
        pop_size = table_stats[1]
        concept_size = table_stats[2]
        num_IRDgenotype = table_stats[3]
        num_annotated = table_stats[4]
        concept_prevalence = concept_size/pop_size
        DAR = num_annotated / num_IRDgenotype

        # Fisher's exact test (returns odds ratio and p-value)
        odds_ratio, p_value = fisher_exact(fisher_table)
        p_value = p_value*n_tests

        # Use Table2x2 from statsmodels to calculate the 95% confidence interval for the odds ratio
        table2x2 = Table2x2(fisher_table)
        conf_int = table2x2.oddsratio_confint(alpha=alpha)
        lower_ci, upper_ci = conf_int[0], conf_int[1]
        
        # calculate concept prevlaence and DAR
        concept_prevalence = table_stats[2]/table_stats[1]

        # Append results: you can also store the p_value or table if needed
        results.append({
            "Test": concept,
            "Odds_Ratio": odds_ratio,
            "Lower_CI": lower_ci,
            "Upper_CI": upper_ci,
            "p_value": p_value,
            "Pop_size" : pop_size,
            "num_IRDgenotype" : num_IRDgenotype,
            "Concept_size" : concept_size,
            "Num_annotated" : num_annotated,
            "Concept_prevalence" : concept_prevalence,
            "DAR": DAR
        })

    # Create a DataFrame from the results
    df_results = pd.DataFrame(results)
    return(df_results)

def plot_fishertable_with_arrows(
    df_fisher,
    plot_title="Forest Plot of Fisher Exact Test Results",
    plot_output_file="forest_plot.png",
    logbase=2,
    custom_labels=False,
    dpi=300,
    arrow_length=0.1,
    offset_OR = 1.4, 
    offset_Pval = 7.5,
    fig_width=10,         
    fig_height=None,
    x_max=None,
    x_min=None,
    pad = 0.5,
    cap=None
):
    # find finite bounds for axis limits
    finite_lowers = df_fisher["Lower_CI"].replace(0, np.nan).dropna()
    finite_uppers = df_fisher["Upper_CI"].replace(np.inf, np.nan).dropna()
    min_lower = finite_lowers.min()
    max_upper = finite_uppers.max()
    print(min_lower)
    print(max_upper)

    # set a tiny positive floor for log-axis
    # floor = min_lower / 100.0
    if x_min:
        floor = x_min
    else:
        floor = 0.5
    # set a cap for infinities
    if cap is None:
        cap = max_upper * 10

    # y positions
    n = len(df_fisher)
    if fig_height is None:
        fig_height = n * 0.5
    y_positions = np.arange(n)

    # start figure
    plt.figure(figsize=(fig_width, fig_height), dpi=dpi)
    ax = plt.gca()
    plt.xscale('log')
    plt.axvline(1, color='red', linestyle='--')

    # custom ticks
    start_exp = int(np.floor(np.log(floor) / np.log(logbase)))
    end_exp   = int(np.ceil(np.log(cap)   / np.log(logbase)))
    ticks = np.logspace(start_exp, end_exp, end_exp - start_exp + 1, base=logbase)
    # Generate tick positions on a power-of-2 scale.
    #ticks = np.logspace(start_exponent, end_exponent, num=end_exponent - start_exponent + 1, base=logbase)
    # Create tick labels directly from ticks (convert to int if desired, or leave as float)
    tick_labels = [str((t)) for t in ticks]
    
    ax.set_xticks(ticks)
    ax.set_xticklabels([f"{t:.2g}" for t in ticks])

    # plot each point + CI/arrows
    for y, row in zip(y_positions, df_fisher.itertuples()):
        OR = row.Odds_Ratio
        L  = row.Lower_CI
        U  = row.Upper_CI
        


        # clip for plotting point
        plot_L = max(L, floor) if L > 0 else floor
        if L > OR:
            plot_L=floor
        plot_U = min(U, cap)   if U < np.inf else cap

        # draw main error-bar line (no caps)
        ax.plot([plot_L, plot_U], [y, y], color='black', linewidth=1)

        # draw the central marker
        ax.plot(OR, y, 'o', color='black')

        # draw caps or arrows:
        #  - if L <= 0, draw left arrow at floor
        #print(L)
        #print(floor)
        #print(plot_L)
        if plot_L <= floor:
            ax.annotate(
                '', xy=(plot_L, y), xytext=(plot_L * (1+arrow_length), y),
                arrowprops=dict(arrowstyle='->', lw=1.5, color = "black")
            )
        else:
            ax.plot(plot_L, y, '|', color='black', markersize=7)

        #  - if U == inf, draw right arrow at cap
        if U == np.inf:
            ax.annotate(
                '', xy=(plot_U, y), xytext=(plot_U / (1 + arrow_length), y),
                arrowprops=dict(arrowstyle='->', lw=1.5, color = "black")
            )
        else:
            ax.plot(plot_U, y, '|', color='black', markersize=7)

    # labels
    plt.yticks(y_positions, df_fisher["Test"] if not custom_labels else custom_labels)
    ax.invert_yaxis()
    if x_max:
        ax.set_xlim(right=x_max)
    if x_min:
        ax.set_xlim(left=x_min)

    plt.xlabel("Odds Ratio")
    plt.title(plot_title, fontsize=10, va='bottom', fontweight = 'bold')

    # annotations of OR [CI] and p-values, etc
    #print(ticks)
    OR_text  = x_max * offset_OR
    Pval_text = x_max * offset_Pval
    header_y = -1
    plt.text(OR_text, header_y, "OR [95% CI]", ha='left', va='center', fontsize=10, fontweight='bold')
    plt.text(Pval_text, header_y, "P-value*", ha='left', va='center', fontsize=10, fontweight='bold')

    for y, row in zip(y_positions, df_fisher.itertuples()):
        if row.Odds_Ratio == 0:
            OR_annotation = f"{row.Odds_Ratio:.2f} [0-{row.Upper_CI:.2f}]"
        else:
            OR_annotation = f"{row.Odds_Ratio:.2f} [{row.Lower_CI:.2f}-{row.Upper_CI:.2f}]"
        if row.p_value < 0.001:
            Pval_annotation = "<0.001"
        elif row.p_value >= 1:
            Pval_annotation = "1"
        else:
            Pval_annotation = f"{row.p_value:.3f}"
        plt.text(OR_text, y, OR_annotation, va='center', fontsize=9)
        plt.text(Pval_text, y, Pval_annotation, va='center', fontsize=9)

    # clean up
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['top'].set_visible(True)
    ax.tick_params(axis='y', which='both', left=False, labelleft=True)

    
    # Set the custom ticks on the x-axis.
    # Remove original x-axis ticks
    ax.set_xticks([], minor=False)
    ax.set_xticks([], minor=True)   
    ax.set_xticks(ticks)
    ax.set_xticklabels(tick_labels)
    
   
    ymin = -pad
    ymax = n - 1 + pad

    # because you inverted y, set_ylim(ymax, ymin)
    ax.set_ylim(ymax, ymin)
    
    for tick in plt.gca().yaxis.get_ticklabels():    # Italicize y-axis label
        tick.set_fontstyle("italic")
    

    
    from matplotlib.ticker import FuncFormatter

    def my_formatter(x, pos):
        if x < 1:
            # for values less than 1, show up to two decimals,
            # but strip trailing zeros (0.50 → .5, 0.25 → .25)
            s = f"{x:.2f}".rstrip('0').rstrip('.')
        else:
            # for 1 and above, show no decimals
            s = f"{int(x)}"
        return s

    ax.xaxis.set_major_formatter(FuncFormatter(my_formatter))
    
    



    plt.tight_layout()
    plt.savefig(plot_output_file)
    plt.show()

def plot_fishertable(df_fisher, plot_title = "Forest Plot of Fisher Exact Test Results", 
                     plot_output_file = "forest_plot.png", logbase = 2,figwidth=8,
                     custom_labels=False, dpi = 300):
    
    # Compute the minimum lower CI and maximum upper CI
    min_lower = df_fisher["Lower_CI"].min()
    max_upper = df_fisher["Upper_CI"].max()

    # Determine the exponents for ticks such that:
    #   2^(start_exponent) <= min_lower and 2^(end_exponent) > max_upper.
    start_exponent = int(np.floor(np.log(min_lower) / np.log(logbase)))
    end_exponent = int(np.ceil(np.log(max_upper) / np.log(logbase))) + 1
    print(start_exponent)

    # Generate tick positions on a power-of-2 scale.
    ticks = np.logspace(start_exponent, end_exponent, num=end_exponent - start_exponent + 1, base=logbase)
    # Create tick labels directly from ticks (convert to int if desired, or leave as float)
    tick_labels = [str((t)) for t in ticks]
    
    #find number of tests
    n_tests = df_fisher.shape[0]
    y_positions = np.arange(n_tests)


    # Plot a forest plot-like graph
    plt.figure(figsize=(figwidth, n_tests * 0.5))
    ax = plt.gca()  # get current axis
    
    # Plot the odds ratio with error bars representing the 95% CI (Bonferroni corrected)
    plt.errorbar(df_fisher["Odds_Ratio"], y_positions,
                 xerr=[df_fisher["Odds_Ratio"] - df_fisher["Lower_CI"],
                       df_fisher["Upper_CI"] - df_fisher["Odds_Ratio"]],
                 fmt='o', color='black', capsize=5)


    # Set the x-axis to logarithmic scale
    plt.xscale('log')

    # Draw a vertical line at an odds ratio of 1 (no effect)
    plt.axvline(x=1, color='red', linestyle='--')

    # Label the y-axis with the test names
    plt.yticks(y_positions, df_fisher["Test"])
    if(custom_labels):
        plt.yticks(y_positions, custom_labels)

    # Invert the y-axis so that "Test 1" appears at the top.
    plt.gca().invert_yaxis()

    plt.xlabel("Odds Ratio")
    plt.title(plot_title, fontsize = 12, va = 'bottom')

    # Determine a right-hand margin for annotations.
    # We take the maximum upper CI value and extend the x-axis a bit.
    max_upper = df_fisher["Upper_CI"].max()
    OR_text = max(ticks) * 1.4
    Pval_text = max(ticks) * 7.5
    #plt.xlim(left=0.1, right=Pval_text * 0.9)

    # Add header annotation above the individual annotations
    header_y = n_tests 
    header_y = -1
    plt.text(OR_text, header_y, "OR [95% CI]", ha='left', va='top', fontsize=10, fontweight='bold')
    plt.text(Pval_text, header_y, "P-value*", ha='left', va='top', fontsize=10, fontweight='bold')

    # Annotate each point with the odds ratio and its Bonferroni-corrected CI
    for y, row in zip(y_positions, df_fisher.itertuples()):
        OR_annotation = f"{row.Odds_Ratio:.2f} [{row.Lower_CI:.2f}-{row.Upper_CI:.2f}]"
        if row.p_value < 0.001:
            Pval_annotation = "<0.001"
        elif row.p_value >=1:
            Pval_annotation = "1"
        else:
            Pval_annotation = f"{row.p_value:.3f}"
        plt.text(OR_text, y, OR_annotation, va='center', fontsize=9)
        plt.text(Pval_text, y, Pval_annotation, va = 'center', fontsize = 9)

    # Remove the left and right spines of the plot so they don't overlap with the annotations.
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['top'].set_visible(True)

    # Remove y-axis ticks and labels
    ax.tick_params(axis='y', which='both', left=False, labelleft=True)
    
    # Set the custom ticks on the x-axis.
    # Remove original x-axis ticks
    ax.set_xticks([], minor=False)
    ax.set_xticks([], minor=True)   
    ax.set_xticks(ticks)
    ax.set_xticklabels(tick_labels)

    plt.tight_layout()
    plt.savefig(plot_output_file)
    plt.show()

def plot_stacked_bars(sol_table_summ, solution_category,plot_title, save_path, fig_width = 7, fig_height = 7, offset = 0.3):
    
    #sol_table = sol_table_summ[sol_table_summ["category"].str.contains(solution_category)]
    sol_table = sol_table_summ





    #starting table
    con_sol_table = sol_table[sol_table.solution_table == solution_category] #plot AD

    #concept categories in selected table
    concept_categories = con_sol_table["category"].unique()
    print(concept_categories)
    
    # Sort data by person_count for better visualization
    sol_table = sol_table.sort_values("person_count_solution", ascending=False).reset_index()
    
    #display(sol_table)
    genes_to_plot = sol_table["gene"].unique()
    


    y_positions = np.arange(0, len(genes_to_plot)/2, 0.5)
    fig, ax = plt.subplots(figsize=(fig_width, fig_height))
    #ax = fig.add_axes([0.2, 0.1, 0.7, 0.8])

    bar_height = 0.4  # Height of each bar

    #first by gene
    for i, gene_to_plot in enumerate(genes_to_plot):

        #print("Processing gene:", gene_to_plot)
        #print(i)




        # Initialize an empty list for person_annotation_rates
        x_abs = []

        # Filter the main table for the current gene
        con_sol_table_gene = con_sol_table[con_sol_table['gene'] == gene_to_plot]
        if not con_sol_table_gene.empty:
            # proceed if there is an entry for the gene in the sol x concept table (i.e. at least 1 person annotated)

            # Iterate over the unique concept categories
            for k, concept_category in enumerate(concept_categories):
                #print("Concept Category:", concept_category)

                # Filter for the current concept category within the gene-specific table
                subset = con_sol_table_gene[con_sol_table_gene['category'] == concept_category]

                if not subset.empty:
                    # Extract person_annotation_rate from the first matching row (adjust if necessary)
                    rate = subset.iloc[0]['person_annotation_rate']*100
                    x_abs.append(rate)
                else:
                    # If the category is not present, check if it's the first element
                    if k == 0:
                        x_abs.append(0)
                    else:
                        # Use the previous rate if available
                        x_abs.append(x_abs[k-1])
        else:
            #if not present, set all to 0
            x_abs = [0,0,0]

        x_delta = [x_abs[0]] + [x_abs[i] - x_abs[i - 1] for i in range(1, len(x_abs))]

        y = y_positions[i]
        left = 0  # starting position for the first segment
        for j in range(len(x_delta)):
            ax.barh(y, x_delta[j],
                    left=left,
                    height=bar_height,
                    color=colors_concepts[j],
                    #align='edge',
                    edgecolor='none',   # Remove the edge (border) to eliminate white gaps
                    linewidth=0)        # Alternatively, set linewidth to 0
            left += x_delta[j]

    # Customize the y-axis to show group names at the center of each bar
    ax.set_yticks(y_positions)
    ax.set_yticklabels(genes_to_plot)
    ax.set_ylim(-offset, max(y_positions)+offset)

    ax.invert_yaxis()
    
    ax.set_xlim(0, 100)
    #ax.set_xlabel("Disease Annotation Frequency")
    plt.gca().set_xlabel('Disease Annotation Frequency (%)', fontdict=dict(weight='bold'))
    #ax.set_aspect('equal')  # or use a numeric value: ax.set_aspect(1)

    
    #draw con_prev_lines
    plt.axvline(x=con_prev[0]*100, color='grey', linestyle='--', linewidth=1)
    plt.axvline(x=con_prev[1]*100, color='grey', linestyle='--', linewidth=1)
    plt.axvline(x=con_prev[2]*100, color='grey', linestyle='--', linewidth=1)

    # Add labels and title
    plt.title(
        plot_title,
        fontsize=16,
        fontweight="bold",
        ha="center"
    )
    
    # Create custom legend patches
    legend_handles = [
        mpatches.Patch(color=colors_concepts[0], label="IRD"),
        mpatches.Patch(color=colors_concepts[1], label="Retinal Degeneration"),
        mpatches.Patch(color=colors_concepts[2], label="Screening Set"),
    ]
    #ax.legend(handles=legend_handles, bbox_to_anchor=(1.05, 1), loc='upper left')
    
    for tick in plt.gca().yaxis.get_ticklabels():    # Italicize y-axis label
        tick.set_fontstyle("italic")


    
    #plt.tight_layout()
    #plt.subplots_adjust(top=0.95, bottom=0.05)

    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"Plot saved to {save_path}")
    plt.rcParams['figure.constrained_layout.use'] = True

    #plt.tight_layout()
    plt.show()
def plot_stacked_bars_bygene(sol_table_summ, 
                             solution_category,plot_title, 
                             save_path, fig_width = 7, fig_height = 7, offset = 0.3):
    
    #sol_table = sol_table_summ[sol_table_summ["category"].str.contains(solution_category)]
    sol_table = sol_table_summ





    #starting table
    con_sol_table = sol_table[sol_table.solution_table == solution_category] #plot AD

    #concept categories in selected table
    concept_categories = con_sol_table["category"].unique()
    print(concept_categories)
    
    # Sort data by person_count for better visualization
    sol_table = sol_table.sort_values("person_annotation_rate", ascending=False).reset_index()
    
    #display(sol_table)
    genes_to_plot = sol_table["gene"].unique()
    


    y_positions = np.arange(0, len(genes_to_plot)/2, 0.5)
    fig, ax = plt.subplots(figsize=(fig_width, fig_height))
    #ax = fig.add_axes([0.2, 0.1, 0.7, 0.8])

    bar_height = 0.4  # Height of each bar

    #first by gene
    for i, gene_to_plot in enumerate(genes_to_plot):

        #print("Processing gene:", gene_to_plot)
        #print(i)




        # Initialize an empty list for person_annotation_rates
        x_abs = []

        # Filter the main table for the current gene
        con_sol_table_gene = con_sol_table[con_sol_table['gene'] == gene_to_plot]
        if not con_sol_table_gene.empty:
            # proceed if there is an entry for the gene in the sol x concept table (i.e. at least 1 person annotated)

            # Iterate over the unique concept categories
            for k, concept_category in enumerate(concept_categories):
                #print("Concept Category:", concept_category)

                # Filter for the current concept category within the gene-specific table
                subset = con_sol_table_gene[con_sol_table_gene['category'] == concept_category]

                if not subset.empty:
                    # Extract person_annotation_rate from the first matching row (adjust if necessary)
                    rate = subset.iloc[0]['person_annotation_rate'] * 100
                    x_abs.append(rate)
                else:
                    # If the category is not present, check if it's the first element
                    if k == 0:
                        x_abs.append(0)
                    else:
                        # Use the previous rate if available
                        x_abs.append(x_abs[k-1])
        else:
            #if not present, set all to 0
            x_abs = [0,0,0]

        x_delta = [x_abs[0]] + [x_abs[i] - x_abs[i - 1] for i in range(1, len(x_abs))]

        y = y_positions[i]
        left = 0  # starting position for the first segment
        for j in range(len(x_delta)):
            ax.barh(y, x_delta[j],
                    left=left,
                    height=bar_height,
                    color=colors_concepts[j],
                    #align='edge',
                    edgecolor='none',   # Remove the edge (border) to eliminate white gaps
                    linewidth=0)        # Alternatively, set linewidth to 0
            left += x_delta[j]

    # Customize the y-axis to show group names at the center of each bar
    ax.set_yticks(y_positions)
    ax.set_yticklabels(genes_to_plot)
    ax.set_ylim(-offset, max(y_positions)+offset)

    ax.invert_yaxis()
    
    ax.set_xlim(0, 100)
    #ax.set_xlabel("Disease Annotation Frequency")
    plt.gca().set_xlabel('Disease Annotation Frequency (%)', fontdict=dict(weight='bold'))
    #ax.set_aspect('equal')  # or use a numeric value: ax.set_aspect(1)

    
    #draw con_prev_lines
    plt.axvline(x=con_prev[0]*100, color='grey', linestyle='--', linewidth=1)
    plt.axvline(x=con_prev[1]*100, color='grey', linestyle='--', linewidth=1)
    plt.axvline(x=con_prev[2]*100, color='grey', linestyle='--', linewidth=1)

    # Add labels and title
    plt.title(
        plot_title,
        fontsize=16,
        fontweight="bold",
        ha="center"
    )
    
    # Create custom legend patches
    legend_handles = [
        mpatches.Patch(color=colors_concepts[0], label="IRD"),
        mpatches.Patch(color=colors_concepts[1], label="Retinal Degeneration"),
        mpatches.Patch(color=colors_concepts[2], label="Screening Set"),
    ]
    #ax.legend(handles=legend_handles, bbox_to_anchor=(1.05, 1), loc='upper left')
    
    for tick in plt.gca().yaxis.get_ticklabels():    # Italicize y-axis label
        tick.set_fontstyle("italic")


    
    #plt.tight_layout()
    #plt.subplots_adjust(top=0.95, bottom=0.05)

    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"Plot saved to {save_path}")
    plt.rcParams['figure.constrained_layout.use'] = True

    #plt.tight_layout()
    plt.show()

# Derive colors for the stacked bar plot
# Define the nested set sizes (cumulative counts)
set_sizes = [1060, 14765, 52836, 115412, 317964]

# Compute segments representing the additional entries in each nested set.
segments = [
    set_sizes[0],
    set_sizes[1] - set_sizes[0],
    set_sizes[2] - set_sizes[1],
    set_sizes[3] - set_sizes[2],
    set_sizes[4] - set_sizes[3]
]

total = set_sizes[-1]

# Calculate percentages relative to the largest set.
pwr = 1/1.5
percentages = [(seg / total)**pwr for seg in set_sizes]

# Define labels for each set.


# Choose colors from a colormap.
colors = plt.cm.Blues(percentages[::-1])


#colors_concepts = ["darkblue", "blue", "lightblue"] #IRD, RetDegen, Screen
colors_concepts = colors[0:3]


def plot_gene_summary_aggregated(solution_gene_summary, 
                                 category = "retdys", 
                                 title="Gene Summary", 
                                 save_path=None, 
                                 height=8, 
                                 width=7, 
                                 offset = 2):
    """
    Create a barplot for a specific category of genes in solution_gene_summary.

    Parameters:
    - solution_gene_summary (pd.DataFrame): DataFrame containing gene summaries.
    - category (str): Category to filter the data.
    - title (str): Title for the plot.
    """
    # Filter the data for the specified category
    category_q = f"^{category}$"
    
    category_data = solution_gene_summary[solution_gene_summary["category"].str.contains(category_q)]

    if category_data.empty:
        print(f"No data available for category: {category}")
        return

    # Sort data by person_count for better visualization
    category_data = category_data.sort_values("person_count_solution", ascending=False).reset_index()

    # Plotting
    plt.figure(figsize=(width, height))
    sns.barplot(
        data=category_data,
        y="gene",  # Genes on the y-axis for a horizontal bar plot
        x="person_count_solution",
        color="lightgrey"
    )
    
    # Barplot for variant_count (overlaid with transparency)
    sns.barplot(
        data=category_data,
        y="gene",
        x="variant_count_solution",
        color="blue",
        alpha=0.6,  # Make the bars partially transparent
        label="Variant Count"  # Add label for the legend
    )

    # Add annotations for person_count
    for index, row in category_data.iterrows():
        plt.text(
            row["person_count_solution"] + offset,  # Offset to the right of the bar
            index,  # Position along the y-axis
            f'{int(row["person_count_solution"])}',  # Annotation text
            va="center",
            fontsize = 10
        )

    # Add labels and title
    #plt.title(
    #    title,
    #    fontsize=16,
    #    fontweight="bold",
    #    ha="center"
    #)
    #plt.xlabel("Number of Participants")
    #plt.ylabel("Gene")
    plt.gca().yaxis.set_tick_params(labelsize=10)
    #plt.gca().yaxis.get_label().set_fontstyle("italic")  
    plt.gca().set_xlabel('Number of Participants', fontdict=dict(weight='bold'))
    plt.gca().set_ylabel('Gene', fontdict=dict(weight='bold'))
    plt.grid(axis='x', visible=False)
    for tick in plt.gca().yaxis.get_ticklabels():    # Italicize y-axis label
        tick.set_fontstyle("italic")
    plt.gca().xaxis.grid(False)
    plt.legend(loc="lower right")

    
    sns.despine()
    sns.set_style("white")
    sns.set_style("ticks")
    
    #plt.grid(b=None)
    # Save the plot if a save_path is provided
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"Plot saved to {save_path}")

    # Adjust layout
    #plt.grid(None)
    #plt.tight_layout()
    plt.show()

def procRace(df):
    df.loc[df['race_source_value'].isin(["AoUDRC_NoneIndicated",
                                   "PMI_Skip",
                                   "PMI_PreferNotToAnswer"]), 'race_source_value'] = "Unknown"


    df.loc[df['race_source_value'].isin(["WhatRaceEthnicity_GeneralizedMultPopulations",
                                   "WhatRaceEthnicity_RaceEthnicityNoneOfThese"]), 'race_source_value'] = "Other"
    df = df.rename(columns={'race_source_value': 'race'})

    return(df)

def replaceRace(df):
    df['race'] = df['race'].replace({
    'WhatRaceEthnicity_White': 'White',
    'WhatRaceEthnicity_Black':   'Black',
    'WhatRaceEthnicity_Asian': 'Asian',
    'WhatRaceEthnicity_AIAN':   'AIAN',
    'WhatRaceEthnicity_MENA': 'MENA'
    })
    
    return(df)


In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import matplotlib.patches as mpatches
from scipy.stats import fisher_exact
from statsmodels.stats.contingency_tables import Table2x2
import numpy as np
import matplotlib

# Set global parameters for a Tufte-like minimalist look
plt.rcParams.update({
    'axes.spines.right': False,
    'axes.spines.top': False,
    'axes.edgecolor': 'gray',
    'axes.linewidth': 0.8,
    'grid.linestyle': '--',
    'grid.linewidth': 0.5,
    'grid.color': 'gray',
    'xtick.top': False,
    'ytick.right': False,
    'font.family': 'sans-serif',
    'font.sans-serif' : ['Arial', 'Tahoma', 'DejaVu Sans',
                               'Lucida Grande', 'Verdana'],
    'font.size': 12,
    # Remove background color and excessive ticks
    'figure.facecolor': 'white',
    'axes.facecolor': 'white'
})

matplotlib.rc('font', size=12)          # Set default text size
matplotlib.rc('axes', titlesize=14)       # Set axes title font size
matplotlib.rc('axes', labelsize=12)       # Set label font size
matplotlib.rc('xtick', labelsize=10)      # Set x tick label size
matplotlib.rc('ytick', labelsize=10)      # Set y tick label size
matplotlib.rc('legend', fontsize=10)      # Set legend font size

In [None]:
# load solution tables (personIDs with IRD-comopatible genotypes & associated metadata)
AD_subset = getTable('GS_AD_solutions.tsv',"gene_solutions")
XL_subset_male = getTable('GS_XLmale_solutions.tsv',"gene_solutions")
Hom_subset = getTable("pid_vid_hom_noXLmale.tsv", "personID_variant")
ALL_IRD = getTable("ALL_IRD.tsv", "personID_concept") #if you have already done the above
# Dictionary of solution tables - names must match!
solution_tables = {
    "ALLIRD_solutions" : ALL_IRD,
    #"AD_solutions": AD_subset,
    #"XLmale_solutions": XL_subset_male,
    #"Hom_solutions" : Hom_subset
}

In [None]:
solution_inheritance_summary = describe_solution_tables(**solution_tables)
saveToBucket(solution_inheritance_summary, "solution_inheritance_summary.tsv", "Concept_Solution")
solution_inheritance_summary

In [None]:
#get concept_person tables (personIDs matching ICD code sets & associated metadata)
retdys_concept_person = getTable(table_name = "retchordys_ICD_person_age_youngest.tsv", folder = "personID_concept")
retdegen_concept_person = getTable(table_name = "retdegen_ICD_person_age_youngest.tsv", folder = "personID_concept")
screenretdys_concept_person = getTable(table_name = "ScreenRetDysICD_person_age_youngest.tsv",  folder = "personID_concept")

AMDspecific_concept_person = getTable(table_name = "AMD_specific_ICD_person_age_youngest.tsv",  folder = "personID_concept")
AMDexudative_concept_person = getTable(table_name = "AMD_exudative_ICD_person_age_youngest.tsv",  folder = "personID_concept")
AMDnonexudative_concept_person = getTable(table_name = "AMD_nonexudative_ICD_person_age_youngest.tsv",  folder = "personID_concept")

CME_concept_person = getTable(table_name = "CME_ICD_person_age_youngest.tsv",  folder = "personID_concept")

myopia_concept_person = getTable(table_name = "myopia_ICD_person_age_youngest.tsv",  folder = "personID_concept")
hypermetropia_concept_person = getTable(table_name = "hypermetropia_ICD_person_age_youngest.tsv",  folder = "personID_concept")
pucker_concept_person = getTable(table_name = "pucker_ICD_person_age_youngest.tsv",  folder = "personID_concept")

In [None]:
#describe concept tables
concept_tables = {
    "retdys_concept_person" : retdys_concept_person,
    "retdegen_concept_person" : retdegen_concept_person,
    "screenretdys_concept_person" : screenretdys_concept_person,
    #"AMDexudative_concept_person" : AMDexudative_concept_person,
    #"AMDnonexudative_concept_person" : AMDnonexudative_concept_person,
    "AMDspecific_concept_person" : AMDspecific_concept_person,
    "CME_concept_person" : CME_concept_person,
    "pucker_concept_person" : pucker_concept_person,
    "myopia_concept_person" : myopia_concept_person,
    "hypermetropia_concept_person" : hypermetropia_concept_person
    
}

concept_summary = describe_tables(**concept_tables)
#saveToBucket(concept_summary, "concept_summary.tsv", "Concept_Solution")
concept_summary

In [None]:
# Generate summaries from solution and concept tables
solution_gene_summary = process_solution_gene_summaries(solution_tables)
concept_gene_summary = process_concept_solution_gene_summaries(concept_tables, solution_tables)
    
# Calculate annotation rates
annotation_rates_solution_gene = calculate_all_annotation_rates(solution_gene_summary, concept_gene_summary)

#saveToBucket(annotation_rates_solution_gene, "annotation_rates_solution_GS.tsv", "Concept_Solution")
#saveToBucket(solution_gene_summary, "solution_geneGS_summary.tsv", "Concept_Solution")

In [None]:
#aggregate all genes with 20 or less participants

full_table = pd.DataFrame()
table_subset = pd.DataFrame()
for table in annotation_rates_solution_gene.table.unique():
    table_subset = annotation_rates_solution_gene[annotation_rates_solution_gene["table"] == table]
    mask = table_subset['person_count_solution'] <= 20

    # sum up the “ultra-rare” group at person-level
    sum_person_count = table_subset.loc[mask, 'person_count_solution'].sum()
    sum_person_concept = table_subset.loc[mask, 'person_count_concept'].sum()
    
    # sum up the "ultra-rare" group at variant-level
    sum_variant_count = table_subset.loc[mask, 'variant_count_solution'].sum()
    sum_variant_concept = table_subset.loc[mask, 'variant_count_concept'].sum()
    
    
    # build the new row
    new_row = {
        'concept': table_subset.concept.unique()[0],
        'table': table_subset.table.unique()[0],
        'solution_table': table_subset.solution_table.unique()[0],
        'category': table_subset.category.unique()[0],
        'category_inheritance': table_subset.category_inheritance.unique()[0],
        'gene': 'Ultra-rare',
        'person_count_solution': sum_person_count,
        'person_count_concept': sum_person_concept,
        'person_annotation_rate': sum_person_concept / sum_person_count if sum_person_concept else float('nan'),
        'variant_count_solution' : sum_variant_count,
        'variant_count_concept' : sum_variant_concept,
        'variant_annotation_rate' : sum_variant_concept / sum_variant_count if sum_variant_concept else float('nan')
    }
    
    # drop the old “ultra-rare” rows and append the aggregate
    table_subset = pd.concat([
        table_subset.loc[~mask],        # keep only rows where col3 > 20
        pd.DataFrame([new_row])         # add aggregated row
    ], ignore_index=True)
    
    full_table = pd.concat([full_table, table_subset], ignore_index = True)

In [None]:
# get enrichment/forest plots for code set by gene
# change concept as desired to get plot (retdys_concept_person = IRD set, retdegen_concept_person = Retinopathy set, etc.)
# change annotation_rates_solution_gene to the table you want to plot (full_table for aggregated data, raw annotation_rates_solution_gene for all genes)
a = loop_through_fisher_bygene(annotation_rates_solution_gene, concept = "retdys_concept_person")
plot_fishertable_with_arrows(a, plot_title = "Enrichment of IRD Set",
                plot_output_file = "forest_plot_gene_in_IRD_bygene.png", logbase = 2, 
                dpi=300, offset_OR = 1.5, offset_Pval = 8, fig_width = 10, fig_height = 10,
                            x_max=128, x_min=0.25, pad = 0.35, cap = 128)

In [None]:
# enrichment plot for the entire dataset, as see in Figure 1 of the paper
custom_labels = ["IRD",
                "Retinopathy",
                "Screening Set",
                "AMD",
                "CME",
                "Pucker / ERM",
                "Myopia",
                "Hypermetropia"
                ]

plot_fisher_tbl = loop_through_fisher(anno_rates_table = annotation_rates_solution_gene)
plot_fishertable(plot_fisher_tbl, plot_title = "Enrichment within IRD genotype",
                plot_output_file = "forest_plot_IRD_vs_concept.png", logbase = 2, custom_labels = custom_labels,
                 figwidth=9,
                dpi=300)

In [None]:
# heatmap as seen in Figure 1
# Extract the heatmap data from the two value columns
heatmap_data = plot_fisher_tbl[['Concept_prevalence', 'DAR']].values * 100

n_tests = plot_fisher_tbl.shape[0]

# Create the heatmap plot using imshow
fig, ax = plt.subplots(figsize=(4, n_tests*0.5))
cax = ax.imshow(heatmap_data, aspect='auto', cmap='OrRd', vmax=40, origin='upper')

# Set row labels using the Title column from the DataFrame
ax.set_yticks(np.arange(len(plot_fisher_tbl)))
ax.set_yticklabels(custom_labels)

# Optionally, set column labels for the two value columns
ax.set_xticks(np.arange(heatmap_data.shape[1]))
ax.set_xticklabels(['Prevalence', 'DAF'])

# Annotate each cell with its data value
for i in range(heatmap_data.shape[0]):
    for j in range(heatmap_data.shape[1]):
        # The format can be adjusted based on your needs
        ax.text(j, i, f'{heatmap_data[i, j]:.2f}', ha='center', va='center', color='black', fontsize=10)

# Add a title to the plot and a colorbar
plt.title('AoU Prevalence \n and DAF (%)')
#plt.colorbar(cax)

#Remove spines
ax.spines['right'].set_visible(True)
ax.spines['top'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)

# Adjust layout and display the plot
plt.tight_layout()
plt.savefig("ConPrev_andAnnoRate.png")
plt.show()

In [None]:
# Pie chart as seen in Figure 1
#data = solution_inheritance_summary["count"]
data = [374, 76, 31]

# pick a colormap and sample N colors from it
cmap = plt.get_cmap('Set3')       
#colors = cmap(np.linspace(0, 1, len(data)))
colors = cmap.colors[:len(data)]


#labels = solution_inheritance_summary.Table_Name
labels = ["Autosomal Dominant \n (N=374)", " X-linked Male or \n Homozygous Female\n (N=76)", 
          "Homozygous \nAutosomal Recessive \n (N=31)"]

# Create the pie chart with percentage labels
plt.pie(data, labels=labels, colors = colors, autopct='%1.1f%%')
plt.title('Participants with definite \n IRD-compatible genotypes \n N=481')
plt.tight_layout()

plt.savefig("participant_dist.png", dpi = 300)
plt.show()

In [None]:
# plot of # of participants per gene in the set as seen in Figure 2
# Assuming solution_gene_summary is a DataFrame with columns: "category", "gene", "person_count"
plot_gene_summary_aggregated(
    solution_gene_summary=full_table,
    category="retdys",  # unnecessary for this plot, but can be used to filter by category
    title="Number of Participants",
    height = 4,
    width = 5,
    save_path="ALL_Solutions_var.png"
)

In [None]:
# DAFs per gene in the set as seen in Figure 2

#get solution and concept categories for to setting up looping or indexing below
solution_categories = core_concept_anno_gene["solution_table"].unique()
concept_categories = core_concept_anno_gene["category"].unique()


# Make the stacked bar plot - change the solution_category to plot different sets
plot_stacked_bars(sol_table_summ = core_concept_anno_aggregated,
                  solution_category=solution_categories[0],
                  plot_title="",
                  save_path = "anno_rate_ALL.png",
                  fig_width=6,
                  fig_height=4,
                 offset = 0.3)

In [None]:
# age plot as seen in Figure 4

an_mask = ALLIRD_in_retdys.condition_source_value.isnull() # or df.isna()
in_retdys = ALLIRD_in_retdys[~nan_mask]
notin_retdys = ALL_IRD[nan_mask]


nan_mask = ALLIRD_in_retdegen.condition_source_value.isnull() # or df.isna()
in_retdegen = ALLIRD_in_retdegen[~nan_mask]
notin_retdegen = ALL_IRD[nan_mask]


nan_mask = ALLIRD_in_screenretdys.condition_source_value.isnull() # or df.isna()
in_screenretdys = ALLIRD_in_screenretdys[~nan_mask]
notin_screenretdys = ALL_IRD[nan_mask]

annotationRate_tables = {
    "in_retdys" : in_retdys,
    "in_retdegen" :in_retdegen,
    "in_screenretdys": in_screenretdys
}

Unannotated_tables = {
    "notin_retdys" : notin_retdys,
    "notin_retdegen" :notin_retdegen,
    "notin_screenretdys": notin_screenretdys
}

#violin plots
in_retdys_sm = in_retdys[["AgeAtVisit_Years"]]
in_retdys_sm["set"] = "IRD"
in_retdys_sm = in_retdys_sm.rename(columns={'AgeAtVisit_Years': 'Age'})


in_retdegen_sm = in_retdegen[["AgeAtVisit_Years"]]
in_retdegen_sm["set"] = "Retinal Degeneration"
in_retdegen_sm = in_retdegen_sm.rename(columns={'AgeAtVisit_Years': 'Age'})

in_screenretdys_sm = in_screenretdys[["AgeAtVisit_Years"]]
in_screenretdys_sm["set"] = "Screening Set"
in_screenretdys_sm = in_screenretdys_sm.rename(columns={'AgeAtVisit_Years': 'Age'})

ALL_IRD_sm = ALL_IRD[["AgeAtConsent_Years"]]
ALL_IRD_sm["set"] = "All Participants"
ALL_IRD_sm = ALL_IRD.rename(columns={'AgeAtConsent_Years': 'Age'})

notin_retdys_sm = notin_retdys[["AgeAtConsent_Years"]]
notin_retdys_sm["set"] = "Not in IRD"
notin_retdys_sm = notin_retdys_sm.rename(columns={'AgeAtConsent_Years': 'Age'})


notin_retdegen_sm = notin_retdegen[["AgeAtConsent_Years"]]
notin_retdegen_sm["set"] = "Not in Retinal Degeneration"
notin_retdegen_sm = notin_retdegen_sm.rename(columns={'AgeAtConsent_Years': 'Age'})

notin_screenretdys_sm = notin_screenretdys[["AgeAtConsent_Years"]]
notin_screenretdys_sm["set"] = "Not in Screening Set"
notin_screenretdys_sm = notin_screenretdys_sm.rename(columns={'AgeAtConsent_Years': 'Age'})

df = pd.concat([in_retdys_sm, notin_retdys_sm, 
                in_retdegen_sm, notin_retdegen_sm, 
                in_screenretdys_sm, notin_screenretdys_sm])

palette = {
    'IRD':        '#0F5786',
    'Not in IRD':         'white',
    'Retinal Degeneration': '#58b1db',
    'Not in Retinal Degeneration':  'white',
    'Screening Set' : '#b0d2e7',
    'Not in Screening Set' : "white"
    # … add entries for every Code Set …
}

df['Age'] = pd.to_numeric(df['Age'], errors='coerce')
df = df.dropna(subset=['Age'])

fig = plt.figure(figsize=(9, 6))
#fig = plt.figure(figsize=(fig_width, fig_height))
ax = fig.add_axes([0.25, 0.15, 0.70,0.80])
ax = sns.violinplot(
    data=df,
    x='set',
    y='Age',
    #hue='Annotated',
    split=True,
    #inner="box",
    #common_norm = True,
    saturation = 1,
    inner_kws=dict(box_width=30, whis_width=2, color=".8"),
    density_norm="count",

    
    palette=palette
)

ax.set_xticklabels(['IRD', 'Not in\nIRD', 
                    'Retinopathy', 'Not in\nRetinopathy', 
                    'Screening\nSet', 'Not in\nScreening\nSet'])

ax.set(xlabel="Code Set Annotation")
ax.set_ylim(-10, 110)



#sns.swarmplot(x ='set', y ='Age', data = df,color= "white")





plt.title('AoU Age Distribution')
#plt.tight_layout()
plt.savefig("AoUAgeDist.png", dpi = 300)

In [None]:
# UKBB age distribution plot as see in Figure 4

UKBB_data_images = getTable("UKBB_data_images.tsv", "UKBB_data")
df = UKBB_data_images
#df = df.sort()

colors_UKBB = ["#9bd08a", "white"]
colors_UKBB2 = ["#116434", "#9bd08a", "white"]

df.Phenotype = pd.Categorical(df.Phenotype, 
                      categories=["All Abnormal","Normal"],
                      ordered=True)

df.Phenotype2 = pd.Categorical(df.Phenotype2, 
                      categories=["Abnormal","Abnormal & Unclear","Normal"],
                      ordered=True)

df.Phenotype3 = pd.Categorical(df.Phenotype3, 
                      categories=["Abnormal","Normal"],
                      ordered=True)

df['Age'] = pd.to_numeric(df['Age'], errors='coerce')
df = df.dropna(subset=['Age'])

fig = plt.figure(figsize=(3.5, 6))
ax = fig.add_axes([0.25, 0.15, 0.70,0.80])

ax = sns.violinplot(
    data=df,
    x='Phenotype',
    y='Age',
    #hue='Annotated',
    split=True,
    #inner="box",
    #common_norm = True,
    saturation = 1,
    inner_kws=dict(box_width=30, whis_width=2, color=".8"),
    density_norm="count",
    palette=colors_UKBB
)

ax.set_xticklabels(['IRD or Abnormal', 'Normal'])

ax.set(xlabel="Phenotype")
ax.set_ylim(-10, 110)



#sns.swarmplot(x ='set', y ='Age', data = df,color= "white")





plt.title('UKB Age Distribution')
#plt.tight_layout()
plt.savefig("UKBBAgeDist.png", dpi = 300)
