Skip to content

Commit

Permalink
m
Browse files Browse the repository at this point in the history
  • Loading branch information
jean-philippe-martin committed May 15, 2015
1 parent f759546 commit 0e9f871
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,15 @@ public final class BQSR_Dataflow implements Serializable {
public static final TupleTag<Read> readTag = new TupleTag<>();
public static final TupleTag<SimpleInterval> intervalTag = new TupleTag<>();
public static final TupleTag<RecalibrationTables> tablesTag = new TupleTag<>();
// public static final TupleTag<TimingLog> logTag = new TupleTag<>();
// public static final TupleTag<ArrayListOfTimingLog> logListTag = new TupleTag<>();

/**
* Get a single RecalibrationTables object that represents the output of phase 1 of BQSR.
* <p>
* The reference file (*.fasta) must also have a .dict and a .fasta.fai next to it.
*/
public static PCollection<RecalibrationTables> GetRecalibrationTables(GenomicsFactory.OfflineAuth offlineAuth, SAMFileHeader readsHeader, PCollection<Read> reads, String referenceFileName, BaseRecalibrationArgumentCollection toolArgs, PCollection<SimpleInterval> placesToIgnore) {
PCollection<RecalibrationTables> stats = computeBlockStatistics(offlineAuth, readsHeader, referenceFileName, toolArgs, groupByBlock(reads, placesToIgnore));
public static PCollection<RecalibrationTables> GetRecalibrationTables(SAMFileHeader readsHeader, PCollection<Read> reads, String referenceFileName, BaseRecalibrationArgumentCollection toolArgs, PCollection<SimpleInterval> placesToIgnore) {
PCollection<RecalibrationTables> stats = computeBlockStatistics(readsHeader, referenceFileName, toolArgs, groupByBlock(reads, placesToIgnore));
PCollection<RecalibrationTables> oneStat = aggregateStatistics(stats);
//PCollection<ArrayListOfTimingLog> logs = gatherLogs(statsAndLogs.get(logTag));
//PCollectionTuple ret = PCollectionTuple.of(tablesTag, tables).and(logListTag, logs);
return oneStat;
}

Expand Down Expand Up @@ -103,9 +99,9 @@ private static String posKey(Locatable loc) {

// a key to group intervals by (aka sharding)
private static String posKey(Read r) {
// Reads are 0-based, SimpleIntervals are 1-based
int start = r.getAlignment().getPosition().getPosition().intValue() + 1;
return posKey(new SimpleInterval(r.getAlignment().getPosition().getReferenceName(),
// Reads are 0-based, SimpleIntervals are 1-based
start,
start));
}
Expand Down Expand Up @@ -165,88 +161,88 @@ public void processElement(ProcessContext c) {
* <p>
* Basically just delegate to the CalibrationTablesBuilder.
*/
static PCollection<RecalibrationTables> computeBlockStatistics(GenomicsFactory.OfflineAuth offlineAuth, SAMFileHeader readsHeader, String referenceFileName, BaseRecalibrationArgumentCollection toolArgs, PCollection<KV<String, CoGbkResult>> readsAndIgnores) {
static PCollection<RecalibrationTables> computeBlockStatistics(SAMFileHeader readsHeader, String referenceFileName, BaseRecalibrationArgumentCollection toolArgs, PCollection<KV<String, CoGbkResult>> readsAndIgnores) {
// readsHeader isn't serializable, so we pass it as a String instead.
StringWriter stringWriter = new StringWriter();
new SAMTextHeaderCodec().encode(stringWriter, readsHeader);
final String serializableHeader = stringWriter.toString();
PCollection<RecalibrationTables> ret = readsAndIgnores.apply(ParDo
.named("computeBlockStatistics")
.of(new DoFn<KV<String, CoGbkResult>, RecalibrationTables>() {
CalibrationTablesBuilder ct;
Stopwatch timer;
int nBlocks = 0;
int nReads = 0;
.named("computeBlockStatistics")
.of(new DoFn<KV<String, CoGbkResult>, RecalibrationTables>() {
CalibrationTablesBuilder ct;
Stopwatch timer;
int nBlocks = 0;
int nReads = 0;

@Override
public void startBundle(DoFn.Context c) throws Exception {
timer = Stopwatch.createStarted();
SAMFileHeader header = new SAMTextHeaderCodec().decode(new StringLineReader(serializableHeader), serializableHeader);
@Override
public void startBundle(DoFn.Context c) throws Exception {
timer = Stopwatch.createStarted();
SAMFileHeader header = new SAMTextHeaderCodec().decode(new StringLineReader(serializableHeader), serializableHeader);

String localReference = referenceFileName;
if (referenceFileName.startsWith("gs://")) {
// the reference is on GCS, download all 3 files locally first.
System.out.println("Downloading reference files");
boolean first = true;
for (String fname : getRelatedFiles(referenceFileName)) {
String localName = "reference";
int slash = fname.lastIndexOf('/');
if (slash >= 0) {
localName = fname.substring(slash + 1);
}
// download reference if necessary
if (new File(localName).exists()) {
if (first) localReference = localName;
} else {
try (
InputStream in = Channels.newInputStream(new GcsUtil.GcsUtilFactory().create(c.getPipelineOptions()).open(GcsPath.fromUri(fname)));
FileOutputStream fout = new FileOutputStream(localName)) {
final byte[] buf = new byte[1024*1024];
int count;
while ((count = in.read(buf)) > 0) {
fout.write(buf, 0, count);
String localReference = referenceFileName;
if (referenceFileName.startsWith("gs://")) {
// the reference is on GCS, download all 3 files locally first.
System.out.println("Downloading reference files");
boolean first = true;
for (String fname : getRelatedFiles(referenceFileName)) {
String localName = "reference";
int slash = fname.lastIndexOf('/');
if (slash >= 0) {
localName = fname.substring(slash + 1);
}
// download reference if necessary
if (new File(localName).exists()) {
if (first) localReference = localName;
} else {
try (
InputStream in = Channels.newInputStream(new GcsUtil.GcsUtilFactory().create(c.getPipelineOptions()).open(GcsPath.fromUri(fname)));
FileOutputStream fout = new FileOutputStream(localName)) {
final byte[] buf = new byte[1024 * 1024];
int count;
while ((count = in.read(buf)) > 0) {
fout.write(buf, 0, count);
}
}
if (first) localReference = localName;
}
first = false;
}
if (first) localReference = localName;
System.out.printf("Done downloading reference files (%s ms)\n", timer.elapsed(TimeUnit.MILLISECONDS));
}
first = false;
}
System.out.printf("Done downloading reference files (%s ms)\n", timer.elapsed(TimeUnit.MILLISECONDS));
}

ct = new CalibrationTablesBuilder(header, localReference, toolArgs);
}
ct = new CalibrationTablesBuilder(header, localReference, toolArgs);
}

@Override
public void processElement(ProcessContext c) throws Exception {
nBlocks++;
// get the reads
KV<String, CoGbkResult> e = c.element();
List<Read> reads = new ArrayList<Read>();
Iterable<Read> readsIter = e.getValue().getAll(BQSR_Dataflow.readTag);
Iterables.addAll(reads, readsIter);
nReads += reads.size();
// get the skip intervals
List<SimpleInterval> skipIntervals = new ArrayList<SimpleInterval>();
Iterables.addAll(skipIntervals, e.getValue().getAll(BQSR_Dataflow.intervalTag));
Collections.sort(skipIntervals, new Comparator<SimpleInterval>() {
@Override
public int compare(SimpleInterval o1, SimpleInterval o2) {
return ComparisonChain.start()
.compare(o1.getContig(), o2.getContig())
.compare(o1.getStart(), o2.getStart())
.result();
public void processElement(ProcessContext c) throws Exception {
nBlocks++;
// get the reads
KV<String, CoGbkResult> e = c.element();
List<Read> reads = new ArrayList<Read>();
Iterable<Read> readsIter = e.getValue().getAll(BQSR_Dataflow.readTag);
Iterables.addAll(reads, readsIter);
nReads += reads.size();
// get the skip intervals
List<SimpleInterval> skipIntervals = new ArrayList<SimpleInterval>();
Iterables.addAll(skipIntervals, e.getValue().getAll(BQSR_Dataflow.intervalTag));
Collections.sort(skipIntervals, new Comparator<SimpleInterval>() {
@Override
public int compare(SimpleInterval o1, SimpleInterval o2) {
return ComparisonChain.start()
.compare(o1.getContig(), o2.getContig())
.compare(o1.getStart(), o2.getStart())
.result();
}
});
// update our statistics
ct.add(reads, skipIntervals);
}
});
// update our statistics
ct.add(reads, skipIntervals);
}

@Override
public void finishBundle(DoFn.Context c) throws Exception {
ct.done();
c.output(ct.getRecalibrationTables());
System.out.println("Finishing a block statistics bundle. It took " + timer.elapsed(TimeUnit.MILLISECONDS) + " ms to process " + nBlocks + " blocks, " + nReads + " reads.");
@Override
public void finishBundle(DoFn.Context c) throws Exception {
ct.done();
c.output(ct.getRecalibrationTables());
System.out.println("Finishing a block statistics bundle. It took " + timer.elapsed(TimeUnit.MILLISECONDS) + " ms to process " + nBlocks + " blocks, " + nReads + " reads.");
timer = null;
}
}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,6 @@ protected Object doWork() {

GenomicsOptions popts = makeDirectPipelineOptions(secretsFile);

GenomicsFactory.OfflineAuth offlineAuth = GenomicsOptions.Methods.getGenomicsAuth(popts);

Pipeline pipeline = Pipeline.create(popts);
CoderRegistry coderRegistry = pipeline.getCoderRegistry();
coderRegistry.registerCoder(Read.class, GenericJsonCoder.of(Read.class));
Expand All @@ -144,7 +142,7 @@ protected Object doWork() {

// 2. set up computation
PCollection<RecalibrationTables> aggregated =
BQSR_Dataflow.GetRecalibrationTables(offlineAuth, inputs.header, inputs.reads, inputs.referenceFileName, inputs.toolArgs, inputs.skipIntervals);
BQSR_Dataflow.GetRecalibrationTables(inputs.header, inputs.reads, inputs.referenceFileName, inputs.toolArgs, inputs.skipIntervals);

if (outputTablesPath.startsWith("gs://")) {
saveSingleResultToGCS(aggregated, outputTablesPath);
Expand Down

0 comments on commit 0e9f871

Please sign in to comment.