In [1]:
import numpy as np
import jax.numpy as jnp
import jax.lax
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

from utils import patchify

In [2]:
img = jnp.array(Image.open('example_imgs/teletubbies.webp'), dtype=jnp.bfloat16)
h, w, _ = img.shape

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [None]:
eps = 0.5
# dists is a very large matrix so it has to be computed in chunks
def compute_weights(chunk, points):
    dists = jnp.linalg.norm(jnp.expand_dims(chunk, axis=1) - jnp.expand_dims(points, axis=0), axis=2) # n x p
    # set a small epsilon to avoid division by zero
    dists = dists.at[dists == 0].set(eps)
    # weight is inversely proportional to distance
    weights = jnp.reciprocal(dists)
    # consider only the `k` closest anchor points for each pixel
    return jax.lax.top_k(weights, 5) # n x k

In [None]:
k = 4 # how many nearby points to use for the gradients
num_iters = 12
indices = jnp.stack(jnp.meshgrid(np.arange(h), jnp.arange(w), indexing='ij'), axis=-1).reshape(-1, 2)
# jnp arrays can't be empty so we initialize with the center of the image
# indices of points to take from the image to compute gradients
points = jnp.array([[h // 2, w // 2]])
# the colors of those points
point_colors = img[jnp.unstack(points, axis=1)]
grad = jnp.zeros_like(img) # the gradient reconstruction
# this helps split up the distance computation into manageable chunks
split_limit = int(1e6)
for iter in tqdm(range(num_iters)):
    # analyze patches at different levels of granularity to determine
    # whether more points are needed to improve the patch's fidelity
    patch_size = int(30 * 0.75 ** iter)
    if patch_size < 1:
        break
    patches = patchify(grad - img, patch_size)
    means = jnp.abs(jnp.mean(patches, axis=(1, 2)))
    stds = jnp.linalg.norm(jnp.std(patches, axis=1), axis=-1)

    # determine new points and map to grid coordinates
    # alternate strategy of only selecting the patches that are the furthest off
    # _, top_mean_indices = jax.lax.top_k(means, min(means.shape[0], 900))
    # add_new_point = jnp.zeros_like(means, dtype=jnp.bool).at[top_mean_indices].set(True)
    add_new_point = means > 5
    # patches with (relatively) low variation are good candidates because they're simpler
    if patch_size > 6:
        add_new_point = jnp.logical_and(add_new_point, stds < patch_size * 3)
    # unflatten and map from bools to actual indices
    add_new_point = jnp.stack(jnp.split(add_new_point, w // patch_size), axis=-1)
    grid = jnp.meshgrid(np.arange(h // patch_size), jnp.arange(w // patch_size), indexing='ij')
    grid = jnp.stack(grid, axis=-1)

    # resolve to full coordinates by adding some noise
    new_points = grid[add_new_point] * patch_size
    # variation within the current grid cell
    coordinate_noise = np.random.randint(0, patch_size, 2 * new_points.shape[0])
    coordinate_noise = jnp.stack(jnp.split(coordinate_noise, 2), axis = -1)
    new_points = new_points + coordinate_noise

    # add new points
    points = jnp.append(points, new_points, axis=0)
    point_colors = jnp.append(point_colors, img[jnp.unstack(new_points, axis=1)], axis=0)

    # reconstruct the image
    # split up the computation to avoid memory overflow
    num_splits = (indices.shape[0] * points.shape[0]) // split_limit
    chunks = jnp.array_split(indices, num_splits)
    weight_chunks = [compute_weights(chunk, points) for chunk in chunks]
    # recompute gradients
    weight_chunks, top_k_chunks = tuple(zip(*weight_chunks))
    # consider only the `k` closest anchor points
    weights, top_k_indices = jnp.concat(weight_chunks), jnp.concat(top_k_chunks)
    # weighted average of anchor colors where the weight is inversely proportional to distance
    weights = weights / jnp.sum(weights, axis = 1, keepdims=True)
    grad = jnp.sum(point_colors[top_k_indices] * jnp.expand_dims(weights, axis=-1), axis=1) # n x k x 3 -> n x 3
    grad = grad.reshape(h, w, 3) # n x 3 -> h x w x 3
    plt.imsave(f'grad/{iter}.png', np.array(grad, dtype=np.uint8))