diff --git a/src/main/java/com/google/cloud/genomics/dataflow/functions/JoinNonVariantSegmentsWithVariants.java b/src/main/java/com/google/cloud/genomics/dataflow/functions/JoinNonVariantSegmentsWithVariants.java index 657d03d..279897b 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/functions/JoinNonVariantSegmentsWithVariants.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/functions/JoinNonVariantSegmentsWithVariants.java @@ -24,23 +24,19 @@ import com.google.cloud.dataflow.sdk.values.PCollection; import com.google.cloud.genomics.utils.OfflineAuth; import com.google.cloud.genomics.utils.ShardBoundary; +import com.google.cloud.genomics.utils.grpc.MergeNonVariantSegmentsWithSnps; +import com.google.cloud.genomics.utils.grpc.VariantEmitterStrategy; +import com.google.cloud.genomics.utils.grpc.VariantMergeStrategy; import com.google.cloud.genomics.utils.grpc.VariantStreamIterator; import com.google.cloud.genomics.utils.grpc.VariantUtils; -import com.google.common.base.Function; import com.google.common.base.Preconditions; import com.google.common.collect.Iterables; -import com.google.common.collect.Lists; -import com.google.common.collect.Ordering; import com.google.genomics.v1.StreamVariantsRequest; import com.google.genomics.v1.StreamVariantsResponse; import com.google.genomics.v1.Variant; -import com.google.genomics.v1.Variant.Builder; import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; import java.util.Iterator; -import java.util.LinkedList; import java.util.List; /** @@ -66,6 +62,11 @@ public static interface Options extends PipelineOptions { int getBinSize(); void setBinSize(int binSize); + @Description("The class that determines the strategy for merging non-variant segments and variants.") + @Default.Class(MergeNonVariantSegmentsWithSnps.class) + Class getVariantMergeStrategy(); + void setVariantMergeStrategy(Class mergeStrategy); + public static class Methods { public static void validateOptions(Options options) { Preconditions.checkArgument(0 < options.getBinSize(), "binSize must be greater than zero"); @@ -211,85 +212,32 @@ public void processElement(DoFn>.Proces } } - /** - * This DoFn converts data with non-variant segments (such as data that was in - * source format Genome VCF (gVCF) or Complete Genomics) to variant-only data with calls from - * non-variant-segments merged into the variants with which they overlap. - * - * This is currently done only for SNP variants. Indels and structural variants are left as-is. - */ public static final class CombineVariantsFn extends DoFn, Variant> { + private VariantMergeStrategy merger; - /** - * Dev note: this code aims to minimize the amount of data held in memory. It should only - * be the current variant we are considering and any non-variant segments that overlap it. - */ @Override - public void processElement(ProcessContext context) throws Exception { - List records = Lists.newArrayList(context.element()); - - // The sort order is critical here so that candidate overlapping reference matching blocks - // occur prior to any variants they may overlap. - Collections.sort(records, NON_VARIANT_SEGMENT_COMPARATOR); - - // The upper bound on potential overlaps is the sample size plus the number of - // block records that occur between actual variants. - List blockRecords = new LinkedList<>(); - - for (Variant record : records) { - if (!VariantUtils.IS_NON_VARIANT_SEGMENT.apply(record)) { - // Dataflow does not allow the output of modified input items, so we make a copy and - // modify that, if applicable. - Builder updatedRecord = Variant.newBuilder(record); - // TODO: determine and implement the correct criteria for overlaps of non-SNP variants - if (VariantUtils.IS_SNP.apply(record)) { - for (Iterator iterator = blockRecords.iterator(); iterator.hasNext();) { - Variant blockRecord = iterator.next(); - if (isOverlapping(blockRecord, record)) { - updatedRecord.addAllCalls(blockRecord.getCallsList()); - } else { - // Remove the current element from the iterator and the list since it is - // left of the genomic region we are currently working on due to our sort. - iterator.remove(); - } - } - } - // Emit this variant and move on (no need to hang onto it in memory). - context.output(updatedRecord.build()); - } else { - blockRecords.add(record); - } - } + public void startBundle(DoFn, Variant>.Context c) throws Exception { + super.startBundle(c); + Options options = c.getPipelineOptions().as(Options.class); + merger = options.getVariantMergeStrategy().newInstance(); } - static final Ordering BY_START = Ordering.natural().onResultOf( - new Function() { - @Override - public Long apply(Variant variant) { - return variant.getStart(); - } - }); + @Override + public void processElement(ProcessContext context) throws Exception { + merger.merge(context.element(), new DataflowVariantEmitter(context)); + } + } - static final Ordering BY_FIRST_OF_ALTERNATE_BASES = Ordering.natural() - .nullsFirst().onResultOf(new Function() { - @Override - public String apply(Variant variant) { - if (null == variant.getAlternateBasesList() || variant.getAlternateBasesList().isEmpty()) { - return null; - } - return variant.getAlternateBases(0); - } - }); + public static class DataflowVariantEmitter implements VariantEmitterStrategy { + private final DoFn, Variant>.ProcessContext context; - // Special-purpose comparator for use in dealing with both variant and non-variant segment data. - // Sort by start position ascending and ensure that if a variant and a ref-matching block are at - // the same position, the non-variant segment record comes first. - static final Comparator NON_VARIANT_SEGMENT_COMPARATOR = BY_START - .compound(BY_FIRST_OF_ALTERNATE_BASES); + public DataflowVariantEmitter(DoFn, Variant>.ProcessContext context) { + this.context = context; + } - static final boolean isOverlapping(Variant blockRecord, Variant variant) { - return blockRecord.getStart() <= variant.getStart() - && blockRecord.getEnd() >= variant.getStart() + 1; + @Override + public void emit(Variant variant) { + context.output(variant); } } } diff --git a/src/main/java/com/google/cloud/genomics/dataflow/functions/ibs/AlleleSimilarityCalculator.java b/src/main/java/com/google/cloud/genomics/dataflow/functions/ibs/AlleleSimilarityCalculator.java index c056825..b49691c 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/functions/ibs/AlleleSimilarityCalculator.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/functions/ibs/AlleleSimilarityCalculator.java @@ -17,7 +17,7 @@ import com.google.cloud.dataflow.sdk.values.KV; import com.google.cloud.genomics.dataflow.utils.CallFilters; import com.google.cloud.genomics.dataflow.utils.PairGenerator; -import com.google.cloud.genomics.utils.grpc.VariantUtils; +import com.google.cloud.genomics.utils.grpc.VariantCallUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.Maps; import com.google.genomics.v1.Variant; @@ -56,7 +56,7 @@ public void processElement(ProcessContext context) { CallSimilarityCalculator callSimilarityCalculator = callSimilarityCalculatorFactory.get(isReferenceMajor(variant)); for (KV pair : PairGenerator.WITHOUT_REPLACEMENT.allPairs( - getSamplesWithVariant(variant), VariantUtils.CALL_COMPARATOR)) { + getSamplesWithVariant(variant), VariantCallUtils.CALL_COMPARATOR)) { accumulateCallSimilarity(callSimilarityCalculator, pair.getKey(), pair.getValue()); } } diff --git a/src/main/java/com/google/cloud/genomics/dataflow/utils/ShardOptions.java b/src/main/java/com/google/cloud/genomics/dataflow/utils/ShardOptions.java index 070516b..b520b40 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/utils/ShardOptions.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/utils/ShardOptions.java @@ -24,11 +24,11 @@ public interface ShardOptions extends GenomicsOptions { @Description("By default, variants analyses will be run on BRCA1. Pass this flag to run on all " + "references present in the dataset. Note that certain jobs such as PCA and IBS " + "will automatically exclude X and Y chromosomes when this option is true.") + @Default.Boolean(false) boolean isAllReferences(); void setAllReferences(boolean allReferences); - @Description("Comma separated tuples of reference:start:end,... " - + "Defaults to '17:41196311:41277499' for BRCA1.") + @Description("Comma separated tuples of reference:start:end,... ") @Default.String("17:41196311:41277499") String getReferences(); void setReferences(String references); diff --git a/src/test/java/com/google/cloud/genomics/dataflow/functions/JoinNonVariantSegmentsWithVariantsTest.java b/src/test/java/com/google/cloud/genomics/dataflow/functions/JoinNonVariantSegmentsWithVariantsTest.java index d3bb391..f5d7233 100644 --- a/src/test/java/com/google/cloud/genomics/dataflow/functions/JoinNonVariantSegmentsWithVariantsTest.java +++ b/src/test/java/com/google/cloud/genomics/dataflow/functions/JoinNonVariantSegmentsWithVariantsTest.java @@ -28,6 +28,7 @@ import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; import com.google.cloud.dataflow.sdk.values.KV; import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.genomics.utils.grpc.VariantUtils; import com.google.common.collect.Lists; import com.google.genomics.v1.Variant; import com.google.genomics.v1.VariantCall; @@ -105,7 +106,7 @@ public void setUp() { @Test public void testVariantVariantComparator() { - Comparator comparator = JoinNonVariantSegmentsWithVariants.CombineVariantsFn.NON_VARIANT_SEGMENT_COMPARATOR; + Comparator comparator = VariantUtils.NON_VARIANT_SEGMENT_COMPARATOR; assertEquals(-1, comparator.compare(blockRecord1, snp1)); assertEquals(1, comparator.compare(blockRecord2, snp1)); @@ -141,10 +142,10 @@ public void testVariantVariantComparator() { @Test public void testIsOverlapping() { - assertTrue(JoinNonVariantSegmentsWithVariants.CombineVariantsFn.isOverlapping(blockRecord1, snp1)); - assertTrue(JoinNonVariantSegmentsWithVariants.CombineVariantsFn.isOverlapping(blockRecord1, snp2)); - assertFalse(JoinNonVariantSegmentsWithVariants.CombineVariantsFn.isOverlapping(blockRecord2, snp1)); - assertTrue(JoinNonVariantSegmentsWithVariants.CombineVariantsFn.isOverlapping(blockRecord2, snp2)); + assertTrue(VariantUtils.isOverlapping(blockRecord1, snp1)); + assertTrue(VariantUtils.isOverlapping(blockRecord1, snp2)); + assertFalse(VariantUtils.isOverlapping(blockRecord2, snp1)); + assertTrue(VariantUtils.isOverlapping(blockRecord2, snp2)); } @Test