[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/stable-diffusion-diffusers-colab/blob/main/flax_stable_diffusion_2.ipynb)

In [None]:
save_to = "here" #@param ["here", "discord", "gdrive"]
if save_to == "gdrive":
  from google.colab import drive
  drive.mount('/content/gdrive')

!pip install -U jax==0.3.25 jaxlib==0.3.25 flax==0.6.2 transformers piexif fold_to_ascii ftfy
!pip install -U diffusers==0.10.0

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu('tpu_driver_nightly')

import jax, random, os, gc, requests, json, piexif
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionPipeline
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from fold_to_ascii import fold
metadata = PngInfo()

def closestNumber(n, m):
    q = int(n / m)
    n1 = m * q
    if (n * m) > 0:
        n2 = m * (q + 1)
    else:
        n2 = m * (q - 1)
    if abs(n - n1) < abs(n - n2):
        return n1
    return n2

def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

folder_max_files = 500 #@param {type: 'integer'}
root_folder = 'ai_images' #@param {type: 'string'}
if save_to == "gdrive":
  root_folder = f"/content/gdrive/MyDrive/{root_folder}"

if os.path.exists(f"{root_folder}") == False:
    os.mkdir(f"{root_folder}")
image_folder = max([int(f) for f in os.listdir(f"{root_folder}")], default=0)
if os.path.exists(f"{root_folder}/{image_folder:04}") == False:
    os.mkdir(f"{root_folder}/{image_folder:04}")
name = max([int(f[: f.index(".")]) for f in os.listdir(f"{root_folder}/{image_folder:04}")],default=0,)

model_folder = "flax/stable-diffusion-2" #@param ["flax/stable-diffusion-2", "flax/stable-diffusion-2-base"] {allow-input: true}
revision = "main"
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(model_folder, revision=revision, from_pt=False, dtype=jax.numpy.bfloat16, safety_checker=None)
params = replicate(params)

def generate(discord_token, discord_channel_id, discord_user, by, num_inference_steps, guidance_scale, sampler, width, height, prompt, negative_prompt, suffix, image_folder, name):
    width = closestNumber(width, 8)
    height = closestNumber(height, 8)
    metadata.add_text("Prompt", f"{prompt}")
    metadata.add_text("by", f"{by}")
    gc.collect()
    real_seed = random.randint(0, 2147483647)
    prng_seed = jax.random.PRNGKey(real_seed)
    prng_seed = jax.random.split(prng_seed, jax.device_count())
    num_samples = jax.device_count()
    prompt_n = num_samples * [prompt]
    prompt_ids = pipe.prepare_inputs(prompt_n)
    prompt_ids = shard(prompt_ids)
    negative_prompt_n = num_samples * [negative_prompt]
    negative_prompt_ids = pipe.prepare_inputs(negative_prompt_n)
    negative_prompt_ids = shard(negative_prompt_ids)
    images = pipe(prompt_ids, params, prng_seed, neg_prompt_ids=negative_prompt_ids, num_inference_steps=num_inference_steps, height=height, width=width, guidance_scale=guidance_scale, jit=True).images
    images = pipe.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
    image = image_grid(images, 2, 4)
    if(suffix == 'png'):
      image.save(f"{root_folder}/{image_folder:04}/{name:04}.{suffix}", pnginfo=metadata)
    else:
      zeroth_ifd = {piexif.ImageIFD.ImageDescription: f"{fold(prompt)}", piexif.ImageIFD.Make: f"{fold(by)}", piexif.ImageIFD.Model: f"{model_folder}"}
      exif_dict = {"0th": zeroth_ifd}
      exif_bytes = piexif.dump(exif_dict)
      image.save(f"{root_folder}/{image_folder:04}/{name:04}.{suffix}", "JPEG", quality=70, exif=exif_bytes)
    files = {f"{image_folder:04}_{name:04}.{suffix}": open(f"{root_folder}/{image_folder:04}/{name:04}.{suffix}", "rb").read()}
    if save_to == "discord":
      payload = {"content": f"{prompt}\nNegative prompt: {negative_prompt}\nSteps: {num_inference_steps}, Sampler: {sampler}, CFG scale: {guidance_scale}, Seed: {real_seed}, Size: {width}x{height}, Model folder: {model_folder} - {discord_user}"}
      requests.post(f"https://discord.com/api/v9/channels/{discord_channel_id}/messages", data=payload, headers={"authorization": f"Bot {discord_token}"}, files=files)
      os.remove(f"{root_folder}/{image_folder:04}/{name:04}.{suffix}")

In [None]:
discord_token = "token" #@param {type: 'string'}
discord_channel_id = 0 #@param {type: 'integer'}
prompt = "duck" #@param {type: 'string'}
negative_prompt = "" #@param {type: 'string'}
width  = 512 #@param {type: 'integer'}
height  = 512 #@param {type: 'integer'}
guidance_scale = 7.5 #@param {type: 'number'}
num_inference_steps = 50 #@param {type: 'integer'}
suffix = "jpg" #@param ["jpg", "png"]
by = "camenduru" #@param {type: 'string'}
template = {
    "discord_token": discord_token,
    "discord_channel_id": discord_channel_id,
    "by": by,
    "num_inference_steps": num_inference_steps,
    "guidance_scale": guidance_scale,
    "sampler": "PLMS",
    "width": width,
    "height": height,
    "prompt": prompt,
    "negative_prompt": negative_prompt,
    "suffix": suffix
  }
with open("template.json", "w") as outfile:
    outfile.write(json.dumps(template))

is_loop = False #@param {type:"boolean"}

if is_loop:
  while True:
      if name < folder_max_files:
          with open("template.json", "r") as file:
              prompts = file.readlines()
          for prompt in prompts:
              d = json.loads(prompt)
              name += 1
              generate(d["discord_token"], d["discord_channel_id"], "camenduru", d["by"], d["num_inference_steps"], d["guidance_scale"], d["sampler"], d["width"], d["height"], d["prompt"], d["negative_prompt"], d["suffix"], image_folder, name)
      else:
          image_folder += 1
          if os.path.exists(f"{root_folder}/{image_folder:04}") == False:
              os.mkdir(f"{root_folder}/{image_folder:04}")
          name = 0
else:
  if name < folder_max_files:
      with open("template.json", "r") as file:
          prompts = file.readlines()
      for prompt in prompts:
          d = json.loads(prompt)
          name += 1
          generate(d["discord_token"], d["discord_channel_id"], "camenduru", d["by"], d["num_inference_steps"], d["guidance_scale"], d["sampler"], d["width"], d["height"], d["prompt"], d["negative_prompt"], d["suffix"], image_folder, name)
  else:
      image_folder += 1
      if os.path.exists(f"{root_folder}/{image_folder:04}") == False:
          os.mkdir(f"{root_folder}/{image_folder:04}")
      name = 0