Skip to content
This repository was archived by the owner on Oct 29, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@

import com.google.api.services.genomics.model.ReadGroupSet;
import com.google.api.services.genomics.model.SearchReadGroupSetsRequest;
import com.google.cloud.dataflow.sdk.transforms.Aggregator;
import com.google.cloud.dataflow.sdk.transforms.DoFn;
import com.google.cloud.dataflow.sdk.transforms.PTransform;
import com.google.cloud.dataflow.sdk.transforms.ParDo;
import com.google.cloud.dataflow.sdk.transforms.Sum;
import com.google.cloud.dataflow.sdk.values.PCollection;
import com.google.cloud.genomics.grpc.Channels;
import com.google.cloud.genomics.utils.Contig;
Expand All @@ -40,6 +42,9 @@
*/
public class ReadStreamer {

// TODO should be replaced with a better heuristic
private static final long SHARD_SIZE = 5000000L;

/**
* Gets ReadGroupSetIds from a given datasetId using the Genomics API.
*/
Expand All @@ -60,8 +65,8 @@ public static List<String> getReadGroupSetIds(String datasetId, GenomicsFactory.
}

/**
* Constructs a StreamReadsRequest for a readGroupSetId, assuming that the user wants all
* to include all references.
* Constructs a StreamReadsRequest for a readGroupSetId, assuming that the user wants to
* include all references.
*/
public static StreamReadsRequest getReadRequests(String readGroupSetId) {
return StreamReadsRequest.newBuilder()
Expand All @@ -77,11 +82,11 @@ public static List<StreamReadsRequest> getReadRequests(final String readGroupSet
final Iterable<Contig> contigs = Contig.parseContigsFromCommandLine(references);
return FluentIterable.from(contigs)
.transformAndConcat(new Function<Contig, Iterable<Contig>>() {
@Override
public Iterable<Contig> apply(Contig contig) {
return contig.getShards(100000000L);
}
})
@Override
public Iterable<Contig> apply(Contig contig) {
return contig.getShards(SHARD_SIZE);
}
})
.transform(new Function<Contig, StreamReadsRequest>() {
@Override
public StreamReadsRequest apply(Contig shard) {
Expand Down Expand Up @@ -111,20 +116,26 @@ public PCollection<Read> apply(PCollection<StreamReadsRequest> input) {

private static class RetrieveReads extends DoFn<StreamReadsRequest, List<Read>> {

private transient StreamingReadServiceGrpc.StreamingReadServiceBlockingStub readStub;
protected Aggregator<Integer> initializedShardCount;
protected Aggregator<Integer> finishedShardCount;

@Override
public void startBundle(Context c) throws IOException {
this.readStub = StreamingReadServiceGrpc.newBlockingStub(Channels.fromDefaultCreds());
initializedShardCount = c.createAggregator("Initialized Shard Count", new Sum.SumIntegerFn());
finishedShardCount = c.createAggregator("Finished Shard Count", new Sum.SumIntegerFn());
}

@Override
public void processElement(ProcessContext c) {
public void processElement(ProcessContext c) throws IOException {
initializedShardCount.addValue(1);
StreamingReadServiceGrpc.StreamingReadServiceBlockingStub readStub
= StreamingReadServiceGrpc.newBlockingStub(Channels.fromDefaultCreds());
Iterator<StreamReadsResponse> iter = readStub.streamReads(c.element());
while (iter.hasNext()) {
StreamReadsResponse readResponse = iter.next();
c.output(readResponse.getAlignmentsList());
}
finishedShardCount.addValue(1);
}
}

Expand All @@ -134,10 +145,18 @@ public void processElement(ProcessContext c) {
*/
private static class ConvergeReadsList extends DoFn<List<Read>, Read> {

protected Aggregator<Long> itemCount;

@Override
public void startBundle(Context c) {
itemCount = c.createAggregator("Number of reads", new Sum.SumLongFn());
}

@Override
public void processElement(ProcessContext c) {
for (Read r : c.element()) {
c.output(r);
itemCount.addValue(1L);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@

import com.google.api.services.genomics.model.SearchVariantSetsRequest;
import com.google.api.services.genomics.model.VariantSet;
import com.google.cloud.dataflow.sdk.transforms.Aggregator;
import com.google.cloud.dataflow.sdk.transforms.DoFn;
import com.google.cloud.dataflow.sdk.transforms.PTransform;
import com.google.cloud.dataflow.sdk.transforms.ParDo;
import com.google.cloud.dataflow.sdk.transforms.Sum;
import com.google.cloud.dataflow.sdk.values.PCollection;
import com.google.cloud.genomics.grpc.Channels;
import com.google.cloud.genomics.utils.Contig;
Expand All @@ -40,6 +42,9 @@
*/
public class VariantStreamer {

// TODO should be replaced with a better heuristic
private static final long SHARD_SIZE = 5000000L;

/**
* Gets VariantSetIds from a given datasetId using the Genomics API.
*/
Expand Down Expand Up @@ -79,7 +84,7 @@ public static List<StreamVariantsRequest> getVariantRequests(final String varian
.transformAndConcat(new Function<Contig, Iterable<Contig>>() {
@Override
public Iterable<Contig> apply(Contig contig) {
return contig.getShards(100000000L);
return contig.getShards(SHARD_SIZE);
}
})
.transform(new Function<Contig, StreamVariantsRequest>() {
Expand Down Expand Up @@ -110,20 +115,26 @@ public PCollection<Variant> apply(PCollection<StreamVariantsRequest> input) {

private static class RetrieveVariants extends DoFn<StreamVariantsRequest, List<Variant>> {

private transient StreamingVariantServiceGrpc.StreamingVariantServiceBlockingStub variantStub;
protected Aggregator<Integer> initializedShardCount;
protected Aggregator<Integer> finishedShardCount;

@Override
public void startBundle(Context c) throws IOException {
this.variantStub = StreamingVariantServiceGrpc.newBlockingStub(Channels.fromDefaultCreds());
initializedShardCount = c.createAggregator("Initialized Shard Count", new Sum.SumIntegerFn());
finishedShardCount = c.createAggregator("Finished Shard Count", new Sum.SumIntegerFn());
}

@Override
public void processElement(ProcessContext c) {
public void processElement(ProcessContext c) throws IOException {
initializedShardCount.addValue(1);
StreamingVariantServiceGrpc.StreamingVariantServiceBlockingStub variantStub =
StreamingVariantServiceGrpc.newBlockingStub(Channels.fromDefaultCreds());
Iterator<StreamVariantsResponse> iter = variantStub.streamVariants(c.element());
while (iter.hasNext()) {
StreamVariantsResponse variantResponse = iter.next();
c.output(variantResponse.getVariantsList());
}
finishedShardCount.addValue(1);
}
}

Expand All @@ -133,10 +144,18 @@ public void processElement(ProcessContext c) {
*/
private static class ConvergeVariantsList extends DoFn<List<Variant>, Variant> {

protected Aggregator<Long> itemCount;

@Override
public void startBundle(Context c) {
itemCount = c.createAggregator("Number of variants", new Sum.SumLongFn());
}

@Override
public void processElement(ProcessContext c) {
for (Variant v : c.element()) {
c.output(v);
itemCount.addValue(1L);
}
}
}
Expand Down