In [1]:
from PIL import Image
import torch as th
import os

from glide_text2im.download import load_checkpoint
from glide_text2im.model_creation import (
    create_model_and_diffusion,
    model_and_diffusion_defaults,
    model_and_diffusion_defaults_upsampler
)

In [2]:
def init_cuda_device_options():
    has_cuda = th.cuda.is_available()
    device = th.device('cuda' if has_cuda else 'cpu')
    options = model_and_diffusion_defaults()
    options['use_fp16'] = has_cuda
    options['timestep_respacing'] = '100'
    return has_cuda, device, options


has_cuda = th.cuda.is_available()
device = th.device('cuda' if has_cuda else 'cpu')
options = model_and_diffusion_defaults()
options['use_fp16'] = has_cuda
options['timestep_respacing'] = '100'

In [4]:
has_cuda, device, options = init_cuda_device_options()
model, diffusion = create_model_and_diffusion(**options)

In [5]:
model.eval()
if has_cuda:
    model.convert_to_fp16()
model.to(device)
model.load_state_dict(load_checkpoint('base', device))
print('total base parameters', sum(x.numel() for x in model.parameters()))

total base parameters 385030726


In [6]:
print(th.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |     806 MB |    2293 MB |    2293 MB |    1486 MB |
|       from large pool |     765 MB |    2226 MB |    2226 MB |    1461 MB |
|       from small pool |      40 MB |      66 MB |      66 MB |      25 MB |
|---------------------------------------------------------------------------|
| Active memory         |     806 MB |    2293 MB |    2293 MB |    1486 MB |
|       from large pool |     765 MB |    2226 MB |    2226 MB |    1461 MB |
|       from small pool |      40 MB |      66 MB |      66 MB |      25 MB |
|---------------------------------------------------------------

In [7]:
th.cuda.empty_cache()
prompt = "a bird on a rocket"
batch_size = 3
guidance_scale = 3.0

upsample_temp = 0.997


In [8]:
# Create the text tokens to feed to the model.
tokens = model.tokenizer.encode(prompt)
tokens, mask = model.tokenizer.padded_tokens_and_mask(tokens, options['text_ctx'])

In [9]:
full_batch_size = batch_size *2
uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask([], options['text_ctx'])

In [10]:
model_kwargs = dict(
    tokens = th.tensor([tokens] * batch_size + [uncond_tokens] * batch_size, device = device),
    mask = th.tensor([mask] * batch_size + [uncond_mask] * batch_size, device=device, dtype=th.bool),
)

In [11]:
# Create a classifier-free guidance sampling function

def model_fn(x_t, ts, **kwargs):
    half = x_t[: len(x_t) //2]
    combined = th.cat([half, half], dim = 0)
    model_out = model(combined, ts, **kwargs)
    eps, rest = model_out[:, :3], model_out[:, 3:]
    cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim = 0)
    half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
    eps = th.cat([half_eps, half_eps], dim = 0)
    return th.cat([eps, rest], dim = 1)

In [12]:
model.del_cache()
samples = diffusion.p_sample_loop(
    model_fn,
    (full_batch_size, 3, options['image_size'], options['image_size']),
    device = device,
    clip_denoised=True,
    progress=True,
    model_kwargs=model_kwargs,
    cond_fn=None
)[:batch_size]

  0%|          | 0/100 [00:00<?, ?it/s]