Skip to content

Commit

Permalink
Refactor make_examples and variant_labeler into an interface that's c…
Browse files Browse the repository at this point in the history
…ompatible with the to-be-added haplotype-aware v2 labeler.

-- Now the variant labeler sees all of the candidates in a region, rather than provided them one-at-a-time by make_examples.
-- As a side effect, we no longer produce examples for candidates that won't receive a label, likely significantly speeding up training runs of make_examples.
-- Refactored and cleaned up a bunch of code in variant_labeler, introducing a new VariantLabel class holding the result of the labeler on a candidate.
-- Moved VariantCounters into variant_labeler, as this functionality will be replaced by a new, better protobuf-based stats tracker in a future CL. Besides moving the code nothing changed here.
-- As part of the refactoring more cleanly separated the making of examples (in make_examples) from the labeling of examples (in variant_labeler). Now make_examples really doesn't know anything about how labels are assigned. In particular this impacted the previous function label_variant, which knew quite a bit about how the previous labeler actually worked.
-- Note there are many TODOs in the current code, which is intentional as they will be addressed in future CLs that bring in the new labeler algorithm.

PiperOrigin-RevId: 187180297
  • Loading branch information
mdepristo authored and Copybara-Service committed Feb 27, 2018
1 parent cf20d97 commit 0a159f3
Show file tree
Hide file tree
Showing 5 changed files with 480 additions and 290 deletions.
1 change: 0 additions & 1 deletion deepvariant/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ py_library(
"//deepvariant/util:proto_utils",
"//deepvariant/util:py_utils",
"//deepvariant/util:ranges",
"//deepvariant/util:variant_utils",
"//deepvariant/util/io:vcf",
"//deepvariant/util/protos:core_py_pb2",
"//deepvariant/util/python:hts_verbose",
Expand Down
159 changes: 57 additions & 102 deletions deepvariant/make_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
from deepvariant.util import proto_utils
from deepvariant.util import ranges
from deepvariant.util import utils
from deepvariant.util import variant_utils
from deepvariant.util.protos import core_pb2
from deepvariant.util.python import hts_verbose
from deepvariant import exclude_contigs
Expand Down Expand Up @@ -589,63 +588,6 @@ def regions_to_process(contigs,
return partitioned


# ---------------------------------------------------------------------------
# Variant labeler
# ---------------------------------------------------------------------------


class _Counter(object):

def __init__(self, name, selectp):
self.name = name
self.selectp = selectp
self.n_selected = 0


class VariantCounters(object):
"""Provides stats about the number of variants satisfying pfuncs."""

def __init__(self, names_and_selectors):
self.counters = []
self.n_total = 0
for name, selector in names_and_selectors:
self.counters.append(_Counter(name, selector))

def update(self, variant):
self.n_total += 1
for counter in self.counters:
if counter.selectp(variant):
counter.n_selected += 1

def log(self):
logging.info('----- VariantCounts -----')
for counter in self.counters:
percent = (100.0 * counter.n_selected) / (max(self.n_total, 1.0))
logging.info('%s: %s/%s (%.2f%%)', counter.name, counter.n_selected,
self.n_total, percent)


def make_counters():
"""Creates all of the VariantCounters we want to track."""

def _gt_selector(*gt_types):
return lambda v: variant_utils.genotype_type(v) in gt_types

return VariantCounters([
('All', lambda v: True),
('SNPs', variant_utils.is_snp),
('Indels', variant_utils.is_indel),
('BiAllelic', variant_utils.is_biallelic),
('MultiAllelic', variant_utils.is_multiallelic),
('HomRef', _gt_selector(variant_utils.GenotypeType.hom_ref)),
('Het', _gt_selector(variant_utils.GenotypeType.het)),
('HomAlt', _gt_selector(variant_utils.GenotypeType.hom_var)),
('NonRef',
_gt_selector(variant_utils.GenotypeType.het,
variant_utils.GenotypeType.hom_var)),
])


# ---------------------------------------------------------------------------
# Region processor
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -727,9 +669,12 @@ def _initialize(self):
options=self.options.pic_options)

if in_training_mode(self.options):
self.labeler = variant_labeler.VariantLabeler(
vcf.VcfReader(self.options.truth_variants_filename),
read_confident_regions(self.options))
self.labeler = variant_labeler.make_labeler(
truth_variants_reader=vcf.VcfReader(
self.options.truth_variants_filename),
ref_reader=self.ref_reader,
confident_regions=read_confident_regions(self.options),
options=None)

self.variant_caller = variant_caller.VariantCaller(
self.options.variant_caller_options)
Expand Down Expand Up @@ -760,19 +705,21 @@ def process(self, region):

self.in_memory_sam_reader.replace_reads(self.region_reads(region))
candidates, gvcfs = self.candidates_in_region(region)
examples = []
for candidate in candidates:
for example in self.create_pileup_examples(candidate):
if in_training_mode(self.options):
if self.label_variant(example, candidate.variant):
examples.append(example)
else:
examples.append(example)

if in_training_mode(self.options):
examples = [
self.add_label_to_example(example, label)
for candidate, label in self.label_candidates(candidates)
for example in self.create_pileup_examples(candidate)
]
else:
examples = [
example for candidate in candidates
for example in self.create_pileup_examples(candidate)
]

logging.info('Found %s candidates in %s [%0.2fs elapsed]', len(examples),
ranges.to_literal(region), region_timer.Stop())
# Useful for debugging what examples are emitted...
# for example in examples:
# logging.info(' example: %s', tf_utils.example_key(example))
return candidates, examples, gvcfs

def region_reads(self, region):
Expand Down Expand Up @@ -857,35 +804,50 @@ def create_pileup_examples(self, dv_call):
image_format=tensor_format))
return examples

def label_variant(self, example, variant):
"""Adds the truth variant and label for variant to example.
def label_candidates(self, candidates):
"""Gets label information for each candidate.
This function uses VariantLabeler to find a match for variant and writes
in the correspond truth variant and derived label to our example proto.
Args:
candidates: list[DeepVariantCalls]: The list of candidate variant calls we
want to label.
Yields:
Tuples of (candidate, label_variants.Label objects) for each candidate in
candidates that could be assigned a label. Candidates that couldn't be
labeled will not be returned.
"""
# Get our list of labels for each candidate variant.
labels = self.labeler.label_variants(
[candidate.variant for candidate in candidates])

# Remove any candidates we couldn't label, yielding candidate, label pairs.
for candidate, label in zip(candidates, labels):
if label.is_confident:
yield candidate, label

def add_label_to_example(self, example, label):
"""Adds label information about the assigned label to our example.
Args:
example: A tf.Example proto. We will write truth_variant and label into
this proto.
variant: A nucleus.genomics.v1.Variant proto.
This is the variant we'll use
to call our VariantLabeler.match to get our truth variant.
label: A variant_labeler.Label object containing the labeling information
to add to our example.
Returns:
True if the variant was in the confident region (meaning that it could be
given a label) and False otherwise.
The example proto with label fields added.
Raises:
ValueError: if label isn't confident.
"""
is_confident, truth_variant = self.labeler.match(variant)
if not is_confident:
return False
alt_alleles = tf_utils.example_alt_alleles(example, variant=variant)
if variant_utils.is_ref(variant):
label = 0
else:
label = self.labeler.match_to_alt_count(variant, truth_variant,
alt_alleles)
tf_utils.example_set_label(example, label)
tf_utils.example_set_truth_variant(example, truth_variant)
return True
if not label.is_confident:
raise ValueError('Cannot add a non-confident label to an example',
example, label)
alt_alleles_indices = tf_utils.example_alt_alleles_indices(example)
tf_utils.example_set_label(example,
label.label_for_alt_alleles(alt_alleles_indices))
tf_utils.example_set_truth_variant(example, label.truth_variant)
return example


def processing_regions_from_options(options):
Expand Down Expand Up @@ -993,9 +955,6 @@ def _write(self, writer_name, *protos):

def make_examples_runner(options):
"""Runs examples creation stage of deepvariant."""
# Counting variants.
counters = make_counters()

logging.info('Preparing inputs')
regions = processing_regions_from_options(options)

Expand Down Expand Up @@ -1023,16 +982,12 @@ def make_examples_runner(options):
# we'll never execute the write.
if gvcfs:
writer.write_gvcfs(*gvcfs)

if in_training_mode(options):
for example in examples:
counters.update(tf_utils.example_truth_variant(example))
writer.write_examples(*examples)

logging.info('Found %s candidate variants', n_candidates)
if in_training_mode(options):
# This printout is misleading if we are in calling mode.
counters.log()
# This printout is only valid in training mode.
region_processor.labeler.log()


def main(argv=()):
Expand Down

0 comments on commit 0a159f3

Please sign in to comment.