Skip to content

Commit

Permalink
Merge 1ac7744 into b4b832f
Browse files Browse the repository at this point in the history
  • Loading branch information
cmungall committed Mar 2, 2017
2 parents b4b832f + 1ac7744 commit 055bc6a
Show file tree
Hide file tree
Showing 10 changed files with 634 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ public MatchSet findMatchProfileImpl(ProfileQuery q) {
// any node which has an off query parent is discounted
//EWAHCompressedBitmap maskedTargetProfileBM = nodesHtBM.and(queryBlanketProfileBM);

LOG.info("TARGET PROFILE for "+itemId+" "+nodesHtBM);
//LOG.info("TARGET PROFILE for "+itemId+" "+nodesHtBM);

// cumulative log-probability
double logp = 0.0;
Expand Down Expand Up @@ -398,7 +398,7 @@ public MatchSet findMatchProfileImpl(ProfileQuery q) {
indArr[n] = itemId;
sumOfProbs += p;
n++;
LOG.info("logp for "+itemId+" = "+logp+" sumOfLogProbs="+sumOfProbs);
//LOG.info("logp for "+itemId+" = "+logp+" sumOfLogProbs="+sumOfProbs);
}
for (n = 0; n<pvector.length; n++) {
double p = pvector[n] / sumOfProbs;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package org.monarchinitiative.owlsim.compute.matcher.impl;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import javax.inject.Inject;
Expand Down Expand Up @@ -36,18 +39,44 @@
public class NaiveBayesFixedWeightTwoStateProfileMatcher extends AbstractProfileMatcher implements ProfileMatcher {

private Logger LOG = Logger.getLogger(NaiveBayesFixedWeightTwoStateProfileMatcher.class);

// set this to more than 1 for frequency-aware;
// a value of 0 defaults to frequency-unaware
private int kLeastFrequent = 0;


@Deprecated
private double defaultFalsePositiveRate = 0.002; // alpha

@Deprecated
private double defaultFalseNegativeRate = 0.10; // beta

/**
* A tuple of (weight, Classes)
*
*/
private class WeightedTypesBM {
// bitmap representing a set of classes assumed to be on
final EWAHCompressedBitmap typesBM;

// probability of the state in which all such classes are on
final double weight;

public WeightedTypesBM(EWAHCompressedBitmap typesBM, Double weight) {
super();
this.typesBM = typesBM;
this.weight = weight;
}
}

// TODO - replace when tetsing is over
//private double[] defaultFalsePositiveRateArr = new double[]{0.002};
//private double[] defaultFalseNegativeRateArr = new double[] {0.10};
private double[] defaultFalsePositiveRateArr = new double[]{1e-10,0.0005,0.001,0.005,0.01};
private double[] defaultFalseNegativeRateArr = new double[] {1e-10,0.005,0.01,0.05,0.1,0.2,0.4,0.8,0.9};

// for maps a pair of (Individual, InterpretationIndex) to a set of inferred (self, direct, indirect) types
private Map<Integer,Map<Integer,WeightedTypesBM>> individualToInterpretationToTypesBM = new HashMap<>();

@Inject
protected NaiveBayesFixedWeightTwoStateProfileMatcher(BMKnowledgeBase kb) {
Expand All @@ -70,8 +99,31 @@ public boolean isUseBlanket() {
public String getShortName() {
return "naive-bayes-fixed-weight-two-state";
}



/**
* @return the kLeastFrequent
*/
public int getkLeastFrequent() {
return kLeastFrequent;
}

/**
* The default for this should be 0. When 0, the behavior is as for frequency unaware
* (i.e. every instance-class association with frequency info will be treated as normal instance-class)
*
* When k>1, will make use of the k least frequent annotations in probabilistic calculation
*
* @param kLeastFrequent the kLeastFrequent to set
*/
public void setkLeastFrequent(int kLeastFrequent) {
// reset cache
individualToInterpretationToTypesBM = new HashMap<>();
this.kLeastFrequent = kLeastFrequent;
}

/**
* Extends the query profile - for every node c, all the direct parents of c are in
* the query profile, then add c to the query profile.
*
Expand Down Expand Up @@ -132,50 +184,82 @@ public MatchSet findMatchProfileImpl(ProfileQuery q) {
double pvector[] = new double[indIds.size()];
String indArr[] = new String[indIds.size()];
int n=0;


for (String itemId : indIds) {
EWAHCompressedBitmap targetProfileBM = knowledgeBase.getTypesBM(itemId);
// any node which has an off query parent is discounted
targetProfileBM = targetProfileBM.and(queryBlanketProfileBM);
LOG.debug("TARGET PROFILE for "+itemId+" "+targetProfileBM);


// two state model.
// mapping to Bauer et al: these correspond to mxy1, x=Q, y=H/T
int numInQueryAndInTarget = queryProfileBM.andCardinality(targetProfileBM);
int numInQueryAndNOTInTarget = queryProfileBM.andNotCardinality(targetProfileBM);
int numNOTInQueryAndInTarget = targetProfileBM.andNotCardinality(queryProfileBM);
int numNOTInQueryAndNOTInTarget =
numClassesConsidered - (numInQueryAndInTarget + numInQueryAndNOTInTarget + numNOTInQueryAndInTarget);

double p = 0.0;
// TODO: optimize this
// integrate over a Dirichlet prior for alpha & beta, rather than gridsearch
// this can be done closed-form
for (double fnr : defaultFalseNegativeRateArr) {
for (double fpr : defaultFalsePositiveRateArr) {

double pQ1T1 = Math.pow(1-fnr, numInQueryAndInTarget);
double pQ0T1 = Math.pow(fnr, numNOTInQueryAndInTarget);
double pQ1T0 = Math.pow(fpr, numInQueryAndNOTInTarget);
double pQ0T0 = Math.pow(1-fpr, numNOTInQueryAndNOTInTarget);



//LOG.debug("pQ1T1 = "+(1-fnr)+" ^ "+ numInQueryAndInTarget+" = "+pQ1T1);
//LOG.debug("pQ0T1 = "+(fnr)+" ^ "+ numNOTInQueryAndInTarget+" = "+pQ0T1);
//LOG.debug("pQ1T0 = "+(fpr)+" ^ "+ numInQueryAndNOTInTarget+" = "+pQ1T0);
//LOG.debug("pQ0T0 = "+(1-fpr)+" ^ "+ numNOTInQueryAndNOTInTarget+" = "+pQ0T0);
//TODO: optimization. We can precalculate the logs for different integers
p +=
Math.exp(Math.log(pQ1T1) + Math.log(pQ0T1) + Math.log(pQ1T0) + Math.log(pQ0T0));

}
}
pvector[n] = p;
indArr[n] = itemId;
sumOfProbs += p;

int effectiveK = kLeastFrequent;
int twoToTheK = (int) Math.pow(2, kLeastFrequent);
int numWeightedTypes = knowledgeBase.getDirectWeightedTypes(itemId).size();
if (numWeightedTypes < kLeastFrequent) {
twoToTheK = (int) Math.pow(2, numWeightedTypes);
effectiveK = numWeightedTypes;
}

double cumulativePr = 0;
for (int comboIndex = 0; comboIndex < twoToTheK; comboIndex++) {

Double comboPr = null;
EWAHCompressedBitmap targetProfileBM;
if (kLeastFrequent == 0) {
targetProfileBM = knowledgeBase.getTypesBM(itemId);
}
else {
WeightedTypesBM wtbm = getTypesFrequencyAware(itemId, comboIndex, effectiveK);
comboPr = wtbm.weight;
targetProfileBM = wtbm.typesBM;
}

// any node which has an off query parent is discounted
targetProfileBM = targetProfileBM.and(queryBlanketProfileBM);
LOG.debug("TARGET PROFILE for "+itemId+" "+targetProfileBM);


// two state model.
// mapping to Bauer et al: these correspond to mxy1, x=Q, y=H/T
int numInQueryAndInTarget = queryProfileBM.andCardinality(targetProfileBM);
int numInQueryAndNOTInTarget = queryProfileBM.andNotCardinality(targetProfileBM);
int numNOTInQueryAndInTarget = targetProfileBM.andNotCardinality(queryProfileBM);
int numNOTInQueryAndNOTInTarget =
numClassesConsidered - (numInQueryAndInTarget + numInQueryAndNOTInTarget + numNOTInQueryAndInTarget);

double p = 0.0;
// TODO: optimize this
// integrate over a Dirichlet prior for alpha & beta, rather than gridsearch
// this can be done closed-form
for (double fnr : defaultFalseNegativeRateArr) {
for (double fpr : defaultFalsePositiveRateArr) {

double pQ1T1 = Math.pow(1-fnr, numInQueryAndInTarget);
double pQ0T1 = Math.pow(fnr, numNOTInQueryAndInTarget);
double pQ1T0 = Math.pow(fpr, numInQueryAndNOTInTarget);
double pQ0T0 = Math.pow(1-fpr, numNOTInQueryAndNOTInTarget);



//LOG.debug("pQ1T1 = "+(1-fnr)+" ^ "+ numInQueryAndInTarget+" = "+pQ1T1);
//LOG.debug("pQ0T1 = "+(fnr)+" ^ "+ numNOTInQueryAndInTarget+" = "+pQ0T1);
//LOG.debug("pQ1T0 = "+(fpr)+" ^ "+ numInQueryAndNOTInTarget+" = "+pQ1T0);
//LOG.debug("pQ0T0 = "+(1-fpr)+" ^ "+ numNOTInQueryAndNOTInTarget+" = "+pQ0T0);
//TODO: optimization. We can precalculate the logs for different integers
p +=
Math.exp(Math.log(pQ1T1) + Math.log(pQ0T1) + Math.log(pQ1T0) + Math.log(pQ0T0));

}
}

if (comboPr != null) {
p *= comboPr;
}
cumulativePr += p;
}
pvector[n] = cumulativePr;
indArr[n] = itemId;

sumOfProbs += cumulativePr;
n++;
LOG.debug("p for "+itemId+" = "+p);
LOG.debug("p for "+itemId+" = "+cumulativePr);

}
for (n = 0; n<pvector.length; n++) {
double p = pvector[n] / sumOfProbs;
Expand All @@ -186,6 +270,61 @@ public MatchSet findMatchProfileImpl(ProfileQuery q) {
mp.sortMatches();
return mp;
}

// for a value of n such that: 0 <= n < 2^k
// where n represents a particular combination of k boolean values, t1, ..., tk,
// each representing the truth value for whether the class t_i is indexed for a
// given individual i.
//
// t1..tk will be the k least frequent annotations for this individual
//
// uses caching
private WeightedTypesBM getTypesFrequencyAware(String itemId, int n, int effectiveK) {
Integer iix = knowledgeBase.getIndividualIndex(itemId);
if (!individualToInterpretationToTypesBM.containsKey(iix)) {
individualToInterpretationToTypesBM.put(iix, new HashMap<>());
}
Map<Integer, WeightedTypesBM> m = individualToInterpretationToTypesBM.get(iix);
if (m.containsKey(n)) {
// use cached value
return m.get(n);
}

// default direct type map.
// note that associations with frequency annotations are includes here alongside
// normal associations
EWAHCompressedBitmap dtmap = knowledgeBase.getDirectTypesBM(itemId);

// associations with frequency info
// map is from ClassIndex -> Weight
Map<Integer, Integer> wmap = knowledgeBase.getDirectWeightedTypes(itemId);

// sort with least frequent first
List<Integer> sortedTypeIndices = new ArrayList<>(wmap.keySet());
sortedTypeIndices.sort( (Integer i, Integer j) -> wmap.get(i) - wmap.get(j));

EWAHCompressedBitmap mask = new EWAHCompressedBitmap();
double pr = 1.0;
for (int i=0; i< effectiveK; i++) {
Integer iClassIx = sortedTypeIndices.get(i);
Double w = wmap.get(iClassIx) / 100.0;
//LOG.info("Class "+iClassIx +" which is "+i+"-least frequent has weight "+w+" for individual "+itemId+" in combo "+n);
if ( (n >> i) % 2 == 0) {
mask.set(iClassIx);
pr *= 1-w;
}
else {
pr *= w;
}
}
//LOG.info("Instance "+itemId+" in combo "+n+" has Pr = "+pr);

EWAHCompressedBitmap dtmapMasked = dtmap.xor(mask);
EWAHCompressedBitmap inferredTypesBM = knowledgeBase.getSuperClassesBM(dtmapMasked);
WeightedTypesBM wtbm = new WeightedTypesBM(inferredTypesBM, pr);
m.put(n, wtbm);
return wtbm;
}

/**
* @return probability a query class is a false positive
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ public MatchSet findMatchProfileImpl(ProfileQuery q) throws IncoherentStateExcep
indArr[n] = itemId;
sumOfProbs += p;
n++;
LOG.info("p for "+itemId+" = "+p);
//LOG.info("p for "+itemId+" = "+p);
}
for (n = 0; n<pvector.length; n++) {
double p = pvector[n] / sumOfProbs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,13 @@ public interface BMKnowledgeBase {
public EWAHCompressedBitmap getDirectSubClassesBM(String classId);


/**
* @param classIds
* @return union of all superclasses (direct and indirect and equivalent) as a bitmap
*/
public EWAHCompressedBitmap getSubClassesBM(Set<String> classIds);
/**
* @param classIds
* @return union of all superclasses (direct and indirect and equivalent) as a bitmap
*/
public EWAHCompressedBitmap getSubClassesBM(Set<String> classIds);


/**
* @param classIds
* @return union of all direct subclasses as a bitmap
Expand Down Expand Up @@ -179,7 +180,13 @@ public interface BMKnowledgeBase {
* @return union of all superclasses as a bitmap
*/
public EWAHCompressedBitmap getSuperClassesBM(Set<String> classIds);


/**
* @param classIds
* @return union of all superclasses (direct and indirect and equivalent) as a bitmap
*/
public EWAHCompressedBitmap getSuperClassesBM(EWAHCompressedBitmap classesBM);

/**
* @param classIndex
* @return superclasses (direct and indirect and equivalent) of classId as bitmap
Expand All @@ -199,12 +206,18 @@ public interface BMKnowledgeBase {
*/
public EWAHCompressedBitmap getTypesBM(String id);

/**
* @param id - an individual
* @return direct types as bitmap
*/
public EWAHCompressedBitmap getDirectTypesBM(String id);

/**
* @param id - an individual
* @return direct types as bitmap
*/
public EWAHCompressedBitmap getDirectTypesBM(String id);

/**
* @param id - an individual
* @return map between Type class index and 0<weight<=100, where probability = weight/100
*/
public Map<Integer, Integer> getDirectWeightedTypes(String id);

/**
* @param itemId
* @return bitmap representation of all (direct and indirect) classes known to be NOT instantiated
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.monarchinitiative.owlsim.kb.ewah;

import java.util.Collection;
import java.util.Set;

import com.googlecode.javaewah.EWAHCompressedBitmap;
Expand Down Expand Up @@ -63,7 +64,7 @@ public EWAHCompressedBitmap getSuperClasses(int clsIndex) {
return storedSuperClasses[clsIndex];
}

public EWAHCompressedBitmap getClasses(Set<Integer> clsIndices) {
public EWAHCompressedBitmap getClasses(Collection<Integer> clsIndices) {
EWAHCompressedBitmap bm = new EWAHCompressedBitmap();
for (int i : clsIndices) {
bm.set(i);
Expand Down
Loading

0 comments on commit 055bc6a

Please sign in to comment.