In [None]:
import os
import sys
import glob
import pandas as pd
import numpy as np
import scipy.misc as misc
import matplotlib.pyplot as plt
import cv2
from skimage.transform import rotate
from random import shuffle, randint

In [None]:
CELLPHONE_IMG_PATH = os.path.join(os.getcwd(), 'cellphone_imgs')
DB_PATH = os.path.join(os.getcwd(), 'png_imgs')

## Get Dataset

Set variables

In [None]:
data_dir = 
crop = 
crop_size = 
MAX_ITERATION = 

In [None]:
class seg_dataset_reader:
    path = ""
    class_mappings = ""
    files = []
    images = []
    annotations = []
    test_images = []
    test_annotations = []
    batch_offset = 0
    epochs_completed = 0

    def __init__(self, data_path, max_pages=40, crop=True, crop_size=[1000,1000], test_size=20):
        """
        Initialize a file reader for the classification data
        :param records_list: path to the dataset
        sample record: {'image': f, 'annotation': annotation_file, 'filename': filename}
        """
        print("Initializing Dataset Reader...")
        self.path = data_path
        self.max_pages = max_pages
        self.crop = crop
        self.crop_size = crop_size
        self.test_size = test_size

        images_list = []
        images_glob = os.path.join(self.path, "images_png", '*.' + 'png')
        images_list.extend(glob.glob(images_glob))

        #shuffle image list
        shuffle(images_list)

        if max_pages is None:
            max_pages = len(images_list)
            import sys
            sys.exit(1)

        if max_pages > len(images_list):
            print("Not enough data, only " + str(len(images_list)) + " available")

        if test_size >= max_pages:
            print("Test set too big ("+str(test_size)+"), max_pages is: "+str(max_pages))
            import sys
            sys.exit(1)

        print("Splitting dataset, train: "+str(max_pages-test_size)+" images, test: "+str(test_size)+ " images")
        test_image_list = images_list[0:test_size]
        train_image_list = images_list[test_size:max_pages]
        
#         test_annotation_list = [image_file.replace("/images_png/", "/pix_annotations_png/") for image_file in test_image_list]
#         train_annotation_list = [image_file.replace("/images_png/", "/pix_annotations_png/") for image_file in train_image_list]
            
        self._read_images(test_image_list,train_image_list)

    def _read_images(self,test_image_list,train_image_list, annotations=True):

        dat_train = [self._transform(filename) for filename in train_image_list]
        for dat in dat_train:
            self.images.append(dat[0])
            self.annotations.append(dat[1])
        self.images = np.array(self.images)
        self.images = np.expand_dims(self.images, -1)

        self.annotations = np.array(self.annotations)
        self.annotations = np.expand_dims(self.annotations, -1)

        print("Training set done")
        dat_test = [self._transform(filename) for filename in test_image_list]
        for dat in dat_test:
            self.test_images.append(dat[0])
            self.test_annotations.append(dat[1])
        self.test_images = np.array(self.test_images)
        self.test_images = np.expand_dims(self.test_images, -1)

        self.test_annotations = np.array(self.test_annotations)
        self.test_annotations = np.expand_dims(self.test_annotations, -1)
        print("Test set done")


    def _transform(self, filename):
        image = misc.imread(filename)
        annotation = misc.imread(filename.replace("/images_png/", "/pix_annotations_png/")) #these are images/annotations
        print("im working!" + str(randint(0,10)))
        if not image.shape[0:2] == annotation.shape[0:2]:
            print("input and annotation have different sizes!")
            import sys
            import pdb
            pdb.set_trace()
            sys.exit(1)

        if image.shape[-1] != 1:
            # take mean over color channels, image BW anyways --> fix in dataset creation
            image = np.mean(image, -1)

        if self.crop:
            coord_0 = randint(0, (image.shape[0] - self.crop_size[0]))
            coord_1 = randint(0, (image.shape[1] - self.crop_size[1]))

            image = image[coord_0:(coord_0+self.crop_size[0]),coord_1:(coord_1+self.crop_size[1])]
            annotation = annotation[coord_0:(coord_0 + self.crop_size[0]), coord_1:(coord_1 + self.crop_size[1])]

        return [image, annotation]
    

    def get_records(self):
        return self.images, self.annotations

    def reset_batch_offset(self, offset=0):
        self.batch_offset = offset

    def get_test_records(self):
        return self.test_images, self.test_annotations

    def next_batch(self, batch_size):
        start = self.batch_offset
        self.batch_offset += batch_size
        if self.batch_offset > self.images.shape[0]:
            # Finished epoch
            self.epochs_completed += 1
            print("****************** Epochs completed: " + str(self.epochs_completed) + "******************")
            # Shuffle the data
            perm = np.arange(self.images.shape[0])
            np.random.shuffle(perm)
            self.images = self.images[perm]
            self.annotations = self.annotations[perm]
            # Start next epoch
            start = 0
            self.batch_offset = batch_size

        end = self.batch_offset
        return self.images[start:end], self.annotations[start:end]

    def get_random_batch(self, batch_size):
        indexes = np.random.randint(0, self.images.shape[0], size=[batch_size]).tolist()
        return self.images[indexes], self.annotations[indexes]

In [None]:
def main(unused_argv):
    data_reader = seg_dataset_reader(data_dir, crop=crop, crop_size=crop_size)
    
    #Placeholders for FeedDict
    keep_probability_conv = tf.placeholder(tf.float32, name="keep_probability_conv")
    image = tf.placeholder(tf.float32, shape=[None, FLAGS.crop_size[0], FLAGS.crop_size[0], 1], name="image")
    annotation = tf.placeholder(tf.int32, shape=[None, FLAGS.crop_size[0], FLAGS.crop_size[0], 1], name="labels")

    # Apply FCN or model    
    
    for itr in range(step, MAX_ITERATION):
        train_images, train_annotations= data_reader.next_batch(FLAGS.batch_size)
        feed_dict = {image: train_images, annotation: train_annotations, keep_probability_conv: 0.85}
        sess.run(train_op, feed_dict=feed_dict)

        print(itr)

        if itr % 10 == 0:
            """ get train loss """
            #train_loss = sess.run([loss], feed_dict=feed_dict)
            print("Step: %d, Train_loss: %g" % (itr, train_loss[0]))

        if itr % 500 == 0 and itr != 0:
            """ get valid loss here """
            valid_images, valid_annotations = data_reader.get_test_records()
            #valid_loss = sess.run(loss, feed_dict={image: valid_images, annotation: valid_annotations, keep_probability_conv: 1.0})
            print("%s ---> Validation_loss: %g" % (datetime.datetime.now(), valid_loss))
            saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr)
    a,b = data_reader.get_test_records()
    #valid_loss, output = sess.run([loss, pred_annotation], feed_dict={image: a, annotation: b, keep_probability_conv: 1.0})

## Generate database

In [None]:
png_paths = glob.glob(os.path.join(DB_PATH, '*.png'))

In [None]:
db = {}
for i in range(len(png_paths)):
    img = cv2.imread(png_paths[i], 0)
    kp, des = sift.detectAndCompute(img, None)
    db[png_paths[i]] = (kp, des)
    print(
        "Finish computing SIFT descriptor {:}/{:}".format(
            i + 1, len(png_paths)),
        file=sys.stderr)

## Get query

In [None]:
query_paths = sorted(glob.glob(os.path.join(CELLPHONE_IMG_PATH, '*.jpg')))

In [None]:
img = cv2.imread(query_paths[0], 0)
plt.figure(figsize=(20, 20))
plt.imshow(img)

In [None]:
df = pd.read_csv(os.path.join(os.getcwd(), 'groundtruth.csv'))
groundtruth = {}

In [None]:
for (cellphone_img, sheet_img) in df.values:
    groundtruth[cellphone_img] = sheet_img + '.png'

## Search

In [None]:
%lsmagic

In [None]:
index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
search_params = dict(checks=50)

In [None]:
def search(db_png_paths,
           des_query,
           expected_match,
           threshold_npairs=25,
           verbose=False):
    scoreList = []
    for idx in range(len(db_png_paths)):
        ref_path = db_png_paths[idx]
        #matches = flann.knnMatch(db[ref_path][1], des_query, k=2)

        totalDistance = 0
        counterGood = 0
        # ratio test as per Lowe's paper
        for i, (m, n) in enumerate(matches):
            if m.distance < 0.7 * n.distance:
                counterGood += 1
                totalDistance += m.distance

        scoreList.append({
            'path': ref_path,
            'distance': totalDistance,
            'n_pairs': counterGood,
        })

        if verbose:
            print(
                "Finish searching {:}/{:} distance = {:} (# good pairs = {:})".
                format(idx + 1, len(png_paths), totalDistance, counterGood),
                file=sys.stderr)
    filteredScore = [
        score for score in scoreList if score['n_pairs'] > threshold_npairs
    ]

    sortedScore = sorted(filteredScore, key=lambda x: x['distance'])
    for score in scoreList:
        if score['n_pairs'] <= threshold_npairs:
            sortedScore.append(score)

    rank = 1
    for score in sortedScore:
        if (os.path.split(score['path'])[1] == expected_match):
            return rank
        rank += 1

    print("Expected match = {:}".format(expected_match), file=sys.stderr)
    print(sortedScore, file=sys.stderr)
    raise ValueError("not found")

In [None]:
MRR = 0
top1acc = 0
query_num = 0
for query_path in query_paths:
    img_query = cv2.imread(query_path, 0)
    #kp_query, des_query = sift.detectAndCompute(img, None)
    expected_match = groundtruth[os.path.split(query_path)[1]]
    rank = search(png_paths, des_query, expected_match)

    MRR += (1 / len(query_paths)) * (1 / rank)
    top1acc += (1 / len(query_paths)) * (rank == 1)

    query_num += 1
    print("Query {:} : rank = {:}".format(query_num, rank), file=sys.stderr)

In [None]:
print(("MRR = {:}".format(MRR)))
print(("Top-1 accuracy = {:}".format(top1acc)))


| Experiment               | MRR           | top-1 accuracy  |
| --------------           | ------------- | --------------- |
| SIFT (0 threshold)       | 0.01          | 0               |
| SIFT (50 threshold)      | 0.05          | 0.025           |
| SIFT (100 threshold)     | 0.03          | 0               |