Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Merge pull request #433 from broadinstitute/eb_finish_chartl_likeliho…

…od_posteriors

Introducing the latest-and-greatest in genotyping: CalculatePosteriors.
  • Loading branch information...
commit 105f95166eef0c866db4c02efbd0e3a7d63f968a 2 parents 1581c7c + 5ae85ac
@eitanbanks eitanbanks authored
View
30 public/java/src/org/broadinstitute/sting/utils/MathUtils.java
@@ -1472,4 +1472,34 @@ else if ( x == 0.0 )
return sliceListByIndices(sampleIndicesWithoutReplacement(list.size(),N),list);
}
+ /**
+ * Return the likelihood of observing the counts of categories having sampled a population
+ * whose categorial frequencies are distributed according to a Dirichlet distribution
+ * @param dirichletParams - params of the prior dirichlet distribution
+ * @param dirichletSum - the sum of those parameters
+ * @param counts - the counts of observation in each category
+ * @param countSum - the sum of counts (number of trials)
+ * @return - associated likelihood
+ */
+ public static double dirichletMultinomial(final double[] dirichletParams, final double dirichletSum,
+ final int[] counts, final int countSum) {
+ if ( dirichletParams.length != counts.length ) {
+ throw new IllegalStateException("The number of dirichlet parameters must match the number of categories");
+ }
+ // todo -- lots of lnGammas here. At some point we can safely switch to x * ( ln(x) - 1)
+ double likelihood = log10MultinomialCoefficient(countSum,counts);
+ likelihood += log10Gamma(dirichletSum);
+ likelihood -= log10Gamma(dirichletSum+countSum);
+ for ( int idx = 0; idx < counts.length; idx++ ) {
+ likelihood += log10Gamma(counts[idx] + dirichletParams[idx]);
+ likelihood -= log10Gamma(dirichletParams[idx]);
+ }
+
+ return likelihood;
+ }
+
+ public static double dirichletMultinomial(double[] params, int[] counts) {
+ return dirichletMultinomial(params,sum(params),counts,(int) sum(counts));
+ }
+
}
View
14 public/java/src/org/broadinstitute/sting/utils/Utils.java
@@ -835,4 +835,18 @@ public static int longestCommonSuffix(final byte[] seq1, final byte[] seq2, fina
// don't perform array copies if we need to copy everything anyways
return ( trimFromFront == 0 && trimFromBack == 0 ) ? seq : Arrays.copyOfRange(seq, trimFromFront, seq.length - trimFromBack);
}
+
+ /**
+ * Simple wrapper for sticking elements of a int[] array into a List<Integer>
+ * @param ar - the array whose elements should be listified
+ * @return - a List<Integer> where each element has the same value as the corresponding index in @ar
+ */
+ public static List<Integer> listFromPrimitives(final int[] ar) {
+ final ArrayList<Integer> lst = new ArrayList<>(ar.length);
+ for ( final int d : ar ) {
+ lst.add(d);
+ }
+
+ return lst;
+ }
}
View
10 public/java/src/org/broadinstitute/sting/utils/variant/GATKVariantContextUtils.java
@@ -565,11 +565,11 @@ private static boolean likelihoodsAreUninformative(final double[] likelihoods) {
* @param newLikelihoods a vector of likelihoods to use if the method requires PLs, should be log10 likelihoods, cannot be null
* @param allelesToUse the alleles we are using for our subsetting
*/
- protected static void updateGenotypeAfterSubsetting(final List<Allele> originalGT,
- final GenotypeBuilder gb,
- final GenotypeAssignmentMethod assignmentMethod,
- final double[] newLikelihoods,
- final List<Allele> allelesToUse) {
+ public static void updateGenotypeAfterSubsetting(final List<Allele> originalGT,
+ final GenotypeBuilder gb,
+ final GenotypeAssignmentMethod assignmentMethod,
+ final double[] newLikelihoods,
+ final List<Allele> allelesToUse) {
gb.noAD();
switch ( assignmentMethod ) {
case SET_TO_NO_CALL:
View
334 public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java
@@ -522,4 +522,338 @@ public void testMedian(final List<Comparable> values, final Comparable expected)
final Comparable actual = MathUtils.median(values);
Assert.assertEquals(actual, expected, "Failed with " + values);
}
+
+
+
+ // man. All this to test dirichlet.
+
+ private double[] unwrap(List<Double> stuff) {
+ double[] unwrapped = new double[stuff.size()];
+ int idx = 0;
+ for ( Double d : stuff ) {
+ unwrapped[idx++] = d == null ? 0.0 : d;
+ }
+
+ return unwrapped;
+ }
+
+ /**
+ * The PartitionGenerator generates all of the partitions of a number n, e.g.
+ * 5 + 0
+ * 4 + 1
+ * 3 + 2
+ * 3 + 1 + 1
+ * 2 + 2 + 1
+ * 2 + 1 + 1 + 1
+ * 1 + 1 + 1 + 1 + 1
+ *
+ * This is used to help enumerate the state space over which the Dirichlet-Multinomial is defined,
+ * to ensure that the distribution function is properly implemented
+ */
+ class PartitionGenerator implements Iterator<List<Integer>> {
+ // generate the partitions of an integer, each partition sorted numerically
+ int n;
+ List<Integer> a;
+ int y;
+ int k;
+ int state;
+ int x;
+ int l;
+
+ public PartitionGenerator(int n) {
+ this.n = n;
+ this.y = n - 1;
+ this.k = 1;
+ this.a = new ArrayList<Integer>();
+ for ( int i = 0; i < n; i++ ) {
+ this.a.add(i);
+ }
+ this.state = 0;
+ }
+
+ public void remove() { /* do nothing */ }
+
+ public boolean hasNext() { return ! ( this.k == 0 && state == 0 ); }
+
+ private String dataStr() {
+ return String.format("a = [%s] k = %d y = %d state = %d x = %d l = %d",
+ Utils.join(",",a), k, y, state, x, l);
+ }
+
+ public List<Integer> next() {
+ if ( this.state == 0 ) {
+ this.x = a.get(k-1)+1;
+ k -= 1;
+ this.state = 1;
+ }
+
+ if ( this.state == 1 ) {
+ while ( 2*x <= y ) {
+ this.a.set(k,x);
+ this.y -= x;
+ this.k++;
+ }
+ this.l = 1+this.k;
+ this.state = 2;
+ }
+
+ if ( this.state == 2 ) {
+ if ( x <= y ) {
+ this.a.set(k,x);
+ this.a.set(l,y);
+ x += 1;
+ y -= 1;
+ return this.a.subList(0, this.k + 2);
+ } else {
+ this.state =3;
+ }
+ }
+
+ if ( this.state == 3 ) {
+ this.a.set(k,x+y);
+ this.y = x + y - 1;
+ this.state = 0;
+ return a.subList(0, k + 1);
+ }
+
+ throw new IllegalStateException("Cannot get here");
+ }
+
+ public String toString() {
+ StringBuffer buf = new StringBuffer();
+ buf.append("{ ");
+ while ( hasNext() ) {
+ buf.append("[");
+ buf.append(Utils.join(",",next()));
+ buf.append("],");
+ }
+ buf.deleteCharAt(buf.lastIndexOf(","));
+ buf.append(" }");
+ return buf.toString();
+ }
+
+ }
+
+ /**
+ * NextCounts is the enumerator over the state space of the multinomial dirichlet.
+ *
+ * It filters the partition of the total sum to only those with a number of terms
+ * equal to the number of categories.
+ *
+ * It then generates all permutations of that partition.
+ *
+ * In so doing it enumerates over the full state space.
+ */
+ class NextCounts implements Iterator<int[]> {
+
+ private PartitionGenerator partitioner;
+ private int numCategories;
+ private int[] next;
+
+ public NextCounts(int numCategories, int totalCounts) {
+ partitioner = new PartitionGenerator(totalCounts);
+ this.numCategories = numCategories;
+ next = nextFromPartitioner();
+ }
+
+ public void remove() { /* do nothing */ }
+
+ public boolean hasNext() { return next != null; }
+
+ public int[] next() {
+ int[] toReturn = clone(next);
+ next = nextPermutation();
+ if ( next == null ) {
+ next = nextFromPartitioner();
+ }
+
+ return toReturn;
+ }
+
+ private int[] clone(int[] arr) {
+ int[] a = new int[arr.length];
+ for ( int idx = 0; idx < a.length ; idx ++) {
+ a[idx] = arr[idx];
+ }
+
+ return a;
+ }
+
+ private int[] nextFromPartitioner() {
+ if ( partitioner.hasNext() ) {
+ List<Integer> nxt = partitioner.next();
+ while ( partitioner.hasNext() && nxt.size() > numCategories ) {
+ nxt = partitioner.next();
+ }
+
+ if ( nxt.size() > numCategories ) {
+ return null;
+ } else {
+ int[] buf = new int[numCategories];
+ for ( int idx = 0; idx < nxt.size(); idx++ ) {
+ buf[idx] = nxt.get(idx);
+ }
+ Arrays.sort(buf);
+ return buf;
+ }
+ }
+
+ return null;
+ }
+
+ public int[] nextPermutation() {
+ return MathUtilsUnitTest.nextPermutation(next);
+ }
+
+ }
+
+ public static int[] nextPermutation(int[] next) {
+ // the counts can swap among each other. The int[] is originally in ascending order
+ // this generates the next array in lexicographic order descending
+
+ // locate the last occurrence where next[k] < next[k+1]
+ int gt = -1;
+ for ( int idx = 0; idx < next.length-1; idx++) {
+ if ( next[idx] < next[idx+1] ) {
+ gt = idx;
+ }
+ }
+
+ if ( gt == -1 ) {
+ return null;
+ }
+
+ int largestLessThan = gt+1;
+ for ( int idx = 1 + largestLessThan; idx < next.length; idx++) {
+ if ( next[gt] < next[idx] ) {
+ largestLessThan = idx;
+ }
+ }
+
+ int val = next[gt];
+ next[gt] = next[largestLessThan];
+ next[largestLessThan] = val;
+
+ // reverse the tail of the array
+ int[] newTail = new int[next.length-gt-1];
+ int ctr = 0;
+ for ( int idx = next.length-1; idx > gt; idx-- ) {
+ newTail[ctr++] = next[idx];
+ }
+
+ for ( int idx = 0; idx < newTail.length; idx++) {
+ next[gt+idx+1] = newTail[idx];
+ }
+
+ return next;
+ }
+
+
+ // before testing the dirichlet multinomial, we need to test the
+ // classes used to test the dirichlet multinomial
+
+ @Test
+ public void testPartitioner() {
+ int[] numsToTest = new int[]{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20};
+ int[] expectedSizes = new int[]{1, 2, 3, 5, 7, 11, 15, 22, 30, 42, 56, 77, 101, 135, 176, 231, 297, 385, 490, 627};
+ for ( int testNum = 0; testNum < numsToTest.length; testNum++ ) {
+ PartitionGenerator gen = new PartitionGenerator(numsToTest[testNum]);
+ int size = 0;
+ while ( gen.hasNext() ) {
+ logger.debug(gen.dataStr());
+ size += 1;
+ gen.next();
+ }
+ Assert.assertEquals(size,expectedSizes[testNum],
+ String.format("Expected %d partitions, observed %s",expectedSizes[testNum],new PartitionGenerator(numsToTest[testNum]).toString()));
+ }
+ }
+
+ @Test
+ public void testNextPermutation() {
+ int[] arr = new int[]{1,2,3,4};
+ int[][] gens = new int[][] {
+ new int[]{1,2,3,4},
+ new int[]{1,2,4,3},
+ new int[]{1,3,2,4},
+ new int[]{1,3,4,2},
+ new int[]{1,4,2,3},
+ new int[]{1,4,3,2},
+ new int[]{2,1,3,4},
+ new int[]{2,1,4,3},
+ new int[]{2,3,1,4},
+ new int[]{2,3,4,1},
+ new int[]{2,4,1,3},
+ new int[]{2,4,3,1},
+ new int[]{3,1,2,4},
+ new int[]{3,1,4,2},
+ new int[]{3,2,1,4},
+ new int[]{3,2,4,1},
+ new int[]{3,4,1,2},
+ new int[]{3,4,2,1},
+ new int[]{4,1,2,3},
+ new int[]{4,1,3,2},
+ new int[]{4,2,1,3},
+ new int[]{4,2,3,1},
+ new int[]{4,3,1,2},
+ new int[]{4,3,2,1} };
+ for ( int gen = 0; gen < gens.length; gen ++ ) {
+ for ( int idx = 0; idx < 3; idx++ ) {
+ Assert.assertEquals(arr[idx],gens[gen][idx],
+ String.format("Error at generation %d, expected %s, observed %s",gen,Arrays.toString(gens[gen]),Arrays.toString(arr)));
+ }
+ arr = nextPermutation(arr);
+ }
+ }
+
+ private double[] addEpsilon(double[] counts) {
+ double[] d = new double[counts.length];
+ for ( int i = 0; i < counts.length; i ++ ) {
+ d[i] = counts[i] + 1e-3;
+ }
+ return d;
+ }
+
+ @Test
+ public void testDirichletMultinomial() {
+ List<double[]> testAlleles = Arrays.asList(
+ new double[]{80,240},
+ new double[]{1,10000},
+ new double[]{0,500},
+ new double[]{5140,20480},
+ new double[]{5000,800,200},
+ new double[]{6,3,1000},
+ new double[]{100,400,300,800},
+ new double[]{8000,100,20,80,2},
+ new double[]{90,20000,400,20,4,1280,720,1}
+ );
+
+ Assert.assertTrue(! Double.isInfinite(MathUtils.log10Gamma(1e-3)) && ! Double.isNaN(MathUtils.log10Gamma(1e-3)));
+
+ int[] numAlleleSampled = new int[]{2,5,10,20,25};
+ for ( double[] alleles : testAlleles ) {
+ for ( int count : numAlleleSampled ) {
+ // test that everything sums to one. Generate all multinomial draws
+ List<Double> likelihoods = new ArrayList<Double>(100000);
+ NextCounts generator = new NextCounts(alleles.length,count);
+ double maxLog = Double.MIN_VALUE;
+ //List<String> countLog = new ArrayList<String>(200);
+ while ( generator.hasNext() ) {
+ int[] thisCount = generator.next();
+ //countLog.add(Arrays.toString(thisCount));
+ Double likelihood = MathUtils.dirichletMultinomial(addEpsilon(alleles),thisCount);
+ Assert.assertTrue(! Double.isNaN(likelihood) && ! Double.isInfinite(likelihood),
+ String.format("Likelihood for counts %s and nAlleles %d was %s",
+ Arrays.toString(thisCount),alleles.length,Double.toString(likelihood)));
+ if ( likelihood > maxLog )
+ maxLog = likelihood;
+ likelihoods.add(likelihood);
+ }
+ //System.out.printf("%d likelihoods and max is (probability) %e\n",likelihoods.size(),Math.pow(10,maxLog));
+ Assert.assertEquals(MathUtils.sumLog10(unwrap(likelihoods)),1.0,1e-7,
+ String.format("Counts %d and alleles %d have nLikelihoods %d. \n Counts: %s",
+ count,alleles.length,likelihoods.size(), "NODEBUG"/*,countLog*/));
+ }
+ }
+ }
}
View
7 public/java/test/org/broadinstitute/sting/utils/variant/GATKVariantContextUtilsUnitTest.java
@@ -32,6 +32,7 @@
import org.broadinstitute.sting.utils.Utils;
import org.broadinstitute.sting.utils.collections.Pair;
import org.broadinstitute.variant.variantcontext.*;
+import org.broadinstitute.variant.vcf.VCFConstants;
import org.testng.Assert;
import org.testng.annotations.BeforeSuite;
import org.testng.annotations.DataProvider;
@@ -56,11 +57,7 @@ public void setup() {
ATCATC = Allele.create("ATCATC");
}
- private Genotype makeG(String sample, Allele a1, Allele a2) {
- return GenotypeBuilder.create(sample, Arrays.asList(a1, a2));
- }
-
- private Genotype makeG(String sample, Allele a1, Allele a2, double log10pError, double... pls) {
+ private Genotype makeG(String sample, Allele a1, Allele a2, double log10pError, int... pls) {
return new GenotypeBuilder(sample, Arrays.asList(a1, a2)).log10PError(log10pError).PL(pls).make();
}
Please sign in to comment.
Something went wrong with that request. Please try again.