In [1]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
from datasets import load_dataset

from utils import *
from patch_kmeans import get_latest_checkpoint

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def visualize_kmeans(imgs, patch_size, k, save_dir: str = "outputs"):
    latest = get_latest_checkpoint(patch_size, k)
    if latest == -1:
        print(f'No saved means for the configuration: patch_size = {patch_size}, k = {k}')
        print('Running with random init...')
        means = jnp.array(np.random.rand(k, patch_size * patch_size), dtype=jnp.float32)
        # we want our means to be unit vectors
        means = means / jnp.linalg.norm(means, axis=1, keepdims=True)
    else:
        print(f'Using checkpoint {KMEANS_DIR}/{patch_size}/{k}/{latest}.npy...')
        means = jnp.load(f'{KMEANS_DIR}/{patch_size}/{k}/{latest}.npy')


    size_squared = patch_size * patch_size # s
    patches = jnp.concat([patchify(jnp.array(img), patch_size) for img in imgs]) # p x s x 3
    # min color over patch
    min_val = jnp.min(patches, axis=1, keepdims=True) # p x 1 x 3
    patches = (patches - min_val).astype(jnp.float32) # p x s x 3
    # treat each color channel as its own vector
    patches = jnp.moveaxis(patches, 2, 1).reshape(-1, size_squared) # 3p x s
    norms = jnp.linalg.norm(patches, axis=1, keepdims=True) # 3p x 1
    vectors = means[jnp.argmax(patches @ means.T, axis=1)] # 3p x s
    # scale vectors
    vectors = norms * vectors
    # split vectors back into color channels
    vectors = jnp.moveaxis(vectors.reshape(-1, 3, size_squared), 2, 1) # p x s x 3
    reconstructed = min_val + vectors # p x 1 x s

    start_idx = 0
    for i, img in enumerate(imgs):
        arr = np.array(img)
        h, w, _ = arr.shape
        ph, pw = h // patch_size, w // patch_size
        # grab relevant patches
        arr_hat = reconstructed[start_idx : start_idx + ph * pw]
        # unflatten each patch
        arr_hat = arr_hat.reshape(-1, patch_size, patch_size, 3)
        # restructure the patches from concatenated columns into grid
        # first split up into columns and stack them along the x-axis
        arr_hat = jnp.concat(jnp.split(arr_hat, pw), axis=2)
        # then drop the patch axis and fuse the rows along the y-axis
        arr_hat = jnp.concat(jnp.unstack(arr_hat))

        # shift pointer
        start_idx += ph * pw
        # save images
        plt.imsave(f'{save_dir}/kmeans_{i}_real.png', arr[:ph * patch_size, :pw * patch_size])
        plt.imsave(f'{save_dir}/kmeans_{i}_recon.png', np.array(arr_hat, dtype=np.uint8))

In [3]:
ds = load_dataset('nlphuji/flickr30k', split='test').select_columns('image').shuffle()

In [4]:
visualize_kmeans(ds[:20]['image'], 20, 800, save_dir = "temp")

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


No saved means for the configuration: patch_size = 20, k = 800
Running with random init...
