In [None]:
%matplotlib notebook

In [None]:
import sys
sys.path.append("../")

In [None]:
import numpy as np
from pickle import load, dump
from pathlib import Path
from tqdm import tqdm
import torch
from torchvision.utils import make_grid
import torchvision.transforms as T
from torchvision.transforms.functional import to_pil_image
from PIL import Image
from multiprocessing import Pool
from scripts.image.kmeans import kmeans
import pandas as pd
import matplotlib.pyplot as plt
import pickle

Let's open and get the data from the pickles file, we are using the embedddings **after tsne**. You can obtain them by running the `<root>/scripts/reduce_dims.py` script.

```
python ./scripts/reduce_dims -i ./<path_to_clip_embeddings> -i . -k 2 --method tsne
```

In [None]:
DATA_PATH = "../reduced-tsne-k=2.pk"

In [None]:
with open(DATA_PATH, "rb") as f:
    data = pickle.load(f)

data

Let's also add the category to each element

In [None]:
def add_categories_to_data(data, categories):
    data['categories'] = []
    for image_paths in tqdm(data['image_paths']):
        # e.g.rf100/chess-pieces-mjzgj/train/images/foo.jpg'
        dataset_name = Path(image_paths).parts[1]
        category = categories.loc[dataset_name].category
        data['categories'].append(category)
        
    return data

In [None]:
categories = pd.read_csv("../metadata/categories.csv", index_col=0)
data = add_categories_to_data(data, categories)

In [None]:
np.unique(data['categories'])

sweet, now we have all the data we need. Let's do some clustering

## Clustering
Let's define some transformations

In [None]:
import torchvision.transforms.functional as F

def read_image_and_transform(image_path, size=(224, 224)):
    image_path = "../" + image_path
    img = Image.open(image_path).convert("RGB")
    img = F.resize(img, size)
    img = F.to_tensor(img)
    return img

### Cluster per category

We first want to get the 100 most representative images per category, thus we will first filter per category and run kmeans with 100

In [None]:
def reored_by_left_top(x):
    left, top = x[:,0].min(), x[:,1].max()
    to_compare = torch.tensor([left, top])
    indexes = ((to_compare - x) ** 2).sum(dim=-1).argsort()
    return indexes

In [None]:
plt.rcParams["savefig.bbox"] = 'tight'

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    return fig

In [None]:
def cluter_per_category(category="real world", num_clusters=100):
    filtered_indexes = np.array(data["categories"]) == category
    # let's use the pca ones
    filtered_x = data['x'][filtered_indexes]
    filtered_image_paths = np.array(data['image_paths'])[filtered_indexes]
    # do kmeans
    means, bins = kmeans(torch.from_numpy(filtered_x), num_clusters=num_clusters, num_iters=100)
#     means = reored_by_left_top(means)
    # compute distance between means and all points
    diffs = (means[:,None,:] - filtered_x[None,...])
    diffs = (diffs**2).sum(dim=-1)
    indexes = diffs.argmin(axis=1)
    # create the grid
    image_paths = filtered_image_paths[indexes]
    indexes = reored_by_left_top(filtered_x[indexes])
    image_paths = image_paths[indexes]

    image = show(
        make_grid(
            list(map(read_image_and_transform, image_paths)),
            nrow=25
        )
    )
    return image

In [None]:
dst = Path("../paper/images/grid/")

for category in categories.category.unique():
    num_clusters = 50
    if category == "real world":
        num_clusters = 200
    cluter_per_category(category, num_clusters).savefig(dst / f"{category}.png", dpi=800, bbox_inches='tight')

## Grid Image

In [None]:
def make_cluster_grid(num_clusters, nrow):
    # let's use the pca ones
    x = data['x']
    image_paths = np.array(data['image_paths'])
    # do kmeans
    means, bins = kmeans(torch.from_numpy(x), num_clusters=num_clusters, num_iters=50)
    diffs = (means[:,None,:] - x[None,...])
    diffs = (diffs**2).sum(dim=-1)
    indexes = diffs.argmin(axis=1)
    # create the grid
    image_paths = image_paths[indexes]
    indexes = reored_by_left_top(x[indexes])
    image_paths = image_paths[indexes]
    # create the grid
    image = show(
        make_grid(
            list(map(lambda x: read_image_and_transform(x, size=(128,128)), image_paths)),
            nrow=nrow
        )
    )
    return image

In [None]:
num_clusters = 40 * 60 

make_cluster_grid(num_clusters, nrow=60).savefig(dst / f"rf100-40x60.png", dpi=800, bbox_inches='tight')

In [None]:
make_cluster_grid(num_clusters, nrow=8).savefig(dst / f"rf100-8x8.png", dpi=800, bbox_inches='tight')