Skip to content

Commit

Permalink
Annotated adversarially sampled negative pairs.
Browse files Browse the repository at this point in the history
  • Loading branch information
kosuke1701 committed Feb 16, 2021
1 parent f54d5d9 commit fcf1ae1
Show file tree
Hide file tree
Showing 3 changed files with 4,045 additions and 0 deletions.
126 changes: 126 additions & 0 deletions adversarial_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from argparse import ArgumentParser
from glob import glob
import os

import numpy as np
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt

def load_image(filename):
try:
with open(filename, "rb") as f:
image = Image.open(f)
return image.convert("RGB")
except UserWarning as e:
print(filename)
input("Something wrong happens while loading image: {} {}".format(filename, str(e)))

# Example Model definition
class Model(object):
def __init__(self, dirname):
import animecv

self.encoder = animecv.general.create_OML_ImageFolder_Encoder(dirname)
self.encoder.to("cuda")

# img: PIL image
def encode(self, img):
vecs = self.encoder.encode([img]).detach().cpu().numpy()
return vecs[0]

if __name__=="__main__":
parser = ArgumentParser()

parser.add_argument("--test-pairs", help="CSV file which lists test image pairs.")
parser.add_argument("--test-dataset-dir", help="Directory of test images.")
parser.add_argument("--ignore-list", default=None, help="List of images which should be ignored during pair sampling.")

parser.add_argument("--out-fn", default="adversarial.csv")

parser.add_argument("--n-negative", type=int, default=3000)

args = parser.parse_args()

if not os.path.exists(args.out_fn):
if args.ignore_list is not None:
df = pd.read_csv(args.ignore_list, header=None)
ignore_list = set(df.values.flatten().tolist())
else:
ignore_list = set()

# Generate adversarial negative pairs.
model = Model("0206_resnet152")

images = glob(os.path.join(args.test_dataset_dir, "**"), recursive=True)
images = [fn for fn in images if os.path.isfile(fn)]
labels = [fn.split(os.path.sep)[-2] for fn in images]

vecs = []
for fn in tqdm(images):
img = load_image(fn)
vecs.append(model.encode(img).reshape((1,-1)))
vecs = np.concatenate(vecs, axis=0)

scores = np.sum(vecs[:,np.newaxis,:] * vecs[np.newaxis,:,:], axis=2)

negative_pairs = []
n_img = scores.shape[0]
sorted_idx = np.argsort(-scores, axis=None).tolist()
strip_len = len(args.test_dataset_dir + os.path.sep)
while len(negative_pairs) < args.n_negative:
idx = sorted_idx.pop(0)
i,j = idx // n_img, idx % n_img
if i<=j:
continue
if labels[i] == labels[j]:
continue
if os.path.basename(images[i]) in ignore_list:
continue
if os.path.basename(images[j]) in ignore_list:
continue
negative_pairs.append((images[i][strip_len:], images[j][strip_len:], 0, -1, 0))

# Reuse positive pairs.
positive_pairs = []
df = pd.read_csv(args.test_pairs)
for pathA, pathB in df[df["label"]==1][["pathA", "pathB"]].values:
#print(pathA, pathB)
positive_pairs.append((pathA, pathB, 1, -1, 0))

pairs = shuffle(positive_pairs + negative_pairs)

df = pd.DataFrame(pairs, columns=["pathA", "pathB", "label", "human_prediction", "invalid"])
df.to_csv(args.out_fn, index=False)
else:
print("Reload")
df = pd.read_csv(args.out_fn)

for i_row in tqdm(list(range(df.values.shape[0]))):
pathA, pathB, label, pred, invalid = df.loc[i_row].values
#print(pathA, pathB)
if pred >= 0:
continue
else:
im1 = np.array(Image.open(os.path.join(args.test_dataset_dir, pathA)))
im2 = np.array(Image.open(os.path.join(args.test_dataset_dir, pathB)))
ax = plt.subplot(1,2,1)
ax.imshow(im1)
ax = plt.subplot(1,2,2)
ax.imshow(im2)
plt.draw()
plt.pause(0.001)
cmd = input("correct?[y/n]: ")
if cmd=="y":
pred = 1
elif cmd=="n":
pred = 0
else:
pred = 0
df.loc[i_row, "invalid"] = 1
df.loc[i_row, "human_prediction"] = pred
df.to_csv(args.out_fn, index=False)
plt.close()
63 changes: 63 additions & 0 deletions dataset/ignore_images.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
3196587.jpg
2715436.jpg
4073331.jpg
355274.jpg
1375976.jpg
2022189.jpg
3679066.jpg
4072723.jpg
619031.jpg
906000.jpg
1306014.jpg
2440275.jpg
523824.jpg
2557668.jpg
3037438.jpg
2537158.jpg
131866.jpg
3471417.jpg
2679171.jpg
1297445.jpg
3654015.jpg
4059979.jpg
2022188.jpg
911056.jpg
2590354.jpg
4052408.jpg
456627.jpg
1943273.jpg
790640.jpg
3663882.jpg
1446387.jpg
1470990.jpg
2669278.jpg
2679325.jpg
2132647.jpg
664895.jpg
3771627.jpg
1610897.jpg
3564166.jpg
2017193.jpg
3306763.jpg
3603215.jpg
3035986.jpg
2665143.jpg
4250933.jpg
2917089.jpg
1779865.jpg
511047.jpg
560368.jpg
836229.jpg
1288484.jpg
481397.jpg
1239831.jpg
3865565.jpg
2518166.jpg
1610891.jpg
1449964.jpg
710851.jpg
3910247.jpg
4174543.jpg
2039196.jpg
651377.jpg
2132651.jpg
Loading

0 comments on commit fcf1ae1

Please sign in to comment.