In [None]:
!pip install --upgrade pip
!pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install jaxlib diffusers huggingface_hub flax transformers
# !pip install tfp-nightly[jax] --upgrade -q 2>&1 1> /dev/null
# !pip install tf-nightly-cpu -q -I 2>&1 1> /dev/null
# !pip install jax -I -q --upgrade 2>&1 1>/dev/null

In [None]:
import jax
import numpy as np
import jax.numpy as jnp

from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline
import transformers
import time
from diffusers.utils import make_image_grid

In [None]:
NUM_DEVICES = jax.device_count()
device_type = jax.devices()[0].device_kind

assert (
    "TPU" in device_type
), "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"

print(f"Found {NUM_DEVICES} JAX devices of type {device_type}.")

In [None]:
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16
)

In [None]:
scheduler_state = params.pop("scheduler")
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
params["scheduler"] = scheduler_state

In [None]:
default_guidance_scale = 5.0
default_num_steps = 25

In [None]:
def tokenize_prompt(prompt, neg_prompt):
    prompt_ids = pipeline.prepare_inputs(prompt)
    neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
    return prompt_ids, neg_prompt_ids

NUM_DEVICES = jax.device_count()

# Model parameters don't change during inference,
# so we only need to replicate them once.
p_params = replicate(params)

def replicate_all(prompt_ids, neg_prompt_ids, seed):
    p_prompt_ids = replicate(prompt_ids)
    p_neg_prompt_ids = replicate(neg_prompt_ids)
    rng = jax.random.PRNGKey(seed)
    rng = jax.random.split(rng, NUM_DEVICES)
    return p_prompt_ids, p_neg_prompt_ids, rng

def generate(
    prompt,
    negative_prompt,
    seed=0,
    guidance_scale=default_guidance_scale,
    num_inference_steps=default_num_steps,
):
    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
    images = pipeline(
        prompt_ids,
        p_params,
        rng,
        num_inference_steps=num_inference_steps,
        neg_prompt_ids=neg_prompt_ids,
        guidance_scale=guidance_scale,
        jit=True,
    ).images

    # convert the images to PIL
    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
    return pipeline.numpy_to_pil(np.array(images))

In [None]:
default_prompt = "A cinematic film still of Shah Rukh Khan starring as Steven Tyler playing guitar, landscape, 18mm lens, deep focus, medium shot, highly detailed, cinematic"
# default_prompt = "An expressive oil painting of a basketball player dunking, depicted as an explosion of a nebula."
# default_neg_prompt = "low-quality, bad anatomy, deformed, fused fingers, mutation, mutilated, poorly drawn hands, ugly"
default_neg_prompt = "illustration, low-quality, bad anatomy, deformed, fused fingers, mutation, mutilated, poorly drawn hands, ugly"
default_seed = 1356

start = time.time()
print(f"Compiling ...")
images_gen = generate(default_prompt, default_neg_prompt, default_seed)
print(f"Compiled in {time.time() - start}")

from diffusers.utils import make_image_grid

make_image_grid(images_gen, 2, 4)

In [None]:
images_gen[4].save("first_image.jpg") 

# Stable Diffusion in JAX / Flax !
> https://huggingface.co/docs/diffusers/using-diffusers/stable_diffusion_jax_how_to

In [None]:
dtype = jnp.bfloat16
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="bf16",
    dtype=dtype,
)

In [None]:
prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 80mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
print(prompt_ids.shape)

p_params = replicate(params)
prompt_ids = shard(prompt_ids)
print(prompt_ids.shape)

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

rng = create_key(12365)
rng = jax.random.split(rng, jax.device_count())

%time
images = pipeline(prompt_ids, p_params, rng, jit=True)[0]

images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

make_image_grid(images, 2, 4)


In [None]:
pipeline.prepare_inputs()

In [None]:
images

# **Flax Img2Img**

In [None]:
!pip install --upgrade pip
!pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install jaxlib diffusers huggingface_hub flax transformers
# !pip install tfp-nightly[jax] --upgrade -q 2>&1 1> /dev/null
# !pip install tf-nightly-cpu -q -I 2>&1 1> /dev/null
# !pip install jax -I -q --upgrade 2>&1 1>/dev/null

In [None]:
import jax
import numpy as np
import jax.numpy as jnp

from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

from huggingface_hub import notebook_login
import transformers
import time
from diffusers.utils import make_image_grid

import requests
from io import BytesIO
from diffusers import FlaxStableDiffusionImg2ImgPipeline

In [None]:
# url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
url = "https://www.baseballbible.net/wp-content/uploads/What-Is-a-Pull-Hitter-in-Baseball.jpg"
response = requests.get(url)
init_img = Image.open(BytesIO(response.content)).convert("RGB")

# init_img = Image.open('/kaggle/input/baseball-pic/Basseball_Batter.jpg').convert("RGB")

init_img = init_img.resize((768, 512))
init_img

In [None]:
pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="flax",
    dtype=jnp.bfloat16,
)

In [None]:
NUM_DEVICES

In [None]:
num_samples = jax.device_count()
seed = 42
rng = jax.random.split(jax.random.PRNGKey(seed), NUM_DEVICES)
prompts = "Boy dressed in blue hitting with a baseball bat. Realistic. portrait, 80mm lens, shallow depth of field, close up, split lighting, cinematic"
# prompts = "A fantasy landscape, trending on artstation"

prompt_ids, processed_image = pipeline.prepare_inputs(
    prompt=[prompts] * num_samples, image=[init_img] * num_samples
)

p_params = replicate(params)
prompt_ids = shard(prompt_ids)
processed_image = shard(processed_image)

output = pipeline(
    prompt_ids=prompt_ids,
    image=processed_image,
    params=p_params,
    prng_seed=rng,
    strength=0.75,
    num_inference_steps=50,
    jit=True,
    height=512,
    width=768,
).images

output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))


In [None]:
# output_images = pipeline.numpy_to_pil(output_images

make_image_grid(output_images, 2, 4)

In [None]:
init_img

# **SDXL FLax Diffusers**

In [None]:
!pip install --upgrade pip
!pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# !pip install jax
!pip install jaxlib huggingface_hub flax transformers
!pip install git+https://github.com/huggingface/diffusers

# !git clone https://github.com/huggingface/diffusers.git
# %cd diffusers
# !pip install -e ".[flax]"
# %cd ..

# import sys
# sys.path.append(r'/kaggle/working/diffusers')
# sys.path.append(r'/kaggle/working/diffusers/src/diffusers')

# !pip install tfp-nightly[jax] --upgrade -q 2>&1 1> /dev/null

In [None]:
import jax
import numpy as np
import jax.numpy as jnp

from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionXLPipeline
import transformers
import time
from diffusers.utils import make_image_grid

In [None]:
dtype = jnp.bfloat16
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"

base_pipeline, base_params = FlaxStableDiffusionXLPipeline.from_pretrained(
    base_model_id, split_head_dim=True
)

base_scheduler_state = base_params.pop("scheduler")
base_params = jax.tree_util.tree_map(lambda x: x.astype(dtype), base_params)
base_params["scheduler"] = base_scheduler_state

In [None]:
base_pipeline.tokenizer_2

In [None]:
default_prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
prompt_ids = base_pipeline.prepare_inputs(default_prompt)

base_pipeline.text_encoder(input_ids=prompt_ids[:,0,:], params=base_params['text_encoder'], output_hidden_states=True)["hidden_states"][-2]

In [None]:
prompt_ids

In [None]:
# dtype = jnp.bfloat16
# refiner_model_id = "stabilityai/stable-diffusion-xl-refiner-1.0"

# refiner_pipeline, refiner_params = FlaxStableDiffusionXLPipeline.from_pretrained(
#     refiner_model_id, split_head_dim=True
# )

# refiner_scheduler_state = refiner_params.pop("scheduler")
# refiner_params = jax.tree_util.tree_map(lambda x: x.astype(dtype), refiner_params)
# refiner_params["scheduler"] = refiner_scheduler_state

In [None]:
default_seed = 33
default_guidance_scale = 5.0
default_num_steps = 25

In [None]:
def tokenize_prompt(prompt, neg_prompt):
    prompt_ids = base_pipeline.prepare_inputs(prompt)
    neg_prompt_ids = base_pipeline.prepare_inputs(neg_prompt)
    return prompt_ids, neg_prompt_ids

NUM_DEVICES = jax.device_count()

# Model parameters don't change during inference,
# so we only need to replicate them once.
p_params = replicate(base_params)

def replicate_all(prompt_ids, neg_prompt_ids, seed):
    p_prompt_ids = replicate(prompt_ids)
    p_neg_prompt_ids = replicate(neg_prompt_ids)
    rng = jax.random.PRNGKey(seed)
    rng = jax.random.split(rng, NUM_DEVICES)
    return p_prompt_ids, p_neg_prompt_ids, rng

def generate(
    prompt,
    negative_prompt,
    seed=default_seed,
    guidance_scale=default_guidance_scale,
    num_inference_steps=default_num_steps,
):
    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
    
    images = base_pipeline(
        prompt_ids,
        p_params,
        rng,
        num_inference_steps=num_inference_steps,
        neg_prompt_ids=neg_prompt_ids,
        guidance_scale=guidance_scale,
        jit=True
    ).images
    
#     images = refiner_pipeline(
#         prompt_ids,
#         p_params,
#         rng,
#         num_inference_steps=num_inference_steps,
#         neg_prompt_ids=neg_prompt_ids,
#         guidance_scale=guidance_scale,
#         jit=True,
#         latents=latents,
#     ).images

    # convert the images to PIL
    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
    return base_pipeline.numpy_to_pil(np.array(images))


In [None]:
default_prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
default_neg_prompt = "illustration, low-quality"

start = time.time()
print(f"Compiling ...")
output_images = generate(default_prompt, default_neg_prompt)
print(f"Compiled in {time.time() - start}")

make_image_grid(output_images, 2, 4)

In [None]:
# output_images[0].save(f"best.png")

In [None]:
default_prompt = "photo of a boy playing a piano in a church, in 1800 Germany, cinematic, award winning photography, closeup"
default_neg_prompt = "illustration, low-quality, bad anatomy, deformed, fused fingers, mutation, mutilated, poorly drawn hands, ugly"

start = time.time()
print(f"Compiling ...")
output_images = generate(default_prompt, default_neg_prompt)
print(f"Compiled in {time.time() - start}")

make_image_grid(output_images, 2, 4)

In [None]:
import importlib
from diffusers import DiffusionPipeline

importlib.import_module(DiffusionPipeline.__module__.split(".")[0])

In [None]:
from diffusers import StableDiffusionXLPipeline

from diffusers.loaders import (
    FromSingleFileMixin,
    IPAdapterMixin,
    StableDiffusionXLLoraLoaderMixin,
    TextualInversionLoaderMixin,
)


In [None]:
class myFlaxSDXLPipeline(FlaxStableDiffusionXLPipeline, StableDiffusionXLLoraLoaderMixin):
    def __init__(
        self,
    ):
        super().__init__()
        self.dtype = dtype

        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            text_encoder_2=text_encoder_2,
            tokenizer=tokenizer,
            tokenizer_2=tokenizer_2,
            unet=unet,
            scheduler=scheduler,
        )
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)

In [None]:
test_flax_class = myFlaxSDXLPipeline.load_lora_weights()

In [None]:
base_pipeline.unet

In [None]:
base_params['unet']['up_blocks_0']['attentions_2']['transformer_blocks_8']['attn2'].keys()

In [None]:
jax.tree_map(lambda x: x.shape, base_params['unet']['up_blocks_0']['attentions_2']['transformer_blocks_8']['attn2']['to_k'])

In [None]:
base_params['unet']['up_blocks_0']['attentions_2']['transformer_blocks_8']['attn2']['to_k']