diff --git a/src/main/java/com/google/cloud/genomics/dataflow/readers/ReadStreamer.java b/src/main/java/com/google/cloud/genomics/dataflow/readers/ReadStreamer.java index c2de7c7..55579c8 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/readers/ReadStreamer.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/readers/ReadStreamer.java @@ -15,9 +15,11 @@ import com.google.api.services.genomics.model.ReadGroupSet; import com.google.api.services.genomics.model.SearchReadGroupSetsRequest; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; import com.google.cloud.dataflow.sdk.transforms.DoFn; import com.google.cloud.dataflow.sdk.transforms.PTransform; import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.Sum; import com.google.cloud.dataflow.sdk.values.PCollection; import com.google.cloud.genomics.grpc.Channels; import com.google.cloud.genomics.utils.Contig; @@ -40,6 +42,9 @@ */ public class ReadStreamer { + // TODO should be replaced with a better heuristic + private static final long SHARD_SIZE = 5000000L; + /** * Gets ReadGroupSetIds from a given datasetId using the Genomics API. */ @@ -60,8 +65,8 @@ public static List getReadGroupSetIds(String datasetId, GenomicsFactory. } /** - * Constructs a StreamReadsRequest for a readGroupSetId, assuming that the user wants all - * to include all references. + * Constructs a StreamReadsRequest for a readGroupSetId, assuming that the user wants to + * include all references. */ public static StreamReadsRequest getReadRequests(String readGroupSetId) { return StreamReadsRequest.newBuilder() @@ -77,11 +82,11 @@ public static List getReadRequests(final String readGroupSet final Iterable contigs = Contig.parseContigsFromCommandLine(references); return FluentIterable.from(contigs) .transformAndConcat(new Function>() { - @Override - public Iterable apply(Contig contig) { - return contig.getShards(100000000L); - } - }) + @Override + public Iterable apply(Contig contig) { + return contig.getShards(SHARD_SIZE); + } + }) .transform(new Function() { @Override public StreamReadsRequest apply(Contig shard) { @@ -111,20 +116,26 @@ public PCollection apply(PCollection input) { private static class RetrieveReads extends DoFn> { - private transient StreamingReadServiceGrpc.StreamingReadServiceBlockingStub readStub; + protected Aggregator initializedShardCount; + protected Aggregator finishedShardCount; @Override public void startBundle(Context c) throws IOException { - this.readStub = StreamingReadServiceGrpc.newBlockingStub(Channels.fromDefaultCreds()); + initializedShardCount = c.createAggregator("Initialized Shard Count", new Sum.SumIntegerFn()); + finishedShardCount = c.createAggregator("Finished Shard Count", new Sum.SumIntegerFn()); } @Override - public void processElement(ProcessContext c) { + public void processElement(ProcessContext c) throws IOException { + initializedShardCount.addValue(1); + StreamingReadServiceGrpc.StreamingReadServiceBlockingStub readStub + = StreamingReadServiceGrpc.newBlockingStub(Channels.fromDefaultCreds()); Iterator iter = readStub.streamReads(c.element()); while (iter.hasNext()) { StreamReadsResponse readResponse = iter.next(); c.output(readResponse.getAlignmentsList()); } + finishedShardCount.addValue(1); } } @@ -134,10 +145,18 @@ public void processElement(ProcessContext c) { */ private static class ConvergeReadsList extends DoFn, Read> { + protected Aggregator itemCount; + + @Override + public void startBundle(Context c) { + itemCount = c.createAggregator("Number of reads", new Sum.SumLongFn()); + } + @Override public void processElement(ProcessContext c) { for (Read r : c.element()) { c.output(r); + itemCount.addValue(1L); } } } diff --git a/src/main/java/com/google/cloud/genomics/dataflow/readers/VariantStreamer.java b/src/main/java/com/google/cloud/genomics/dataflow/readers/VariantStreamer.java index e510ece..41dbe24 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/readers/VariantStreamer.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/readers/VariantStreamer.java @@ -15,9 +15,11 @@ import com.google.api.services.genomics.model.SearchVariantSetsRequest; import com.google.api.services.genomics.model.VariantSet; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; import com.google.cloud.dataflow.sdk.transforms.DoFn; import com.google.cloud.dataflow.sdk.transforms.PTransform; import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.Sum; import com.google.cloud.dataflow.sdk.values.PCollection; import com.google.cloud.genomics.grpc.Channels; import com.google.cloud.genomics.utils.Contig; @@ -40,6 +42,9 @@ */ public class VariantStreamer { + // TODO should be replaced with a better heuristic + private static final long SHARD_SIZE = 5000000L; + /** * Gets VariantSetIds from a given datasetId using the Genomics API. */ @@ -79,7 +84,7 @@ public static List getVariantRequests(final String varian .transformAndConcat(new Function>() { @Override public Iterable apply(Contig contig) { - return contig.getShards(100000000L); + return contig.getShards(SHARD_SIZE); } }) .transform(new Function() { @@ -110,20 +115,26 @@ public PCollection apply(PCollection input) { private static class RetrieveVariants extends DoFn> { - private transient StreamingVariantServiceGrpc.StreamingVariantServiceBlockingStub variantStub; + protected Aggregator initializedShardCount; + protected Aggregator finishedShardCount; @Override public void startBundle(Context c) throws IOException { - this.variantStub = StreamingVariantServiceGrpc.newBlockingStub(Channels.fromDefaultCreds()); + initializedShardCount = c.createAggregator("Initialized Shard Count", new Sum.SumIntegerFn()); + finishedShardCount = c.createAggregator("Finished Shard Count", new Sum.SumIntegerFn()); } @Override - public void processElement(ProcessContext c) { + public void processElement(ProcessContext c) throws IOException { + initializedShardCount.addValue(1); + StreamingVariantServiceGrpc.StreamingVariantServiceBlockingStub variantStub = + StreamingVariantServiceGrpc.newBlockingStub(Channels.fromDefaultCreds()); Iterator iter = variantStub.streamVariants(c.element()); while (iter.hasNext()) { StreamVariantsResponse variantResponse = iter.next(); c.output(variantResponse.getVariantsList()); } + finishedShardCount.addValue(1); } } @@ -133,10 +144,18 @@ public void processElement(ProcessContext c) { */ private static class ConvergeVariantsList extends DoFn, Variant> { + protected Aggregator itemCount; + + @Override + public void startBundle(Context c) { + itemCount = c.createAggregator("Number of variants", new Sum.SumLongFn()); + } + @Override public void processElement(ProcessContext c) { for (Variant v : c.element()) { c.output(v); + itemCount.addValue(1L); } } }