-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.py
65 lines (52 loc) · 2.12 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from argparse import ArgumentParser
import os
import numpy as np
from scipy.optimize import brentq
from scipy.interpolate import interp1d
from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from PIL import Image
def load_image(filename):
try:
with open(filename, "rb") as f:
image = Image.open(f)
return image.convert("RGB")
except UserWarning as e:
print(filename)
input("Something wrong happens while loading image: {} {}".format(filename, str(e)))
# Example Model definition
class Model(object):
def __init__(self, dirname):
import animecv
self.encoder = animecv.general.create_OML_ImageFolder_Encoder(dirname)
self.encoder.to("cuda")
# img1, img2: PIL image
def score(self, img1, img2):
vecs = self.encoder.encode([img1, img2]).detach().cpu().numpy()
score = np.dot(vecs[0], vecs[1])
return score
if __name__=="__main__":
parser = ArgumentParser()
parser.add_argument("--test-pairs", help="CSV file which lists test image pairs.")
parser.add_argument("--test-dataset-dir", help="Directory of test images.")
parser.add_argument("--target-fnr", type=float, default=0.139, help="Reference FNR used to compute FPR.")
args = parser.parse_args()
model = Model("0206_seresnet152")
df = pd.read_csv(args.test_pairs)
df = df[df["invalid"]==0]
true_labels = df["label"].values
ROOT_DIR = args.test_dataset_dir
scores = []
for pathA, pathB, label in tqdm(df[["pathA", "pathB", "label"]].values):
img1 = load_image(os.path.join(args.test_dataset_dir, pathA))
img2 = load_image(os.path.join(args.test_dataset_dir, pathB))
score = model.score(img1, img2)
scores.append(score)
fpr, tpr, threshold = roc_curve(true_labels, scores)
eer = 1. - brentq(lambda x: 1. - x - interp1d(tpr, fpr)(x), 0., 1.)
fnr = 1. - tpr
print("False Positive Rate: ", interp1d(fnr, fpr)(args.target_fnr))
print("Threshold: ", interp1d(fnr, threshold)(args.target_fnr))
print("Equal Error Rate: ", eer)