## Setup

### Install dependencies

In [None]:
!git clone https://github.com/drdaxxy/dm_generations_decode
%cd dm_generations_decode
!pip install -r requirements.txt

### Prepare VQGAN decoder

In [None]:
import os
# do not reserve 90% of GPU memory for JAX
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

import jax
import jax.numpy as jnp
import numpy as np
from vqgan_jax.modeling_flax_vqgan import VQModel

VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
device = jax.devices()[0]

vqgan, vqgan_params = VQModel.from_pretrained(
    VQGAN_REPO, revision=VQGAN_COMMIT_ID, dtype=jnp.float32, _do_init=False
)
del vqgan_params["encoder"]
vqgan_params = jax.device_put(vqgan_params, device)

@jax.jit
def vqgan_decode(indices: np.ndarray, vqgan_params):
    return (
        vqgan.decode_code(indices, params=vqgan_params).clip(0.0, 1.0) * 255
    ).astype(jnp.uint8)

### Program code

In [None]:
from dm_generations_decode import *

from collections import deque
from functools import partial
from math import ceil
from tqdm.auto import tqdm
from typing import Dict, Iterable

def process_groups(
    groups: Iterable[ImageGroup],
    batch_size: int,
    device,
    vqgan_params: Dict,
    task_factory: ImageGroupTaskFactory,
) -> None:
    total = 0
    queue = deque()
    
    getter = lambda g: g.embeddings
    for group in groups:
        task = task_factory(group)
        queue.append((task, Producer(partial(getter, group))))
        total += group.ct

    for batch, index in tqdm(
        batch_iter(iter(queue), batch_size), total=ceil(total / batch_size)
    ):
        codes = jax.device_put(batch, device)
        images = jax.device_get(vqgan_decode(codes, vqgan_params))
        i = 0
        for task, chunk_len in index:
            task(images[i : i + chunk_len])
            i += chunk_len

### Open database

In [None]:
import sqlite3
con = sqlite3.connect("dm_generations_sampling.db", check_same_thread=False)
con.row_factory = sqlite3.Row
cur = con.cursor()

## Adjust settings and run
* "images" mode produces one file per generated image
  * **needed for the sample browser**
* "gallery" mode creates grids showing all images generated with the same settings, sorted by CLIP score
  * good for browsing in an image viewer

The first batch will take a little longer due to compilation, be patient.

In total, the provided 71,680 images took about 25 min on my 3090 and should take about 75 min on Colab's T4.

In [None]:
mode = "images" #@param ["images", "gallery"]
gallery_cols =  8#@param {type:"integer"}
gallery_rows =  4#@param {type:"integer"}

from functools import partialmethod
import concurrent.futures

executor = concurrent.futures.ThreadPoolExecutor()

task_user_args = {
    "out_dir": mode,
    "overwrite": False,
}

if mode == "images":
    base_task = SaveImagesTaskBase
elif mode == "gallery":
    base_task = SaveGalleryTaskBase
    task_user_args.update({
        "cols": gallery_cols,
        "rows": gallery_rows
    })
else:
    raise ValueError(f"Unknown mode: {mode}")

class task_factory(base_task):
    __call__ = partialmethod(base_task.__call__, **task_user_args, executor=executor)

groups = ImageGroup.fetch_meta(cur)
process_groups(groups, 32, device, vqgan_params, task_factory)

After running the above cell, outputs will be in `[dm_generations_decode]/images` or `[dm_generations_decode]/gallery`.

If you're running this on a notebook server like Colab and want to download the results, archive them first:

In [None]:
!zip -0 -r -q dm_generations_{mode}.zip {mode}