In [1]:
from __future__ import annotations
import pysam
import os
import pandas
import typing
import re

In [2]:
os.getcwd()

'/ceph01/homedirs/brand/Projects/extreme-memorizers/playbooks'

In [32]:
VCF_FILE = "../output/102-08068/102-08068.hard-filtered.annotated.vcf.gz"
MIN_CADD = 24
MAX_GNOMAD_AF = 0
IMPACT = [ 'HIGH', 'MODERATE' ]
LOF_METRICS = "../external_scripts/files/gnomad.v2.1.1.lof_metrics.by_gene.txt"
MIN_PLI = 0.9
MIN_MIS_Z = 3.09
MAX_OE_LOF = 0.35

In [33]:
apply_fn = typing.Callable[[pysam.libcbcf.VariantRecord], pysam.libcbcf.VariantRecord]
filter_fn = typing.Callable[[pysam.libcbcf.VariantRecord], bool]
record_type = pysam.libcbcf.VariantRecord


record_attributes = {
    'alts': 'list',
    'chrom': 'fixed',
    'filter': 'list',
    'format': 'list',
    'id': 'fixed',
    'info': 'dict',
    'pos': 'fixed',
    'qual': 'fixed',
    'ref': 'fixed',
    'samples': 'dict',
}


class FilterIteratorException(Exception):
    pass


class RecordAttributeError(Exception):
    pass


class FilterVariants():
    class _TransformFn():
        def __init__(self, fn: typing.Union[apply_fn, filter_fn], is_filter: bool = True):
            self._fn = fn
            self.is_filter = is_filter
        
        def __call__(self, record: record_type):
            return self._fn(record)
        
    def __init__(self, vcf_file: str):
        self._vcf = pysam.VariantFile(vcf_file)
        self._transforms = []
        self._is_iterating = False
        self._filtered_records = 0
        self._annotation_header = None
        
    def apply(self, fn: apply_fn) -> FilterVariants:
        if self._is_iterating:
            raise FilterIteratorException("Changing transformations during iteration.")
        self._transforms.append(FilterVariants._TransformFn(fn, is_filter=False))
        return self
        
    def filter(self, fn: fiter_fn) -> FilterVariants:
        if self._is_iterating:
            raise FilterIteratorException("Changing transformations during iteration.")
        self._transforms.append(FilterVariants._TransformFn(fn, is_filter=True))
        return self
        
    def _apply_transforms(self, record: record_type) -> typing.Tuple[bool, record_type]:
        is_filtered = False
        for fn in self._transforms:
            if is_filtered:
                break
            if fn.is_filter:
                is_filtered = not fn(record)
            else:
                record = fn(record)
        return (is_filtered, record)
        
    @property
    def records(self):
        for record in self._vcf.fetch():
            self._is_iterating = True
            is_filtered, record = self._apply_transforms(record)
            if not is_filtered:
                yield record
            else:
                self._filtered_records += 1
        self._is_iterating = False
    
    @property
    def annotation_header(self) -> typing.List[str]:
        if self._annotation_header is None:
            self._annotation_header = list(filter(lambda x: x.get('ID') == 'CSQ', self._vcf.header.records))[0].get('Description').split(':')[1].strip(" \"").split('|')
        return self._annotation_header
    
    @classmethod
    def get_annotations(cls, record: record_type, header: typing.List[str], info_field='CSQ') -> typing.List[typing.Mapping[str, str]]:
        def _parse_val(val: str) -> typing.Union[str, float, int]:
            if re.match(r"^[+-]?\d+$", val):
                return int(val)
            elif re.match(r'^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?$', val):
                return float(val)
            elif val == '':
                return None
            else:
                return val
        return [{
            header[i]: _parse_val(k) for i, k in enumerate(csq_field.split('|'))
        } for csq_field in record.info[info_field] ]
    
    def _sample_to_row(self,
        record: record_type,
        attributes: typing.List[str] = None,
    ) -> typing.Tuple[typing.List[typing.List[str]], typing.List[str]]:
        rows = []
        if attributes is None:
            s = list(record.samples)[0]
            s_attrs = set(list(record.samples[s]))
        else:
            attrs = set(filter(lambda s: s.startswith('samples'), attributes))
            s_attrs = set(map(lambda s: s.split('.')[1], attrs))
        header = None
        for s in list(record.samples):
            attrs_to_row = set(list(record.samples[s])) & s_attrs
            header = list(sorted(attrs_to_row))
            rows.append([
                s,
                *[ record.samples[s][a] for a in sorted(attrs_to_row) ]
            ])
        return rows, ['sample_id', *[ f'sample.{a}' for a in header ]]
            
    def _attrs_to_row(self,
        record: record_type,
        attributes: typing.List[str] = None,
        excl_attr: typing.List[str] = [ 'samples' ],
        anno_field: str = 'CSQ',
        split_anno: bool = True,
    ) -> typing.Tuple[typing.List[typing.List[str]], typing.List[str]]:
        
        def _get_single_row(record, attrs, anno_field=anno_field) -> typing.List[str]:
            row = []
            for a in attrs:
                a_key = a.split('.')[0]
                if record_attributes[a_key] == 'fixed':
                    row.append(getattr(record, a_key))
                elif record_attributes[a_key] == 'list':
                    row.append(','.join(list(getattr(record, a_key))))
                elif record_attributes[a_key] == 'dict':
                    attr_value = getattr(record, a_key)
                    if len(a.split('.')) == 1:
                        # Get all
                        for key in list(attr_value):
                            if a_key == 'info' and key == anno_field:
                                continue
                            else:
                                row.append(attr_value[key])
                    elif a.split('.')[1] in list(attr_value):
                        row.append(attr_value[a.split('.')[1]])
                    else:
                        raise RecordAttributeError(f"{a} could not be found on record.")
            return row

        if attributes is not None:
            attrs = set(filter(lambda s: s.split('.')[0] in record_attributes.keys())) - set(excl_attr)
        else:
            attrs = set(record_attributes.keys()) - set(excl_attr)

        has_anno = any(a.startswith(f'info.{anno_field}') for a in attrs)
        if has_anno or attributes is None:
            anno = FilterVariants.get_annotations(record, self.annotation_header)
        else:
            anno = None
            
        if anno is None:
            return [ _get_single_row(record, attrs) ]
        else:
            if len(set(list(record.info)) - set([ anno_field ])) == 0:
                attrs -= set([ 'info' ])
            return [
                [ *a.values(), *_get_single_row(record, attrs) ] for a in anno 
            ], [ *[ f"info.{a}" for a in anno[0].keys() ] , *attrs ]

    
    def _record_to_row(self,
        record: record_type,
        attributes: typing.List[str] = None,
        info_field: str = 'CSQ',
        split_info: bool = True
    ) -> typing.Tuple[typing.List[typing.List[str]], typing.List[str]]:
        rows = []
        sample_rows = self._sample_to_row(record, attributes)
        row_attrs = self._attrs_to_row(record, attributes, anno_field=info_field, split_anno=split_info)
        var_id = f"{record.contig}_{record.start}_{record.ref}_{record.alts[0]}"
        header = [ 'var_id', *row_attrs[1], *sample_rows[1] ]
        for s in sample_rows[0]:
            for a in row_attrs[0]:
                rows.append([
                    var_id,
                    *a,
                    *s,
                ])
        return rows, header 
    
    def to_pandas(self, attributes: typing.List[str] = None) -> pandas.DataFrame:
        rows = []
        header = []
        for record in self.records:
            data, header = self._record_to_row(record, attributes)
            rows = [ *rows, *data ]
        return pandas.DataFrame(rows, columns=header)
        

In [49]:
filter_obj = FilterVariants(VCF_FILE)
lof = pandas.read_csv(LOF_METRICS, sep="\t").set_index("gene", drop=False)


def filter_cadd(
    record: record_type,
    header: typing.List[str] = filter_obj.annotation_header,
    min_cadd: int = MIN_CADD
) -> bool:
    anno = FilterVariants.get_annotations(record, header)
    return any('CADD_PHRED' in a.keys() and a['CADD_PHRED'] is not None and a['CADD_PHRED'] >= min_cadd for a in anno)

def filter_gnomad(
    record: record_type,
    header: typing.List[str] = filter_obj.annotation_header,
    max_gnomad_af: float = MAX_GNOMAD_AF
) -> bool:
    anno = FilterVariants.get_annotations(record, header)
    return any('gnomAD_AF' in a and (a['gnomAD_AF'] is None or a['gnomAD_AF'] <= max_gnomad_af) for a in anno)

def filter_impact(
    record: record_type,
    header: typing.List[str] = filter_obj.annotation_header,
    impact: typing.List[str] = IMPACT,
) -> bool:
    anno = FilterVariants.get_annotations(record, header)
    return any('IMPACT' in a and a['IMPACT'] is not None and a['IMPACT'] in impact for a in anno)

def filter_constraint(
    record: record_type,
    header: typing.List[str] = filter_obj.annotation_header,
    lof_metrics: pandas.DataFrame = lof,
    min_pli: float = MIN_PLI,
    min_mis_z: float = MIN_MIS_Z,
    max_oe_lof: float = MAX_OE_LOF,
) -> bool:
    anno = FilterVariants.get_annotations(record, header)
    gene = anno[0]['SYMBOL']
    if gene in lof_metrics.index:
        return any(
                (a['Consequence'] == 'missense_variant' and lof_metrics.loc[gene].mis_z > min_mis_z)
                or (lof_metrics.loc[gene].pLI >= min_pli and lof_metrics.loc[gene].oe_lof_upper < max_oe_lof)
            for a in anno)
    else:
        return False

filter_obj = filter_obj.filter(filter_impact).filter(filter_constraint).filter(filter_gnomad) #.filter(filter_cadd)

for record in filter_obj.records:
    anno = FilterVariants.get_annotations(record, filter_obj.annotation_header)
    print(record.chrom, record.pos, [ a['gnomAD_AF'] for a in anno ], [ a['CADD_PHRED'] for a in anno ], [ a['IMPACT'] for a in anno ], [ a['Consequence'] for a in anno ])


3 133969487 [None, None, None] [21.8, 21.8, 21.8] ['HIGH', 'MODIFIER', 'MODIFIER'] ['frameshift_variant', '5_prime_UTR_variant', 'upstream_gene_variant']
5 60628574 [0] [12.31] ['MODERATE'] ['inframe_deletion']


In [272]:
df = filter_obj.to_pandas()
df = df[df['info.IMPACT'].apply(lambda s: s in IMPACT)]

In [273]:
df.to_csv("filtered_vars.tsv", sep="\t", index=False)

In [46]:
lof.loc['NGF']['pLI']

0.82026

In [22]:
anno

[{'Allele': '-',
  'Consequence': 'intron_variant',
  'IMPACT': 'MODIFIER',
  'SYMBOL': 'ILK',
  'Gene': 'ENSG00000166333',
  'Feature_type': 'Transcript',
  'Feature': 'ENST00000299421',
  'BIOTYPE': 'protein_coding',
  'EXON': None,
  'INTRON': '5/12',
  'HGVSc': None,
  'HGVSp': None,
  'cDNA_position': None,
  'CDS_position': None,
  'Protein_position': None,
  'Amino_acids': None,
  'Codons': None,
  'Existing_variation': 'COSV54990603',
  'DISTANCE': None,
  'STRAND': 1,
  'FLAGS': None,
  'SYMBOL_SOURCE': 'HGNC',
  'HGNC_ID': 6040,
  'SOURCE': None,
  'SIFT': None,
  'PolyPhen': None,
  'gnomAD_AF': None,
  'gnomAD_AFR_AF': None,
  'gnomAD_AMR_AF': None,
  'gnomAD_ASJ_AF': None,
  'gnomAD_EAS_AF': None,
  'gnomAD_FIN_AF': None,
  'gnomAD_NFE_AF': None,
  'gnomAD_OTH_AF': None,
  'gnomAD_SAS_AF': None,
  'CLIN_SIG': None,
  'SOMATIC': 1,
  'PHENO': 1,
  'CADD_PHRED': 1.914,
  'CADD_RAW': -0.052206,
  'gnomADg': 'rs11314683',
  'gnomADg_AF_AFR': 1.0,
  'gnomADg_AF_AMR': 1.0,
  'gn