# Similar Images Retrieval with Color Clusters
Given an image, this script will retrieve the K most color-similar images out of an image collection.

For the demonstration of the algorithm a Pokemon dataset has been extracted. This dataset contains for each Pokemon it's weight, height, type and sprite. Furthermore, we investigate if there is a correlation between the colors of a Pokemon and it's type.

In [None]:
import requests
import json
import os

sprites_path = 'sprites/'
data_json = 'data.json'
data = {}

def sprite_file(name):
    return sprites_path + name + '.png'


def download_file(url, name):
    r = requests.get(url)
    with open(name, 'wb') as f:
        f.write(r.content)


def download_data():
    if not os.path.exists(sprites_path):
        os.makedirs(sprites_path)
    j = requests.get('https://pokeapi.co/api/v2/pokemon?limit=-1').json()
    for res in j['results']:
        name = res['name']
        url = res['url']
        p = requests.get(url).json()
        height = p['height']
        weight = p['weight']
        types = [t['type']['name'] for t in p['types']]
        sprite = p['sprites']['front_default']
        if sprite:
            fname = sprite_file(name)
            if not os.path.exists(fname):
                download_file(sprite, fname)
            data[name] = [height, weight, types]
        else:
            print('Missing', name)


def save_data(data, fname):
    with open(fname, 'w') as f:
        f.write(json.dumps(data))


def load_data(fname):
    with open(fname, 'r') as f:
        source = f.read()
        data = json.loads(source)
    return data


if os.path.exists(data_json):
    print('Loading data')
    data = load_data(data_json)
else:
    print('Downloading data')
    download_data()
    save_data(data, data_json)
print('Done')

#### A sample visualization of the dataset

In [None]:
import matplotlib.pyplot as plt

def display_sample(rows, cols):
    print('Total size:', len(data))
    _, axs = plt.subplots(rows, cols, figsize=(10, 10))
    axs = axs.flatten()
    # Display part of the pokemon dataset
    for (name, value), ax in zip(data.items(), axs):
        print(name, value)
        im = plt.imread(sprite_file(name))
        ax.axis('off')
        ax.imshow(im)


display_sample(3, 3)

#### Clustering on pixels

For every pokemon image we are retrieving its M most frequent colors by performing a clustering procedure (DBSCAN) on its pixels. Then, M number of clusters will be created each one of them corresponding to a most frequent color. 

We are giving DBSCAN a small radius-value and a big number of components per cluster so that it finds very small, compact clusters representing frequent colors. All the rest points-colors are considered to be noise. We take the core-components returned by DBSCAN for each cluster and we define their mean value as the representatives of those clusters.
From now on, we are using those representatives to "represent" the image.
We create (once) a new .json file named "pokemon4.json" which contains all of the new 'representations' of the images.

#### Removing all the transparent pixels - Transforming RGBA to RGB

In [None]:
import numpy as np

def get_pixels(im):
    im = im.reshape(-1, 4)
    im = im[im[:,3] > 0]
    im = np.delete(im, 3, 1)
#     pixels = (im * 255).astype(np.uint8)
    return im

pixel_dict = {}
for name in data:
    im = plt.imread(sprite_file(name))
    pixel_dict[name] = get_pixels(im)

### Color cluster the image collection
The following block creates a new .json file in which the representatives of the images get stored. 

This is another where the clustering takes place. It is time consuming so it is reasonable to cache is.
If the "extracted_data.json" file does not exist in the project folder this code is executed in order to create it.


In [None]:
from mpl_toolkits.mplot3d import Axes3D
from sklearn.cluster import DBSCAN, KMeans

extracted_json = 'extracted_data.json'
extracted_colors = {}
k = 2

def KMeans_estimator(pixels):
    est = KMeans(n_clusters=k+1)
    est.fit(pixels)
    _, counts = np.unique(est.labels_, return_counts=True)
    idx = np.argsort(counts)
    centers = est.cluster_centers_[idx]
    centers = np.delete(centers, np.argmin(np.sum(centers, axis=1)), axis=0)
    return est, centers

def extract_data(pixels):
    est = KMeans(n_clusters=k+1)
    est.fit(pixels)
    _, counts = np.unique(est.labels_, return_counts=True)
    idx = np.argsort(counts)
    centers = est.cluster_centers_[idx]
    centers = np.delete(centers, np.argmin(np.sum(centers, axis=1)), axis=0)
    return np.flip(centers, 0)

if os.path.exists(extracted_json):
    print('Loading extracted data')
    extracted_colors = load_data(extracted_json)
else:
    print('Extracting data')
    for name, pixels in pixel_dict.items():
        extracted_colors[name] = extract_data(pixels).tolist()
    save_data(extracted_colors, extracted_json)
print('Done')

#### Visualization of Clusterings
Below, we can depict the clustering that is produced from the images of some given pokemons


In [None]:
import math

def display_colors(name):
    pixels = pixel_dict[name]
    a = int(math.sqrt(pixels.size//3))
    colors = pixels[:(a*a)]
    colors = colors.reshape((a,a,3))
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(1, 3, 1)
    im = plt.imread(sprite_file(name))
    ax.axis('off')
    ax.imshow(im)
    ax = fig.add_subplot(1, 3, 2)
    ax.axis('off')
    ax.imshow(colors)
    ax = fig.add_subplot(1, 3, 3)
    colors = np.array(sorted(pixels, key=lambda tup: -tup[0]*255*255 -tup[1]*255 -tup[2]))
    colors = colors[:(a*a)]
    colors = colors.reshape((a,a,3))
    ax.axis('off')
    ax.imshow(colors)

    fig = plt.figure(figsize=(10, 8))
    colors = extracted_colors[name]
    for i, c in enumerate(colors):
        ax = fig.add_subplot(1, k, i+1)
        im = [[c] * 16]*16
        ax.axis('off')
        ax.imshow(im)
    plt.show()

# c=labels.astype(np.float)
for name, pixels in pixel_dict.items():
    display_colors(name)
    colors = np.array(extracted_colors[name])

    fig = plt.figure(figsize=(10, 4))
    ax = fig.add_subplot(121, projection='3d')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_zticklabels([])
    ax.scatter(pixels[:,0], pixels[:,1], pixels[:,2], c=pixels)
    ax = fig.add_subplot(122, projection='3d')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_zticklabels([])
    ax.scatter(colors[:,0], colors[:,1], colors[:,2], c=colors)
    plt.show()
    break

### Selecting a good k

In [None]:
for name, pixels in pixel_dict.items():
    inertia = []
    K = range(1,7)
    for y in K:
        km = KMeans(n_clusters=y)
        km = km.fit(pixels)
        inertia.append(km.inertia_)
    inertia = np.array(inertia)/inertia[0]
    plt.plot(K, inertia)
    break
plt.xlabel('k')
plt.ylabel('Inertia')
plt.title('Elbow Method For Optimal k')
plt.show()

### Retrieve K most similar Images

Now, for a given image specified by its file-name (e.g. "0001.png") we compute all the eucledian distances between that image and the rest of them (based on their new representations), and then print the images corresponding to the K smallest distances.

In [None]:
import itertools
from scipy.spatial import distance

# Computes the minimum distance between two images
def color_distance(a, b):
#     na = [c*(k-i+1) for i, c in enumerate(a)]
#     nb = [c*(k-i+1) for i, c in enumerate(b)]
#     return distance.euclidean(np.sum(na, axis=0), np.sum(nb, axis=0))
# 
#     return distance.euclidean(np.sum(a, axis=0), np.sum(b, axis=0))
    it = itertools.permutations(b)
    return np.min([np.mean([distance.euclidean(x,y) for (x, y) in zip(a, permutation)]) for permutation in it])


def retrieve_similar_images(name, rows, cols):
    distances = []
    colors = extracted_colors[name]
    for key, values in extracted_colors.items():
        if key != name: 
            dist = color_distance(colors, values)
            distances.append([key, dist])
    distances = sorted(distances, key=lambda x: x[1])
    # Retrieve the K most similar images
    k = rows * cols
    similar_K = distances[:k]
    display_colors(name)
    # print(similar_K)

    fig = plt.figure(figsize=(16, 8))
    for i, f in enumerate(similar_K):
        ax = fig.add_subplot(rows, cols, i+1)
        ax.axis('off')
        im = plt.imread(sprite_file(f[0]))
        ax.imshow(im)


#### Some Testing...

In [None]:
# Bulbasaur
retrieve_similar_images('bulbasaur', 2, 4)

In [None]:
# Charmander
retrieve_similar_images('charmander', 3, 4)

In [None]:
# Squirtle
retrieve_similar_images('squirtle', 3, 4)

In [None]:
# Pikachu
retrieve_similar_images('pikachu', 3, 4)