Skip to content

Commit

Permalink
Refactor simplify_alleles
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 308140249
  • Loading branch information
sgoe1 authored and Copybara-Service committed Apr 23, 2020
1 parent a9c4b52 commit 01bad74
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 60 deletions.
1 change: 1 addition & 0 deletions deepvariant/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ py_test(
"//third_party/nucleus/protos:variants_py_pb2",
"//third_party/nucleus/util:errors",
"//third_party/nucleus/util:genomics_math",
"//third_party/nucleus/util:variant_utils",
"//third_party/nucleus/util:vcf_constants",
"@absl_py//absl/flags",
"@absl_py//absl/testing:absltest",
Expand Down
29 changes: 4 additions & 25 deletions deepvariant/postprocess_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,28 +573,6 @@ def prune_alleles(variant, alt_alleles_to_remove):
return new_variant


def simplify_alleles(variant):
"""Replaces the alleles in variants with their simplified versions.
This function takes a variant and replaces its ref and alt alleles with those
produced by a call to variant_utils.simplify_alleles() to remove common
postfix bases in the alleles that may be present due to pruning away alleles.
Args:
variant: learning.genomics.genomics.Variant proto we want to simplify.
Returns:
variant with its ref and alt alleles replaced with their simplified
equivalents.
"""
simplified_alleles = variant_utils.simplify_alleles(variant.reference_bases,
*variant.alternate_bases)
variant.reference_bases = simplified_alleles[0]
variant.alternate_bases[:] = simplified_alleles[1:]
variant.end = variant.start + len(variant.reference_bases)
return variant


def merge_predictions(call_variants_outputs, qual_filter=None):
"""Merges the predictions from the multi-allelic calls."""
# See the logic described in the class PileupImageCreator pileup_image.py
Expand All @@ -610,7 +588,8 @@ def merge_predictions(call_variants_outputs, qual_filter=None):
first_call, other_calls = call_variants_outputs[0], call_variants_outputs[1:]
canonical_variant = first_call.variant
if not other_calls:
canonical_variant = simplify_alleles(canonical_variant)
canonical_variant = variant_utils.simplify_variant_alleles(
canonical_variant)
return canonical_variant, first_call.genotype_probabilities

alt_alleles_to_remove = get_alt_alleles_to_remove(call_variants_outputs,
Expand All @@ -626,10 +605,10 @@ def merge_predictions(call_variants_outputs, qual_filter=None):
if sum(predictions) == 0:
predictions = [1.0] * len(predictions)
denominator = sum(predictions)
# Note the simplify_alleles call *must* happen after the predictions
# Note the simplify_variant_alleles call *must* happen after the predictions
# calculation above. flattened_probs_dict is indexed by alt allele, and
# simplify can change those alleles so we cannot simplify until afterwards.
canonical_variant = simplify_alleles(canonical_variant)
canonical_variant = variant_utils.simplify_variant_alleles(canonical_variant)
return canonical_variant, [i / denominator for i in predictions]


Expand Down
38 changes: 3 additions & 35 deletions deepvariant/postprocess_variants_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from third_party.nucleus.protos import variants_pb2
from third_party.nucleus.testing import test_utils
from third_party.nucleus.util import genomics_math
from third_party.nucleus.util import variant_utils
from third_party.nucleus.util import vcf_constants
from deepvariant import dv_constants
from deepvariant import dv_vcf_constants
Expand Down Expand Up @@ -1016,39 +1017,6 @@ def test_prune_alleles(self, canonical_variant, alt_alleles_to_remove,
alt_alleles_to_remove),
expected_variant)

@parameterized.parameters(
dict(
alleles=['CAA', 'CA', 'C'],
start=5,
expected_alleles=['CAA', 'CA', 'C'],
expected_end=8),
dict(
alleles=['CAA', 'CA'],
start=4,
expected_alleles=['CA', 'C'],
expected_end=6),
dict(
alleles=['CAA', 'C'],
start=3,
expected_alleles=['CAA', 'C'],
expected_end=6),
dict(
alleles=['CCA', 'CA'],
start=2,
expected_alleles=['CC', 'C'],
expected_end=4),
)
def test_simplify_alleles(self, alleles, start, expected_alleles,
expected_end):
"""Test that simplify_alleles works as expected."""
variant = _create_variant_with_alleles(
ref=alleles[0], alts=alleles[1:], start=start)
simplified = postprocess_variants.simplify_alleles(variant)
self.assertEqual(simplified.reference_bases, expected_alleles[0])
self.assertEqual(simplified.alternate_bases, expected_alleles[1:])
self.assertEqual(simplified.start, start)
self.assertEqual(simplified.end, expected_end)

@parameterized.parameters(
# Check that we are simplifying alleles and that the simplification deps
# on the alleles we've removed.
Expand Down Expand Up @@ -1083,11 +1051,11 @@ def test_simplify_alleles(self, alleles, start, expected_alleles,
def test_prune_and_simplify_alleles(self, alleles, start,
alt_alleles_to_remove, expected_alleles,
expected_end):
"""Test that prune_alleles + simplify_alleles works as expected."""
"""Test that prune_alleles + simplify_variant_alleles works as expected."""
variant = _create_variant_with_alleles(
ref=alleles[0], alts=alleles[1:], start=start)
pruned = postprocess_variants.prune_alleles(variant, alt_alleles_to_remove)
simplified = postprocess_variants.simplify_alleles(pruned)
simplified = variant_utils.simplify_variant_alleles(pruned)
self.assertEqual(simplified.reference_bases, expected_alleles[0])
self.assertEqual(simplified.alternate_bases, expected_alleles[1:])
self.assertEqual(simplified.start, start)
Expand Down
22 changes: 22 additions & 0 deletions third_party/nucleus/util/variant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,28 @@ def all_the_same(items):
return alleles


def simplify_variant_alleles(variant):
"""Replaces the alleles in variants with their simplified versions.
This function takes a variant and replaces its ref and alt alleles with those
produced by a call to variant_utils.simplify_alleles() to remove common
postfix bases in the alleles that may be present due to pruning away alleles.
Args:
variant: learning.genomics.genomics.Variant proto we want to simplify.
Returns:
variant with its ref and alt alleles replaced with their simplified
equivalents.
"""
simplified_alleles = simplify_alleles(variant.reference_bases,
*variant.alternate_bases)
variant.reference_bases = simplified_alleles[0]
variant.alternate_bases[:] = simplified_alleles[1:]
variant.end = variant.start + len(variant.reference_bases)
return variant


def is_filtered(variant):
"""Returns True if variant has a non-PASS filter field, or False otherwise."""
return bool(variant.filter) and any(
Expand Down
44 changes: 44 additions & 0 deletions third_party/nucleus/util/variant_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@
TRUE_MISS = variant_utils.AlleleMismatchType.unmatched_true_alleles
EVAL_MISS = variant_utils.AlleleMismatchType.unmatched_eval_alleles

_DEFAULT_SAMPLE_NAME = 'NA12878'


def _create_variant_with_alleles(ref=None, alts=None, start=0):
"""Creates a Variant record with specified alternate_bases."""
return variants_pb2.Variant(
reference_bases=ref,
alternate_bases=alts,
start=start,
calls=[variants_pb2.VariantCall(call_set_name=_DEFAULT_SAMPLE_NAME)])


class VariantUtilsTests(parameterized.TestCase):

Expand Down Expand Up @@ -441,6 +452,39 @@ def test_simplify_alleles(self, alleles, expected):
variant_utils.simplify_alleles(*reversed(alleles)),
tuple(reversed(expected)))

@parameterized.parameters(
dict(
alleles=['CAA', 'CA', 'C'],
start=5,
expected_alleles=['CAA', 'CA', 'C'],
expected_end=8),
dict(
alleles=['CAA', 'CA'],
start=4,
expected_alleles=['CA', 'C'],
expected_end=6),
dict(
alleles=['CAA', 'C'],
start=3,
expected_alleles=['CAA', 'C'],
expected_end=6),
dict(
alleles=['CCA', 'CA'],
start=2,
expected_alleles=['CC', 'C'],
expected_end=4),
)
def test_simplify_variant_alleles(self, alleles, start, expected_alleles,
expected_end):
"""Test that simplify_variant_alleles works as expected."""
variant = _create_variant_with_alleles(
ref=alleles[0], alts=alleles[1:], start=start)
simplified = variant_utils.simplify_variant_alleles(variant)
self.assertEqual(simplified.reference_bases, expected_alleles[0])
self.assertEqual(simplified.alternate_bases, expected_alleles[1:])
self.assertEqual(simplified.start, start)
self.assertEqual(simplified.end, expected_end)

@parameterized.parameters(
(['A', 'C'], ['A', 'C'], NO_MISMATCH),
(['A', 'AC'], ['A', 'AC'], NO_MISMATCH),
Expand Down

0 comments on commit 01bad74

Please sign in to comment.