diff --git a/pom.xml b/pom.xml index 9e1c59c..cf93de8 100644 --- a/pom.xml +++ b/pom.xml @@ -204,6 +204,12 @@ protobuf-java 3.0.0-alpha-3 + + org.apache.commons + commons-math3 + 3.2 + jar + diff --git a/src/main/java/com/google/cloud/genomics/dataflow/functions/LikelihoodFn.java b/src/main/java/com/google/cloud/genomics/dataflow/functions/LikelihoodFn.java new file mode 100644 index 0000000..32fc76c --- /dev/null +++ b/src/main/java/com/google/cloud/genomics/dataflow/functions/LikelihoodFn.java @@ -0,0 +1,175 @@ +/* + * Copyright 2015 Google. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.genomics.dataflow.functions; + +import com.google.api.services.genomics.model.Position; +import com.google.cloud.genomics.dataflow.model.ReadCounts; +import com.google.cloud.genomics.dataflow.model.ReadQualityCount; +import com.google.cloud.genomics.dataflow.model.ReadQualityCount.Base; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import org.apache.commons.math3.analysis.UnivariateFunction; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; + +/** + * Implementation of the likelihood function in equation (2) in + * G. Jun, M. Flickinger, K. N. Hetrick, Kurt, J. M. Romm, K. F. Doheny, + * G. Abecasis, M. Boehnke,and H. M. Kang, Detecting and Estimating + * Contamination of Human DNA Samples in Sequencing and Array-Based Genotype + * Data, American journal of human genetics doi:10.1016/j.ajhg.2012.09.004 + * (volume 91 issue 5 pp.839 - 848) + * http://www.sciencedirect.com/science/article/pii/S0002929712004788 + */ +public class LikelihoodFn implements UnivariateFunction { + + /** Possible genotypes for a SNP with a single alternate */ + enum Genotype { + REF_HOMOZYGOUS, HETEROZYGOUS, NONREF_HOMOZYGOUS + } + /** Possible error statuses for a base in a read */ + enum ReadStatus { + CORRECT, ERROR + } + + static int toTableIndex(Base observed, Genotype trueGenotype, ReadStatus status) { + return observed.ordinal() + + Base.values().length * (status.ordinal() + + ReadStatus.values().length * trueGenotype.ordinal()); + } + + /* + * P_OBS_GIVEN_TRUTH contains the probability of observing a particular + * base (reference, non-reference, or other) given the true genotype + * and the error status of the read. See Table 1 of Jun et al. + */ + private static final ImmutableList P_OBS_GIVEN_TRUTH; + static { + final ImmutableList pTable = ImmutableList.of( + // Observed base + // REF NONREF OTHER + 1.0, 0.0, 0.0, // P(base | REF_HOMOZYGOUS, CORRECT) + 0.0, 1.0 / 3.0, 2.0 / 3.0, // P(base | REF_HOMOZYGOUS, ERROR) + 0.5, 0.5, 0.0, // P(base | HETEROZYGOUS, CORRECT) + 1.0 / 6.0, 1.0 / 6.0, 2.0 / 3.0, // P(base | HETEROZYGOUS, ERROR) + 0.0, 1.0, 0.0, // P(base | NONREF_HOMOZYGOUS, CORRECT) + 1.0 / 3.0, 0.0, 2.0 / 3.0); // P(base | NONREF_HOMOZYGOUS, ERROR) + Iterator itProb = pTable.iterator(); + ArrayList pCond = new ArrayList<>(); + pCond.addAll(Collections.nCopies( + Base.values().length * ReadStatus.values().length * Genotype.values().length, 0.0)); + for ( + Genotype g : ImmutableList.of(Genotype.REF_HOMOZYGOUS, Genotype.HETEROZYGOUS, + Genotype.NONREF_HOMOZYGOUS)) { + for (ReadStatus r : ImmutableList.of(ReadStatus.CORRECT, ReadStatus.ERROR)) { + for (Base b : ImmutableList.of(Base.REF, Base.NONREF, Base.OTHER)) { + pCond.set(toTableIndex(b, g, r), itProb.next()); + } + } + } + P_OBS_GIVEN_TRUTH = ImmutableList.copyOf(pCond); + } + + private final Map readCounts; + + /** + * Create a new LikelihoodFn instance for a given set of read counts. + * + * @param readCounts counts of reads by quality for each position of interest + */ + public LikelihoodFn(Map readCounts) { + // copy the map so the counts don't get changed out from under us + this.readCounts = ImmutableMap.copyOf(readCounts); + } + + /** + * Compute the probability of a genotype given the reference allele probability. + */ + private static double pGenotype(Genotype g, double refProb) { + switch(g) { + case REF_HOMOZYGOUS: + return refProb * refProb; + case HETEROZYGOUS: + return refProb * (1.0 - refProb); + case NONREF_HOMOZYGOUS: + return (1.0 - refProb) * (1.0 - refProb); + default: + throw new IllegalArgumentException("Illegal genotype"); + } + } + + /** + * Look up the probability of an observation conditioned on the underlying state. + */ + private static double probObsGivenTruth(Base observed, Genotype trueGenotype, + ReadStatus trueStatus) { + return P_OBS_GIVEN_TRUTH.get(toTableIndex(observed, trueGenotype, trueStatus)); + } + + /** + * Compute the likelihood of a contaminant fraction alpha. + * + *

See equation (2) in Jun et al. + */ + @Override + public double value(double alpha) { + double logLikelihood = 0.0; + for (ReadCounts rc : readCounts.values()) { + double refProb = rc.getRefFreq(); + + double pPosition = 0.0; + for (Genotype trueGenotype1 : Genotype.values()) { + double pGenotype1 = pGenotype(trueGenotype1, refProb); + for (Genotype trueGenotype2 : Genotype.values()) { + double pGenotype2 = pGenotype(trueGenotype2, refProb); + + double pObsGivenGenotype = 1.0; + + for (ReadQualityCount rqc : rc.getReadQualityCounts()) { + Base base = rqc.getBase(); + double pErr = phredToProb(rqc.getQuality()); + double pObs + = ((1.0 - alpha) + * probObsGivenTruth(base, trueGenotype1, ReadStatus.CORRECT) + + (alpha) + * probObsGivenTruth(base, trueGenotype2, ReadStatus.CORRECT) + ) * (1.0 - pErr) + + ((1.0 - alpha) + * probObsGivenTruth(base, trueGenotype1, ReadStatus.ERROR) + + (alpha) + * probObsGivenTruth(base, trueGenotype2, ReadStatus.ERROR) + ) * pErr; + pObsGivenGenotype *= Math.pow(pObs, rqc.getCount()); + } + pPosition += pObsGivenGenotype * pGenotype1 * pGenotype2; + } + } + logLikelihood += Math.log(pPosition); + } + return logLikelihood; + } + + /** + * Convert a Phred score to a probability. + */ + private static double phredToProb(int phred) { + return Math.pow(10.0, -(double) phred / 10.0); + } +} diff --git a/src/main/java/com/google/cloud/genomics/dataflow/model/AlleleFreq.java b/src/main/java/com/google/cloud/genomics/dataflow/model/AlleleFreq.java new file mode 100644 index 0000000..da99231 --- /dev/null +++ b/src/main/java/com/google/cloud/genomics/dataflow/model/AlleleFreq.java @@ -0,0 +1,61 @@ +/* + * Copyright 2015 Google. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.genomics.dataflow.model; + +import com.google.api.client.json.GenericJson; +import com.google.cloud.dataflow.sdk.coders.DefaultCoder; +import com.google.cloud.genomics.dataflow.coders.GenericJsonCoder; + +import java.util.List; + +/** + * Contains frequency for a set of alleles for a single position on a single chromosome. + * Used in VerifyBamId. + */ +@DefaultCoder(GenericJsonCoder.class) +public class AlleleFreq extends GenericJson { + // Strings of length 1 of one of the following bases: ['A', 'C', 'T', 'G']. + private String refBases; + // List of length 1 of a String of length 1 of one of the following bases: ['A', 'C', 'T', 'G']. + private List altBases; + // Frequency for a set of alleles for the given position on the given chromosome + // in the range [0,1]. + private double refFreq; + + public String getRefBases() { + return refBases; + } + + public void setRefBases(String refBases) { + this.refBases = refBases; + } + + public List getAltBases() { + return altBases; + } + + public void setAltBases(List altBases) { + this.altBases = altBases; + } + + public double getRefFreq() { + return refFreq; + } + + public void setRefFreq(double refFreq) { + this.refFreq = refFreq; + } +} diff --git a/src/main/java/com/google/cloud/genomics/dataflow/model/ReadCounts.java b/src/main/java/com/google/cloud/genomics/dataflow/model/ReadCounts.java new file mode 100644 index 0000000..b591984 --- /dev/null +++ b/src/main/java/com/google/cloud/genomics/dataflow/model/ReadCounts.java @@ -0,0 +1,69 @@ +/* + * Copyright 2015 Google. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.genomics.dataflow.model; + +import com.google.api.client.json.GenericJson; +import com.google.api.client.util.Lists; +import com.google.cloud.dataflow.sdk.coders.DefaultCoder; +import com.google.cloud.genomics.dataflow.coders.GenericJsonCoder; +import com.google.cloud.genomics.dataflow.model.ReadQualityCount.Base; + +import java.util.List; + +/** + * Counts of reads for a single SNP with a single alternate value for use in the + * VerifyBamId pipeline. For each SNP, we accumulate counts of bases and quality scores + * for associated aligned reads. + */ +@DefaultCoder(GenericJsonCoder.class) +public class ReadCounts extends GenericJson { + /** + * The count for a single base and quality score is stored in a ReadQualityCount object. + */ + private List readQualityCounts = Lists.newArrayList(); + /** + * refFreq contains the population frequency of the reference allele. + */ + private double refFreq; + + public List getReadQualityCounts() { + return readQualityCounts; + } + + public void setReadQualityCounts(List readQualityCounts) { + this.readQualityCounts = readQualityCounts; + } + + public void addReadQualityCount(Base base, int quality, long count) { + ReadQualityCount rqc = new ReadQualityCount(); + rqc.setBase(base); + rqc.setCount(count); + rqc.setQuality(quality); + this.readQualityCounts.add(rqc); + } + + public void addReadQualityCount(ReadQualityCount rqc) { + this.readQualityCounts.add(rqc); + } + + public double getRefFreq() { + return refFreq; + } + + public void setRefFreq(double refFreq) { + this.refFreq = refFreq; + } +} diff --git a/src/main/java/com/google/cloud/genomics/dataflow/model/ReadQualityCount.java b/src/main/java/com/google/cloud/genomics/dataflow/model/ReadQualityCount.java new file mode 100644 index 0000000..7cf0669 --- /dev/null +++ b/src/main/java/com/google/cloud/genomics/dataflow/model/ReadQualityCount.java @@ -0,0 +1,63 @@ +/* + * Copyright 2015 Google. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.genomics.dataflow.model; + +/** + * This class is used to count the number of reads aligned to a SNP that show the reference base, + * the non-reference base, some other base, or an unknown base. Within each category, we count the + * number with each quality score. + * + * For example, we might have 2 reads that show the reference base with quality 10, 5 reads that + * show the non-reference base with quality 60, and 1 read that shows a different nucleotide with + * quality 0. + */ +public class ReadQualityCount { + + private Base base; + private int quality; + private long count; + + /** + * Which type of Base this ReadQualityCount represents. + */ + public enum Base { + UNKNOWN, REF, NONREF, OTHER + }; + + public Base getBase() { + return base; + } + + public void setBase(Base base) { + this.base = base; + } + + public int getQuality() { + return quality; + } + + public void setQuality(int quality) { + this.quality = quality; + } + + public long getCount() { + return count; + } + + public void setCount(long count) { + this.count = count; + } +} diff --git a/src/main/java/com/google/cloud/genomics/dataflow/pipelines/VerifyBamId.java b/src/main/java/com/google/cloud/genomics/dataflow/pipelines/VerifyBamId.java new file mode 100644 index 0000000..ef861bb --- /dev/null +++ b/src/main/java/com/google/cloud/genomics/dataflow/pipelines/VerifyBamId.java @@ -0,0 +1,616 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.genomics.dataflow.pipelines; + +import com.google.api.services.genomics.Genomics; +import com.google.api.services.genomics.model.Position; +import com.google.api.services.genomics.model.ReadGroupSet; +import com.google.api.services.genomics.model.Reference; +import com.google.api.services.genomics.model.SearchReferencesRequest; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.SerializableCoder; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.Default; +import com.google.cloud.dataflow.sdk.options.Description; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Filter; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult; +import com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey; +import com.google.cloud.dataflow.sdk.transforms.join.KeyedPCollectionTuple; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.genomics.dataflow.coders.GenericJsonCoder; +import com.google.cloud.genomics.dataflow.functions.LikelihoodFn; +import com.google.cloud.genomics.dataflow.model.AlleleFreq; +import com.google.cloud.genomics.dataflow.model.ReadBaseQuality; +import com.google.cloud.genomics.dataflow.model.ReadBaseWithReference; +import com.google.cloud.genomics.dataflow.model.ReadCounts; +import com.google.cloud.genomics.dataflow.model.ReadQualityCount; +import com.google.cloud.genomics.dataflow.readers.ReadStreamer; +import com.google.cloud.genomics.dataflow.readers.VariantStreamer; +import com.google.cloud.genomics.dataflow.utils.DataflowWorkarounds; +import com.google.cloud.genomics.dataflow.utils.GCSOptions; +import com.google.cloud.genomics.dataflow.utils.GenomicsDatasetOptions; +import com.google.cloud.genomics.dataflow.utils.GenomicsOptions; +import com.google.cloud.genomics.dataflow.utils.ReadUtils; +import com.google.cloud.genomics.dataflow.utils.Solver; +import com.google.cloud.genomics.dataflow.utils.VariantUtils; +import com.google.cloud.genomics.utils.Contig; +import com.google.cloud.genomics.utils.GenomicsFactory; +import com.google.common.base.Function; +import com.google.common.collect.FluentIterable; +import com.google.common.collect.ImmutableMultiset; +import com.google.common.collect.Lists; +import com.google.common.collect.Multiset; +import com.google.genomics.v1.Read; +import com.google.genomics.v1.StreamReadsRequest; +import com.google.genomics.v1.StreamVariantsRequest; +import com.google.genomics.v1.Variant; +import com.google.protobuf.ListValue; + +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.nio.ByteBuffer; +import java.security.GeneralSecurityException; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * Test a set of reads for contamination. + * + * Takes a set of specified ReadGroupSets of reads to test and statistics on reference allele + * frequencies for SNPs with a single alternative from a specified set of VariantSets. + * + * Uses the sequence data alone approach described in: + * G. Jun, M. Flickinger, K. N. Hetrick, Kurt, J. M. Romm, K. F. Doheny, + * G. Abecasis, M. Boehnke,and H. M. Kang, Detecting and Estimating + * Contamination of Human DNA Samples in Sequencing and Array-Based Genotype + * Data, American journal of human genetics doi:10.1016/j.ajhg.2012.09.004 + * (volume 91 issue 5 pp.839 - 848) + * http://www.sciencedirect.com/science/article/pii/S0002929712004788 + */ +public class VerifyBamId { + + private static VerifyBamId.VerifyBamIdOptions options; + private static Pipeline p; + private static GenomicsFactory.OfflineAuth auth; + + /** + * Constant that represents the size that user given references will be parsed into for each + * individual request. + */ + private static final long SHARD_SIZE = 10000000L; + + /** + * String prefix used for sampling hash function + */ + private static final String HASH_PREFIX = ""; + + /** + * Options required to run this pipeline. + */ + public static interface VerifyBamIdOptions extends GenomicsDatasetOptions, GCSOptions { + + @Description("A comma delimited list of the IDs of the Google Genomics ReadGroupSets this " + + "pipeline is working with. Default (empty) indicates all ReadGroupSets in InputDatasetId." + + " This(and variantSetIds) or InputDatasetId must be set. InputDatasetId overrides " + + "ReadGroupSetIds (if InputDatasetId is set, this field will be ignored).") + @Default.String("") + String getReadGroupSetIds(); + + void setReadGroupSetIds(String readGroupSetId); + + @Description("A comma delimited list of the IDs of the Google Genomics VariantSets this " + + "pipeline is working with. Default (empty) indicates all VariantSets in InputDatasetId." + + " This(and readGroupSetIds) or InputDatasetId must be set. InputDatasetId overrides " + + "VariantSetIds (if InputDatasetId is set, this field will be ignored).") + @Default.String("") + String getVariantSetIds(); + + void setVariantSetIds(String variantSetId); + + @Description("The ID of the Google Genomics Dataset that the pipeline will get its input reads" + + " from. Default (empty) means to use ReadGroupSetIds and VariantSetIds instead. This or" + + " ReadGroupSetIds and VariantSetIds must be set. InputDatasetId overrides" + + " ReadGroupSetIds and VariantSetIds (if this field is set, ReadGroupSetIds and" + + " VariantSetIds will be ignored).") + @Default.String("") + String getInputDatasetId(); + + void setInputDatasetId(String inputDatasetId); + + @Description("The minimum allele frequency to use in analysis. Defaults to 0.01.") + @Default.Double(0.01) + double getMinFrequency(); + + void setMinFrequency(double minFrequency); + + @Description("The fraction of positions to check. Defaults to 0.01.") + @Default.Double(0.01) + double getSamplingFraction(); + + void setSamplingFraction(double minFrequency); + } + + /** + * Run the VerifyBamId algorithm and output the resulting contamination estimate. + */ + public static void main(String[] args) throws GeneralSecurityException, IOException { + // Register the options so that they show up via --help + PipelineOptionsFactory.register(VerifyBamIdOptions.class); + options = PipelineOptionsFactory.fromArgs(args) + .withValidation() + .as(VerifyBamId.VerifyBamIdOptions.class); + // Option validation is not yet automatic, we make an explicit call here. + GenomicsDatasetOptions.Methods.validateOptions(options); + auth = GenomicsOptions.Methods.getGenomicsAuth(options); + + p = Pipeline.create(options); + DataflowWorkarounds.registerGenomicsCoders(p); + DataflowWorkarounds.registerCoder(p, Read.class, SerializableCoder.of(Read.class)); + DataflowWorkarounds.registerCoder(p, Variant.class, SerializableCoder.of(Variant.class)); + DataflowWorkarounds.registerCoder(p, ReadBaseQuality.class, + GenericJsonCoder.of(ReadBaseQuality.class)); + DataflowWorkarounds.registerCoder(p, AlleleFreq.class, GenericJsonCoder.of(AlleleFreq.class)); + DataflowWorkarounds.registerCoder(p, ReadCounts.class, GenericJsonCoder.of(ReadCounts.class)); + + if (options.getInputDatasetId().isEmpty() + && (options.getReadGroupSetIds().isEmpty() || options.getVariantSetIds().isEmpty())) { + throw new IllegalArgumentException("InputDatasetId or ReadGroupSetIds and VariantSetIds must" + + " be specified"); + } + + List rgsIds; + List vsIds; + if (options.getInputDatasetId().isEmpty()) { + rgsIds = Lists.newArrayList(options.getReadGroupSetIds().split(",")); + vsIds = Lists.newArrayList(options.getVariantSetIds().split(",")); + } else { + rgsIds = ReadStreamer.getReadGroupSetIds(options.getInputDatasetId(), auth); + vsIds = VariantStreamer.getVariantSetIds(options.getInputDatasetId(), auth); + } + + List contigs; + String referenceSetId = checkReferenceSetIds(rgsIds); + if (options.isAllReferences()) { + contigs = getAllReferences(referenceSetId); + } else { + contigs = parseReferences(options.getReferences(), referenceSetId); + } + + /* + TODO: We can reduce the number of requests needed to be created by doing the following: + 1. Stream the Variants first (rather than concurrently with the Reads). Select a subset of + them equal to some threshold (say 50K by default). + 2. Create the requests for streaming Reads by running a ParDo over the selected Variants + to get their ranges (we only need to stream Reads that overlap the selected Variants). + 3. Stream the Reads from the created requests. + */ + + // Reads in Reads. + PCollection reads = getReadsFromAPI(rgsIds); + + // Reads in Variants. TODO potentially provide an option to load the Variants from a file. + PCollection variants = getVariantsFromAPI(vsIds); + + PCollection> refFreq = getFreq(variants, options.getMinFrequency()); + + PCollection> readCountsTable = + combineReads(reads, options.getSamplingFraction(), HASH_PREFIX, refFreq); + + // Converts our results to a single Map of Position keys to ReadCounts values. + PCollectionView> view = readCountsTable + .apply(View.asMap().withSingletonValues()); + + // Calculates the contamination estimate based on the resulting Map above. + PCollection result = p.begin().apply(Create.of("")) + .apply(ParDo.of(new Maximizer(view)).withSideInputs(view)); + + // Writes the result to the given output location in Cloud Storage. + result.apply(TextIO.Write.to(options.getOutput()).named("WriteOutput").withoutSharding()); + + p.run(); + + } + + /** + * Compute a PCollection of reference allele frequencies for SNPs of interest. + * The SNPs all have only a single alternate allele, and neither the + * reference nor the alternate allele have a population frequency < minFreq. + * The results are returned in a PCollection indexed by Position. + * + * @param variants a set of variant calls for a reference population + * @param minFreq the minimum allele frequency for the set + * @return a PCollection mapping Position to AlleleCounts + */ + static PCollection> getFreq( + PCollection variants, double minFreq) { + return variants.apply(Filter.by(VariantUtils.IS_PASSING).named("PassingFilter")) + .apply(Filter.by(VariantUtils.IS_ON_CHROMOSOME).named("OnChromosomeFilter")) + .apply(Filter.by(VariantUtils.IS_NOT_LOW_QUALITY).named("NotLowQualityFilter")) + .apply(Filter.by(VariantUtils.IS_SINGLE_ALTERNATE_SNP).named("SNPFilter")) + .apply(ParDo.of(new GetAlleleFreq())) + .apply(Filter.by(new FilterFreq(minFreq))); + } + + /** + * Filter, pile up, and sample reads, then join against reference statistics. + * + * @param reads A PCollection of reads + * @param samplingFraction Fraction of reads to keep + * @param samplingPrefix A prefix used in generating hashes used in sampling + * @param refCounts A PCollection mapping position to counts of alleles in + * a reference population. + * @return A PCollection mapping Position to a ReadCounts proto + */ + static PCollection> combineReads(PCollection reads, + double samplingFraction, String samplingPrefix, + PCollection> refFreq) { + // Runs filters on input Reads, splits into individual aligned bases (emitting the + // base and quality) and grabs a sample of them based on a hash mod of Position. + PCollection> joinReadCounts = + reads.apply(Filter.by(ReadUtils.IS_ON_CHROMOSOME).named("IsOnChromosome")) + .apply(Filter.by(ReadUtils.IS_NOT_QC_FAILURE).named("IsNotQCFailure")) + .apply(Filter.by(ReadUtils.IS_NOT_DUPLICATE).named("IsNotDuplicate")) + .apply(Filter.by(ReadUtils.IS_PROPER_PLACEMENT).named("IsProperPlacement")) + .apply(ParDo.of(new SplitReads())) + .apply(Filter.by(new SampleReads(samplingFraction, samplingPrefix))); + + TupleTag readCountsTag = new TupleTag<>(); + TupleTag refFreqTag = new TupleTag<>(); + // Pile up read counts, then join against reference stats. + PCollection> joined = KeyedPCollectionTuple + .of(readCountsTag, joinReadCounts) + .and(refFreqTag, refFreq) + .apply(CoGroupByKey.create()); + return joined.apply(ParDo.of(new PileupAndJoinReads(readCountsTag, refFreqTag))); + } + + /** + * Split reads into individual aligned bases and emit base + quality. + */ + static class SplitReads extends DoFn> { + + @Override + public void processElement(ProcessContext c) throws Exception { + List readBases = ReadUtils.extractReadBases(c.element()); + if (!readBases.isEmpty()) { + for (ReadBaseWithReference rb : readBases) { + c.output(KV.of(rb.getRefPosition(), rb.getRbq())); + } + } + } + } + + /** + * Sample bases via a hash mod of position. + */ + static class SampleReads implements SerializableFunction, Boolean> { + + private final double samplingFraction; + private final String samplingPrefix; + + public SampleReads(double samplingFraction, String samplingPrefix) { + this.samplingFraction = samplingFraction; + this.samplingPrefix = samplingPrefix; + } + + @Override + public Boolean apply(KV input) { + if (samplingFraction == 1.0) { + return true; + } else { + byte[] msg; + Position p = input.getKey(); + try { + msg = (samplingPrefix + p.getReferenceName() + ":" + p.getPosition() + ":" + + p.getReverseStrand()).getBytes("UTF-8"); + } catch (UnsupportedEncodingException e) { + throw new AssertionError("UTF-8 not available - should not happen"); + } + MessageDigest md; + try { + md = MessageDigest.getInstance("MD5"); + } catch (NoSuchAlgorithmException e) { + throw new AssertionError("MD5 not available - should not happen"); + } + byte[] digest = md.digest(msg); + if (digest.length != 16) { + throw new AssertionError("MD5 should return 128 bits"); + } + ByteBuffer buffer = ByteBuffer.allocate(Long.SIZE); + buffer.put(Arrays.copyOf(digest, Long.SIZE)); + return ((((double) buffer.getLong(0) / (double) ((long) 1 << 63)) + 1.0) * 0.5) + < samplingFraction; + } + } + } + + /** + * Map a variant to a Position, AlleleFreq pair. + */ + static class GetAlleleFreq extends DoFn> { + + @Override + public void processElement(ProcessContext c) throws Exception { + ListValue lv = c.element().getInfo().get("AF"); + if (lv != null && lv.getValuesCount() > 0) { + Position p = new Position() + .setPosition(c.element().getStart()) + .setReferenceName(c.element().getReferenceName()); + AlleleFreq af = new AlleleFreq(); + af.setRefFreq(lv.getValues(0).getNumberValue()); + af.setAltBases(c.element().getAlternateBasesList()); + af.setRefBases(c.element().getReferenceBases()); + c.output(KV.of(p, af)); + } else { + // AF field wasn't populated in info, so we don't have frequency information + // for this Variant. + // TODO instead of straight throwing an exception, log a warning. If at the end of this + // step the number of AlleleFreqs retrieved is below a given threshold, then throw an + // exception. + throw new IllegalArgumentException("Variant " + c.element().getId() + " does not have " + + "allele frequency information stored."); + } + } + } + + /** + * Filters out AlleleFreqs for which the reference or alternate allele + * frequencies are below a minimum specified at construction. + */ + static class FilterFreq implements SerializableFunction, Boolean> { + + private final double minFreq; + + public FilterFreq(double minFreq) { + this.minFreq = minFreq; + } + + @Override + public Boolean apply(KV input) { + double freq = input.getValue().getRefFreq(); + if (freq >= minFreq && (1.0 - freq) >= minFreq) { + return true; + } + return false; + } + } + + /** + * Piles up reads and joins them against reference population statistics. + */ + static class PileupAndJoinReads + extends DoFn, KV> { + + private final TupleTag readCountsTag; + private final TupleTag refFreqTag; + + public PileupAndJoinReads(TupleTag readCountsTag, + TupleTag refFreqTag) { + this.readCountsTag = readCountsTag; + this.refFreqTag = refFreqTag; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + AlleleFreq af = null; + af = c.element().getValue().getOnly(refFreqTag, null); + if (af == null) { + // no ref stats + return; + } + if (af.getAltBases().size() != 1) { + throw new IllegalArgumentException("Wrong number (" + af.getAltBases().size() + ") of" + + " alternate bases for Position " + c.element().getKey()); + } + + Iterable reads = c.element().getValue().getAll(readCountsTag); + + ImmutableMultiset.Builder rqSetBuilder = ImmutableMultiset.builder(); + for (ReadBaseQuality r : reads) { + ReadQualityCount.Base b; + if (af.getRefBases().equals(r.getBase())) { + b = ReadQualityCount.Base.REF; + } else if (af.getAltBases().get(0).equals(r.getBase())) { + b = ReadQualityCount.Base.NONREF; + } else { + b = ReadQualityCount.Base.OTHER; + } + ReadQualityCount rqc = new ReadQualityCount(); + rqc.setBase(b); + rqc.setQuality(r.getQuality()); + rqSetBuilder.add(rqc); + } + + ReadCounts rc = new ReadCounts(); + rc.setRefFreq(af.getRefFreq()); + for (Multiset.Entry entry : rqSetBuilder.build().entrySet()) { + ReadQualityCount rq = entry.getElement(); + rq.setCount(entry.getCount()); + rc.addReadQualityCount(rq); + } + c.output(KV.of(c.element().getKey(), rc)); + } + } + + /** + * Calls the Solver to maximize via a univariate function the results of the pipeline, inputted + * as a PCollectionView (the best way to retrieve our results as a Map in Dataflow). + */ + static class Maximizer extends DoFn { + + private final PCollectionView> view; + // Target absolute error for Brent's algorithm + private static final double ABS_ERR = 0.00001; + // Target relative error for Brent's algorithm + private static final double REL_ERR = 0.0001; + // Maximum number of evaluations of the Likelihood function in Brent's algorithm + private static final int MAX_EVAL = 100; + // Maximum number of iterations of Brent's algorithm + private static final int MAX_ITER = 100; + // Grid search step size + private static final double GRID_STEP = 0.05; + + public Maximizer(PCollectionView> view) { + this.view = view; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + c.output(Double.toString(Solver.maximize(new LikelihoodFn(c.sideInput(view)), + 0.0, 0.5, GRID_STEP, REL_ERR, ABS_ERR, MAX_ITER, MAX_EVAL))); + } + } + + /** + * Checks to make sure all of the input ReadGroupSets have the same referenceSetId. + * + * @param readGroupSetIds list of input ReadGroupSet ids to check + * @return the referenceSetId of the given ReadGroupSets + */ + private static String checkReferenceSetIds(List readGroupSetIds) + throws GeneralSecurityException, IOException { + String referenceSetId = null; + for (String rgsId : readGroupSetIds) { + Genomics.Readgroupsets.Get rgsRequest = auth.getGenomics(auth.getDefaultFactory()) + .readgroupsets().get(rgsId).setFields("referenceSetId"); + ReadGroupSet rgs = rgsRequest.execute(); + if (referenceSetId == null) { + referenceSetId = rgs.getReferenceSetId(); + } else if (!rgs.getReferenceSetId().equals(referenceSetId)) { + throw new IllegalArgumentException("ReferenceSetIds must be the same for all" + + " ReadGroupSets in given input."); + } + } + return referenceSetId; + } + + /** + * Parses and shards all of the References in the given ReferenceSet. + * + * @param referenceSetId the id of the ReferenceSet we want the Contigs for + * @return list of Contigs representing all of the References in the given ReferenceSet + */ + public static List getAllReferences(String referenceSetId) + throws IOException, GeneralSecurityException { + List contigs = Lists.newArrayList(); + Genomics.References.Search refRequest = auth.getGenomics(auth.getDefaultFactory()) + .references() + .search(new SearchReferencesRequest().setReferenceSetId(referenceSetId)); + List referencesList = refRequest.execute().getReferences(); + for (Reference r : referencesList) { + contigs.add(new Contig(r.getName(), 0L, r.getLength())); + } + contigs = Lists.newArrayList(FluentIterable.from(contigs) + .transformAndConcat(new Function>() { + @Override + public Iterable apply(Contig contig) { + return contig.getShards(SHARD_SIZE); + } + })); + Collections.shuffle(contigs); + return contigs; + } + + /** + * Parses and shards the References in the given ReferenceSet. + * + * @param references the user given references from the command line. + * @param referenceSetId the id of the ReferenceSet we want the Contigs for + * @return list of Contigs representing all of the References in the given ReferenceSet + */ + public static List parseReferences(String references, + String referenceSetId) throws IOException, GeneralSecurityException { + List splitReferences = Lists.newArrayList(references.split(",")); + List contigs = Lists.newArrayList(); + List referencesList = null; // If needed + for (String ref : splitReferences) { + String[] splitPieces = ref.split(":"); + if (splitPieces.length != 3 && referencesList == null) { + // referencesList hasn't been needed up until this point, so we must request it. + // Assume that they are asking for one entire specific reference. + Genomics.References.Search refRequest = auth.getGenomics(auth.getDefaultFactory()) + .references() + .search(new SearchReferencesRequest().setReferenceSetId(referenceSetId)); + referencesList = refRequest.execute().getReferences(); + for (Reference r : referencesList) { + if (r.getName().equals(splitPieces[0])) { // Found the reference we want. + contigs.add(new Contig(splitPieces[0], 0L, r.getLength())); + break; + } + } + } else if (splitPieces.length != 3) { + // Assume that they are asking for one entire specific reference. + for (Reference r : referencesList) { + if (r.getName().equals(splitPieces[0])) { // Found the reference we want. + contigs.add(new Contig(splitPieces[0], 0L, r.getLength())); + break; + } + } + } else { + contigs.add( + new Contig(splitPieces[0], Long.valueOf(splitPieces[1]), Long.valueOf(splitPieces[2]))); + } + } + contigs = Lists.newArrayList(FluentIterable.from(contigs) + .transformAndConcat(new Function>() { + @Override + public Iterable apply(Contig contig) { + return contig.getShards(SHARD_SIZE); + } + })); + Collections.shuffle(contigs); + return contigs; + } + + private static PCollection getReadsFromAPI(List rgsIds) + throws IOException, GeneralSecurityException { + List requests = Lists.newArrayList(); + for (String r : rgsIds) { + if (options.isAllReferences()) { + requests.add(ReadStreamer.getReadRequests(r)); + } else { + requests.addAll(ReadStreamer.getReadRequests(r, options.getReferences())); + } + } + PCollection readRequests = p.begin().apply(Create.of(requests)); + return readRequests.apply(new ReadStreamer.StreamReads()); + } + + private static PCollection getVariantsFromAPI(List vsIds) + throws IOException, GeneralSecurityException { + List requests = Lists.newArrayList(); + for (String v : vsIds) { + if (options.isAllReferences()) { + requests.add(VariantStreamer.getVariantRequests(v)); + } else { + requests.addAll(VariantStreamer.getVariantRequests(v, options.getReferences())); + } + } + PCollection variantRequests = p.begin().apply(Create.of(requests)); + return variantRequests.apply(new VariantStreamer.StreamVariants()); + } +} \ No newline at end of file diff --git a/src/main/java/com/google/cloud/genomics/dataflow/utils/Solver.java b/src/main/java/com/google/cloud/genomics/dataflow/utils/Solver.java new file mode 100644 index 0000000..785be18 --- /dev/null +++ b/src/main/java/com/google/cloud/genomics/dataflow/utils/Solver.java @@ -0,0 +1,98 @@ +/* + * Copyright 2015 Google. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.genomics.dataflow.utils; + +import org.apache.commons.math3.analysis.UnivariateFunction; +import org.apache.commons.math3.geometry.euclidean.oned.Interval; +import org.apache.commons.math3.optim.MaxEval; +import org.apache.commons.math3.optim.MaxIter; +import org.apache.commons.math3.optim.nonlinear.scalar.GoalType; +import org.apache.commons.math3.optim.univariate.BrentOptimizer; +import org.apache.commons.math3.optim.univariate.SearchInterval; +import org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction; +import org.apache.commons.math3.optim.univariate.UnivariatePointValuePair; + +/** + * Maximize a univariate (likelihood) function. + */ +public class Solver { + + /** + * Runs a grid search for the maximum value of a univariate function. + * + * @param fn the likelihood function to minimize + * @param start lower bound of the interval to search + * @param end upper bound of the interval to search + * @param step grid step size + * @return an Interval bracketing the minimum + */ + static Interval gridSearch(UnivariateFunction fn, double start, + double end, double step) { + double lowMax = start; // lower bound on interval surrounding alphaMax + double alphaMax = start - step; + double likMax = 0.0; + + double lastAlpha = start; + double alpha = start; + while (alpha < end) { + double likelihood = fn.value(alpha); + if (alphaMax < start || likelihood > likMax) { + lowMax = lastAlpha; + alphaMax = alpha; + likMax = likelihood; + } + lastAlpha = alpha; + alpha += step; + } + // make sure we've checked the rightmost endpoint (won't happen if + // end - start is not an integer multiple of step, because of roundoff + // errors, etc) + double likelihood = fn.value(end); + if (likelihood > likMax) { + lowMax = lastAlpha; + alphaMax = end; + likMax = likelihood; + } + return new Interval(lowMax, Math.min(end, alphaMax + step)); + } + + /** + * Maximizes a univariate function using a grid search followed by Brent's algorithm. + * + * @param fn the likelihood function to minimize + * @param gridStart the lower bound for the grid search + * @param gridEnd the upper bound for the grid search + * @param gridStep step size for the grid search + * @param relErr relative error tolerance for Brent's algorithm + * @param absErr absolute error tolerance for Brent's algorithm + * @param maxIter maximum # of iterations to perform in Brent's algorithm + * @param maxEval maximum # of Likelihood function evaluations in Brent's algorithm + * + * @return the value of the parameter that maximizes the function + */ + public static double maximize(UnivariateFunction fn, double gridStart, double gridEnd, + double gridStep, double relErr, double absErr, int maxIter, int maxEval) { + Interval interval = gridSearch(fn, gridStart, gridEnd, gridStep); + BrentOptimizer bo = new BrentOptimizer(relErr, absErr); + UnivariatePointValuePair max = bo.optimize( + new MaxIter(maxIter), + new MaxEval(maxEval), + new SearchInterval(interval.getInf(), interval.getSup()), + new UnivariateObjectiveFunction(fn), + GoalType.MAXIMIZE); + return max.getPoint(); + } +} diff --git a/src/test/java/com/google/cloud/genomics/dataflow/functions/LikelihoodFnTest.java b/src/test/java/com/google/cloud/genomics/dataflow/functions/LikelihoodFnTest.java new file mode 100644 index 0000000..c82d96c --- /dev/null +++ b/src/test/java/com/google/cloud/genomics/dataflow/functions/LikelihoodFnTest.java @@ -0,0 +1,132 @@ +/* + * Copyright 2015 Google. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.genomics.dataflow.functions; + +import com.google.api.services.genomics.model.Position; +import com.google.cloud.genomics.dataflow.model.ReadCounts; +import com.google.cloud.genomics.dataflow.model.ReadQualityCount; +import com.google.cloud.genomics.dataflow.model.ReadQualityCount.Base; +import com.google.common.collect.ImmutableMap; + +import junit.framework.TestCase; + +import org.apache.commons.math3.util.Precision; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Map; + +/** + * Unit tests for {@link LikelihoodFn}. + */ +@RunWith(JUnit4.class) +public class LikelihoodFnTest extends TestCase { + + @Test + public void testValue() { + final double tolerance = 0.0000001; // slop factor for floating point comparisons + final double alpha = 0.25; + final int errPhred = 10; // P(err) = 0.1 + final double pRef = 0.6; + + Position position1 = new Position() + .setReferenceName("1") + .setPosition(123L); + + /* + * Observe single REF read + */ + ImmutableMap.Builder countsBuilder + = ImmutableMap.builder(); + ReadCounts rc = new ReadCounts(); + rc.setRefFreq(pRef); + ReadQualityCount rqc = new ReadQualityCount(); + rqc.setBase(Base.REF); + rqc.setQuality(errPhred); + rqc.setCount(1); + rc.addReadQualityCount(rqc); + countsBuilder.put(position1, rc); + Map readCounts = countsBuilder.build(); + LikelihoodFn fn = new LikelihoodFn(readCounts); + // Likelihood of a single REF with the parameters above + final double likRef = 0.3354133333; + assertEquals(Precision.compareTo(fn.value(alpha), Math.log(likRef), tolerance), 0); + + /* + * Observe single NONREF read + */ + countsBuilder = ImmutableMap.builder(); + rc = new ReadCounts(); + rc.setRefFreq(pRef); + rqc = new ReadQualityCount(); + rqc.setBase(Base.NONREF); + rqc.setQuality(errPhred); + rqc.setCount(1); + rc.addReadQualityCount(rqc); + countsBuilder.put(position1, rc); + readCounts = countsBuilder.build(); + fn = new LikelihoodFn(readCounts); + // Likelihood of a single NONREF with the parameters above + final double likNonref = 0.20368; + assertEquals(Precision.compareTo(fn.value(alpha), Math.log(likNonref), tolerance), 0); + + /* + * Observe single OTHER read + */ + countsBuilder = ImmutableMap.builder(); + rc = new ReadCounts(); + rc.setRefFreq(pRef); + rqc = new ReadQualityCount(); + rqc.setBase(Base.OTHER); + rqc.setQuality(errPhred); + rqc.setCount(1); + rc.addReadQualityCount(rqc); + countsBuilder.put(position1, rc); + readCounts = countsBuilder.build(); + fn = new LikelihoodFn(readCounts); + // Likelihood of a single OTHER with the parameters above + final double likOther = 0.03850666667; + assertEquals(Precision.compareTo(fn.value(alpha), Math.log(likOther), tolerance), 0); + + // Likelihood for reads at 2 different positions should be product of + // likelihoods for individual reads + Position position2 = new Position() + .setReferenceName("1") + .setPosition(124L); + countsBuilder = ImmutableMap.builder(); + rc = new ReadCounts(); + rc.setRefFreq(pRef); + rqc = new ReadQualityCount(); + rqc.setBase(Base.REF); + rqc.setQuality(errPhred); + rqc.setCount(1); + rc.addReadQualityCount(rqc); + countsBuilder.put(position1, rc); + rc = new ReadCounts(); + rc.setRefFreq(pRef); + rqc = new ReadQualityCount(); + rqc.setBase(Base.NONREF); + rqc.setQuality(errPhred); + rqc.setCount(1); + rc.addReadQualityCount(rqc); + countsBuilder.put(position2, rc); + readCounts = countsBuilder.build(); + fn = new LikelihoodFn(readCounts); + assertEquals(Precision.compareTo(fn.value(alpha), Math.log(likRef * likNonref), + tolerance), 0); + } +} diff --git a/src/test/java/com/google/cloud/genomics/dataflow/pipelines/VerifyBamIdTest.java b/src/test/java/com/google/cloud/genomics/dataflow/pipelines/VerifyBamIdTest.java new file mode 100644 index 0000000..d6c5053 --- /dev/null +++ b/src/test/java/com/google/cloud/genomics/dataflow/pipelines/VerifyBamIdTest.java @@ -0,0 +1,331 @@ +/* + * Copyright 2015 Google. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.genomics.dataflow.pipelines; + +import com.google.api.services.genomics.model.Position; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFnTester; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult; +import com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey; +import com.google.cloud.dataflow.sdk.transforms.join.KeyedPCollectionTuple; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.cloud.genomics.dataflow.coders.GenericJsonCoder; +import com.google.cloud.genomics.dataflow.model.AlleleFreq; +import com.google.cloud.genomics.dataflow.model.ReadBaseQuality; +import com.google.cloud.genomics.dataflow.model.ReadCounts; +import com.google.cloud.genomics.dataflow.model.ReadQualityCount; +import com.google.cloud.genomics.dataflow.pipelines.VerifyBamId.FilterFreq; +import com.google.cloud.genomics.dataflow.pipelines.VerifyBamId.GetAlleleFreq; +import com.google.cloud.genomics.dataflow.pipelines.VerifyBamId.PileupAndJoinReads; +import com.google.cloud.genomics.dataflow.pipelines.VerifyBamId.SampleReads; +import com.google.cloud.genomics.dataflow.pipelines.VerifyBamId.SplitReads; +import com.google.cloud.genomics.dataflow.utils.DataflowWorkarounds; +import com.google.common.collect.ImmutableList; +import com.google.genomics.v1.CigarUnit; +import com.google.genomics.v1.CigarUnit.Operation; +import com.google.genomics.v1.LinearAlignment; +import com.google.genomics.v1.Read; +import com.google.genomics.v1.Variant; +import com.google.protobuf.ListValue; +import com.google.protobuf.Value; + +import com.beust.jcommander.internal.Lists; + +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Test; + +/** + * Tests for the VerifyBamId pipeline. + */ +public class VerifyBamIdTest { + + @Test + public void testSplitReads() { + DoFnTester> splitReads = DoFnTester.of(new SplitReads()); + + // single matched base -> one SingleReadQuality proto + Read r = Read.newBuilder() + .setProperPlacement(true) + .setAlignment(LinearAlignment.newBuilder() + .setPosition(com.google.genomics.v1.Position.newBuilder() + .setReferenceName("1") + .setPosition(123)) + .addCigar(CigarUnit.newBuilder() + .setOperation(Operation.ALIGNMENT_MATCH) + .setOperationLength(1))) + .setAlignedSequence("A") + .addAlignedQuality(3) + .build(); + Assert.assertThat(splitReads.processBatch(r), CoreMatchers.hasItems(KV.of(new Position() + .setReferenceName("1") + .setPosition(123L), + new ReadBaseQuality("A", 3)))); + + // two matched bases -> two SingleReadQuality protos + r = Read.newBuilder() + .setProperPlacement(true) + .setAlignment(LinearAlignment.newBuilder() + .setPosition(com.google.genomics.v1.Position.newBuilder() + .setReferenceName("1") + .setPosition(123)) + .addCigar(CigarUnit.newBuilder() + .setOperation(Operation.ALIGNMENT_MATCH) + .setOperationLength(2))) + .setAlignedSequence("AG") + .addAllAlignedQuality(ImmutableList.of(3, 4)) + .build(); + Assert.assertThat(splitReads.processBatch(r), CoreMatchers.hasItems(KV.of(new Position() + .setReferenceName("1") + .setPosition(123L), + new ReadBaseQuality("A", 3)), + KV.of(new Position() + .setReferenceName("1") + .setPosition(124L), + new ReadBaseQuality("G", 4)))); + + // matched bases with different offsets onto the reference + r = Read.newBuilder() + .setProperPlacement(true) + .setAlignment(LinearAlignment.newBuilder() + .setPosition(com.google.genomics.v1.Position.newBuilder() + .setReferenceName("1") + .setPosition(123)) + .addCigar(CigarUnit.newBuilder() + .setOperation(Operation.INSERT) + .setOperationLength(1)) + .addCigar(CigarUnit.newBuilder() + .setOperation(Operation.ALIGNMENT_MATCH) + .setOperationLength(1)) + .addCigar(CigarUnit.newBuilder() + .setOperation(Operation.DELETE) + .setOperationLength(1)) + .addCigar(CigarUnit.newBuilder() + .setOperation(Operation.ALIGNMENT_MATCH) + .setOperationLength(2))) + // 1I1M1D2M; C -> ref 0, GT -> ref 2, 3 + .setAlignedSequence("ACGT") + .addAllAlignedQuality(ImmutableList.of(1, 2, 3, 4)) + .build(); + Assert.assertThat(splitReads.processBatch(r), CoreMatchers.hasItems(KV.of(new Position() + .setReferenceName("1") + .setPosition(123L), + new ReadBaseQuality("C", 2)), + KV.of(new Position() + .setReferenceName("1") + .setPosition(125L), + new ReadBaseQuality("G", 3)), + KV.of(new Position() + .setReferenceName("1") + .setPosition(126L), + new ReadBaseQuality("T", 4)))); + + // matched bases with different offsets onto the reference + r = Read.newBuilder() + .setAlignment(LinearAlignment.newBuilder() + .setPosition(com.google.genomics.v1.Position.newBuilder() + .setReferenceName("1") + .setPosition(123)) + .addCigar(CigarUnit.newBuilder() + .setOperation(Operation.DELETE) + .setOperationLength(1)) + .addCigar(CigarUnit.newBuilder() + .setOperation(Operation.ALIGNMENT_MATCH) + .setOperationLength(2)) + .addCigar(CigarUnit.newBuilder() + .setOperation(Operation.INSERT) + .setOperationLength(1)) + .addCigar(CigarUnit.newBuilder() + .setOperation(Operation.ALIGNMENT_MATCH) + .setOperationLength(1))) + // 1D2M1I1M; A and G match positions 1 and 3 of the ref + .setAlignedSequence("ACGT") + .addAllAlignedQuality(ImmutableList.of(1, 2, 3, 4)) + .build(); + Assert.assertThat(splitReads.processBatch(r), CoreMatchers.hasItems(KV.of(new Position() + .setReferenceName("1") + .setPosition(124L), + new ReadBaseQuality("A", 1)), + KV.of(new Position() + .setReferenceName("1") + .setPosition(125L), + new ReadBaseQuality("C", 2)), + KV.of(new Position() + .setReferenceName("1") + .setPosition(126L), + new ReadBaseQuality("T", 4)))); + } + + @Test + public void testSampleReads() { + SampleReads sampleReads = new SampleReads(0.5, ""); + Assert.assertTrue(sampleReads.apply(KV.of( + new Position() + .setReferenceName("1") + .setPosition(125L) + .setReverseStrand(false), + new ReadBaseQuality()))); + Assert.assertFalse(sampleReads.apply(KV.of( + new Position() + .setReferenceName("2") + .setPosition(124L) + .setReverseStrand(false), + new ReadBaseQuality()))); + } + + @Test + public void testGetAlleleFreq() { + DoFnTester> getAlleleFreq = DoFnTester.of( + new GetAlleleFreq()); + Position pos = new Position().setReferenceName("1").setPosition(123L); + Variant.Builder vBuild = Variant.newBuilder() + .setReferenceName("1") + .setStart(123L) + .setReferenceBases("C") + .addAlternateBases("T"); + vBuild.getMutableInfo().put("AF", ListValue.newBuilder() + .addValues(Value.newBuilder().setNumberValue(0.25).build()).build()); + AlleleFreq af = new AlleleFreq(); + af.setAltBases(Lists.newArrayList("T")); + af.setRefBases("C"); + af.setRefFreq(0.25); + Assert.assertThat(getAlleleFreq.processBatch(vBuild.build()), + CoreMatchers.hasItems(KV.of(pos, af))); + } + + @Test + public void testFilterFreq() { + FilterFreq filterFreq = new FilterFreq(0.01); + Position pos = new Position().setReferenceName("1").setPosition(123L); + AlleleFreq af = new AlleleFreq(); + af.setRefFreq(0.9999); + Assert.assertFalse(filterFreq.apply(KV.of(pos, af))); + af.setRefFreq(0.9901); + Assert.assertFalse(filterFreq.apply(KV.of(pos, af))); + af.setRefFreq(0.9899); + Assert.assertTrue(filterFreq.apply(KV.of(pos, af))); + } + + private final Position position1 + = new Position() + .setReferenceName("1") + .setPosition(123L); + private final Position position2 + = new Position() + .setReferenceName("1") + .setPosition(124L); + private final Position position3 + = new Position() + .setReferenceName("1") + .setPosition(125L); + + private final ImmutableList> refCountList; + { + ImmutableList.Builder> refBuilder + = ImmutableList.builder(); + AlleleFreq af = new AlleleFreq(); + af.setRefBases("A"); + af.setAltBases(Lists.newArrayList("G")); + af.setRefFreq(0.8); + refBuilder.add(KV.of(position1, af)); + af = new AlleleFreq(); + af.setRefBases("C"); + af.setAltBases(Lists.newArrayList("T")); + af.setRefFreq(0.5); + refBuilder.add(KV.of(position2, af)); + af = new AlleleFreq(); + af.setRefBases("T"); + af.setAltBases(Lists.newArrayList("C")); + af.setRefFreq(0.6); + refBuilder.add(KV.of(position3, af)); + refCountList = refBuilder.build(); + } + + @Test + public void testPileupAndJoinReads() { + VerifyBamId.VerifyBamIdOptions popts = + PipelineOptionsFactory.create().as(VerifyBamId.VerifyBamIdOptions.class); + Pipeline p = TestPipeline.create(popts); + DataflowWorkarounds.registerCoder(p, Position.class, GenericJsonCoder.of(Position.class)); + DataflowWorkarounds.registerCoder(p, ReadBaseQuality.class, + GenericJsonCoder.of(ReadBaseQuality.class)); + DataflowWorkarounds.registerCoder(p, AlleleFreq.class, GenericJsonCoder.of(AlleleFreq.class)); + DataflowWorkarounds.registerCoder(p, ReadCounts.class, GenericJsonCoder.of(ReadCounts.class)); + ReadBaseQuality srq = new ReadBaseQuality("A", 10); + PCollection> readCounts = p.apply( + Create.of(KV.of(position1, srq))); + PCollection> refFreq = p.apply(Create.of(refCountList)); + TupleTag readCountsTag = new TupleTag<>(); + TupleTag refFreqTag = new TupleTag<>(); + PCollection> joined = KeyedPCollectionTuple + .of(readCountsTag, readCounts) + .and(refFreqTag, refFreq) + .apply(CoGroupByKey.create()); + PCollection> result = joined.apply( + ParDo.of(new PileupAndJoinReads(readCountsTag, refFreqTag))); + ReadCounts rc = new ReadCounts(); + rc.setRefFreq(0.8); + rc.addReadQualityCount(ReadQualityCount.Base.REF, 10, 1); + DataflowAssert.that(result).containsInAnyOrder(KV.of(position1, rc)); + } + + @Test + public void testCombineReads() { + VerifyBamId.VerifyBamIdOptions popts = + PipelineOptionsFactory.create().as(VerifyBamId.VerifyBamIdOptions.class); + Pipeline p = TestPipeline.create(popts); + DataflowWorkarounds.registerCoder(p, Position.class, GenericJsonCoder.of(Position.class)); + DataflowWorkarounds.registerCoder(p, ReadBaseQuality.class, + GenericJsonCoder.of(ReadBaseQuality.class)); + DataflowWorkarounds.registerCoder(p, AlleleFreq.class, GenericJsonCoder.of(AlleleFreq.class)); + DataflowWorkarounds.registerCoder(p, ReadCounts.class, GenericJsonCoder.of(ReadCounts.class)); + PCollection> refCounts = p.apply(Create.of(this.refCountList)); + + PCollection reads = p.apply(Create.of(Read.newBuilder() + .setProperPlacement(true) + .setAlignment(LinearAlignment.newBuilder() + .setPosition(com.google.genomics.v1.Position.newBuilder() + .setReferenceName("1") + .setPosition(123)) + .addCigar(CigarUnit.newBuilder() + .setOperation(Operation.ALIGNMENT_MATCH) + .setOperationLength(3))) + .setAlignedSequence("ATG") + .addAllAlignedQuality(ImmutableList.of(3, 4, 5)) + .build())); + + PCollection> results = + VerifyBamId.combineReads(reads, 1.0, "", refCounts); + ReadCounts one = new ReadCounts(); + one.setRefFreq(0.8); + one.addReadQualityCount(ReadQualityCount.Base.REF, 3, 1L); + ReadCounts two = new ReadCounts(); + two.setRefFreq(0.5); + two.addReadQualityCount(ReadQualityCount.Base.NONREF, 4, 1L); + ReadCounts three = new ReadCounts(); + three.setRefFreq(0.6); + three.addReadQualityCount(ReadQualityCount.Base.OTHER, 5, 1L); + DataflowAssert.that(results) + .containsInAnyOrder(KV.of(position1, one), KV.of(position2, two), KV.of(position3, three)); + } +} diff --git a/src/test/java/com/google/cloud/genomics/dataflow/utils/SolverTest.java b/src/test/java/com/google/cloud/genomics/dataflow/utils/SolverTest.java new file mode 100644 index 0000000..e2f1892 --- /dev/null +++ b/src/test/java/com/google/cloud/genomics/dataflow/utils/SolverTest.java @@ -0,0 +1,173 @@ +/* + * Copyright 2015 Google. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.genomics.dataflow.utils; + +import com.google.api.services.genomics.model.Position; +import com.google.cloud.genomics.dataflow.functions.LikelihoodFn; +import com.google.cloud.genomics.dataflow.model.ReadCounts; +import com.google.cloud.genomics.dataflow.model.ReadQualityCount; +import com.google.cloud.genomics.dataflow.model.ReadQualityCount.Base; +import com.google.common.collect.ImmutableMap; + +import junit.framework.TestCase; + +import org.apache.commons.math3.analysis.UnivariateFunction; +import org.apache.commons.math3.geometry.euclidean.oned.Interval; +import org.apache.commons.math3.util.Precision; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.Map; + +/** + * Unit tests for Solver. + */ +@RunWith(JUnit4.class) +public class SolverTest extends TestCase { + + /** Parabola with maximum value at a specified point */ + public class Parabola implements UnivariateFunction { + private final double max; + + /** Create a Parabola with maximum at the specified value */ + Parabola(double max) { + this.max = max; + } + + @Override + public double value(double x) { + return 1.0 - Math.pow(x - this.max, 2.0); + } + } + + @Test + public void testGridSearch() { + Interval interval = Solver.gridSearch(new Parabola(0.25), 0.0, 1.0, 0.1); + assertEquals(Precision.compareTo(interval.getInf(), 0.1, 1), 0); + assertEquals(Precision.compareTo(interval.getSup(), 0.3, 1), 0); + + interval = Solver.gridSearch(new Parabola(1.2), 0.0, 1.0, 0.1); + assertEquals(Precision.compareTo(interval.getInf(), 0.9, 1), 0); + assertEquals(Precision.compareTo(interval.getSup(), 1.0, 1), 0); + + interval = Solver.gridSearch(new Parabola(1.2), 0.0, 1.0, 0.3); + assertEquals(Precision.compareTo(interval.getInf(), 0.9, 1), 0); + assertEquals(Precision.compareTo(interval.getSup(), 1.0, 1), 0); + } + + @Test + public void testMaximize() { + assertEquals(Precision.compareTo(Solver.maximize(new Parabola(0.47), 0.0, 1.0, + 0.1, 0.00001, 0.00001, 100, 100), 0.47, 0.00001), 0); + } + + @Test + public void testSolverOnKnownLikelihoodCases() { + int phred = 200; + + Position position1 = new Position() + .setReferenceName("1") + .setPosition(123L); + + /* + * Observe 900 REF reads and 100 NONREF + * P(REF) = 0.8 + * error probability is near 0 + * Most likely explanation should be ~20% contamination + * (if P(REF) were 0.5, we'd have peaks at 10% (nonref homozygous contaminant) + * and 20% (heterozygous contaminant) + */ + ImmutableMap.Builder countsBuilder + = ImmutableMap.builder(); + ReadCounts rc = new ReadCounts(); + rc.setRefFreq(0.8); + ReadQualityCount rqc = new ReadQualityCount(); + rqc.setBase(Base.REF); + rqc.setQuality(phred); + rqc.setCount(900); + rc.addReadQualityCount(rqc); + rqc = new ReadQualityCount(); + rqc.setBase(Base.NONREF); + rqc.setQuality(phred); + rqc.setCount(100); + rc.addReadQualityCount(rqc); + rqc = new ReadQualityCount(); + rqc.setBase(Base.OTHER); + rqc.setQuality(1); + rqc.setCount(2); + rc.addReadQualityCount(rqc); + countsBuilder.put(position1, rc); + Map readCounts = countsBuilder.build(); + assertEquals(Precision.compareTo(Solver.maximize( + new LikelihoodFn(readCounts), 0.0, 0.5, 0.05, + 0.0001, 0.0001, 100, 100), 0.2, 0.0001), 0); + + /* + * Make sure things are symmetrical. Observe 900 NONREF reads and 100 REF + * P(NONREF) = 0.8 (i.e. P(REF) = 0.2) + * error probability is near 0 + * Most likely explanation should be ~20% contamination + * (if P(REF) were 0.5, we'd have peaks at 10% (nonref homozygous contaminant) + * and 20% (heterozygous contaminant) + */ + countsBuilder = ImmutableMap.builder(); + rc = new ReadCounts(); + rc.setRefFreq(0.2); + rqc = new ReadQualityCount(); + rqc.setBase(Base.NONREF); + rqc.setQuality(phred); + rqc.setCount(900); + rc.addReadQualityCount(rqc); + rqc = new ReadQualityCount(); + rqc.setBase(Base.REF); + rqc.setQuality(phred); + rqc.setCount(100); + rc.addReadQualityCount(rqc); + countsBuilder.put(position1, rc); + readCounts = countsBuilder.build(); + assertEquals(Precision.compareTo(Solver.maximize( + new LikelihoodFn(readCounts), 0.0, 0.5, 0.05, + 0.0001, 0.0001, 100, 100), 0.2, 0.0001), 0); + + /* + * Assume a heterozygous desired base pair with a homozygous contaminant. + * Observe 450 NONREF reads and 550 REF + * P(REF) = 0.8 + * error probability is near 0 + * Most likely explanation should be ~10% contamination + */ + countsBuilder = ImmutableMap.builder(); + rc = new ReadCounts(); + rc.setRefFreq(0.5); + rqc = new ReadQualityCount(); + rqc.setBase(Base.NONREF); + rqc.setQuality(phred); + rqc.setCount(450); + rc.addReadQualityCount(rqc); + rqc = new ReadQualityCount(); + rqc.setBase(Base.REF); + rqc.setQuality(phred); + rqc.setCount(550); + rc.addReadQualityCount(rqc); + countsBuilder.put(position1, rc); + readCounts = countsBuilder.build(); + assertEquals(Precision.compareTo(Solver.maximize( + new LikelihoodFn(readCounts), 0.0, 0.5, 0.05, + 0.0001, 0.0001, 100, 100), 0.1, 0.0001), 0); + } + +} \ No newline at end of file