In [None]:
from pyspark.sql import SparkSession
import hail as hl
import os
import time
import dxpy
import logging
import pandas as pd
import re


# Build spark
builder = (
    SparkSession
    .builder
    .enableHiveSupport()
)
spark = builder.getOrCreate()
hl.init(sc=spark.sparkContext)

In [None]:
def get_rare_variants(mt):
    """
    Returns a matrix table with alt allele frequency < 0.001
    """
    mt = mt.annotate_rows(gt_stats = hl.agg.call_stats(mt.GT, mt.alleles))
    mt = mt.filter_rows((mt.gt_stats.AF[1] < 0.001) & (mt.gt_stats.AC[1] > 1))
    return mt


def add_annotations(mt, vep_file="file:///mnt/project/exome_annot/annot_run/vep_config_109_v4.json"):
    """
    Add vep and dbnsfp annotations
    """
    mt = hl.vep(mt, vep_file) # annot table with vep
    db = hl.experimental.DB(region='us', cloud='aws')
    mt = db.annotate_rows_db(mt, 'dbNSFP_variants') # add dbNSFP annotations
    return mt



def get_protein_coding_variants(mt):
    """
    Search for protein coding transcript consequences.
    """
    mt = mt.filter_rows(hl.any(lambda x: x=="protein_coding", mt.vep.transcript_consequences.biotype))
    return mt


def create_deleteriousness_scores(mt):
    metrics = ["SIFT", "LRT", "FATHMM", "PROVEAN", "MetaSVM", "MetaLR", "PrimateAI", "DEOGEN2"] # 
    # metrics with D as deleterious and others as tolerant
    kwd_dict = {f"{m}_pred":hl.if_else(hl.any(lambda x: x.contains("D"), mt.dbNSFP_variants[f"{m}_pred"]), 1., 0) for m in metrics}
    mt = mt.annotate_rows(**kwd_dict)
    # MutationAssessor
    mt = mt.annotate_rows(MutationAssessor_pred=hl.if_else(hl.any(lambda x: x.contains("H"), mt.dbNSFP_variants["MutationAssessor_pred"]), 1., 0))
    metrics = metrics + ["MutationAssessor"]
    cols2sum = [f"{m}_pred" for m in metrics]
    mt = mt.annotate_rows(del_score=hl.sum([mt[col] for col in cols2sum]))
    return mt


def create_vtype_annotations(mt):
    lof_mutations = hl.set([
        "transcript_ablation", "stop_gained", "frameshift_variant", "stop_lost", "start_lost"
    ])
    missense_mutations = hl.set(["missense_variant"])
    splice_lof_mutations = hl.set([
        "splice_acceptor_variant", "splice_donor_variant"
    ])
    splice_mutations = hl.set(["splice_donor_5th_base_variant", 
        "splice_region_variant", "splice_donor_region_variant", "splice_polypyrimidine_tract_variant"])

    mt = mt.annotate_rows(
        lof = hl.len(lof_mutations.intersection(hl.set(hl.flatten(mt.vep.transcript_consequences.consequence_terms)))) != 0,
        missense = hl.len(missense_mutations.intersection(hl.set(hl.flatten(mt.vep.transcript_consequences.consequence_terms)))) != 0,
        splice_lof = hl.len(splice_lof_mutations.intersection(hl.set(hl.flatten(mt.vep.transcript_consequences.consequence_terms)))) != 0,
        splice = hl.len(splice_mutations.intersection(hl.set(hl.flatten(mt.vep.transcript_consequences.consequence_terms)))) != 0,
    )
    return mt


def get_deleterious(mt):
    # keep all lof (loftee score maybe?) and deleterious missense variants (filter by deleteriousness score, majority vote is 5/9).
    mt = mt.filter_rows((mt.lof==True)|(mt.splice_lof==True)|(mt.missense==True)|(mt.splice==True))
    return mt


def add_gene_info(mt):
    mt =  mt.annotate_rows(gene=mt.vep.transcript_consequences.gene_symbol)
    mt = mt.explode_rows(mt.gene)
    return mt


def keep_relevant_columns(mt):
    mt = mt.select_rows(
        mt.gene, mt.lof, mt.missense, mt.splice_lof, mt.splice, mt.del_score,
    )
    return mt


def get_annot_table(mt):
    # split multi-allelic hits to bi-allelic
    mt_filtered = hl.split_multi_hts(mt, permit_shuffle=True)
    # filter for rare variants only
    mt_filtered = get_rare_variants(mt_filtered)
    # add vep and dbnsfp annotations
    mt_filtered = add_annotations(mt_filtered)
    # Only keep protein coding variants
    mt_filtered = get_protein_coding_variants(mt_filtered)
    # create deleteriousness scores for all variants
    mt_filtered = create_deleteriousness_scores(mt_filtered)
    # create vtype annotations
    mt_filtered = create_vtype_annotations(mt_filtered)
    # filter for lof variants or variants above deletrious score threshold
    mt_filtered = get_deleterious(mt_filtered)
    # add gene info and explode by gene
    mt_filtered = add_gene_info(mt_filtered)
    # only keep important columns
    mt_filtered = keep_relevant_columns(mt_filtered)
    # add sample info per variant
    mt_filtered = mt_filtered.annotate_rows(samples = hl.bind(lambda x: hl.delimit(x, ","), hl.agg.filter(mt_filtered.GT.n_alt_alleles() > 0, hl.agg.collect(mt_filtered.s))))
    # get burden table
    annot_table = mt_filtered.rows()
    annot_df = annot_table.to_pandas()
    annot_df["alleles"] = annot_df.alleles.apply(lambda x: "_".join(x))
    annot_df = annot_df.drop_duplicates().reset_index(drop=True)
    return annot_df


def upload_file_to_project(filename, proj_dir):
    dxpy.upload_local_file(filename, folder=proj_dir, parents=True)
    print(f"*********{filename} uploaded!!*********")
    os.remove(filename)
    return

In [None]:
vcf_dir = "/mnt/project/Bulk/Exome sequences/Population level exome OQFE variants, pVCF format - final release/"
chr_num = "X"
vcf_files = sorted(["file://" + os.path.join(vcf_dir, fp) for fp in os.listdir(vcf_dir) if (f"_c{chr_num}_" in fp and fp.endswith("vcf.gz"))])

# Annotation configure
logging.basicConfig(filename=f"chr{chr_num}_annot_vep109.log", level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')

i=0

# change i to new value if the instance restarts
proj_dir = f"/mnt/project/exome_annot/annot_run/notebooks/chr{chr_num}/annot_tables_vep109/"

if os.path.exists(proj_dir):
    existing_files = sorted([fp for fp in os.listdir(proj_dir)], key=lambda x: int(''.join(filter(str.isdigit, x))))
    last_file = existing_files[-1]
    pattern = re.compile("^block_(\d+).tsv.gz$")
    m = re.match(pattern, last_file)
    i = int(m.groups()[0])
    
while i<len(vcf_files):
    time_start = time.time()
    
    # read the matrix table
    db_name = f"exome_chr{chr_num}"
    db_uri = dxpy.find_one_data_object(name=f"{db_name}".lower(), classname="database")['id']
    mt_name = f"block_{i}.mt"
    mt_url = f"dnax://{db_uri}/{mt_name}"
    mt = hl.read_matrix_table(mt_url)
    
    try:
        # create annot table
        annot_df = get_annot_table(mt)
        # save annot table to local
        annot_df_name = f"block_{i}.tsv.gz"
        annot_df.to_csv(annot_df_name, sep='\t')
        # upload table to project
        proj_dir = f"/exome_annot/annot_run/notebooks/chr{chr_num}/annot_tables_vep109/"
        upload_file_to_project(annot_df_name, proj_dir)

        time_end = time.time()
        time_taken = (time_end - time_start)/60
        logging.info(f"Time to annotate block {i}: {time_taken} mins\n")

        # remove tmp files created by hail to prevent storage issues 
        tmp_dir = "/tmp/"
        for file in os.listdir(tmp_dir):
            if file.startswith("persist_Table"):
                os.remove(os.path.join(tmp_dir, file))
                
    except Exception as error:
        logging.warning(f"block {i} not annotated due to {error}\n")
        print(f"!!!!!!!!block {i} not annotated!!!!!!!!")
        
    i+=1

In [None]:
hl.stop()
spark.sparkContext.stop()
spark.stop()