Skip to content

Commit

Permalink
Merge pull request #153 from melgor/LFW-speed
Browse files Browse the repository at this point in the history
SpeedUP LFW test
  • Loading branch information
Brandon Amos committed Jun 29, 2016
2 parents 026288a + 36cc8dd commit b0f9610
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions evaluation/lfw.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env python3
#!/usr/bin/env python2
#
# Copyright 2015-2016 Carnegie Mellon University
#
Expand Down Expand Up @@ -39,12 +39,14 @@

def main():
parser = argparse.ArgumentParser()
parser.add_argument('tag', type=str, help='The label/tag to put on the ROC curve.')
parser.add_argument(
'tag', type=str, help='The label/tag to put on the ROC curve.')
parser.add_argument('workDir', type=str,
help='The work directory with labels.csv and reps.csv.')
pairsDefault = os.path.expanduser("~/openface/data/lfw/pairs.txt")
parser.add_argument('--lfwPairs', type=str,
default=os.path.expanduser("~/openface/data/lfw/pairs.txt"),
default=os.path.expanduser(
"~/openface/data/lfw/pairs.txt"),

help='Location of the LFW pairs file from http://vis-www.cs.umass.edu/lfw/pairs.txt')
args = parser.parse_args()
Expand Down Expand Up @@ -137,27 +139,35 @@ def writeROC(fname, thresholds, embeddings, pairsTest):
return


def evalThresholdAccuracy(embeddings, pairs, threshold):
def getDistances(embeddings, pairsTrain):
list_dist = []
y_true = []
y_predict = []
for pair in pairs:
for pair in pairsTrain:
(x1, x2, actual_same) = getEmbeddings(pair, embeddings)
diff = x1 - x2
dist = np.dot(diff.T, diff)
predict_same = dist < threshold
y_predict.append(predict_same)
list_dist.append(dist)
y_true.append(actual_same)
return np.asarray(list_dist), np.array(y_true)


def evalThresholdAccuracy(embeddings, pairs, threshold):
distances, y_true = getDistances(embeddings, pairs)
y_predict = np.zeros(y_true.shape)
y_predict[np.where(distances < threshold)] = 1

y_true = np.array(y_true)
y_predict = np.array(y_predict)
accuracy = accuracy_score(y_true, y_predict)
return accuracy
return accuracy, pairs[np.where(y_true != y_predict)]


def findBestThreshold(thresholds, embeddings, pairsTrain):
bestThresh = bestThreshAcc = 0
distances, y_true = getDistances(embeddings, pairsTrain)
for threshold in thresholds:
accuracy = evalThresholdAccuracy(embeddings, pairsTrain, threshold)
y_predlabels = np.zeros(y_true.shape)
y_predlabels[np.where(distances < threshold)] = 1
accuracy = accuracy_score(y_true, y_predlabels)
if accuracy >= bestThreshAcc:
bestThreshAcc = accuracy
bestThresh = threshold
Expand All @@ -184,7 +194,7 @@ def verifyExp(workDir, pairs, embeddings):

bestThresh = findBestThreshold(
thresholds, embeddings, pairs[train])
accuracy = evalThresholdAccuracy(
accuracy, pairs_bad = evalThresholdAccuracy(
embeddings, pairs[test], bestThresh)
accuracies.append(accuracy)
f.write('{}, {:0.2f}, {:0.2f}\n'.format(
Expand Down

0 comments on commit b0f9610

Please sign in to comment.