diff --git a/public/java/src/org/broadinstitute/sting/utils/MathUtils.java b/public/java/src/org/broadinstitute/sting/utils/MathUtils.java index bfae7e94c2..8697fcab6b 100644 --- a/public/java/src/org/broadinstitute/sting/utils/MathUtils.java +++ b/public/java/src/org/broadinstitute/sting/utils/MathUtils.java @@ -1472,4 +1472,34 @@ public static List randomSubset(final List list, final int N) { return sliceListByIndices(sampleIndicesWithoutReplacement(list.size(),N),list); } + /** + * Return the likelihood of observing the counts of categories having sampled a population + * whose categorial frequencies are distributed according to a Dirichlet distribution + * @param dirichletParams - params of the prior dirichlet distribution + * @param dirichletSum - the sum of those parameters + * @param counts - the counts of observation in each category + * @param countSum - the sum of counts (number of trials) + * @return - associated likelihood + */ + public static double dirichletMultinomial(final double[] dirichletParams, final double dirichletSum, + final int[] counts, final int countSum) { + if ( dirichletParams.length != counts.length ) { + throw new IllegalStateException("The number of dirichlet parameters must match the number of categories"); + } + // todo -- lots of lnGammas here. At some point we can safely switch to x * ( ln(x) - 1) + double likelihood = log10MultinomialCoefficient(countSum,counts); + likelihood += log10Gamma(dirichletSum); + likelihood -= log10Gamma(dirichletSum+countSum); + for ( int idx = 0; idx < counts.length; idx++ ) { + likelihood += log10Gamma(counts[idx] + dirichletParams[idx]); + likelihood -= log10Gamma(dirichletParams[idx]); + } + + return likelihood; + } + + public static double dirichletMultinomial(double[] params, int[] counts) { + return dirichletMultinomial(params,sum(params),counts,(int) sum(counts)); + } + } diff --git a/public/java/src/org/broadinstitute/sting/utils/Utils.java b/public/java/src/org/broadinstitute/sting/utils/Utils.java index 75bd6a3d12..5cb1410746 100644 --- a/public/java/src/org/broadinstitute/sting/utils/Utils.java +++ b/public/java/src/org/broadinstitute/sting/utils/Utils.java @@ -835,4 +835,18 @@ public static byte[] trimArray(final byte[] seq, final int trimFromFront, final // don't perform array copies if we need to copy everything anyways return ( trimFromFront == 0 && trimFromBack == 0 ) ? seq : Arrays.copyOfRange(seq, trimFromFront, seq.length - trimFromBack); } + + /** + * Simple wrapper for sticking elements of a int[] array into a List + * @param ar - the array whose elements should be listified + * @return - a List where each element has the same value as the corresponding index in @ar + */ + public static List listFromPrimitives(final int[] ar) { + final ArrayList lst = new ArrayList<>(ar.length); + for ( final int d : ar ) { + lst.add(d); + } + + return lst; + } } diff --git a/public/java/src/org/broadinstitute/sting/utils/variant/GATKVariantContextUtils.java b/public/java/src/org/broadinstitute/sting/utils/variant/GATKVariantContextUtils.java index 11cd27a9fe..03bb9763c8 100644 --- a/public/java/src/org/broadinstitute/sting/utils/variant/GATKVariantContextUtils.java +++ b/public/java/src/org/broadinstitute/sting/utils/variant/GATKVariantContextUtils.java @@ -565,11 +565,11 @@ private static boolean likelihoodsAreUninformative(final double[] likelihoods) { * @param newLikelihoods a vector of likelihoods to use if the method requires PLs, should be log10 likelihoods, cannot be null * @param allelesToUse the alleles we are using for our subsetting */ - protected static void updateGenotypeAfterSubsetting(final List originalGT, - final GenotypeBuilder gb, - final GenotypeAssignmentMethod assignmentMethod, - final double[] newLikelihoods, - final List allelesToUse) { + public static void updateGenotypeAfterSubsetting(final List originalGT, + final GenotypeBuilder gb, + final GenotypeAssignmentMethod assignmentMethod, + final double[] newLikelihoods, + final List allelesToUse) { gb.noAD(); switch ( assignmentMethod ) { case SET_TO_NO_CALL: diff --git a/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java index a137975238..de049fe89f 100644 --- a/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java @@ -522,4 +522,338 @@ public void testMedian(final List values, final Comparable expected) final Comparable actual = MathUtils.median(values); Assert.assertEquals(actual, expected, "Failed with " + values); } + + + + // man. All this to test dirichlet. + + private double[] unwrap(List stuff) { + double[] unwrapped = new double[stuff.size()]; + int idx = 0; + for ( Double d : stuff ) { + unwrapped[idx++] = d == null ? 0.0 : d; + } + + return unwrapped; + } + + /** + * The PartitionGenerator generates all of the partitions of a number n, e.g. + * 5 + 0 + * 4 + 1 + * 3 + 2 + * 3 + 1 + 1 + * 2 + 2 + 1 + * 2 + 1 + 1 + 1 + * 1 + 1 + 1 + 1 + 1 + * + * This is used to help enumerate the state space over which the Dirichlet-Multinomial is defined, + * to ensure that the distribution function is properly implemented + */ + class PartitionGenerator implements Iterator> { + // generate the partitions of an integer, each partition sorted numerically + int n; + List a; + int y; + int k; + int state; + int x; + int l; + + public PartitionGenerator(int n) { + this.n = n; + this.y = n - 1; + this.k = 1; + this.a = new ArrayList(); + for ( int i = 0; i < n; i++ ) { + this.a.add(i); + } + this.state = 0; + } + + public void remove() { /* do nothing */ } + + public boolean hasNext() { return ! ( this.k == 0 && state == 0 ); } + + private String dataStr() { + return String.format("a = [%s] k = %d y = %d state = %d x = %d l = %d", + Utils.join(",",a), k, y, state, x, l); + } + + public List next() { + if ( this.state == 0 ) { + this.x = a.get(k-1)+1; + k -= 1; + this.state = 1; + } + + if ( this.state == 1 ) { + while ( 2*x <= y ) { + this.a.set(k,x); + this.y -= x; + this.k++; + } + this.l = 1+this.k; + this.state = 2; + } + + if ( this.state == 2 ) { + if ( x <= y ) { + this.a.set(k,x); + this.a.set(l,y); + x += 1; + y -= 1; + return this.a.subList(0, this.k + 2); + } else { + this.state =3; + } + } + + if ( this.state == 3 ) { + this.a.set(k,x+y); + this.y = x + y - 1; + this.state = 0; + return a.subList(0, k + 1); + } + + throw new IllegalStateException("Cannot get here"); + } + + public String toString() { + StringBuffer buf = new StringBuffer(); + buf.append("{ "); + while ( hasNext() ) { + buf.append("["); + buf.append(Utils.join(",",next())); + buf.append("],"); + } + buf.deleteCharAt(buf.lastIndexOf(",")); + buf.append(" }"); + return buf.toString(); + } + + } + + /** + * NextCounts is the enumerator over the state space of the multinomial dirichlet. + * + * It filters the partition of the total sum to only those with a number of terms + * equal to the number of categories. + * + * It then generates all permutations of that partition. + * + * In so doing it enumerates over the full state space. + */ + class NextCounts implements Iterator { + + private PartitionGenerator partitioner; + private int numCategories; + private int[] next; + + public NextCounts(int numCategories, int totalCounts) { + partitioner = new PartitionGenerator(totalCounts); + this.numCategories = numCategories; + next = nextFromPartitioner(); + } + + public void remove() { /* do nothing */ } + + public boolean hasNext() { return next != null; } + + public int[] next() { + int[] toReturn = clone(next); + next = nextPermutation(); + if ( next == null ) { + next = nextFromPartitioner(); + } + + return toReturn; + } + + private int[] clone(int[] arr) { + int[] a = new int[arr.length]; + for ( int idx = 0; idx < a.length ; idx ++) { + a[idx] = arr[idx]; + } + + return a; + } + + private int[] nextFromPartitioner() { + if ( partitioner.hasNext() ) { + List nxt = partitioner.next(); + while ( partitioner.hasNext() && nxt.size() > numCategories ) { + nxt = partitioner.next(); + } + + if ( nxt.size() > numCategories ) { + return null; + } else { + int[] buf = new int[numCategories]; + for ( int idx = 0; idx < nxt.size(); idx++ ) { + buf[idx] = nxt.get(idx); + } + Arrays.sort(buf); + return buf; + } + } + + return null; + } + + public int[] nextPermutation() { + return MathUtilsUnitTest.nextPermutation(next); + } + + } + + public static int[] nextPermutation(int[] next) { + // the counts can swap among each other. The int[] is originally in ascending order + // this generates the next array in lexicographic order descending + + // locate the last occurrence where next[k] < next[k+1] + int gt = -1; + for ( int idx = 0; idx < next.length-1; idx++) { + if ( next[idx] < next[idx+1] ) { + gt = idx; + } + } + + if ( gt == -1 ) { + return null; + } + + int largestLessThan = gt+1; + for ( int idx = 1 + largestLessThan; idx < next.length; idx++) { + if ( next[gt] < next[idx] ) { + largestLessThan = idx; + } + } + + int val = next[gt]; + next[gt] = next[largestLessThan]; + next[largestLessThan] = val; + + // reverse the tail of the array + int[] newTail = new int[next.length-gt-1]; + int ctr = 0; + for ( int idx = next.length-1; idx > gt; idx-- ) { + newTail[ctr++] = next[idx]; + } + + for ( int idx = 0; idx < newTail.length; idx++) { + next[gt+idx+1] = newTail[idx]; + } + + return next; + } + + + // before testing the dirichlet multinomial, we need to test the + // classes used to test the dirichlet multinomial + + @Test + public void testPartitioner() { + int[] numsToTest = new int[]{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20}; + int[] expectedSizes = new int[]{1, 2, 3, 5, 7, 11, 15, 22, 30, 42, 56, 77, 101, 135, 176, 231, 297, 385, 490, 627}; + for ( int testNum = 0; testNum < numsToTest.length; testNum++ ) { + PartitionGenerator gen = new PartitionGenerator(numsToTest[testNum]); + int size = 0; + while ( gen.hasNext() ) { + logger.debug(gen.dataStr()); + size += 1; + gen.next(); + } + Assert.assertEquals(size,expectedSizes[testNum], + String.format("Expected %d partitions, observed %s",expectedSizes[testNum],new PartitionGenerator(numsToTest[testNum]).toString())); + } + } + + @Test + public void testNextPermutation() { + int[] arr = new int[]{1,2,3,4}; + int[][] gens = new int[][] { + new int[]{1,2,3,4}, + new int[]{1,2,4,3}, + new int[]{1,3,2,4}, + new int[]{1,3,4,2}, + new int[]{1,4,2,3}, + new int[]{1,4,3,2}, + new int[]{2,1,3,4}, + new int[]{2,1,4,3}, + new int[]{2,3,1,4}, + new int[]{2,3,4,1}, + new int[]{2,4,1,3}, + new int[]{2,4,3,1}, + new int[]{3,1,2,4}, + new int[]{3,1,4,2}, + new int[]{3,2,1,4}, + new int[]{3,2,4,1}, + new int[]{3,4,1,2}, + new int[]{3,4,2,1}, + new int[]{4,1,2,3}, + new int[]{4,1,3,2}, + new int[]{4,2,1,3}, + new int[]{4,2,3,1}, + new int[]{4,3,1,2}, + new int[]{4,3,2,1} }; + for ( int gen = 0; gen < gens.length; gen ++ ) { + for ( int idx = 0; idx < 3; idx++ ) { + Assert.assertEquals(arr[idx],gens[gen][idx], + String.format("Error at generation %d, expected %s, observed %s",gen,Arrays.toString(gens[gen]),Arrays.toString(arr))); + } + arr = nextPermutation(arr); + } + } + + private double[] addEpsilon(double[] counts) { + double[] d = new double[counts.length]; + for ( int i = 0; i < counts.length; i ++ ) { + d[i] = counts[i] + 1e-3; + } + return d; + } + + @Test + public void testDirichletMultinomial() { + List testAlleles = Arrays.asList( + new double[]{80,240}, + new double[]{1,10000}, + new double[]{0,500}, + new double[]{5140,20480}, + new double[]{5000,800,200}, + new double[]{6,3,1000}, + new double[]{100,400,300,800}, + new double[]{8000,100,20,80,2}, + new double[]{90,20000,400,20,4,1280,720,1} + ); + + Assert.assertTrue(! Double.isInfinite(MathUtils.log10Gamma(1e-3)) && ! Double.isNaN(MathUtils.log10Gamma(1e-3))); + + int[] numAlleleSampled = new int[]{2,5,10,20,25}; + for ( double[] alleles : testAlleles ) { + for ( int count : numAlleleSampled ) { + // test that everything sums to one. Generate all multinomial draws + List likelihoods = new ArrayList(100000); + NextCounts generator = new NextCounts(alleles.length,count); + double maxLog = Double.MIN_VALUE; + //List countLog = new ArrayList(200); + while ( generator.hasNext() ) { + int[] thisCount = generator.next(); + //countLog.add(Arrays.toString(thisCount)); + Double likelihood = MathUtils.dirichletMultinomial(addEpsilon(alleles),thisCount); + Assert.assertTrue(! Double.isNaN(likelihood) && ! Double.isInfinite(likelihood), + String.format("Likelihood for counts %s and nAlleles %d was %s", + Arrays.toString(thisCount),alleles.length,Double.toString(likelihood))); + if ( likelihood > maxLog ) + maxLog = likelihood; + likelihoods.add(likelihood); + } + //System.out.printf("%d likelihoods and max is (probability) %e\n",likelihoods.size(),Math.pow(10,maxLog)); + Assert.assertEquals(MathUtils.sumLog10(unwrap(likelihoods)),1.0,1e-7, + String.format("Counts %d and alleles %d have nLikelihoods %d. \n Counts: %s", + count,alleles.length,likelihoods.size(), "NODEBUG"/*,countLog*/)); + } + } + } } diff --git a/public/java/test/org/broadinstitute/sting/utils/variant/GATKVariantContextUtilsUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/variant/GATKVariantContextUtilsUnitTest.java index 220e64f7de..575fe49361 100644 --- a/public/java/test/org/broadinstitute/sting/utils/variant/GATKVariantContextUtilsUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/utils/variant/GATKVariantContextUtilsUnitTest.java @@ -32,6 +32,7 @@ import org.broadinstitute.sting.utils.Utils; import org.broadinstitute.sting.utils.collections.Pair; import org.broadinstitute.variant.variantcontext.*; +import org.broadinstitute.variant.vcf.VCFConstants; import org.testng.Assert; import org.testng.annotations.BeforeSuite; import org.testng.annotations.DataProvider; @@ -56,11 +57,7 @@ public void setup() { ATCATC = Allele.create("ATCATC"); } - private Genotype makeG(String sample, Allele a1, Allele a2) { - return GenotypeBuilder.create(sample, Arrays.asList(a1, a2)); - } - - private Genotype makeG(String sample, Allele a1, Allele a2, double log10pError, double... pls) { + private Genotype makeG(String sample, Allele a1, Allele a2, double log10pError, int... pls) { return new GenotypeBuilder(sample, Arrays.asList(a1, a2)).log10PError(log10pError).PL(pls).make(); }