In [None]:
chrom = None

In [1]:
import pyspark
import dxpy
import hail as hl
import pandas as pd
from math import ceil

WD='/opt/notebooks'

In [2]:
my_database = dxpy.find_one_data_object(
    name="my_database", 
    project=dxpy.find_one_project()["id"]
)["id"]
database_dir = f'dnax://{my_database}'
sc = pyspark.SparkContext()
spark = pyspark.sql.SparkSession(sc)
hl.init(sc=sc, tmp_dir=f'{database_dir}/tmp/')

pip-installed Hail requires additional configuration options in Spark referring
  to the path to the Hail Python module directory HAIL_DIR,
  e.g. /path/to/python/site-packages/hail:
    spark.jars=HAIL_DIR/hail-all-spark.jar
    spark.driver.extraClassPath=HAIL_DIR/hail-all-spark.jar
    spark.executor.extraClassPath=./hail-all-spark.jarRunning on Apache Spark version 2.4.4
SparkUI available at http://ip-10-60-90-238.eu-west-2.compute.internal:8081
Welcome to
     __  __     <>__
    / /_/ /__  __/ /
   / __  / _ `/ / /
  /_/ /_/\_,_/_/_/   version 0.2.78-b17627756568
LOGGING: writing to /opt/notebooks/hail-20220928-0846-0.2.78-b17627756568.log


## S0. Define functions, load data

In [3]:
def get_gnomad_vcf_path(chrom, blocks):
    vcf_dir = 'file:///mnt/project/Bulk/Exome sequences_Alternative exome processing/Exome variant call files (gnomAD) (VCFs)'
    if blocks != '*':
        blocks = '{'+','.join(map(str, blocks))+'}'
        
    return f'{vcf_dir}/ukb24068_c{chrom}_b{blocks}_v1.vcf.gz'


def get_partitioned_chrom(chrom_w_suffix):
    """
    chrom_w_suffix should be of the form "{chr}-?of?", e.g. "8-1of4" for partition 1 of 4 in chromosome 8
    """
    chrom, suffix = chrom_w_suffix.split('-')
    assert chrom in list(map(str, range(1,23)))+['X','Y'], "chrom must be in  {1-22, X, Y}"
    part_idx, total_parts = map(int, suffix.split('of'))
    assert (part_idx>=1) & (part_idx<=total_parts)
    
    total_vcfs = len(hl.hadoop_ls(get_gnomad_vcf_path(chrom=chrom, blocks="*")))
    
    part_size = ceil(total_vcfs/total_parts)
    
    start_idx = (part_idx-1)*part_size
    stop_idx = min((part_idx)*part_size-1, total_vcfs-1)
    
    return get_gnomad_vcf_path(chrom, blocks=range(start_idx, stop_idx+1))
    
    

def import_single_chrom_vcf(chrom, blocks = '*'):
    if 'of' in str(chrom):
        # Get chunk of chromosome
        vcf_path = get_partitioned_chrom(chrom_w_suffix=chrom)
    else:
        vcf_path = get_gnomad_vcf_path(chrom=chrom, blocks=blocks)
    
    return hl.import_vcf(
        vcf_path, 
        force_bgz=True,
        reference_genome='GRCh38'
    )


def get_mad_threshold_tsv_fname(n_mads, classification):
    return f'ukb_wes_450k.mad_threshold.nmad_{n_mads}.popclass_{classification}.tsv.gz'


def get_pass_mad_threshold_expr(mt, n_mads='4', classification='strict'):
    mad_fname = get_mad_threshold_tsv_fname(n_mads=n_mads, classification=classification)
    mad_path = f'file:///mnt/project/data/03_mad_threshold/{mad_fname}'
#     mad_path = f'file:///opt/notebooks/{mad_fname}'
    print(mad_path)
    mad_ht = hl.import_table(
        mad_path, 
        types={
            's': hl.tstr, 
            'pass': hl.tbool
        },
        key='s',
        force=True
    )
    
    return mad_ht[mt.s]['pass']

def get_fail_interval_qc_expr(mt):
    return mt.info.fail_interval_qc

def get_lcr_expr(mt):
    return mt.info.lcr

def get_segdup_expr(mt):
    return mt.info.segdup

def get_filter_contains_rf_expr(mt):
    return mt.filters.contains('RF')

def get_inbreeding_coeff(mt):
    return mt.info.InbreedingCoeff[0]

def site_filter(mt):
    # Set genotype to missing if:
    # - DP < 10
    # - GQ < 20
    # - If heterozygous: Alt allele balance <= 0.2
    
    SITE_DP_MIN = 10
    SITE_GQ_MIN = 20

    pass_dp = mt.DP>=SITE_DP_MIN
    pass_gq = mt.GQ>=SITE_GQ_MIN

    pass_ab_het = mt.GT.is_het() & (mt.AD[1]/mt.DP>0.2)
    pass_ab = ~mt.GT.is_het() | pass_ab_het
    mt = mt.filter_entries(pass_dp & pass_gq & pass_ab)

    return mt


def final_variant_filter(mt):
    # Remove if:
    # - FILTER row field contains "RF" (random forest true positive probability < {threshold})
    # - Excess heterozygotes (inbreeding coefficient < -0.3)
    # - Fails gnomAD interval QC
    # - In low-complexity region
    # - segdup is true (segment duplication region?)
    # - No sample has a high quality genotype
    
    MIN_INBREEDING_COEFF = -0.3
    fails_inbreeding_coeff = get_inbreeding_coeff(mt) < MIN_INBREEDING_COEFF
    
    # Fail if all genotypes are missing
    fails_any_hq_genotypes = hl.agg.all(hl.is_missing(mt.GT))
    
    return mt.filter_rows(
        get_filter_contains_rf_expr(mt)
        | fails_inbreeding_coeff
        | get_fail_interval_qc_expr(mt)
        | get_lcr_expr(mt)
        | get_segdup_expr(mt)
        | fails_any_hq_genotypes,
        keep=False
    )

def final_variant_filter_skip_hq_gt(mt):
    # Remove if:
    # - FILTER row field contains "RF" (random forest true positive probability < {threshold})
    # - Excess heterozygotes (inbreeding coefficient < -0.3)
    # - Fails gnomAD interval QC
    # - In low-complexity region
    # - segdup is true (segment duplication region?)
    # - No sample has a high quality genotype
    
    MIN_INBREEDING_COEFF = -0.3
    fails_inbreeding_coeff = get_inbreeding_coeff(mt) < MIN_INBREEDING_COEFF
    
    # Fail if all genotypes are missing
    fails_any_hq_genotypes = hl.agg.all(hl.is_missing(mt.GT))
    
    return mt.filter_rows(
        get_filter_contains_rf_expr(mt)
        | fails_inbreeding_coeff
        | get_fail_interval_qc_expr(mt)
        | get_lcr_expr(mt)
        | get_segdup_expr(mt),
        keep=False
    )

def export_table(ht, fname, out_folder):
    ht.naive_coalesce(1).export(f'file:///opt/notebooks/{fname}')

    dxpy.upload_local_file(
        filename=f'/opt/notebooks/{fname}',
        name=fname,
        folder=out_folder,
        parents=True
    )
    
def export_file(path, out_folder):
    dxpy.upload_local_file(
        filename=path,
        name=path.split('/')[-1],
        folder=out_folder,
        parents=True
    )

def final_filter(mt):
    pass_mad_threshold_expr = get_pass_mad_threshold_expr(mt, n_mads='4', classification='strict')
    mt = mt.filter_cols(pass_mad_threshold_expr)

    mt = site_filter(mt)

    # NOTE: Final variant filter MUST come after site filter in order to remove variants where no individuals have high quality genotypes
    mt = final_variant_filter(mt)
    
    return mt

def get_final_filter_mt_path(chrom):
    return f'{database_dir}/04_final_filter_write_to_mt/ukb_wes_450k.qced.chr{chrom}.mt'

def get_final_filter_count_tsv_fname(chrom):
    return f'variant_sample_count.final_filter.c{chrom}.tsv'


def export_count_as_tsv(mt, chrom, fname, out_folder):
    row_ct, col_ct = mt.count()
    
    df = pd.DataFrame(data={
        'row_count': [row_ct], 
        'col_count': [col_ct]
    })
    ht = hl.Table.from_pandas(df)

    export_table(
        ht=ht, 
        fname=fname, 
        out_folder=out_folder
    )

In [5]:
chrom=21

In [None]:
%%time

chrom = chrom

raw = import_single_chrom_vcf(chrom)

## S1. Write list of variants to keep

In [None]:
mt = raw

# Write file of variants to keep
mt = final_variant_filter_skip_hq_gt(mt)
ht = mt.rows().key_by()
ht = ht.select(
    chrom = ht.locus.contig,
    position = ht.locus.position,
    a1 = ht.alleles[0],
    a2 = ht.alleles[1]
)

fname=f'ukb_wes_450k.pass_variants.chr{chrom}.tsv.gz'

export_table(
    ht=ht,
    fname=fname,
    out_folder='/data/04_final_filter'
)

2022-09-27 14:22:45 Hail: INFO: Coerced sorted dataset
2022-09-27 14:22:45 Hail: INFO: Coerced dataset with out-of-order partitions.
