# GENOTYPE EXTRACTION

This notebook demonstrates how to extract genotypes for specific genetic variants across all participants from the All of Us Researcher Workbench.

## WORKFLOW OVERVIEW:
1. Set up environment and utility functions
2. Configure paths and database connections
3. Load and explore genomic data using Hail
4. Extract genotypes for one or more specified variants
5. Create a genotype matrix (participants √ó variants)
6. (Optional) Add clinical/demographic data

## KEY FEATURES:
- Extract genotypes for **multiple variants simultaneously**
- Creates a wide-format DataFrame with one column per variant
- Each column shows actual genotypes (0/0, 0/1, 1/1) for all participants
- Handles multi-allelic sites automatically

## LEARNING OBJECTIVES:
- Understand Hail MatrixTable structure for genomic data
- Learn how to query specific genetic variants
- Extract participant genotypes across multiple variants
- Create analysis-ready genotype matrices

## OUTPUT FORMAT:
The function returns a Polars DataFrame with structure:
- **Columns**: person_id, chr#:pos:ref:alt (one column per variant)
- **Values**: "0/0", "0/1", "1/1", or null (missing)
- **Example**:
  ```
  person_id | 19:39248514:TT:G | 19:39247938:G:A
  ----------|------------------|------------------
  1000001   | 0/0              | 0/1
  1000002   | 0/1              | 0/0
  1000003   | 1/1              | null
  ```

# SECTION 1: IMPORTS AND SETUP

In [None]:
from google.cloud import bigquery
import pandas as pd
import polars as pl
import numpy as np
import os
import sys
from itertools import combinations

# Hail will be imported later (after we set up paths)
# import hail as hl

# SECTION 2: UTILITY FUNCTIONS

In [None]:
def polars_gbq(query: str) -> pl.DataFrame:
    """
    Execute a SQL query on Google BigQuery and return result as Polars DataFrame.

    This is our standard method for querying the All of Us OMOP CDM database.
    Polars is preferred over pandas for better performance with large datasets.

    :param query: BigQuery SQL query string (can be multi-line)
    :type query: str
    :return: Query results as Polars DataFrame
    :rtype: pl.DataFrame

    """
    rows = bigquery.Client().query(query).result()
    df = pl.from_arrow(rows.to_arrow())
    return df

In [None]:
def spark_to_polars(spark_df) -> pl.DataFrame:
    """
    Convert Spark DataFrame to Polars DataFrame.

    WHY THIS IS NEEDED:
    - Hail (genomics framework) uses Apache Spark under the hood
    - Hail's .to_spark() method exports data as Spark DataFrames
    - We convert to Polars for easier analysis and better performance

    HOW IT WORKS:
    1. Converts Spark DataFrame to Apache Arrow format (columnar data structure)
    2. Converts Arrow to Polars DataFrame
    """
    import pyarrow as pa

    polars_df = pl.from_arrow(pa.Table.from_batches(spark_df._collect_as_arrow()))
    return polars_df

# SECTION 3: BIGQUERY SQL QUERIES FOR CLINICAL DATA

These functions generate SQL queries to extract clinical and demographic data from the All of Us OMOP CDM database.

**OMOP TABLES USED:**
- person: Demographics (birth date, sex)
- death: Death dates (if applicable)
- condition_occurrence: Diagnosis codes
- observation: Additional diagnosis codes
- concept: Code definitions and vocabularies

In [None]:
def current_age_query(ds: str, participant_ids: tuple) -> str:
    """
    Generate SQL query to calculate participant ages.

    WHAT IT CALCULATES:
    - Date of birth
    - Year of birth
    - Current age (or age at death if deceased)
    - Age squared (for polynomial regression models)
    - Age cubed (for polynomial regression models)

    HOW IT WORKS:
    1. Joins person table with death table (left join, so NULL if alive)
    2. Uses death date if available, otherwise current date
    3. Calculates age in years (accounting for leap years: 365.2425 days/year)
    4. Computes polynomial terms for flexible age modeling

    :param ds: Google BigQuery dataset ID containing OMOP data tables
    :type ds: str
    :param participant_ids: Tuple of participant IDs to query
    :type participant_ids: tuple
    :return: SQL query string
    :rtype: str

    EXAMPLE:
        participant_ids = (1000001, 1000002, 1000003)
        query = current_age_query(WORKSPACE_CDR, participant_ids)
        age_df = polars_gbq(query)
    """
    query = f"""
        SELECT
            DISTINCT p.person_id,
            EXTRACT(DATE FROM DATETIME(birth_datetime)) AS date_of_birth,
            EXTRACT(YEAR FROM DATETIME(birth_datetime)) AS year_of_birth,
            DATETIME_DIFF(
                IF(DATETIME(death_datetime) IS NULL, CURRENT_DATETIME(), DATETIME(death_datetime)),
                DATETIME(birth_datetime),
                DAY
            )/365.2425 AS current_age,
            POW(DATETIME_DIFF(
                IF(DATETIME(death_datetime) IS NULL, CURRENT_DATETIME(), DATETIME(death_datetime)),
                DATETIME(birth_datetime),
                DAY
            )/365.2425, 2) AS current_age_squared,
            POW(DATETIME_DIFF(
                IF(DATETIME(death_datetime) IS NULL, CURRENT_DATETIME(), DATETIME(death_datetime)),
                DATETIME(birth_datetime),
                DAY
            )/365.2425, 3) AS current_age_cubed
        FROM
            {ds}.person AS p
        LEFT JOIN
            {ds}.death AS d
        ON
            p.person_id = d.person_id
        WHERE
            p.person_id IN {participant_ids}
    """
    return query

In [None]:
def ehr_dx_code_query(ds: str, participant_ids: tuple) -> str:
    """
    Generate SQL query to calculate EHR utilization metrics.

    WHAT IT CALCULATES:
    - Last EHR date: Most recent diagnosis code entry
    - EHR length: Years from first to last diagnosis code (measure of data span)
    - Dx code occurrence count: Total number of diagnosis codes recorded
    - Dx condition count: Number of unique diagnosis codes
    - Age at last EHR event: Age at most recent diagnosis

    WHY THIS IS USEFUL:
    - Controls for healthcare utilization bias in analyses
    - Participants with more EHR data have more opportunities for diagnoses
    - Common covariates in epidemiological studies

    DATA SOURCES:
    - Condition_occurrence table (primary source for diagnosis codes)
    - Observation table (supplementary diagnosis codes)
    - Both ICD9CM and ICD10CM vocabularies

    HOW IT WORKS:
    1. Unions 4 queries to capture all diagnosis codes:
       - condition_occurrence.condition_source_value (direct code match)
       - condition_occurrence.condition_source_concept_id (concept ID match)
       - observation.observation_source_value (direct code match)
       - observation.observation_source_concept_id (concept ID match)
    2. Joins with person table to get birth dates
    3. Calculates summary statistics per person

    :param ds: Google BigQuery dataset ID containing OMOP data tables
    :type ds: str
    :param participant_ids: Tuple of participant IDs to query
    :type participant_ids: tuple
    :return: SQL query string
    :rtype: str

    EXAMPLE:
        participant_ids = (1000001, 1000002, 1000003)
        query = ehr_dx_code_query(WORKSPACE_CDR, participant_ids)
        ehr_df = polars_gbq(query)
    """
    query = f"""
        SELECT DISTINCT
            df1.person_id,
            MAX(date) AS last_ehr_date,
            (DATETIME_DIFF(MAX(date), MIN(date), DAY) + 1)/365.2425 AS ehr_length,
            COUNT(code) AS dx_code_occurrence_count,
            COUNT(DISTINCT(code)) AS dx_condition_count,
            DATETIME_DIFF(MAX(date), MIN(birthday), DAY)/365.2425 AS age_at_last_ehr_event,
            POW(DATETIME_DIFF(MAX(date), MIN(birthday), DAY)/365.2425, 2) AS age_at_last_ehr_event_squared,
            POW(DATETIME_DIFF(MAX(date), MIN(birthday), DAY)/365.2425, 3) AS age_at_last_ehr_event_cubed
        FROM
            (
                -- Query 1: condition_occurrence with direct code match
                (
                SELECT DISTINCT
                    co.person_id,
                    co.condition_start_date AS date,
                    c.concept_code AS code
                FROM
                    {ds}.condition_occurrence AS co
                INNER JOIN
                    {ds}.concept AS c
                ON
                    co.condition_source_value = c.concept_code
                WHERE
                    c.vocabulary_id IN ("ICD9CM", "ICD10CM")
                    AND
                    person_id IN {participant_ids}
                )
            UNION DISTINCT
                -- Query 2: condition_occurrence with concept ID match
                (
                SELECT DISTINCT
                    co.person_id,
                    co.condition_start_date AS date,
                    c.concept_code AS code
                FROM
                    {ds}.condition_occurrence AS co
                INNER JOIN
                    {ds}.concept AS c
                ON
                    co.condition_source_concept_id = c.concept_id
                WHERE
                    c.vocabulary_id IN ("ICD9CM", "ICD10CM")
                    AND
                    person_id IN {participant_ids}
                )
            UNION DISTINCT
                -- Query 3: observation with direct code match
                (
                SELECT DISTINCT
                    o.person_id,
                    o.observation_date AS date,
                    c.concept_code AS code
                FROM
                    {ds}.observation AS o
                INNER JOIN
                    {ds}.concept AS c
                ON
                    o.observation_source_value = c.concept_code
                WHERE
                    c.vocabulary_id IN ("ICD9CM", "ICD10CM")
                    AND
                    person_id IN {participant_ids}
                )
            UNION DISTINCT
                -- Query 4: observation with concept ID match
                (
                SELECT DISTINCT
                    o.person_id,
                    o.observation_date AS date,
                    c.concept_code AS code
                FROM
                    {ds}.observation AS o
                INNER JOIN
                    {ds}.concept AS c
                ON
                    o.observation_source_concept_id = c.concept_id
                WHERE
                    c.vocabulary_id IN ("ICD9CM", "ICD10CM")
                    AND
                    person_id IN {participant_ids}
                )
            ) AS df1
        INNER JOIN
            (
                SELECT
                    person_id,
                    EXTRACT(DATE FROM DATETIME(birth_datetime)) AS birthday
                FROM
                    {ds}.person
                WHERE
                    person_id IN {participant_ids}
            ) AS df2
        ON
            df1.person_id = df2.person_id
        GROUP BY
            df1.person_id
    """
    return query

In [None]:
def sex_at_birth_query(ds: str, participant_ids: tuple) -> str:
    """
    Generate SQL query to extract sex at birth.

    ENCODING:
    - Male = 1
    - Female = 0

    DATA SOURCE:
    - Uses sex_at_birth_source_concept_id from person table
    - Concept ID 1585846 = Male
    - Concept ID 1585847 = Female

    NOTE:
    - This is self-reported sex at birth from surveys
    - For genetics analyses, you may also want to use genomic sex (from ploidy)

    :param ds: Google BigQuery dataset ID containing OMOP data tables
    :type ds: str
    :param participant_ids: Tuple of participant IDs to query
    :type participant_ids: tuple
    :return: SQL query string
    :rtype: str

    EXAMPLE:
        participant_ids = (1000001, 1000002, 1000003)
        query = sex_at_birth_query(WORKSPACE_CDR, participant_ids)
        sex_df = polars_gbq(query)
    """
    query = f"""
        SELECT
            *
        FROM
            (
                -- Male participants
                (
                SELECT
                    person_id,
                    1 AS sex_at_birth
                FROM
                    {ds}.person
                WHERE
                    sex_at_birth_source_concept_id = 1585846
                AND
                    person_id IN {participant_ids}
                )
            UNION DISTINCT
                -- Female participants
                (
                SELECT
                    person_id,
                    0 AS sex_at_birth
                FROM
                    {ds}.person
                WHERE
                    sex_at_birth_source_concept_id = 1585847
                AND
                    person_id IN {participant_ids}
                )
            )
    """
    return query

# SECTION 4: CONFIGURATION AND PATHS

Set up paths to genomic data and database connections.

**IMPORTANT:** You must run `_reference/verily/00_setup_workspace.ipynb` first!

That notebook sets up the WORKSPACE_CDR and WORKSPACE_BUCKET environment variables that are needed for all All of Us analyses.

In [None]:
# Get workspace environment variables
WORKSPACE_CDR = os.environ.get('WORKSPACE_CDR')
WORKSPACE_BUCKET = os.environ.get('WORKSPACE_BUCKET')
GOOGLE_PROJECT = os.environ.get('GOOGLE_PROJECT')

# Validate environment setup
if WORKSPACE_CDR is None:
    print("ERROR: WORKSPACE_CDR not set!")
    print("Please run _reference/verily/00_setup_workspace.ipynb first")
    sys.exit(1)

print(f"Workspace CDR: {WORKSPACE_CDR}")
print(f"Workspace Bucket: {WORKSPACE_BUCKET}")
print(f"Google Project: {GOOGLE_PROJECT}")

## GENOMIC DATA PATHS

These paths point to the All of Us genomic data in Google Cloud Storage.

**cdr8_mt_path:**
- Hail MatrixTable (.mt) containing all genotype data for CDR v8
- Contains ~245,000 participants with whole genome sequencing
- Format: Hail MatrixTable (requires Hail to read)
- Location: Controlled-tier All of Us dataset

**ancestry_pred_path:**
- TSV file with predicted genetic ancestry for each participant
- Includes ancestry labels (EUR, AFR, AMR, EAS, SAS)
- Includes first 16 principal components (PCs) from ancestry PCA
- Used for ancestry stratification in GWAS

In [None]:
# All of Us CDR v8 genomic data (Hail MatrixTable)
cdr8_mt_path = "gs://fc-aou-datasets-controlled/v8/wgs/short_read/snpindel/acaf_threshold/splitMT/hail.mt"

# Ancestry predictions - SET THIS TO YOUR ANCESTRY FILE PATH
# The user should provide this path
ancestry_pred_path = "YOUR_ANCESTRY_PREDICTIONS_PATH_HERE"
# Example: ancestry_pred_path = f"{WORKSPACE_BUCKET}/data/ancestry_predictions.tsv"

# SECTION 5: GENOTYPE EXTRACTION

This is the main workflow for extracting genotypes across all participants.

## STEPS:
1. Define variant(s) of interest (chromosome, position, alleles)
2. Initialize Hail and load genomic data
3. For each variant:
   - Filter to variant of interest
   - Extract genotypes for all participants
   - Format genotypes as strings (0/0, 0/1, 1/1)
4. Merge all variants into a single DataFrame
5. Save genotype matrix file

## OUTPUT:
A wide-format DataFrame where:
- Rows = participants (person_id)
- Columns = variants (chr:pos:ref:alt)
- Values = genotypes ("0/0", "0/1", "1/1", or null)

In [None]:
def extract_genotypes_for_variants(
    variant_list: list,
    reference_genome: str = "GRCh38",
    output_file_path: str = None
) -> pl.DataFrame:
    """
    Extract genotypes for all participants across multiple specified variants.

    PARAMETERS:
    -----------
    variant_list : list of dict
        List of variants to query. Each variant is a dictionary with keys:
        - 'chromosome': int (chromosome number, 1-22, or 23 for X, 24 for Y)
        - 'position': int (genomic position in base pairs)
        - 'ref': str (reference allele)
        - 'alt': str (alternative allele)

        Example:
        [
            {'chromosome': 19, 'position': 39248514, 'ref': 'TT', 'alt': 'G'},
            {'chromosome': 19, 'position': 39247938, 'ref': 'G', 'alt': 'A'}
        ]

    reference_genome : str
        Reference genome version ("GRCh37" or "GRCh38")
        All of Us uses GRCh38 by default

    output_file_path : str
        Path to save cohort TSV file
        If None, will create default filename

    RETURNS:
    --------
    pl.DataFrame
        DataFrame with columns:
        - person_id: participant ID
        - One column per variant with name format "chr#:pos:ref:alt"
        - Each variant column contains genotype strings: "0/0", "0/1", "1/1", or null

    EXAMPLE:
    --------
    # Extract genotypes for two variants
    variants = [
        {'chromosome': 19, 'position': 39248514, 'ref': 'TT', 'alt': 'G'},
        {'chromosome': 19, 'position': 39247938, 'ref': 'G', 'alt': 'A'}
    ]
    
    genotype_df = extract_genotypes_for_variants(
        variant_list=variants,
        output_file_path="genotypes.tsv"
    )
    """

    # ========================================================================
    # STEP 1: VALIDATE INPUTS
    # ========================================================================

    print("\n" + "="*70)
    print("GENOTYPE EXTRACTION FOR MULTIPLE VARIANTS")
    print("="*70)

    if not variant_list or len(variant_list) == 0:
        print("‚ùå ERROR: variant_list is empty. Please provide at least one variant.")
        sys.exit(1)

    print(f"\nüìç Extracting genotypes for {len(variant_list)} variant(s):")
    for i, var in enumerate(variant_list, 1):
        print(f"   {i}. chr{var['chromosome']}:{var['position']}:{var['ref']}:{var['alt']}")

    # ========================================================================
    # STEP 2: INITIALIZE HAIL
    # ========================================================================

    import hail as hl

    print("\n‚öôÔ∏è  Initializing Hail...")
    try:
        hl.init(default_reference=reference_genome)
        print("‚úì Hail initialized successfully")
    except Exception as err:
        if "IllegalArgumentException" not in str(err):
            raise
        else:
            print("‚úì Hail already initialized (skipping)")

    # ========================================================================
    # STEP 3: LOAD GENOMIC DATA
    # ========================================================================

    print(f"\nüìÇ Loading genomic data from:")
    print(f"   {cdr8_mt_path}")

    mt = hl.read_matrix_table(cdr8_mt_path)
    print(f"‚úì MatrixTable loaded")

    # ========================================================================
    # STEP 4: PROCESS EACH VARIANT
    # ========================================================================

    all_genotype_dfs = []

    for var_idx, var in enumerate(variant_list, 1):
        chromosome_number = var['chromosome']
        genomic_position = var['position']
        ref_allele = var['ref']
        alt_allele = var['alt']

        print(f"\n{'='*70}")
        print(f"PROCESSING VARIANT {var_idx}/{len(variant_list)}")
        print(f"{'='*70}")

        # Construct variant string
        alleles = f"{ref_allele}:{alt_allele}"
        base_locus = f"{chromosome_number}:{genomic_position}"

        if reference_genome == "GRCh38":
            locus = "chr" + base_locus
        elif reference_genome == "GRCh37":
            locus = base_locus
        else:
            print("‚ùå Invalid reference version. Allowed inputs are 'GRCh37' or 'GRCh38'.")
            sys.exit(1)

        variant_string = locus + ":" + alleles
        variant_col_name = f"{base_locus}:{ref_allele}:{alt_allele}"
        
        print(f"\nüîç Searching for variant: {variant_string}")
        print(f"   Column name: {variant_col_name}")

        # Parse variant
        variant = hl.parse_variant(variant_string, reference_genome=reference_genome)

        # Filter to locus
        mt_variant = mt.filter_rows(mt.locus == hl.Locus.parse(locus))

        n_variants_at_locus = mt_variant.count_rows()
        if n_variants_at_locus == 0:
            print(f"‚ùå WARNING: Locus {locus} not found in dataset!")
            print(f"   This variant will be skipped.")
            continue
        else:
            print(f"‚úì Found {n_variants_at_locus} variant(s) at locus {locus}")

        # Handle multi-allelic sites
        allele_count_df = spark_to_polars(mt_variant.entries().select("info").to_spark())

        if len(allele_count_df) > 0 and allele_count_df["info.AF"][0] is not None:
            allele_count = len(allele_count_df["info.AF"][0])

            if allele_count > 1:
                print(f"\n‚ö†Ô∏è  Multi-allelic site detected ({allele_count} alt alleles)")
                print("   Splitting multi-allelic variants...")
                mt_variant = hl.split_multi_hts(mt_variant)
                print("‚úì Split complete")

        # Filter to exact variant
        print(f"\nüéØ Filtering to exact variant: {variant_string}")
        mt_variant = mt_variant.filter_rows(
            (mt_variant.locus == variant["locus"]) &
            (mt_variant.alleles == variant["alleles"])
        )

        n_variants = mt_variant.count_rows()
        if n_variants == 0:
            print(f"‚ùå WARNING: Variant {variant_string} not found!")
            print(f"   This variant will be skipped.")
            continue
        else:
            print(f"‚úì Variant {variant_string} found!")

        # Extract genotypes
        print("\nüìä Extracting genotypes for all participants...")

        spark_df = mt_variant.entries().select("GT").to_spark()
        polars_df = spark_to_polars(spark_df)

        print(f"‚úì Extracted genotypes for {len(polars_df):,} participants")

        # Format genotypes
        print("\nüîÑ Converting genotype format...")

        polars_df = polars_df.with_columns(
            pl.col("GT.alleles").list.get(0).cast(pl.Utf8).alias("GT0"),
            pl.col("GT.alleles").list.get(1).cast(pl.Utf8).alias("GT1"),
        )

        polars_df = polars_df.with_columns(
            (pl.col("GT0") + "/" + pl.col("GT1")).alias(variant_col_name)
        )

        # Rename 's' to 'person_id'
        polars_df = polars_df.rename({"s": "person_id"})
        polars_df = polars_df.with_columns(pl.col("person_id").cast(int))

        # Select only person_id and genotype column
        polars_df = polars_df.select(["person_id", variant_col_name])

        print("‚úì Genotypes formatted")

        # Show genotype distribution
        print(f"\nüìà Genotype distribution for {variant_col_name}:")
        gt_counts = polars_df.group_by(variant_col_name).agg(
            pl.count().alias("count")
        ).sort("count", descending=True)
        
        for row in gt_counts.iter_rows(named=True):
            count = row['count']
            gt = row[variant_col_name]
            if count < 20:
                print(f"   {gt}: <20 participants")
            else:
                print(f"   {gt}: {count:,} participants")

        all_genotype_dfs.append(polars_df)

    # ========================================================================
    # STEP 5: MERGE ALL VARIANTS
    # ========================================================================

    if len(all_genotype_dfs) == 0:
        print("\n‚ùå ERROR: No variants were successfully extracted!")
        return None

    print(f"\n{'='*70}")
    print("MERGING VARIANTS")
    print(f"{'='*70}")

    # Start with first variant
    merged_df = all_genotype_dfs[0]
    print(f"\n‚úì Starting with {len(merged_df):,} participants from variant 1")

    # Join remaining variants
    for i, df in enumerate(all_genotype_dfs[1:], 2):
        print(f"\nüîó Joining variant {i}...")
        merged_df = merged_df.join(df, on="person_id", how="outer")
        print(f"   ‚úì Now have {len(merged_df):,} participants total")

    # ========================================================================
    # STEP 6: SAVE AND SUMMARIZE
    # ========================================================================

    print(f"\n{'='*70}")
    print("FINAL GENOTYPE MATRIX SUMMARY")
    print(f"{'='*70}")

    print(f"\n‚úÖ Total participants: {len(merged_df):,}")
    print(f"‚úÖ Total variants: {len(merged_df.columns) - 1}")
    print(f"\nüìã Columns: {', '.join(merged_df.columns)}")

    # Count participants with data for all variants
    non_null_counts = merged_df.select(
        [pl.col(c).is_not_null().sum().alias(c) for c in merged_df.columns if c != "person_id"]
    )
    
    print(f"\nüìä Participants with genotype data per variant:")
    for col in merged_df.columns:
        if col != "person_id":
            count = merged_df.filter(pl.col(col).is_not_null()).shape[0]
            print(f"   {col}: {count:,}")

    # Set default output filename
    if output_file_path is None:
        output_file_path = "genotype_matrix.tsv"

    # Save to file
    merged_df.write_csv(output_file_path, separator="\t")
    print(f"\nüíæ Genotype matrix saved to: {output_file_path}")

    print("\n" + "="*70 + "\n")

    return merged_df

# SECTION 6: EXAMPLE USAGE

This section demonstrates how to use the extract_genotypes_for_variants() function.

**EXAMPLE 1: Single variant**
Extract genotypes for one variant across all participants.

**EXAMPLE 2: Multiple variants**
Extract genotypes for multiple variants simultaneously.
The function will create a wide-format DataFrame with one column per variant.

In [None]:
print("\n" + "="*70)
print("EXAMPLE 1: Single variant - GPR15 rs28688207")
print("="*70)
print("\nExtracting genotypes for one variant associated with IBD")
print("Reference: chr2:233,269,839 C>T (GRCh38)")

# Define single variant
variants_single = [
    {
        'chromosome': 2,
        'position': 233269839,
        'ref': 'C',
        'alt': 'T'
    }
]

# Extract genotypes
genotypes_single = extract_genotypes_for_variants(
    variant_list=variants_single,
    reference_genome="GRCh38",
    output_file_path="gpr15_genotypes.tsv"
)

if genotypes_single is not None:
    print("\nFirst 10 rows:")
    print(genotypes_single.head(10))

print("\n" + "="*70)
print("EXAMPLE 2: Multiple variants - Custom variant list")
print("="*70)
print("\nExtracting genotypes for multiple variants simultaneously")

# Define multiple variants
# User can specify any variants they want to analyze
variants_multiple = [
    {
        'chromosome': 19,
        'position': 39248514,
        'ref': 'TT',
        'alt': 'G'
    },
    {
        'chromosome': 19,
        'position': 39247938,
        'ref': 'G',
        'alt': 'A'
    }
]

# Extract genotypes for all variants
genotypes_multiple = extract_genotypes_for_variants(
    variant_list=variants_multiple,
    reference_genome="GRCh38",
    output_file_path="multi_variant_genotypes.tsv"
)

if genotypes_multiple is not None:
    print("\nFirst 10 rows of multi-variant matrix:")
    print(genotypes_multiple.head(10))
    
    print("\n" + "="*70)
    print("GENOTYPE MATRIX STRUCTURE")
    print("="*70)
    print("\nThe output DataFrame has:")
    print(f"  - {len(genotypes_multiple):,} participants (rows)")
    print(f"  - {len(genotypes_multiple.columns) - 1} variant(s) (columns)")
    print(f"  - Columns: {', '.join(genotypes_multiple.columns)}")
    print("\nEach variant column contains genotypes:")
    print("  - '0/0' = homozygous reference")
    print("  - '0/1' = heterozygous")
    print("  - '1/1' = homozygous alternative")
    print("  - null = missing genotype data")

# SECTION 7: OPTIONAL - ADDING CLINICAL DATA

After creating the genotype matrix, you may want to add:
- Demographics (age, sex)
- Clinical data (EHR length, diagnosis counts)
- Ancestry information (genetic ancestry, PCs)

This section shows how to use the SQL query functions to retrieve this data and join it with your genotype matrix.

**Use case**: Combine genotype data with phenotype data for association analyses.

In [None]:
def add_clinical_data_example(genotype_df: pl.DataFrame) -> pl.DataFrame:
    """
    Example function showing how to add clinical data to a genotype matrix.

    This is for educational purposes - shows the workflow step-by-step.

    :param genotype_df: Genotype matrix with person_id column and variant columns
    :return: Enhanced genotype matrix with clinical data
    """

    print("\n" + "="*70)
    print("ADDING CLINICAL DATA TO GENOTYPE MATRIX")
    print("="*70)

    # Get list of participant IDs
    participant_ids = tuple(genotype_df["person_id"].unique().to_list())
    print(f"\nüìã Adding data for {len(participant_ids):,} participants")

    # ------------------------------------------------------------------------
    # AGE DATA
    # ------------------------------------------------------------------------
    print("\n1Ô∏è‚É£  Retrieving age data...")
    age_query = current_age_query(WORKSPACE_CDR, participant_ids)
    age_df = polars_gbq(age_query)
    print(f"   ‚úì Retrieved age data for {len(age_df):,} participants")

    # Join with genotype matrix
    genotype_df = genotype_df.join(
        age_df[["person_id", "current_age", "year_of_birth"]],
        how="left",
        on="person_id"
    )

    # ------------------------------------------------------------------------
    # SEX AT BIRTH
    # ------------------------------------------------------------------------
    print("\n2Ô∏è‚É£  Retrieving sex at birth...")
    sex_query = sex_at_birth_query(WORKSPACE_CDR, participant_ids)
    sex_df = polars_gbq(sex_query)
    print(f"   ‚úì Retrieved sex data for {len(sex_df):,} participants")

    # Join with genotype matrix
    genotype_df = genotype_df.join(sex_df, how="left", on="person_id")

    # ------------------------------------------------------------------------
    # EHR UTILIZATION DATA
    # ------------------------------------------------------------------------
    print("\n3Ô∏è‚É£  Retrieving EHR utilization data...")
    ehr_query = ehr_dx_code_query(WORKSPACE_CDR, participant_ids)
    ehr_df = polars_gbq(ehr_query)
    print(f"   ‚úì Retrieved EHR data for {len(ehr_df):,} participants")

    # Join with genotype matrix
    genotype_df = genotype_df.join(
        ehr_df[["person_id", "ehr_length", "dx_code_occurrence_count", "age_at_last_ehr_event"]],
        how="left",
        on="person_id"
    )

    # ------------------------------------------------------------------------
    # SUMMARY
    # ------------------------------------------------------------------------
    print("\n" + "="*70)
    print("ENHANCED GENOTYPE MATRIX SUMMARY")
    print("="*70)

    print(f"\nüìä Matrix size: {len(genotype_df):,} participants")
    print(f"\nüìã Columns: {', '.join(genotype_df.columns)}")
    
    # Separate genotype columns from clinical columns
    genotype_cols = [c for c in genotype_df.columns if ':' in c]
    clinical_cols = [c for c in genotype_df.columns if c not in genotype_cols and c != 'person_id']
    
    print(f"\nüß¨ Genotype columns ({len(genotype_cols)}): {', '.join(genotype_cols)}")
    print(f"üè• Clinical columns ({len(clinical_cols)}): {', '.join(clinical_cols)}")

    print("\nüìà Summary statistics (clinical data):")
    if clinical_cols:
        print(genotype_df.select(clinical_cols).describe())

    # Save enhanced genotype matrix
    output_path = "enhanced_genotype_matrix.tsv"
    genotype_df.write_csv(output_path, separator="\t")
    print(f"\nüíæ Enhanced genotype matrix saved to: {output_path}")

    return genotype_df


# Example usage of clinical data addition:
# if genotypes_multiple is not None:
#     enhanced_genotype_matrix = add_clinical_data_example(genotypes_multiple)

# END OF EDUCATIONAL WORKFLOW

## NEXT STEPS:
1. Modify the variant_list for your own variants of interest
2. Add clinical/demographic data as needed (see Section 7 for examples)
3. Perform downstream analyses:
   - Association testing (compare genotypes across phenotypes)
   - PheWAS (phenome-wide association studies)
   - Stratification analysis
   - Visualization (genotype distributions, allele frequencies)

## USING YOUR GENOTYPE MATRIX:
The extracted genotype matrix can be used for many analyses:

**Example 1: Filter to specific genotypes**
```python
# Get participants with heterozygous genotype at first variant
het_carriers = genotype_df.filter(
    pl.col("19:39248514:TT:G") == "0/1"
)
```

**Example 2: Count genotypes**
```python
# Count each genotype for a variant
genotype_df.group_by("19:39248514:TT:G").agg(pl.count())
```

**Example 3: Compound genotype analysis**
```python
# Find participants with specific genotype combinations
compound = genotype_df.filter(
    (pl.col("19:39248514:TT:G") == "0/1") &
    (pl.col("19:39247938:G:A") == "1/1")
)
```

## RESOURCES:
- Hail documentation: https://hail.is/docs/0.2/
- All of Us Data Browser: https://databrowser.researchallofus.org/
- OMOP CDM documentation: https://ohdsi.github.io/CommonDataModel/

## TROUBLESHOOTING:
- "Variant not found": Check chromosome, position, ref/alt alleles match GRCh38
- "Hail initialization error": Restart kernel and try again
- "Out of memory": Use fewer variants or request larger machine
- "Multi-allelic warning": This is normal - the function handles it automatically