In [33]:
import torch

import jax.numpy as jnp

from jax import random, jit
from tqdm.auto import tqdm

from diffusers import (
    FlaxAutoencoderKL,
    FlaxDDPMScheduler,
    FlaxUNet2DConditionModel,
)

from transformers import ByT5Tokenizer, FlaxT5Model

In [10]:
vae = FlaxAutoencoderKL.from_pretrained("flax/stable-diffusion-2-1", subfolder="vae")

In [11]:
unet = FlaxUNet2DConditionModel.from_pretrained("character-aware-diffusion/charred")

In [12]:
scheduler = FlaxDDPMScheduler(
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        num_train_timesteps=1000,
    )

In [13]:
tokenizer = ByT5Tokenizer.from_pretrained("google/byt5-base")
lm = FlaxT5Model.from_pretrained("google/byt5-base")

Some weights of the model checkpoint at google/byt5-base were not used when initializing FlaxT5Model: {('lm_head', 'kernel')}
- This IS expected if you are initializing FlaxT5Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxT5Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [14]:
prompt = ["a photograph of an astronaut riding a camel"]
height = 512
width = 512
num_inference_steps = 100
guidance_scale = 7.5 
key = random.PRNGKey(0)
batch_size = len(prompt)

In [15]:
text_input = tokenizer(
            text=prompt,
            max_length=1024,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        ).input_ids

In [16]:
text_embeddings = lm.encode(
            jnp.array(text_input),
            params=lm.params,
            train=False,
        )[0]

In [17]:
uncond_input = tokenizer(
    [""] * batch_size, padding="max_length", max_length=1024, return_tensors="pt"
).input_ids

In [18]:
uncond_embeddings = lm.encode(jnp.array(uncond_input), params=lm.params, train=False)[0]  

In [19]:
embeddings = jnp.concatenate([uncond_embeddings, text_embeddings])

In [22]:
latents = jnp.array(torch.randn(
    (batch_size, unet[0].in_channels, height // 8, width // 8),
))

In [26]:
state = scheduler.create_state()

In [37]:
latents = latents * state.init_noise_sigma

In [46]:
schduler_timestep = scheduler.set_timesteps(state, unet[0].in_channels)

In [50]:
for t in tqdm(test.timesteps):
    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
    latent_model_input = jnp.concatenate([latents] * 2)

    latent_model_input = scheduler.scale_model_input(state, latent_model_input, timestep=t)

    # predict the noise residual
    #with torch.no_grad():
    noise_pred = unet[0](latent_model_input, t, encoder_hidden_states=text_embeddings).sample

    # perform guidance
    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

    # compute the previous noisy sample x_t -> x_t-1
    latents = scheduler.step(noise_pred, t, latents).prev_sample

  0%|                                                                                                       | 0/4 [00:00<?, ?it/s]
