In [1]:
!pip install git+https://github.com/openai/glide-text2im

Collecting git+https://github.com/openai/glide-text2im
  Cloning https://github.com/openai/glide-text2im to /tmp/pip-req-build-jegjdlsm
  Running command git clone --filter=blob:none --quiet https://github.com/openai/glide-text2im /tmp/pip-req-build-jegjdlsm
  Resolved https://github.com/openai/glide-text2im to commit 69b530740eb6cef69442d6180579ef5ba9ef063e
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from glide-text2im==0.0.0)
  Downloading ftfy-6.1.3-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.4/53.4 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: glide-text2im
  Building wheel for glide-text2im (setup.py) ... [?25l[?25hdone
  Created wheel for glide-text2im: filename=glide_text2im-0.0.0-py3-none-any.whl size=1953625 sha256=a31613d65e483a362c7fb9a01519e04af3255cdac79fce24f63f777897039390
  Stored in directory: /tmp/pip-ephem-wheel-cache-db5ovmlh/wheels/88/21/5e/57cab1c107

In [2]:
from PIL import Image
from IPython.display import display
import torch as th
import os
from glide_text2im.download import load_checkpoint
from glide_text2im.model_creation import (
    create_model_and_diffusion,
    model_and_diffusion_defaults,
    model_and_diffusion_defaults_upsampler
)


In [7]:
def save_images(batch: th.Tensor, output_path, prompt):
    os.makedirs(os.path.join(output_path, prompt), exist_ok=True)  # Ensure the prompt-specific directory exists
    scaled = ((batch + 1) * 127.5).round().clamp(0, 255).to(th.uint8).cpu()
    reshaped = scaled.permute(0, 2, 3, 1)
    for index, image in enumerate(reshaped):
        img = Image.fromarray(image.numpy())
        img.save(os.path.join(output_path, prompt, str(index) + ".png"))


# Create a classifier-free guidance sampling function
def model_fn(x_t, ts, **kwargs):
    half = x_t[: len(x_t) // 2]
    combined = th.cat([half, half], dim=0)
    model_out = model(combined, ts, **kwargs)
    eps, rest = model_out[:, :3], model_out[:, 3:]
    cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
    half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
    eps = th.cat([half_eps, half_eps], dim=0)
    return th.cat([eps, rest], dim=1)


has_cuda = th.cuda.is_available()
device = th.device('cpu' if not has_cuda else 'cuda')


guidance_scale = 3.0

# Create base model.
options = model_and_diffusion_defaults()
options['use_fp16'] = has_cuda
options['timestep_respacing'] = '70' # use 100 diffusion steps for fast sampling
model, diffusion = create_model_and_diffusion(**options)
model.eval()
if has_cuda:
    model.convert_to_fp16()
model.to(device)
model.load_state_dict(load_checkpoint('base', device))


# Create upsampler model.
options_up = model_and_diffusion_defaults_upsampler()
options_up['use_fp16'] = has_cuda
options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling
model_up, diffusion_up = create_model_and_diffusion(**options_up)
model_up.eval()
if has_cuda:
    model_up.convert_to_fp16()
model_up.to(device)
model_up.load_state_dict(load_checkpoint('upsample', device))

def txt2img(skip_grid=True, skip_save=False, n_samples=1, outdir=""):
    # Read prompts from the prompts.txt file
    with open("/content/prompts.txt", "r") as file:
        prompts = [line.strip() for line in file]

    # Sampling parameters
    batch_size = n_samples

    # Tune this parameter to control the sharpness of 256x256 images.
    # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
    upsample_temp = 0.997

    ##############################
    # Sample from the base model #
    ##############################

    rows = len(prompts)
    for index, prompt in enumerate(prompts):
        # Create the text tokens to feed to the model.
        tokens = model.tokenizer.encode(prompt)
        tokens, mask = model.tokenizer.padded_tokens_and_mask(
            tokens, options['text_ctx']
        )

        # Create the classifier-free guidance tokens (empty)
        full_batch_size = batch_size * 2
        uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(
            [], options['text_ctx']
        )

        # Pack the tokens together into model kwargs.
        model_kwargs = dict(
            tokens=th.tensor(
                [tokens] * batch_size + [uncond_tokens] * batch_size, device=device
            ),
            mask=th.tensor(
                [mask] * batch_size + [uncond_mask] * batch_size,
                dtype=th.bool,
                device=device,
            ),
        )

        # Sample from the base model.
        model.del_cache()
        samples = diffusion.p_sample_loop(
            model_fn,
            (full_batch_size, 3, options["image_size"], options["image_size"]),
            device=device,
            clip_denoised=True,
            progress=True,
            model_kwargs=model_kwargs,
            cond_fn=None,
        )[:batch_size]
        model.del_cache()

        # Show the output


        ##############################
        # Upsample the 64x64 samples #
        ##############################

        tokens = model_up.tokenizer.encode(prompt)
        tokens, mask = model_up.tokenizer.padded_tokens_and_mask(
            tokens, options_up['text_ctx']
        )

        # Create the model conditioning dict.
        model_kwargs = dict(
            # Low-res image to upsample.
            low_res=((samples+1)*127.5).round()/127.5 - 1,

            # Text tokens
            tokens=th.tensor(
                [tokens] * batch_size, device=device
            ),
            mask=th.tensor(
                [mask] * batch_size,
                dtype=th.bool,
                device=device,
            ),
        )

        # Sample from the base model.
        model_up.del_cache()
        up_shape = (batch_size, 3, options_up["image_size"], options_up["image_size"])
        up_samples = diffusion_up.ddim_sample_loop(
            model_up,
            up_shape,
            noise=th.randn(up_shape, device=device) * upsample_temp,
            device=device,
            clip_denoised=True,
            progress=True,
            model_kwargs=model_kwargs,
            cond_fn=None,
        )[:batch_size]
        model_up.del_cache()

        # Save the output
        save_images(up_samples, "/content/images", prompt)

        if index % 100:
            print("Generated", str(index), "/", str(rows))



  0%|          | 0.00/1.54G [00:00<?, ?iB/s]

  0%|          | 0.00/1.59G [00:00<?, ?iB/s]

In [8]:
txt2img()

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Generated 1 / 5


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Generated 2 / 5


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Generated 3 / 5


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Generated 4 / 5


In [None]:
!rm -r /content/images/*