In [None]:
import os
from collections import Counter, defaultdict

import h5py
import cv2
import scipy.cluster
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.patches

from vsim_common import load_vocabulary, descriptors_to_bow_vector, load_SIFT_file, \
                        VisualDatabase, SiftColornamesDatabase, AnnVisualDatabase, sift_file_for_image, inside_roi, \
                        filter_roi, label_data_annoy, label_data

%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [None]:
DATASET_DIR = '/home/hannes/Datasets/narrative2/'

In [None]:
database_data = [
                #'db_gridsift_2k.h5'
                 #'db_sift_15k.h5',
                 #'db_sift_10k.h5',
                 #'db_sift_5k.h5',
                 #'db_sift_2k.h5',
                ]

#databases = [VisualDatabase.from_file(db_path, stop_bottom=0.05, stop_top=0.1) for db_path in database_data]
#db = MultiVisualDatabase(databases)

#db = AnnVisualDatabase.from_file('db_sift_300k.h5', stop_bottom=0.01, stop_top=0.1)
#db = AnnVisualDatabase.from_file('db_sift_200k.h5', stop_bottom=0, stop_top=0.1)
#db = AnnVisualDatabase.from_file('db_sift_100k.h5', stop_bottom=0, stop_top=0.2)
#db = AnnVisualDatabase.from_file('db_sift_200k.h5', stop_top=0.1)
#db = VisualDatabase.from_file('db_cname_500.h5', stop_top=0.1)
db = SiftColornamesDatabase.from_files('db_sift_10k.h5', 'db_cname_15k.h5', sift_stop_top=0.1, cname_stop_top=0.1)
#db = VisualDatabase.from_file('db_sift_15k.h5', stop_top=0.1)

#ann_dbs = [AnnVisualDatabase.from_file('db_sift_{:d}k.h5'.format(k), stop_bottom=0, stop_top=0.1) for k in [100, 15]]
#db = MultiVisualDatabase(ann_dbs)

In [None]:
with open(os.path.join(DATASET_DIR, 'labels.txt'), 'r') as f:
    label_dict = {}
    for line in f.readlines():
        filename, *labels = line.split()
        label_dict[filename] = labels

def top_n_score(matches, n, label='H'):
    top_n = []
    if n > len(matches):
        raise ValueError("Selected n={:d} is larger than available matches {:d}".format(n, len(matches)))
        
    for i, (fname, score, *_) in enumerate(matches, start=1):
        if label in label_dict[fname + '.jpg']:
            top_n.append(i)
            if len(top_n) == n:
                break
    return sum(top_n)

def top_n_prob(matches, n, label='PE'):
    if n > len(matches):
        raise ValueError("Selected n={:d} is larger than available matches {:d}".format(n, len(matches)))
    num_label_top_n = Counter(label in label_dict[fname + '.jpg'] for fname, score, *_ in matches[:n])[True]
    top_n_inlier_rate = num_label_top_n / n
    
    return top_n_inlier_rate

In [None]:
label_dict

In [None]:
def plot_result(matches, n, maxn=None):
    print('Top {:d} score: {:d} (theoretical minimum={:d})'.format(n, top_n_score(matches, n), sum(range(n))))
    
    if maxn is None:
        maxn = min(len(matches), len(label_dict))
    top_n_chances = np.array([top_n_prob(matches, n) for n in range(1, maxn+1)])
    num_inliers = Counter('H' in v for v in label_dict.values())[True]
    random_chance = num_inliers / len(label_dict)
    print('Top {:d} inlier probability: {:.1f}% (random chance {:.1f}%)'.format(n, 100 * top_n_prob(matches, n), 100*random_chance))
    fig, (ax1, ax2) = plt.subplots(2,1, figsize=(12, 7))
    ax1.plot(range(1, maxn+1), top_n_chances)
    ax1.set_ylabel('Accuracy')
    ax1.axhline(random_chance, color='k')
    #plt.plot(range(1, maxn+1), top_n_chances / top_n_chances[-1])
    #plt.ylabel('Improvement over pure chance')
    #plt.axhline(1.0, color='k')
    ax1.axvline(num_inliers, color='r')
    ax1.axvline(n, color='g')
    ax1.set_xlim(0, min(maxn+1, int(5*n), len(label_dict)))
    ax1.set_ylim(0, 1.05)
    print(num_inliers)

    scores = [score for fname, score, *_ in matches]
    ax22 = ax2.twinx()
    xdata = np.arange(len(scores))
    ax2.plot(xdata, scores, '.')
    num_features = []
    for i, (fname, score, *_) in enumerate(matches):
        if 'H' in label_dict[fname + '.jpg']:
            ax2.axvline(i, color='r')
        with h5py.File(os.path.join(DATASET_DIR, fname + '.sift.h5'), 'r') as f:
            num_features.append(f['descriptors'].shape[0])
    ax2.set_xlim(xmin=-1, xmax=min(maxn, len(scores)))
    ax22.plot(xdata, num_features, color='g')
    ax2.set_ylabel('Distance')
    ax22.set_ylabel('#keypoints')

In [None]:
def min_match(scdb, image, roi, sift_file, cname_file, plot=True):
    kwargs = {'method': 'default', 'distance': 'cos', 'use_stop_list': True}
    m_sift = dict(scdb.sift_db.query_image(image, roi, sift_file=sift_file, **kwargs))
    m_cname = dict(scdb.cname_db.query_image(image, roi, cname_file=cname_file, **kwargs))
    
    matches = [(key, min(m_sift[key], m_cname[key])) for key in m_sift]
    
    if plot:
        plt.figure()
        sift_scores = sorted(m_sift.values())
        cname_scores = sorted(m_cname.values())
        min_scores = sorted([s for key, s in matches])
        plt.plot(sift_scores, label='SIFT')
        plt.plot(cname_scores, label='cname')
        plt.plot(min_scores, label='min')
        plt.legend()
    
    return sorted(matches, key=lambda x: x[1])

In [None]:
kwargs = {'distance': 'cos', 'method': 'default', 'use_stop_list': True}
#matches = db.query_image(test_image, roi, sift_file=test_sift_file, cname_file=test_cname_file, **kwargs)
#matches = db.sift_db.query_image(test_image, roi, sift_file=test_sift_file, **kwargs)
#matches = db.cname_db.query_image(test_image, roi, cname_file=test_cname_file, **kwargs)

matches = min_match(db, test_image, roi, test_sift_file, test_cname_file)

plot_result(matches, 10)

In [None]:
matches_sift = db.sift_db.query_image(test_image, roi, sift_file=test_sift_file, **kwargs)
matches_cname = db.cname_db.query_image(test_image, roi, cname_file=test_cname_file, **kwargs)

import collections
mval = collections.defaultdict(list)
for pos, (key, score) in enumerate(matches_sift):
    mval[key].append((score, pos))
for pos, (key, score) in enumerate(matches_cname):
    mval[key].append((score, pos))

image_keys = []
image_score_sift = []
image_pos_sift = []
image_score_cname = []
image_pos_cname = []
key_order = [key for key,_ in matches_sift]
for key in key_order:
    (ss, si), (cs, ci) = mval[key]
    image_keys.append(key)
    image_score_sift.append(ss)
    image_pos_sift.append(si)
    image_score_cname.append(cs)
    image_pos_cname.append(ci)
    


In [None]:
plt.figure()
plt.plot(sorted(image_score_sift), label='SIFT')
plt.plot(sorted(image_score_cname), label='cname')
plt.plot(sorted(np.minimum(image_score_sift, image_score_cname)))
plt.legend()

In [None]:
xdata = np.arange(len(image_keys))
plt.figure()
plt.plot(xdata, image_score_sift, label='SIFT')
plt.plot(xdata, image_score_cname, label='cname')
plt.legend()

plt.figure()
plt.plot(xdata, image_pos_sift, label='SIFT')
plt.plot(xdata, image_pos_cname, label='cname')
plt.legend()

In [None]:
import itertools, time, random

def spatial_score_affine(query_keypoints, query_descriptors, query_roi, test_keypoints, test_descriptors, label_func):
    # Cut out ROI part only
    valid = [i for i, kp in enumerate(query_keypoints) if inside_roi(kp, query_roi)]
    query_keypoints = [query_keypoints[i] for i in valid]
    query_descriptors = query_descriptors[valid]
    
    query_labels = label_func(query_descriptors)
    test_labels = label_func(test_descriptors)
    
    stop_list = db.stop_list
    
    if stop_list is not None:
        query_valid = [i for i, l in enumerate(query_labels) if l not in stop_list]
        test_valid = [i for i, l in enumerate(test_labels) if l not in stop_list]
        
        print('Stop list; Query {:d} -> {:d}, Test {:d} -> {:d}'.format(len(query_labels), len(query_valid),
                                                                      len(test_labels), len(test_valid)))
        
        query_keypoints = [kp for i, kp in enumerate(query_keypoints) if i in query_valid]
        test_keypoints = [kp for i, kp in enumerate(test_keypoints) if i in test_valid]
        query_labels = query_labels[query_valid]
        test_labels = test_labels[test_valid]
    
    query_pts = np.array([kp.pt for kp in query_keypoints], dtype='float32').reshape(-1, 1, 2)
    test_pts = np.array([kp.pt for kp in test_keypoints], dtype='float32').reshape(-1, 1, 2)
    
    putative = [(q, t) for (q, lq), (t, lt) in itertools.product(enumerate(query_labels), enumerate(test_labels)) 
                        if lq == lt]
    
    print('{:d} of putative matches with same labels, down from {:d} max'.format(
            len(putative), len(query_labels) * len(test_labels)))
    
    # RANSAC affine
    ransac_iterations = 50000
    ransac_threshold = 5.0
    t0 = time.time()
    
    best_num_inliers = 0
    for _ in range(ransac_iterations):
        random.shuffle(putative)
        
        model_query_idx = [q for q, t in putative[:3]]
        model_test_idx = [t for q, t in putative[:3]]
        
        model_query_pts = query_pts[model_query_idx]
        model_test_pts = test_pts[model_test_idx]
        
        A = cv2.getAffineTransform(model_query_pts, model_test_pts)
        
        test_query_idx = [q for q, t in putative[3:]]
        test_test_idx = [t for q, t in putative[3:]]
        
        test_query_pts = query_pts[test_query_idx]
        test_test_pts = test_pts[test_test_idx]
        
        test_query_pts_trfm = cv2.transform(test_query_pts, A)
        distances = np.linalg.norm(test_query_pts_trfm - test_test_pts, axis=-1).ravel()
        
        num_inliers = np.count_nonzero(distances < ransac_threshold)
        if num_inliers > best_num_inliers:
            best_num_inliers = num_inliers
    
    elapsed = time.time() - t0
    #print('RANSAC took {:.2f} seconds ({:.3g} s/iteration)'.format(elapsed, elapsed / ransac_iterations))
    #print('{} inliers'.format(best_num_inliers))
    return best_num_inliers
    

def spatial_score_voting(query_keypoints, query_descriptors, query_roi, test_keypoints, test_descriptors, label_func):
    # Cut out ROI part only
    valid = [i for i, kp in enumerate(query_keypoints) if inside_roi(kp, query_roi)]
    query_keypoints = [query_keypoints[i] for i in valid]
    query_descriptors = query_descriptors[valid]
    
    query_labels = label_func(query_descriptors)
    test_labels = label_func(test_descriptors)
    
    stop_list = db.stop_list
    
    if stop_list is not None:
        query_valid = [i for i, l in enumerate(query_labels) if l not in stop_list]
        test_valid = [i for i, l in enumerate(test_labels) if l not in stop_list]
        
        print('Stop list; Query {:d} -> {:d}, Test {:d} -> {:d}'.format(len(query_labels), len(query_valid),
                                                                      len(test_labels), len(test_valid)))
        
        query_keypoints = [kp for i, kp in enumerate(query_keypoints) if i in query_valid]
        test_keypoints = [kp for i, kp in enumerate(test_keypoints) if i in test_valid]
        query_labels = query_labels[query_valid]
        test_labels = test_labels[test_valid]
    
    query_pts = np.array([kp.pt for kp in query_keypoints], dtype='float32')
    test_pts = np.array([kp.pt for kp in test_keypoints], dtype='float32')
    
    query_kdtree = scipy.spatial.cKDTree(query_pts, leafsize=32)
    test_kdtree = scipy.spatial.cKDTree(test_pts, leafsize=32)
    
    num_close = 10
    votes = 0
    for kpq, lq in zip(query_keypoints, query_labels):
        _, closeq = query_kdtree.query(kpq.pt, num_close)
        cq = Counter(query_labels[closeq])
        kp_vote = 0
        for kpt, lt in zip(test_keypoints, test_labels):            
            if lt == lq:
                _, closet = test_kdtree.query(kpt.pt, num_close)
                ct = Counter(test_labels[closet])
                v = sum((cq & ct).values()) # number of valid 1-to-1 assignments
                kp_vote = max(v, kp_vote)
        votes += kp_vote
        
    return votes

def label_func_annoy(descriptors, index):
    labels = np.empty(len(descriptors), dtype='uint')
    for i, x in enumerate(descriptors):
        l, *_ = index.get_nns_by_vector(x, 1)
        labels[i] = l
    return labels

from vsim_common import descriptors_to_bow_vector

query_des, query_kps = load_SIFT_file(test_sift_file)
spatial_check_first_n = 30

spatial_check = []
for fname, score in matches[:spatial_check_first_n]:
    db_image_path = os.path.join(DATASET_DIR, fname + '.jpg')
    db_sift_file = sift_file_for_image(db_image_path)
    test_des, test_kps = load_SIFT_file(db_sift_file)
    votes = spatial_score_voting(query_kps, query_des, roi, test_kps, test_des, lambda x, index=db.index: label_func_annoy(x, index))
    #votes = spatial_score_affine(query_kps, query_des, roi, test_kps, test_des, lambda x, index=db.index: label_func_annoy(x, index))
    spatial_check.append((fname, score, votes))
    print(fname, score, votes, label_dict[fname + '.jpg'])

In [None]:
spatial_check.sort(key=lambda x: x[2], reverse=True)
print('Spatially verified\n---------------------------------------')
plot_result(spatial_check, 10)
print('Original matches\n---------------------------------------')
plot_result(matches, 10, maxn=len(spatial_check))

In [None]:
tmp = np.random.randint(0, 10, size=(3,3))
print(tmp)
s = set()
for x in tmp.ravel():
    s.add(x)
s

In [None]:
s = {*'abcd'}
s.difference_update({*'xybz'})
s

In [None]:
#des_source = test_sift_file
des_source = test_cname_file
source_db = db.cname_db
qdes, qkps = load_SIFT_file(des_source)
print('Query image has {} keypoints'.format(len(qkps)), end=' ')
qkps, qdes = filter_roi(qkps, qdes, roi)
print('filtered down to {}'.format(len(qkps)))
try:

    qlabelset = set(label_data_annoy(source_db.index, qdes))
    print('Used ANN')
except AttributeError:
    qlabelset = set(label_data(source_db.vocabulary, qdes))
    print('Used exact NN')
print('Unique labels:', len(qlabelset))
qlabelset.difference_update(source_db.stop_list)
print('After stoplist:', len(qlabelset))

In [None]:
rows, cols = 5, 2
fig, axes = plt.subplots(rows, cols, figsize=(12, 5*rows))

stop_list = set(source_db.stop_list)
for ax, (fname, score) in zip(axes.flatten(), matches):
    image_path = os.path.join(DATASET_DIR, fname + '.jpg')
    image = plt.imread(image_path)
    
    #sift_file = sift_file_for_image(image_path)
    sift_file = os.path.splitext(image_path)[0] + '.cname.h5'
    
    if False:
        des, kps = load_SIFT_file(sift_file)
        try:
            labels = label_data_annoy(source_db.index, des)
            print('Using ANN')
        except AttributeError:
            labels = label_data(source_db.vocabulary, des)
            print('Using Exact NN')
        #valid = {i for i, (kp, l) in enumerate(zip(kps, labels)) if l in qlabelset}
        valid_kps = [kp for i, (kp, l) in enumerate(zip(kps, labels)) if l in qlabelset]
        invalid_kps = [kp for i, (kp, l) in enumerate(zip(kps, labels)) if not l in qlabelset and l not in stop_list]
        print(fname, len(kps), '->', len(valid_kps), ' + ', len(invalid_kps), ' = ', len(valid_kps) + len(invalid_kps))
        grey_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        #kp_image = cv2.drawKeypoints(grey_image, kps, None, (255, 0, 0), cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
        #ax.imshow(kp_image)

        ax.imshow(grey_image, interpolation='none', cmap=plt.cm.Greys_r)

        for kps, color, alpha in [(invalid_kps, 'y', 0.75), (valid_kps, 'r', 1.0)]:
            pos = np.vstack([kp.pt for kp in kps])
            ax.scatter(pos[:, 0], pos[:,1], color=color, marker='x', alpha=alpha)
    else:
        ax.imshow(image, interpolation='none')
    with h5py.File(sift_file, 'r') as f:
        num_kps = f['descriptors'].shape[0]
    ax.set_title("{}{} ({:.3f}, {:d})".format('[HIT] ' if 'PE' in label_dict[fname+'.jpg'] else '', fname, score, num_kps))
fig.tight_layout()

In [None]:
127e6*((6*8)+64) / 1024**3

In [None]:
#test_filename = '/home/hannes/Datasets/narrative2/20161129_094450_000.jpg'
#test_filename = '/home/hannes/Datasets/narrative2/20161129_094114_000.jpg'
#test_filename = '/home/hannes/Datasets/narrative2/20161129_095828_000.jpg'
test_filename = '/home/hannes/Datasets/narrative2/20161129_091416_000.jpg'
test_sift_file = os.path.splitext(test_filename)[0] + '.sift.h5'
test_cname_file = os.path.splitext(test_filename)[0] + '.cname.h5'
test_image = cv2.imread(test_filename)
fig, ax = plt.subplots()
ax.imshow(cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB))
rois = {'20161129_094450_000.jpg': [1515, 1200, 100, 250], # x, y, w, h
        '20161129_094114_000.jpg': [1130, 860, 410, 1020],
        '20161129_095828_000.jpg': [2197, 10, 185, 305],
        '20161129_091416_000.jpg': [1540, 860, 850, 1600]
       }
roi = rois.get(os.path.split(test_filename)[-1], None)

if roi:
    rect = matplotlib.patches.Rectangle(roi[:2], roi[2], roi[3], facecolor='none', edgecolor='r')
    ax.add_patch(rect)


## Match features in query

In [None]:
source_file = test_cname_file
des, kps = load_SIFT_file(source_file)
query_kps, query_des = filter_roi(kps, des, roi)
query_labels = label_data(db.cname_db.vocabulary, query_des)
#test_image_gray = cv2.cvtColor(test_image, cv2.COLOR_BGR2GRAY)

# Cut out only ROI patch
def move_to_roi(kp, roi):
    x, y, w, h = roi
    px, py = kp.pt
    new_kp = cv2.KeyPoint(px - x, py - y, kp.size)
    return new_kp
    
x, y, w, h = roi
test_image_roi = test_image[y:y+h, x:x+w]
test_image_roi = cv2.cvtColor(test_image_roi, cv2.COLOR_BGR2RGB)
query_kps = [move_to_roi(kp, roi) for kp in query_kps]

In [None]:
db_key = '20161129_094802_000'
db_image_path = os.path.join(DATASET_DIR, db_key + '.jpg')
db_image = cv2.imread(db_image_path)
db_image = cv2.cvtColor(db_image, cv2.COLOR_BGR2RGB)
db_source_file = os.path.join(DATASET_DIR, db_key + '.cname.h5')
db_des, db_kps = load_SIFT_file(db_source_file)
db_labels = label_data(db.cname_db.vocabulary, db_des)

In [None]:
num_input = 2
max_match = 10
selected = np.random.choice(len(query_kps), num_input)

matches = []
for qi in selected:
    ql = query_labels[qi]
    num_found = 0
    for di, dl in enumerate(db_labels):
        if dl == ql:
            m = cv2.DMatch(qi, di, 100)
            matches.append(m)
            num_found += 1
            if num_found >= max_match:
                break

match_image = cv2.drawMatches(test_image_roi, query_kps, db_image, db_kps, matches, np.array([]), flags=cv2.DRAW_MATCHES_FLAGS_NOT_DRAW_SINGLE_POINTS)
print(len(matches), 'matches found')

In [None]:
plt.figure(figsize=(12, 8))
plt.imshow(match_image)

In [None]:
np.linalg.norm(query_des[:5], axis=1)

In [None]:
np.sum(query_des[:5], axis=1)

In [None]:
plt.figure()
plt.plot(sorted(db.cname_db._log_idf))