In [None]:
from pyspark.sql import SparkSession
import hail as hl
import os
import time
import dxpy
import logging
import pandas as pd
import re


# Had to set the configuration to navigate RDD partition error
# Build spark
builder = (
    SparkSession
    .builder
    .appName("Variant annotation")  # Set a meaningful application name
    .config("spark.driver.memory", "96g")  # Set driver memory (e.g., 8 GB)
    .config("spark.executor.memory", "108g")  # Set executor memory (e.g., 16 GB)
    .config("spark.executor.cores", "30")  # Optional: Set number of cores per executor 
    .enableHiveSupport()
)
spark = builder.getOrCreate()

hl.init(sc=spark.sparkContext, idempotent=True)


In [None]:
def save_in_hail_format(hail_obj, db_name, hail_obj_name, rerun):
    # Create DB if it does not exist
    stmt = f"CREATE DATABASE IF NOT EXISTS {db_name} LOCATION 'dnax://'"
    spark.sql(stmt).show()
    # Find database ID of newly created database using dxpy method
    db_uri = dxpy.find_one_data_object(name=f"{db_name}".lower(), classname="database")['id']
    # Write hail object
    url = f"dnax://{db_uri}/{hail_obj_name}"
    if rerun:
        hail_obj.write(url, overwrite=True)
    return url

def get_url(db_name, hail_obj_name):
    # Find database ID of newly created database using dxpy method
    db_uri = dxpy.find_one_data_object(name=f"{db_name}".lower(), classname="database")['id']
    # Write hail object
    url = f"dnax://{db_uri}/{hail_obj_name}"
    return url

def get_chrm_mt(chr_num):
    db_name = f"exomes"
    # Find database ID of newly created database using dxpy method
    db_uri = dxpy.find_one_data_object(name=f"{db_name}".lower(), classname="database")['id']
    url = f"dnax://{db_uri}/chr{chr_num}_vqc.mt"
    mt = hl.read_matrix_table(url)
    return mt

def upload_file_to_project(filename, proj_dir):
    dxpy.upload_local_file(filename, folder=proj_dir, parents=True)
    print(f"*********{filename} uploaded!!*********")
    os.remove(filename)
    return

def get_rare_variants(mt):
    """
    Returns a matrix table with alt allele frequency < 0.01
    """
    # with the filtered variants calculate 
    mt = mt.annotate_rows(gt_stats = hl.agg.call_stats(mt.GT, mt.alleles))
    # filter to keep rare (1%) variants and variants present in at least one sample
    mt = mt.filter_rows((mt.gt_stats.AF[1] < 0.01) & (mt.gt_stats.AC[1] > 0))
    # add maf and mac info
    mt = mt.annotate_rows(maf=mt.gt_stats.AF[1], mac=mt.gt_stats.AC[1])
    return mt

def variant_qc(mt):
    # annotate variant call rate, hwe p-value, min read depth
    mt = mt.annotate_rows(
        call_rate=mt.variant_qc.call_rate,
        p_value_hwe=mt.variant_qc.p_value_hwe,
        min_rd=mt.variant_qc.dp_stats.min,
    )
    return mt

def sample_qc(mt, sample_qc_annot_file="file:///mnt/project/notebooks/wes/sample_qc/data/flagged_samples.tsv"):
    sample_annot_ht = hl.import_table(sample_qc_annot_file)
    sample_annot_ht = sample_annot_ht.key_by('s')
    mt = mt.annotate_cols(sample_filters=sample_annot_ht[mt.s].filters)
    mt = mt.filter_cols(mt.sample_filters=="")
    return mt

def add_vep_annotations(mt, vep_file="file:///mnt/project/notebooks/wes/variant_annot/data/vep_config_109_v7.json"):
    """
    Add vep and dbnsfp annotations
    """
    # add vep annotations
    mt = hl.vep(mt, vep_file) # annot table with vep
    # combine multiple consequences for a single transcript into one string
    mt = mt.annotate_rows(consequences = hl.map(lambda x: hl.delimit(x, delimiter=";"), mt.vep.transcript_consequences.consequence_terms))
    # annotate genes, transcripts, consequences and biotype
    mt = mt.annotate_rows(gene_transcript_consequence_biotype = hl.zip(
        mt.vep.transcript_consequences.gene_symbol,
        mt.vep.transcript_consequences.transcript_id,
        mt.consequences,
        mt.vep.transcript_consequences.biotype,
        mt.vep.transcript_consequences.lof,
    ))
    # only keep relevant columns
    mt = mt.select_rows(
        mt.gene_transcript_consequence_biotype,
        mt.maf, mt.mac,
        mt.call_rate, mt.p_value_hwe, mt.min_rd
    )
    # explode by gene-trancsript-consequence column
    mt = mt.explode_rows("gene_transcript_consequence_biotype")
    
    # get plof and missense mutations
    lof_mutations = "stop_gained|frameshift_variant|stop_lost|start_lost"
    splice_lof_mutations = "splice_acceptor_variant|splice_donor_variant"
    missense_mutations = "missense_variant"
    
    mt = mt.annotate_rows(
        lof = mt.gene_transcript_consequence_biotype[2].matches(lof_mutations),
        missense = mt.gene_transcript_consequence_biotype[2].matches(missense_mutations),
        splice_lof = mt.gene_transcript_consequence_biotype[2].matches(splice_lof_mutations)
    )
    
    # filter for these mutation types
    mt = mt.filter_rows((mt.lof==True)|(mt.splice_lof==True)|(mt.missense==True))
    return mt

def create_deleteriousness_scores(mt):
    metrics = ["SIFT", "LRT", "FATHMM", "PROVEAN", "MetaSVM", "MetaLR", "PrimateAI", "DEOGEN2", "MutationAssessor"]
    kwd_dict = {f"{m}_pred": hl.dict(hl.zip(
            mt.dbNSFP_variants.genename,
            hl.map(lambda x: hl.dict(hl.zip(x[0].split(";"), x[1].split(";"))), hl.zip(mt.dbNSFP_variants.Ensembl_transcriptid, mt.dbNSFP_variants[f"{m}_pred"]))
    )) for m in metrics}
    mt = mt.annotate_rows(**kwd_dict)
    
    def get_del_score_func(gtcb, del_pred):
        gene = gtcb[0]
        transcript = gtcb[1]
        val = hl.if_else(del_pred.contains(gene) & del_pred[gene].contains(transcript) & (del_pred[gene][transcript]=="D"), 1, 0)
        return val
    
    kwd_dict = {f"{m}_pred": get_del_score_func(mt.gene_transcript_consequence_biotype, mt[f"{m}_pred"]) for m in metrics[:-1]}
    mt = mt.annotate_rows(**kwd_dict)

    def get_del_score_func_ma(gtcb, del_pred):
        gene = gtcb[0]
        transcript = gtcb[1]
        val = hl.if_else(del_pred.contains(gene) & del_pred[gene].contains(transcript) & (del_pred[gene][transcript]=="H"), 1, 0)
        return val
    
    mt = mt.annotate_rows(MutationAssessor_pred = get_del_score_func_ma(mt.gene_transcript_consequence_biotype, mt.MutationAssessor_pred))
    cols2sum = [f"{m}_pred" for m in metrics]
    mt = mt.annotate_rows(del_score=hl.sum([mt[col] for col in cols2sum]))
    return mt
    
def add_dbnsfp_annotations(mt):
    db = hl.experimental.DB(region='us', cloud='aws')
    mt = db.annotate_rows_db(mt, 'dbNSFP_variants') # add dbNSFP annotations
    mt = create_deleteriousness_scores(mt)
    return mt

def keep_deleterious_variants(mt):
    # filter to keep deleterious mutations only
    mt = mt.filter_rows((mt.lof==True)|(mt.splice_lof==True)|((mt.missense==True)&(mt.del_score>4)))
    # annotate properly
    mt = mt.annotate_rows(
        gene = mt.gene_transcript_consequence_biotype[0],
        transcript = mt.gene_transcript_consequence_biotype[1],
        consequence = mt.gene_transcript_consequence_biotype[2],
        biotype = mt.gene_transcript_consequence_biotype[3],
        loftee = mt.gene_transcript_consequence_biotype[4]
    )
    # only keep relevant row information
    mt  = mt.select_rows(
        mt.gene, mt.transcript, mt.consequence, mt.biotype, mt.loftee,
        mt.lof, mt.splice_lof, mt.missense, mt.del_score,
        mt.maf, mt.mac, mt.call_rate, mt.p_value_hwe, mt.min_rd
    )
    return mt

def add_sample_info(mt):
    # add sample info per variant
    mt = mt.annotate_rows(
        samples = hl.bind(lambda x: hl.delimit(x, ","), hl.agg.filter(mt.GT.n_alt_alleles() > 0, hl.agg.collect(mt.s))),
        hetz_samples = hl.bind(lambda x: hl.delimit(x, ","), hl.agg.filter(mt.GT.is_het(), hl.agg.collect(mt.s))),
        homo_samples = hl.bind(lambda x: hl.delimit(x, ","), hl.agg.filter(mt.GT.is_hom_var(), hl.agg.collect(mt.s))),
    )
    return mt

def get_annot_table(chr_num):
    # read chromosome file
    mt_filtered = get_chrm_mt(chr_num)
    # filter for rare variants only
    mt_filtered = get_rare_variants(mt_filtered)
    # annotate variant qc
    mt_filtered = variant_qc(mt_filtered)
    # annotate sample qc
    mt_filtered = sample_qc(mt_filtered)
    # add vep annotations
    mt_filtered = add_vep_annotations(mt_filtered)
    # add dbnsfp annotations
    mt_filtered = add_dbnsfp_annotations(mt_filtered)
    # keep deleterious variants
    mt_filtered = keep_deleterious_variants(mt_filtered)
    # add sample info
    mt_filtered = add_sample_info(mt_filtered)
    # get burden table
    annot_table = mt_filtered.rows()
    # save as hail ht
    url = save_in_hail_format(annot_table, "exomes", f"chr{chr_num}_annot.ht", rerun=True)
    annot_table = hl.read_table(url)
    # reload the table in pandas
    annot_df = annot_table.to_pandas()
    annot_df["alleles"] = annot_df.alleles.apply(lambda x: "_".join(x))
    return annot_df


In [None]:
chr_num = 1

In [None]:
annot_df = get_annot_table(chr_num)

In [None]:
proj_dir = f"/notebooks/wes/variant_annot/data/"
annot_df_name = f"chr{chr_num}.tsv.gz"
annot_df.to_csv(annot_df_name, sep='\t', index=False)
upload_file_to_project(annot_df_name, proj_dir)


In [None]:
hl.stop()
spark.sparkContext.stop()
spark.stop()
