Skip to content

Commit

Permalink
After we fix the bases before we pass them on to BuildAlleleMap in Vc…
Browse files Browse the repository at this point in the history
…fCandidateImporter, we now need to simplify alleles for all.

PiperOrigin-RevId: 306257932
  • Loading branch information
pichuan authored and Copybara-Service committed Apr 13, 2020
1 parent 846d1b4 commit 77fdc54
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
1 change: 1 addition & 0 deletions deepvariant/postprocess_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ def merge_predictions(call_variants_outputs, qual_filter=None):
first_call, other_calls = call_variants_outputs[0], call_variants_outputs[1:]
canonical_variant = first_call.variant
if not other_calls:
canonical_variant = simplify_alleles(canonical_variant)
return canonical_variant, first_call.genotype_probabilities

alt_alleles_to_remove = get_alt_alleles_to_remove(call_variants_outputs,
Expand Down
58 changes: 54 additions & 4 deletions deepvariant/postprocess_variants_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _create_variant_with_alleles(ref=None, alts=None, start=0):


def _create_call_variants_output(indices,
probabilities,
probabilities=None,
ref=None,
alts=None,
variant=None):
Expand Down Expand Up @@ -549,13 +549,29 @@ def test_convert_call_variants_outputs_to_probs_dict(self,
alts=['C', 'G', 'T']),
], [0, 0.001, 0, 0.0002, 0, 0, 0.0002, 0.0003, 0.9997, 0.00006]),
)
def test_merge_predictions(self, inputs, expected_unnormalized_probs):
def test_merge_predictions_probs(self, inputs, expected_unnormalized_probs):
denominator = sum(expected_unnormalized_probs)
for permuted_inputs in itertools.permutations(inputs):
_, predictions = postprocess_variants.merge_predictions(permuted_inputs)
np.testing.assert_almost_equal(
predictions, [x / denominator for x in expected_unnormalized_probs])

@parameterized.parameters(
([
_create_call_variants_output(indices=[0], ref='CA', alts=['A']),
], ['CA', 'A']),
([
_create_call_variants_output(indices=[0], ref='CAA', alts=['AA']),
], ['CA', 'A']),
([
_create_call_variants_output(indices=[0], ref='GA', alts=['GAA']),
], ['G', 'GA']),
)
def test_merge_predictions_simplify_variant(self, inputs, expected_alleles):
output_variants, _ = postprocess_variants.merge_predictions(inputs)
self.assertEqual(output_variants.reference_bases, expected_alleles[0])
self.assertEqual(output_variants.alternate_bases, expected_alleles[1:])

@parameterized.parameters(
# With 1 alt allele, we expect to see 1 alt_allele_indices: [0].
(
Expand Down Expand Up @@ -1000,6 +1016,39 @@ def test_prune_alleles(self, canonical_variant, alt_alleles_to_remove,
alt_alleles_to_remove),
expected_variant)

@parameterized.parameters(
dict(
alleles=['CAA', 'CA', 'C'],
start=5,
expected_alleles=['CAA', 'CA', 'C'],
expected_end=8),
dict(
alleles=['CAA', 'CA'],
start=4,
expected_alleles=['CA', 'C'],
expected_end=6),
dict(
alleles=['CAA', 'C'],
start=3,
expected_alleles=['CAA', 'C'],
expected_end=6),
dict(
alleles=['CCA', 'CA'],
start=2,
expected_alleles=['CC', 'C'],
expected_end=4),
)
def test_simplify_alleles(self, alleles, start, expected_alleles,
expected_end):
"""Test that simplify_alleles works as expected."""
variant = _create_variant_with_alleles(
ref=alleles[0], alts=alleles[1:], start=start)
simplified = postprocess_variants.simplify_alleles(variant)
self.assertEqual(simplified.reference_bases, expected_alleles[0])
self.assertEqual(simplified.alternate_bases, expected_alleles[1:])
self.assertEqual(simplified.start, start)
self.assertEqual(simplified.end, expected_end)

@parameterized.parameters(
# Check that we are simplifying alleles and that the simplification deps
# on the alleles we've removed.
Expand Down Expand Up @@ -1031,8 +1080,9 @@ def test_prune_alleles(self, canonical_variant, alt_alleles_to_remove,
expected_alleles=['CC', 'C'],
expected_end=4),
)
def test_simplify_alleles(self, alleles, start, alt_alleles_to_remove,
expected_alleles, expected_end):
def test_prune_and_simplify_alleles(self, alleles, start,
alt_alleles_to_remove, expected_alleles,
expected_end):
"""Test that prune_alleles + simplify_alleles works as expected."""
variant = _create_variant_with_alleles(
ref=alleles[0], alts=alleles[1:], start=start)
Expand Down

0 comments on commit 77fdc54

Please sign in to comment.