In [None]:
# Set the output bucket to write to, dataproc service account must have write access
# Do not include trailing slash or "gs://"
output_bucket = "clingen-dataproc-workspace-kferrite"
# Set the TSV path to write into bucket. Can contain slash like "folder/file.tsv"
# Do not include leading slash
output_filename = "report.tsv"


# Set this to true to limit output variants to be those within transcript coding regions
transcript_filter = False


In [None]:
import hail as hl
# `idempontent=True` is useful for running all cells in the notebook
hl.init(idempotent=True)

In [None]:
# Obtain desired thresholds
import io, re

thresholds = """
MYH7 BA1 0.10%
MYH7 BS1 0.02%
PTPN11 BA1 0.05%
PTPN11 BS1 0.03%
CDH1 BA1 0.20%
CDH1 BS1 0.10%
RUNX1 BA1 0.15%
RUNX1 BS1 0.015%
TP53 BA1 0.10%
TP53 BS1 0.03%
GJB2 BA1 0.50%
GJB2 BS1 0.30%
PAH BA1 1.50%
PAH BS1 0.20%
GAA BA1 1%
GAA BS1 0.50%
HRAS BA1 0.05%
HRAS BS1 0.03%
NRAS BA1 0.05%
NRAS BS1 0.03%
"""

thresh_reader = io.StringIO(thresholds)

def parse_thresholds(reader):
    """
    Expects `reader` to be a file/io reader 
    with a newline delimited list of:
    <gene-symbol> <thresh-name> <thresh-percent>
    ...
    <thresh-percent> may be pure float or contain % denoting 10e2 scaling
    Returns a multilayer dictionary of gene(str)->threshname(str)->AF->percent(float)
    Example:
    gene_thresholds = {
        "MYH7": {
            "BA1": {
                "AF": 0.0005
            },
            "BS1": {
                "AF": 0.0002
            }
        }
    }
    """
    thresholds = {}
    # Load whole reader contents, should be small enough
    contents = reader.read()
    lines = contents.splitlines()
    lines = [l for l in lines if l and len(l)] # skip empty lines
    for line in lines:
        gene, thresh_name, thresh = re.split("\s+", line)
        if thresh.endswith("%"):
            thresh = float(thresh[:-1]) / 100.0
        else:
            thresh = float(thresh)
        if gene not in thresholds:
            thresholds[gene] = {}
        thresholds[gene][thresh_name] = {"AF": thresh}
    return thresholds

        
gene_thresholds = parse_thresholds(thresh_reader)
print(gene_thresholds)

In [None]:
import io
import re

# Read gnomAD data as Hail Tables
# sources = {
#     "Exomes": "gs://gnomad-public/release/2.1.1/ht/exomes/gnomad.exomes.r2.1.1.sites.ht",
#     "Genomes": "gs://gnomad-public/release/2.1.1/ht/exomes/gnomad.exomes.r2.1.1.sites.ht"
# }

ds_exomes = hl.read_table("gs://gnomad-public/release/2.1.1/ht/exomes/gnomad.exomes.r2.1.1.sites.ht")
ds_exomes = ds_exomes.annotate(
    source="gnomAD Exomes"
)
ds_genomes = hl.read_table("gs://gnomad-public/release/2.1.1/ht/genomes/gnomad.genomes.r2.1.1.sites.ht")
ds_genomes = ds_genomes.annotate(
    source="gnomAD Genomes"
)

# Can perform a union here if wanting both (ds = ds1.union(ds2))
def select_necessary_cols(ds):
    ds = ds.select(ds.freq, ds.faf, ds.vep, ds.source)
    return ds

ds_exomes = select_necessary_cols(ds_exomes)
ds_genomes = select_necessary_cols(ds_genomes)

ds = ds_genomes.union(ds_exomes, unify=True)

# Show the schema of the hail Table
# ds.describe()

In [None]:
"""
ds.freq has raw frequency information, including AN, AC, and pop label. This is an array of 
structs, at indices determined by the categories in ds.globals.freq_index_dict

ds.faf has filtered allele frequency information, including confidence intervals faf95 adn faf99.
This is an array of structs, at indices determined by the category map in ds.globals.faf_index_dict
"""

def add_popmax_af(ds):
    """
    Adds a popmax_faf and popmax_af_pop column to the ds Hail Table.
    
    popmax_faf is a faf structure from the original ds, containing the maximum faf of the
    listed faf structures in the original ds, based on the filtering criteria 
    `default_faf_filter_type`. 
    
    The popmax_index_dict_key column contains the text field from the
    ds.globals.faf_index_dict which corresponds to each popmax_faf. This is similar to the
    ds.popmax_faf.meta["pop"] value but not exactly the same (gnomad_afr vs afr)
    
    Returns the updated ds.
    """
    # Identify indices in FAF field that correspond to the entire dataset (not a subset like non-cancer)
    # faf_index_map = [(k,v) for k, v in hl.eval(ds.globals.faf_index_dict).items() if k.startswith("gnomad_")]
    from enum import Enum
    class FafFilterType(Enum):
        # Each correponds to a filter func for a (k,v) of faf label to value
        GNOMAD_GLOBAL = lambda t: t[0] == "gnomad"
        GNOMAD_SUPERPOP = lambda t: t[0].startswith("gnomad_")
        ANY = lambda t: True

    # By default, filter to superpopulations aggregate faf
    default_faf_filter_type = FafFilterType.GNOMAD_SUPERPOP

    def faf_filter(faf_idx_tuple:tuple):
        return default_faf_filter_type(faf_idx_tuple)

    # Get list of the global faf_index_dict which meets the default_faf_filter criteria
    # This gives the indices of the desired populations, by default will take all top level populations
    faf_index_map = list(filter(faf_filter, [(k,v) for k,v in hl.eval(ds.globals.faf_index_dict).items()]))
    faf_indices = [v for k,v in faf_index_map]
    faf_labels = [k for k,v in faf_index_map]
    
    # Annotate table with popmax FAF
    
    # This only will return the maximum pop FAF for each
    # variant, even if multiple populations meet the criteria. 
    # If we want all matching populations, need an explode() call
    # to flatten the pop FAFs into a record per pop per variant
    
    ds = ds.annotate(
        popmax_faf=hl.sorted(
            # Take only the FAF entries that correspond to the desired populations (faf_indices)
            hl.literal(faf_indices).map(lambda i: ds.faf[i]),
            # Sort by 95% confidence FAF
            lambda faf_entry: faf_entry.faf95,
            # Sort high to low
            reverse=True
        )[0] # Take the first entry with the highest FAF
        ,
        # Label of the freq_index_dict entry for this record's max pop
        popmax_index_dict_key=hl.sorted(
            # List of tuples of (poplabel, faf_index)
            list(zip(list(faf_labels), list(faf_indices))),

            # Take only the FAF entries that correspond to the entire dataset
            # Sort by 95% confidence FAF
            key=lambda tpl: ds.faf[tpl[1]].faf95,
            # Sort high to low
            reverse=True
        )[0][0] # Take the first entry, which has the highest FAF
    )
    
    ds = ds.annotate(
#         popmax_faf_pop_freq=ds.freq[ds.globals.freq_index_dict["gnomad_" + ds.popmax_faf.meta.get("pop")]]

        # ds.globals.freq_index_dict uses the same keys as ds.globals.faf_index_dict so
        # we can reuse ds.popmax_index_dict_key created above
        popmax_faf_pop_freq=ds.freq[ds.globals.freq_index_dict[ds.popmax_index_dict_key]] 
    )
    
    return ds


ds = add_popmax_af(ds)

In [None]:
# These next 2 functions override functions from hail.experimental, modified to return a mapping
# of gene_symbols to the intervals they correspond to. Existing methods return unordered list

import operator
import functools
from hail.genetics.reference_genome import reference_genome_type
from hail.typecheck import typecheck, nullable, sequenceof
from hail.utils.java import info
from hail.utils import new_temp_file

def _load_gencode_gtf(gtf_file=None, reference_genome=None):
    """
    Get Gencode GTF (from file or reference genome)

    Parameters
    ----------
    reference_genome : :obj:`str` or :class:`.ReferenceGenome`, optional
       Reference genome to use (passed along to import_gtf).
    gtf_file : :obj:`str`
       GTF file to load. If none is provided, but `reference_genome` is one of
       `GRCh37` or `GRCh38`, a default will be used (on Google Cloud Platform).

    Returns
    -------
    :class:`.Table`
    """
    GTFS = {
        'GRCh37': 'gs://hail-common/references/gencode/gencode.v19.annotation.gtf.bgz',
        'GRCh38': 'gs://hail-common/references/gencode/gencode.v29.annotation.gtf.bgz',
    }
    if reference_genome is None:
        reference_genome = hl.default_reference().name
    else:
        reference_genome = reference_genome.name
    if gtf_file is None:
        gtf_file = GTFS.get(reference_genome)
        if gtf_file is None:
            raise ValueError(
                'get_gene_intervals requires a GTF file, or the reference genome be one of GRCh37 or GRCh38 (when on Google Cloud Platform)')
    ht = hl.experimental.import_gtf(gtf_file, reference_genome=reference_genome,
                                    skip_invalid_contigs=True, min_partitions=12)
    ht = ht.annotate(gene_id=ht.gene_id.split('\\.')[0],
                     transcript_id=ht.transcript_id.split('\\.')[0])
    return ht

@typecheck(gene_symbols=nullable(sequenceof(str)),
           gene_ids=nullable(sequenceof(str)),
           transcript_ids=nullable(sequenceof(str)),
           verbose=bool, reference_genome=nullable(reference_genome_type), gtf_file=nullable(str))
def get_gene_intervals(gene_symbols=None, gene_ids=None, transcript_ids=None,
                       verbose=True, reference_genome=None, gtf_file=None):
    """Get intervals of genes or transcripts.

    Get the boundaries of genes or transcripts from a GTF file, for quick filtering of a Table or MatrixTable.

    On Google Cloud platform:
    Gencode v19 (GRCh37) GTF available at: gs://hail-common/references/gencode/gencode.v19.annotation.gtf.bgz
    Gencode v29 (GRCh38) GTF available at: gs://hail-common/references/gencode/gencode.v29.annotation.gtf.bgz

    Example
    -------
    >>> hl.filter_intervals(ht, get_gene_intervals(gene_symbols=['PCSK9'], reference_genome='GRCh37'))  # doctest: +SKIP

    Parameters
    ----------

    gene_symbols : :obj:`list` of :obj:`str`, optional
       Gene symbols (e.g. PCSK9).
    gene_ids : :obj:`list` of :obj:`str`, optional
       Gene IDs (e.g. ENSG00000223972).
    transcript_ids : :obj:`list` of :obj:`str`, optional
       Transcript IDs (e.g. ENSG00000223972).
    verbose : :obj:`bool`
       If ``True``, print which genes and transcripts were matched in the GTF file.
    reference_genome : :obj:`str` or :class:`.ReferenceGenome`, optional
       Reference genome to use (passed along to import_gtf).
    gtf_file : :obj:`str`
       GTF file to load. If none is provided, but `reference_genome` is one of
       `GRCh37` or `GRCh38`, a default will be used (on Google Cloud Platform).

    Returns
    -------
    :obj:`list` of :class:`.Interval`
    """
    if gene_symbols is None and gene_ids is None and transcript_ids is None:
        raise ValueError('get_gene_intervals requires at least one of gene_symbols, gene_ids, or transcript_ids')
    ht = _load_gencode_gtf(gtf_file, reference_genome)
    criteria = []
    if gene_symbols:
        criteria.append(hl.any(lambda y: (ht.feature == 'gene') & (ht.gene_name == y), gene_symbols))
    if gene_ids:
        criteria.append(hl.any(lambda y: (ht.feature == 'gene') & (ht.gene_id == y.split('\\.')[0]), gene_ids))
    if transcript_ids:
        criteria.append(hl.any(lambda y: (ht.feature == 'transcript') & (ht.transcript_id == y.split('\\.')[0]), transcript_ids))

    ht = ht.filter(functools.reduce(operator.ior, criteria))
    gene_info = ht.aggregate(hl.agg.collect((ht.feature, ht.gene_name, ht.gene_id, ht.transcript_id, ht.interval)))
    if verbose:
        info(f'get_gene_intervals found {len(gene_info)} entries:\n'
             + "\n".join(map(lambda x: f'{x[0]}: {x[1]} ({x[2] if x[0] == "gene" else x[3]})', gene_info)))
    # intervals = list(map(lambda x: x[-1], gene_info))
    intervals = list(map(lambda x: {
        'gene_symbol': x[1],
        'gene_id': x[2],
        'transcript_id': x[3],
        'interval': x[4]
    }, gene_info))
    return intervals

                         
# Look up intervals for the gene symbols in the input thresholds
gene_symbols = [k for k in gene_thresholds.keys()]
intervals = get_gene_intervals(gene_symbols=gene_symbols, reference_genome="GRCh37")

def get_gene_interval(gene_symbol:str):
    global intervals
    for i in intervals:
        if i["gene_symbol"] == gene_symbol:
            return i["interval"]
    print("Getting new gene interval: %s" % gene_symbol)
    i = get_gene_intervals(gene_symbols=[gene_symbol], reference_genome="GRCh37")[0]
    intervals.append(i)
    return i["interval"]

In [None]:
# Perform some preliminary annotations
ds_crit = ds

# This was removed because we can't assume all gene symbols are the same, a variant can have >1
# ds_crit = ds_http://localhost:8123/notebooks/clingen-dataproc-workspace-kferrite/ClinGen-Gnomad-FAF-Report-V2.ipynb#crit.annotate(
#     gene_symbol=ds_crit.vep.transcript_consequences.gene_symbol # Can't assume they are all the same
# )

print(intervals)
ivl_struct_list = hl.literal(
    [hl.struct(
        gene_symbol=i["gene_symbol"],
        gene_id=i["gene_id"],
        transcript_id=i["transcript_id"],
        interval=i["interval"]
    ) for i in intervals]
)

# Filter by intervals of genes provided in input criteria
ds_crit = hl.filter_intervals(ds_crit, [i["interval"] for i in intervals])

# Now attach the gene field using 1 of two methods.
# If transcript_filter is true, attach gene label based on transcript_consequences
# If transcript_filter is false, attach based on which gene interval it is contained in
if transcript_filter is False:
    ds_crit = ds_crit.annotate(
#         gene=lambda _: 
        gene=ivl_struct_list.find(
            lambda ivl: ivl["interval"].contains(ds_crit.locus)
        ).gene_symbol
    )
else:
    # Explode a new record per transcript consequence, each now has 1 gene
    ds_crit = ds_crit.annotate(
        transcript_consequences=ds_crit.vep.transcript_consequences
    )
    ds_crit = ds_crit.explode("transcript_consequences")
    ds_crit = ds_crit.annotate(
        gene=ds_crit.transcript_consequences.gene_symbol
    )


# Sort each gene's criteria thresholds descending by AF so first hl.find is the max
gene_thresholds = hl.literal(gene_thresholds)
gene_thresholds_sorted = gene_thresholds.map_values(
    lambda gene_criteria: hl.sorted(
        # Transform {"BA1": {"AF": 0.02}} to list of [("BA1", {"AF": 0.02})]
        gene_criteria.keys().map(lambda crit_name: (crit_name, gene_criteria[crit_name])),
        
        # Key to sort the above ArrayExpression
        lambda t: t[1]["AF"],
        
        # Reverse order so we find the max threshold first
        reverse=True
    )
)
print(gene_thresholds_sorted.collect())

# Filter to variants in genes we care about
ds_crit = ds_crit.filter(
    gene_thresholds.keys().contains(ds_crit.gene)
)

ds_crit = ds_crit.annotate(
    # Get the max AF threshold which is less or equal to popmax_faf.faf95
    criteria_satisfied=hl.or_missing(
        # Condition
        gene_thresholds.keys().any(
            lambda threshold_gene: ds_crit.gene == threshold_gene
        ),
        
        # If this gene is in criteria, find max criteria (already reverse sorted, find gets first)
        hl.find(
            lambda tpl: tpl[1]["AF"] <= ds_crit.popmax_faf.faf95,
            # gene_thresholds[ds_crit.gene][crit_name]["AF"] <= ds_crit.popmax_faf.faf95,

            # List of (crit_name, {"AF": 0.02})
            gene_thresholds_sorted[ds_crit.gene]
        )[0] # [0] returns the criteria name (ex: BA1)
        #[1]["AF"]
    ),
#     gene_symbol=hl.join(
#         ",",
#         # Join the gene symbols of the record which are in gene_thresholds, ignore others
#         ds_crit.vep.transcript_consequences.gene_symbol.filter(
#             lambda s: gene_thresholds.keys().contains(s)
#         )
#     )
)

# Filter to variants which meet a criteria
ds_crit = ds_crit.filter(
    ~hl.is_missing(ds_crit.criteria_satisfied)
)


filtered_ds = ds_crit.select(
    criteria_satisfied = ds_crit.criteria_satisfied,
    source = ds_crit.source,
    gene = ds_crit.gene,
    popmax_pop = ds_crit.popmax_faf.meta["pop"],
    popmax_ac = ds_crit.popmax_faf_pop_freq.AC,
    popmax_an = ds_crit.popmax_faf_pop_freq.AN,
    faf95 = ds_crit.popmax_faf.faf95,
    genomic_coordinates = hl.format("%s-%s-%s-%s",
        ds_crit.locus.contig,
        hl.str(ds_crit.locus.position),
        ds_crit.alleles[0],
        ds_crit.alleles[1]
    )
)
# filtered_ds.show()

In [None]:
# Import ClinVar VCF as Hail Table
# clinvar = hl.import_vcf("/path/to/clinvar.vcf.gz", force_bgz=True, drop_samples=True, skip_invalid_loci=True).rows()

# Download clinvar BGZF
import os, requests, subprocess

# Function to download a file to a localpath. ClinVar VCF is small enough to download to dataproc default local disk.
def download_to_file(url, filepath):
    r = requests.get(url, stream=True)
    with open(filepath, "wb") as fout: 
        for chunk in r.iter_content(chunk_size=1024): 
             if chunk:
                 fout.write(chunk)
# This url always points to the latest dump file, updated periodically by ClinVar
clinvar_vcf_url = "https://ftp.ncbi.nlm.nih.gov/pub/clinvar/vcf_GRCh37/clinvar.vcf.gz"
clinvar_vcf_localpath = "/home/hail/clinvar.vcf.gz"
clinvar_vcf_hdfs = "clinvar.vcf.gz"
download_to_file(clinvar_vcf_url, clinvar_vcf_localpath)
assert(os.path.exists(clinvar_vcf_localpath))
print("Downloaded ClinVar VCF, file size (expecting ~28M): %d" % os.path.getsize(clinvar_vcf_localpath))

# Hail needs the file in HDFS
p = subprocess.Popen(["hdfs", "dfs", "-cp", "file://" + clinvar_vcf_localpath, clinvar_vcf_hdfs])
print(p.communicate())


clinvar = hl.import_vcf(
    clinvar_vcf_hdfs,
    force_bgz=True,
    drop_samples=True, 
    skip_invalid_loci=True
).rows()
print(clinvar.count())

# Join ClinVar table to gnomAD table. ClinVar fields available under the table.clinvar struct
gnomad_clinvar_ds = filtered_ds.annotate(
    clinvar=clinvar[filtered_ds.locus, filtered_ds.alleles]
)

# ClinVar VCF export sets ID column to the ClinVar Variation ID (not rsid)
# And sets the RS field of INFO to the rsid if it exists.
# (https://ftp.ncbi.nlm.nih.gov/pub/clinvar/README_VCF.txt)
# Hail then sets this ClinVar ID as the rsid column of the clinvar struct
# We can filter to only the variants that exist in clinvar with:
# gnomad_clinvar_ds = gnomad_clinvar_ds.filter(
#     ~hl.is_missing(gnomad_clinvar_ds.clinvar_rsid)
# )

In [None]:
# gnomad_clinvar_ds.show()

In [None]:
# Select desired output fields (ordered)
output_ds = gnomad_clinvar_ds.select(
    gnomad_clinvar_ds.criteria_satisfied,
    gnomad_clinvar_ds.gene,
    gnomad_clinvar_ds.faf95,
    gnomad_clinvar_ds.source,
    gnomad_clinvar_ds.popmax_pop,
    gnomad_clinvar_ds.popmax_ac,
    gnomad_clinvar_ds.popmax_an,
    gnomad_clinvar_ds.genomic_coordinates,
    clinvar_variation_id=gnomad_clinvar_ds.clinvar.rsid,
    clinvar_review_status=hl.delimit(gnomad_clinvar_ds.clinvar.info["CLNREVSTAT"], ","),
    clinvar_significance=hl.delimit(gnomad_clinvar_ds.clinvar.info["CLNSIG"], ","),
    clinvar_significance_interpretations=hl.delimit(gnomad_clinvar_ds.clinvar.info["CLNSIGCONF"], ",")
)

# output_ds.describe()

# Export to TSV
report_filename = "report.tsv"
import time
print("Starting export to %s" % report_filename)
start_time = time.time()
output_ds.export(report_filename)
end_time = time.time()
print("Export took %.2f seconds" % (end_time - start_time))

In [None]:
# The export is in HDFS now, copy to machine-local file
report_localpath = os.path.join(os.getcwd(), report_filename)
os.system("rm %s" % report_localpath)
p = subprocess.Popen(["hdfs", "dfs", "-cp", report_filename, "file://" + report_localpath],
        stdout=subprocess.PIPE, stderr=subprocess.PIPE)
print(p.communicate())

In [None]:
# Upload to bucket and filepath set at top of notebook
gs_output_file = "gs://%s/%s" % (output_bucket, output_filename)
p = subprocess.Popen(["gsutil", "cp", report_localpath, gs_output_file],
        stdout=subprocess.PIPE, stderr=subprocess.PIPE)
print(p.communicate())