Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
akiezun committed Apr 14, 2015
1 parent 27003ee commit ef80890
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 66 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.broadinstitute.hellbender.tools.walkers.vqsr;

import com.google.common.annotations.VisibleForTesting;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.exceptions.GATKException;
Expand All @@ -11,29 +12,36 @@
import java.util.*;

/*
* Tranche in VQSR.
* Package-private because it's not usable outside.
* Represents a truth sensitivity tranche in VQSR.
* (Package-private because it's not usable outside.)
*/
final class Tranche {
public final double targetTruthSensitivity;
public final double minVQSLod; //minimum value of VQSLOD in this tranche
public final double knownTiTv; //titv value of known sites in this tranche
public final double novelTiTv; //titv value of novel sites in this tranche
public final int numKnown; //number of known sites in this tranche
public final int numNovel; //number of novel sites in this tranche
public final String name; //Name of the tranche
public final VariantRecalibratorArgumentCollection.Mode model; //this is a SNP VQSR tranche or Indel tranche?

private static final String DEFAULT_TRANCHE_NAME = "anonymous";
private static final String COMMENT_STRING = "#";
private static final String VALUE_SEPARATOR = ",";
private static final int EXPECTED_COLUMN_COUNT = 11;

static final Comparator<Tranche> TRUTH_SENSITIVITY_ORDER = (tranche1, tranche2) -> Double.compare(tranche1.targetTruthSensitivity, tranche2.targetTruthSensitivity);

private final static Logger logger = LogManager.getLogger(Tranche.class);

//Note: visibility is set to package-local for testing
final double targetTruthSensitivity;
final double minVQSLod; //minimum value of VQSLOD in this tranche
final double knownTiTv; //titv value of known sites in this tranche
final double novelTiTv; //titv value of novel sites in this tranche
final int numKnown; //number of known sites in this tranche
final int numNovel; //number of novel sites in this tranche
final String name; //Name of the tranche
private final VariantRecalibratorArgumentCollection.Mode model; //this is a SNP VQSR tranche or Indel tranche?

private final int accessibleTruthSites;
private final int callsAtTruthSites;

public Tranche(double targetTruthSensitivity, double minVQSLod, int numKnown, double knownTiTv, int numNovel, double novelTiTv, int accessibleTruthSites, int callsAtTruthSites, VariantRecalibratorArgumentCollection.Mode model) {
this(targetTruthSensitivity, minVQSLod, numKnown, knownTiTv, numNovel, novelTiTv, accessibleTruthSites, callsAtTruthSites, model, "anonymous");
}

public Tranche(double targetTruthSensitivity, double minVQSLod, int numKnown, double knownTiTv, int numNovel, double novelTiTv, int accessibleTruthSites, int callsAtTruthSites, VariantRecalibratorArgumentCollection.Mode model, String name) {
if ( targetTruthSensitivity < 0.0 || targetTruthSensitivity > 100.0) {
throw new UserException("Target FDR is unreasonable " + targetTruthSensitivity);
throw new GATKException("Target FDR is unreasonable " + targetTruthSensitivity);
}

if ( numKnown < 0 || numNovel < 0) {
Expand All @@ -57,8 +65,6 @@ public Tranche(double targetTruthSensitivity, double minVQSLod, int numKnown, do
this.callsAtTruthSites = callsAtTruthSites;
}

public static final Comparator<Tranche> TRUTH_SENSITIVITY_ORDER = (tranche1, tranche2) -> Double.compare(tranche1.targetTruthSensitivity, tranche2.targetTruthSensitivity);

@Override
public String toString() {
return String.format("Tranche targetTruthSensitivity=%.2f minVQSLod=%.4f known=(%d @ %.4f) novel=(%d @ %.4f) truthSites(%d accessible, %d called), name=%s]",
Expand All @@ -67,26 +73,42 @@ public String toString() {

private static double getRequiredDouble(Map<String,String> bindings, String key ) {
if ( bindings.containsKey(key) ) {
return Double.valueOf(bindings.get(key));
try {
return Double.valueOf(bindings.get(key));
} catch (NumberFormatException e){
throw new UserException.MalformedFile("Malformed tranches file. Invalid value for key " + key);
}
} else {
throw new UserException.MalformedFile("Malformed tranches file. Missing required key " + key);
}
}

private static double getOptionalDouble(Map<String,String> bindings, String key, double defaultValue ) {
return Double.valueOf(bindings.getOrDefault(key, String.valueOf(defaultValue)));
try{
return Double.valueOf(bindings.getOrDefault(key, String.valueOf(defaultValue)));
} catch (NumberFormatException e){
throw new UserException.MalformedFile("Malformed tranches file. Invalid value for key " + key);
}
}

private static int getRequiredInteger(Map<String,String> bindings, String key) {
if ( bindings.containsKey(key) ) {
return Integer.valueOf(bindings.get(key));
try{
return Integer.valueOf(bindings.get(key));
} catch (NumberFormatException e){
throw new UserException.MalformedFile("Malformed tranches file. Invalid value for key " + key);
}
} else {
throw new UserException.MalformedFile("Malformed tranches file. Missing required key " + key);
}
}

private static int getOptionalInteger(Map<String,String> bindings, String key, int defaultValue) {
return Integer.valueOf(bindings.getOrDefault(key, String.valueOf(defaultValue)));
try{
return Integer.valueOf(bindings.getOrDefault(key, String.valueOf(defaultValue)));
} catch (NumberFormatException e){
throw new UserException.MalformedFile("Malformed tranches file. Invalid value for key " + key);
}
}

/**
Expand All @@ -99,14 +121,14 @@ public static List<Tranche> readTranches(File f) throws IOException{

try (XReadLines xrl = new XReadLines(f) ) {
for (final String line : xrl) {
if (line.startsWith("#")) { //comment
if (line.startsWith(COMMENT_STRING)) {
continue;
}

final String[] vals = line.split(",");
final String[] vals = line.split(VALUE_SEPARATOR);
if (header == null) { //reading the header
header = vals;
if (header.length != 11) {
if (header.length != EXPECTED_COLUMN_COUNT) {
throw new UserException.MalformedFile(f, "Expected 11 elements in header line " + line);
}
} else {
Expand All @@ -115,7 +137,9 @@ public static List<Tranche> readTranches(File f) throws IOException{
}

Map<String, String> bindings = new HashMap<>();
for (int i = 0; i < vals.length; i++) bindings.put(header[i], vals[i]);
for (int i = 0; i < vals.length; i++) {
bindings.put(header[i], vals[i]);
}
tranches.add(new Tranche(
getRequiredDouble(bindings, "targetTruthSensitivity"),
getRequiredDouble(bindings, "minVQSLod"),
Expand All @@ -135,38 +159,25 @@ public static List<Tranche> readTranches(File f) throws IOException{
}
}

protected final static Logger logger = LogManager.getLogger(Tranche.class);

// Code to determine FDR tranches for VariantDatum[]
public static abstract class SelectionMetric {
final String name;

public SelectionMetric(String name) {
this.name = name;
}

public String getName() { return name; }

public abstract double getThreshold(double tranche);
public abstract void calculateRunningMetric(List<VariantDatum> data);
public abstract double getRunningMetric(int i);
}

public static final class TruthSensitivityMetric extends SelectionMetric {
double[] runningSensitivity;
@VisibleForTesting
static final class TruthSensitivityMetric {
private final String name;
private double[] runningSensitivity;
private final int nTrueSites;

public TruthSensitivityMetric(int nTrueSites) {
super("TruthSensitivity");
this.name = "TruthSensitivity";
this.nTrueSites = nTrueSites;
}

@Override
public String getName(){
return name;
}

public double getThreshold(double tranche) {
return 1.0 - tranche/100.0; // tranche of 1 => 99% sensitivity target
}

@Override
public void calculateRunningMetric(List<VariantDatum> data) {
int nCalledAtTruth = 0;
runningSensitivity = new double[data.size()];
Expand All @@ -178,14 +189,12 @@ public void calculateRunningMetric(List<VariantDatum> data) {
}
}

@Override
public double getRunningMetric(int i) {
return runningSensitivity[i];
}

}

public static List<Tranche> findTranches( final List<VariantDatum> data, final double[] trancheThresholds, final SelectionMetric metric, final VariantRecalibratorArgumentCollection.Mode model) {
public static List<Tranche> findTranches( final List<VariantDatum> data, final double[] trancheThresholds, final TruthSensitivityMetric metric, final VariantRecalibratorArgumentCollection.Mode model) {
logger.info(String.format("Finding %d tranches for %d variants", trancheThresholds.length, data.size()));

Collections.sort(data, VariantDatum.VariantDatumLODComparator);
Expand All @@ -209,18 +218,18 @@ public static List<Tranche> findTranches( final List<VariantDatum> data, final d
return tranches;
}

private static Tranche findTranche( final List<VariantDatum> data, final SelectionMetric metric, final double trancheThreshold, final VariantRecalibratorArgumentCollection.Mode model ) {
logger.info(String.format(" Tranche threshold %.2f => selection metric threshold %.3f", trancheThreshold, metric.getThreshold(trancheThreshold)));
private static Tranche findTranche( final List<VariantDatum> data, final TruthSensitivityMetric metric, final double trancheThreshold, final VariantRecalibratorArgumentCollection.Mode model ) {
logger.debug(String.format(" Tranche threshold %.2f => selection metric threshold %.3f", trancheThreshold, metric.getThreshold(trancheThreshold)));

double metricThreshold = metric.getThreshold(trancheThreshold);
int n = data.size();
for ( int i = 0; i < n; i++ ) {
if ( metric.getRunningMetric(i) >= metricThreshold ) {
// we've found the largest group of variants with sensitivity >= our target truth sensitivity
Tranche t = trancheOfVariants(data, i, trancheThreshold, model);
logger.info(String.format(" Found tranche for %.3f: %.3f threshold starting with variant %d; running score is %.3f ",
logger.debug(String.format(" Found tranche for %.3f: %.3f threshold starting with variant %d; running score is %.3f ",
trancheThreshold, metricThreshold, i, metric.getRunningMetric(i)));
logger.info(String.format(" Tranche is %s", t));
logger.debug(String.format(" Tranche is %s", t));
return t;
}
}
Expand Down Expand Up @@ -254,7 +263,7 @@ private static Tranche trancheOfVariants( final List<VariantDatum> data, int min
int accessibleTruthSites = VariantDatum.countCallsAtTruth(data, Double.NEGATIVE_INFINITY);
int nCallsAtTruth = VariantDatum.countCallsAtTruth(data, minLod);

return new Tranche(ts, minLod, numKnown, knownTiTv, numNovel, novelTiTv, accessibleTruthSites, nCallsAtTruth, model);
return new Tranche(ts, minLod, numKnown, knownTiTv, numNovel, novelTiTv, accessibleTruthSites, nCallsAtTruth, model, DEFAULT_TRANCHE_NAME);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
*/
final class VariantRecalibratorArgumentCollection {

static public enum Mode {
public enum Mode {
SNP,
INDEL,
BOTH
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,27 @@
* doSomeWork(line);
* }
*
* For the love of god, please use this system for reading lines in a file.
* Please use this class for reading lines in a file.
*/
public final class XReadLines implements Iterator<String>, Iterable<String>, AutoCloseable {
private final BufferedReader in; // The stream we're reading from
private String nextLine = null; // Return value of next call to next()
private final boolean trimWhitespace;
private final String commentPrefix;

/**
* Opens the given file for reading lines.
* The file may be a text file or a gzipped text file (the distinction is made by the file extension).
* By default, it will trim whitespaces.
*/
public XReadLines(final File filename) {
this(filename, true);
}

/**
* Opens the given file for reading lines and optionally trim whitespaces.
* The file may be a text file or a gzipped text file (the distinction is made by the file extension).
*/
public XReadLines(final File filename, final boolean trimWhitespace) {
this(IOUtils.makeReaderMaybeGzipped(filename), trimWhitespace, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,26 +59,23 @@ public final List<Tranche> read(File f) throws IOException{
return Tranche.readTranches(f);
}

private static void assertTranchesAreTheSame(List<Tranche> newFormat, List<Tranche> oldFormat, boolean completeP, boolean includeName) {
private static void assertTranchesAreTheSame(List<Tranche> newFormat, List<Tranche> oldFormat) {
Assert.assertEquals(oldFormat.size(), newFormat.size());
for ( int i = 0; i < newFormat.size(); i++ ) {
Tranche n = newFormat.get(i);
Tranche o = oldFormat.get(i);
Assert.assertEquals(n.targetTruthSensitivity, o.targetTruthSensitivity, 1e-3);
Assert.assertEquals(n.numNovel, o.numNovel);
Assert.assertEquals(n.novelTiTv, o.novelTiTv, 1e-3);
if ( includeName )
Assert.assertEquals(n.name, o.name);
if ( completeP ) {
Assert.assertEquals(n.numKnown, o.numKnown);
Assert.assertEquals(n.knownTiTv, o.knownTiTv, 1e-3);
}
Assert.assertEquals(n.numKnown, o.numKnown);
Assert.assertEquals(n.name, o.name);
Assert.assertEquals(n.knownTiTv, o.knownTiTv, 1e-3);
}
}

private static List<Tranche> findMyTranches(ArrayList<VariantDatum> vd, double[] tranches) {
final int nCallsAtTruth = VariantDatum.countCallsAtTruth( vd, Double.NEGATIVE_INFINITY );
final Tranche.SelectionMetric metric = new Tranche.TruthSensitivityMetric( nCallsAtTruth );
final Tranche.TruthSensitivityMetric metric = new Tranche.TruthSensitivityMetric( nCallsAtTruth );
return Tranche.findTranches(vd, tranches, metric, VariantRecalibratorArgumentCollection.Mode.SNP);
}

Expand All @@ -87,7 +84,7 @@ public final void testFindTranches1() throws IOException {
ArrayList<VariantDatum> vd = readData();
List<Tranche> tranches = findMyTranches(vd, TRUTH_SENSITIVITY_CUTS);
tranches.sort(Tranche.TRUTH_SENSITIVITY_ORDER);
assertTranchesAreTheSame(read(EXPECTED_TRANCHES_NEW), tranches, true, false);
assertTranchesAreTheSame(read(EXPECTED_TRANCHES_NEW), tranches);
}

@Test(expectedExceptions = {UserException.class})
Expand Down

0 comments on commit ef80890

Please sign in to comment.