In [1]:
import cv2, math, random, time, glob
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from scipy import misc
from sklearn import cluster
from sklearn.metrics import silhouette_samples, silhouette_score
from skimage.transform import pyramid_gaussian
from annoy import AnnoyIndex

In [2]:
color_array = np.array([
    [140,70,50],[40,40,40],[222,222,222],[180,170,160],[160,20,90],
    [170,10,40],[30,50,90],[30,70,170],[150,70,140],[70,120,90]])

def get_color(image_RGB):

    #Prepare the image 3D array
    Z = image_RGB.reshape((-1,3))
    index = (Z >= 250 ).sum(1)
    Z =  np.float32(Z[index != 3, :])

    #Perform k means assuming there are 4 colors
    K = 5
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
    cost, labels_array, centroid_array = cv2.kmeans(Z,K,None,criteria,10,cv2.KMEANS_RANDOM_CENTERS)

    #Determine image color
    labels_count = np.array([ labels_array[labels_array == i ].shape[0] for i in range(K)])
    query_color = centroid_array[np.argsort(-labels_count)[0]] #Of the 4 centroids, get the most common
    
    #Initialize cost array
    cost_array = np.zeros(color_array.shape[0])

    #Determine cost for each color bin
    for i in range(color_array.shape[0]):
        color = color_array[i]
        cost_array[i] = np.linalg.norm(query_color-color)
    
    #Find the NN color
    index = np.argsort(cost_array)[0]
    color_result = color_array[index]

    print (centroid_array)
    print (labels_count)

    #return color_result
    return color_result

##########################################################################################
#Define keypoint functions

#Initiliaze
sift = cv2.xfeatures2d.SIFT_create()

def get_keypoints_and_descriptors(image):
    
    #Create array of RGB 3D array
    image_RGB = cv2.cvtColor(cv2.imread(image),cv2.COLOR_BGR2RGB)

    #Initialize the arrays
    keypoints_array = []
    descriptors_array = []

    #Loop through each color channel
    for i in range(3):
        #Get the keypoints and descriptor and store the variables
        k, d = sift.detectAndCompute(image_RGB[:,:,i],None)
        keypoints_array.extend(k)	
        descriptors_array.extend(d)

    #Create np arrays
    keypoints_array = np.array(keypoints_array)
    descriptors_array = np.array(descriptors_array)	

    #Add the color as a descriptor of the point
    image_color = 10*get_color(image_RGB) #Multiply by 10 to increase the weight of colors

    #Join the color array to the descriptor array
    color_array = np.tile(image_color,(descriptors_array.shape[0],1)) #Create the N x 3 array
    descriptors_array = np.hstack((descriptors_array,color_array)) #Join the N x 128 array with the N x 3 array

    return keypoints_array, descriptors_array


def get_bow_descriptors(query_image):

    #Get the SIFT keypoints and descriptors
    keypoints, descriptors = get_keypoints_and_descriptors(query_image)

    #Perform k means assuming there are 4 clusters
    Z = np.float32(descriptors)
    
    #Create k means model
    clusters_array = range(100,101) #DEFINE THE RANGE OF CLUSTERS 
    kmeans_array = [cluster.KMeans(n_clusters = i, init="k-means++").fit(Z) for i in clusters_array]
    number_of_models = len(kmeans_array)
    
    #Determine the silhouette scores
    silhouette_score_array = []
    for model in kmeans_array:
        silhouette_score_array.append(silhouette_score(Z, model.fit_predict(Z)))

    #Calculate Silhouette score
    silhouette_score_array = np.array(silhouette_score_array)
    model_index = np.argsort(-silhouette_score_array)[0]
    K = clusters_array[np.argsort(-silhouette_score_array)[0]]

    #Get the bow codewords
    codewords = kmeans_array[model_index].cluster_centers_
    
    return codewords

##########################################################################################
#Plot images

def plot_matches(im1, im2, p1, p2, matches):

    fig = plt.figure()

    #Define a seperation
    seperation = 20

    #new_im = np.zeros((max(im1.shape[0], im2.shape[0]), im1.shape[1]+im2.shape[1]),dtype=np.uint8)
    new_im = np.zeros((im1.shape[0]+seperation+im2.shape[0],(max(im1.shape[1], im2.shape[1])),im1.shape[2]),dtype=np.uint8)
    
    #new_im[:im1.shape[0], :im1.shape[1]] = im1
    #new_im[:im2.shape[0], im1.shape[1]:] = im2
    new_im[:im1.shape[0], :im1.shape[1],:] = im1
    new_im[seperation+im1.shape[0]:, :im2.shape[1],:] = im2
    
    #plt.imshow(new_im,cmap='gray')
    plt.imshow(new_im)

    plt.autoscale(False)
    for ind1, ind2, d in matches:

        #Get image indicies
        ind1, ind2 = int(ind1), int(ind2)
    
        #Define the points on the image
        x1 = p1[ind1].pt[0]
        y1 = p1[ind1].pt[1]
        x2 = p2[ind2].pt[0]
        y2 = im1.shape[0] + seperation + p2[ind2].pt[1]

        plt.plot([x1, x2],[y1, y2], '-x')

    print ('total runtime: ' + str(round(time.time() - t,1)) + ' seconds')
    plt.show()

##########################################################################################
#Functions for the object detection

def generate_image_pyramids(query_image, scale = 2, min_size=30, gaussian=True):

    #Load image
    image = cv2.cvtColor(cv2.imread(query_image),cv2.COLOR_BGR2RGB)

    #Initialize array
    image_pyramid = []

    #Create the image pyramid
    while True:

        #Store image
        image_pyramid.append(image)

        #Resize image
        new_w, new_h = (image.shape[1] * (1.0/scale), image.shape[0] * (1.0/scale))
        image = cv2.resize(image, (int(new_w), int(new_h)), interpolation=cv2.INTER_AREA)

        if image.shape[0] < min_size or image.shape[1] < min_size:
            break

    #Return the top 5 images
    return image_pyramid

In [7]:
def create_image_database(bow=False):
    #Define database array
    directory_list = ['square-mens-shortsleeve-tshirts']
    database_array = np.array(glob.glob(directory_list[0] + "/*.jpg"))
    database_size = database_array.shape[0]

    #Initialize the database
    ann_database = AnnoyIndex(256) 
    index_array = []

    #Store each image in the database
    for i in range(database_size):

        #Track progress
        print (str(round(100*(1.0*i/database_size),1)) + "%")

        #Define current image
        image = database_array[i]

        #Extract image details
        keypoints, descriptors = get_keypoints_and_descriptors(image)

        #Add the descriptor for each point into the database
        for j in range(descriptors.shape[0]):
            des = descriptors[j]
            index = i*10000+j #Assumes you can't have more than 10,000 keypoints per image
            ann_database.add_item(index, des)

    #Save the database
    print ('Saving database')
    ann_database.build(10)
    ann_database.save('image_database.ann')
    print ('Database saved')


In [4]:
#Initiliaze
sift = cv2.xfeatures2d.SIFT_create()

In [5]:
##########################################################################################
#Define the ANN for the query image

def get_ann_image_subset(query_image):

    print ("Start: create image subset")

    #Initialize the database
    ann_database = AnnoyIndex(256) 
    ann_database.load('image_database.ann')

    #Define database size
    
    directory_list = ['square-mens-shortsleeve-tshirts']
    database_array = np.array(glob.glob(directory_list[0] + "/*.jpg"))

    database_size = database_array.shape[0]

    #Extract image details
    keypoints, descriptors = get_keypoints_and_descriptors(query_image)

    #Score the images
    score_array = np.zeros(database_size)
    for des in descriptors:
        nearest_images = ann_database.get_nns_by_vector(des, 3,include_distances=False)
        for i in nearest_images:
            j = int(np.floor(1.0*i/10000))
            score_array[j] += 1 

    #Sort the images
    sorted_score_index = np.argsort(score_array)
    print (database_array[sorted_score_index[-5:]])


    #Return the top 5 images
    return np.array(sorted_score_index[-5:])

##########################################################################################

In [9]:
t = time.time()

#Define query input
query_image = 'square-mens-shortsleeve-tshirts/mens-longsleeve-shirts-11.jpg'

############################################
#Run algorithm

create_image_database()
similar_images_set = get_ann_image_subset(query_image)
#print (similar_images_set)
