In [None]:
from pyspark.sql import SparkSession
import hail as hl
import os
import time
import dxpy
import logging
import pandas as pd


# Build spark
builder = (
    SparkSession
    .builder
    .enableHiveSupport()
)
spark = builder.getOrCreate()
hl.init(sc=spark.sparkContext)



vcf_dir = "/mnt/project/Bulk/Exome sequences/Population level exome OQFE variants, pVCF format - final release/"
chr_num = "2"
vcf_files = sorted(["file://" + os.path.join(vcf_dir, fp) for fp in os.listdir(vcf_dir) if (f"_c{chr_num}_" in fp and fp.endswith("vcf.gz"))])



def get_rare_variants(mt):
    """
    Returns a matrix table with alt allele frequency < 0.0001
    """
    mt = mt.annotate_rows(gt_stats = hl.agg.call_stats(mt.GT, mt.alleles))
    mt = mt.filter_rows((mt.gt_stats.AF[1] < 0.0001) & (mt.gt_stats.AC[1] > 1))
    return mt


def add_annotations(mt, vep_file="file:///mnt/project/exome_annot/annot_run/vep_config_109_v2.json"):
    """
    Add vep, cadd and dbnsfp annotations
    """
    mt = hl.vep(mt, vep_file) # annot table with vep
    db = hl.experimental.DB(region='us', cloud='aws')
    mt = db.annotate_rows_db(mt, 'CADD', 'dbNSFP_variants') # add CADD, dbNSFP annotations
    return mt


def get_protein_coding_variants(mt):
    """
    Search for protein coding transcript consequences.
    """
    mt = mt.filter_rows(hl.any(lambda x: x=="protein_coding", mt.vep.transcript_consequences.biotype))
    return mt

def create_deleteriousness_scores(mt):
    metrics = ["SIFT", "LRT", "FATHMM", "PROVEAN", "MetaSVM", "MetaLR", "PrimateAI", "DEOGEN2"] # 
    # metrics with D as deleterious and others as tolerant
    kwd_dict = {f"{m}_pred":hl.if_else(hl.any(lambda x: x.contains("D"), mt.dbNSFP_variants[f"{m}_pred"]), 1., 0) for m in metrics}
    mt = mt.annotate_rows(**kwd_dict)
    # MutationAssessor
    mt = mt.annotate_rows(MutationAssessor_pred=hl.if_else(hl.any(lambda x: x.contains("H"), mt.dbNSFP_variants["MutationAssessor_pred"]), 1., 0))
    metrics = metrics + ["MutationAssessor"]
    cols2sum = [f"{m}_pred" for m in metrics]
    mt = mt.annotate_rows(del_score=hl.sum([mt[col] for col in cols2sum]), cadd=mt.CADD.PHRED_score)
    return mt

def create_lof_annotation(mt):
    lof_mutations = hl.set([
        "transcript_ablation", "splice_acceptor_variant", "splice_donor_variant",
        "stop_gained", "frameshift_variant", "stop_lost", "start_lost"
    ])
    mt = mt.annotate_rows(
        lof = hl.len(lof_mutations.intersection(hl.set(hl.flatten(mt.vep.transcript_consequences.consequence_terms)))) != 0,
        loftee_lof = mt.vep.transcript_consequences.lof,
        loftee_lof_flag = mt.vep.transcript_consequences.lof_flags,
        loftee_lof_filter = mt.vep.transcript_consequences.lof_filter,
        loftee_lof_info = mt.vep.transcript_consequences.lof_info
    )
    return mt

def get_deleterious(mt):
    # keep all lof (loftee score maybe?) and deleterious missense variants (filter by deleteriousness score, majority vote is 5/9).
    # keep cadd greater than ?
    mt = mt.filter_rows((mt.lof==True)|(mt.del_score>4))
    return mt

def add_gene_info(mt):
    mt =  mt.annotate_rows(gene=mt.vep.transcript_consequences.gene_symbol)
    mt = mt.explode_rows(mt.gene)
    return mt

def keep_relevant_columns(mt):
    mt = mt.select_rows(
        mt.del_score, mt.lof, mt.gene, mt.cadd, 
        mt.loftee_lof, mt.loftee_lof_flag, mt.loftee_lof_filter, mt.loftee_lof_info,
    )
    return mt

def create_burden_matrix(mt):
    mt_burden = mt.group_rows_by(mt.gene)
    mt_burden = mt_burden.aggregate(n_variants = hl.agg.count_where(mt.GT.n_alt_alleles() > 0))
    # filter to genes with at least one rare variant!
    mt_burden = mt_burden.filter_rows(hl.agg.sum(mt_burden.n_variants) > 0)
    return mt_burden

def merge_combine(df1, df2):
    df1 = df1.dropna()
    df2 = df2.dropna()
    df = df1.merge(df2, on="gene", how="outer")
    df = df.fillna("")
    df["samples"] = df["samples_x"] + "," + df["samples_y"]
    df = df.drop(columns=["samples_x", "samples_y"])
    df["samples"] = df.samples.str.strip(",")
    df["samples"] = df.samples.str.split(",").apply(lambda x: ",".join(set(x)))
    return df

def get_burden_table(mt):
    # split multi-allelic hits to bi-allelic
    mt_filtered = hl.split_multi_hts(mt, permit_shuffle=True)
    # filter for rare variants only
    mt_filtered = get_rare_variants(mt_filtered)
    # add vep and dbnsfp annotations
    mt_filtered = add_annotations(mt_filtered)
    # Only keep protein coding variants
    mt_filtered = get_protein_coding_variants(mt_filtered)
    # create deleteriousness scores for all variants
    mt_filtered = create_deleteriousness_scores(mt_filtered)
    # create lof annotations
    mt_filtered = create_lof_annotation(mt_filtered)
    # filter for lof variants or variants above deletrious score threshold
    mt_filtered = get_deleterious(mt_filtered)
    # add gene info and explode by gene
    mt_filtered = add_gene_info(mt_filtered)
    # only keep important columns
    mt_filtered = keep_relevant_columns(mt_filtered)
    # create burden matrix
    burden_mt = create_burden_matrix(mt_filtered)
    burden_mt = burden_mt.annotate_rows(samples = hl.bind(lambda x: hl.delimit(x, ","), hl.agg.filter(burden_mt.n_variants>0, hl.agg.collect(burden_mt.s))))
    # get burden table
    burden_table = burden_mt.rows()
    burden_df = burden_table.to_pandas()
    burden_df = burden_df.dropna()
    return burden_df

def upload_file_to_project(filename, proj_dir):
    dxpy.upload_local_file(filename, folder=proj_dir, parents=True)
    print(f"*********{filename} uploaded!!*********")
    os.remove(filename)
    return


# Annotation configure
logging.basicConfig(filename=f"chr{chr_num}_annot.log", level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')

burden_df = pd.DataFrame({"gene":[], "samples":[]})

for i in list(range(0, len(vcf_files))):
    time_start = time.time()
    
    # read the matrix table
    db_name = f"exome_chr{chr_num}"
    db_uri = dxpy.find_one_data_object(name=f"{db_name}".lower(), classname="database")['id']
    mt_name = f"block_{i}.mt"
    mt_url = f"dnax://{db_uri}/{mt_name}"
    mt = hl.read_matrix_table(mt_url)
    
    try:
        # create burden table
        burden_df = get_burden_table(mt)
        # save burden table to local
        burden_df_name = f"block_{i}.tsv"
        burden_df.to_csv(burden_df_name, sep='\t')
        # upload table to project
        proj_dir = f"/exome_annot/annot_run/notebooks/chr{chr_num}/burden_tables_0001/"
        upload_file_to_project(burden_df_name, proj_dir)

        time_end = time.time()
        time_taken = (time_end - time_start)/60
        logging.info(f"Time to annotate block {i}: {time_taken} mins\n")

        # remove tmp files created by hail to prevent storage issues 
        tmp_dir = "/tmp/"
        for file in os.listdir(tmp_dir):
            if file.startswith("persist_Table"):
                os.remove(os.path.join(tmp_dir, file))
    except Exception as error:
        logging.warning(f"block {i} not annotated due to {error}\n")
        print(f"!!!!!!!!block {i} not annotated!!!!!!!!")
