[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/notebooks/blob/main/camenduru's_stable_diffusion_flax.ipynb)

In [None]:
!pip install -q torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 torchtext==0.14.1 torchdata==0.5.1 --extra-index-url https://download.pytorch.org/whl/cu116 -U
from google.colab import drive
drive.mount('/content/gdrive')

!pip install git+https://github.com/huggingface/diffusers.git
!pip install --upgrade jax jaxlib 

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu('tpu_driver_20221011')

!pip install flax transformers ftfy
jax.devices()

import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard

import os, gc, requests, subprocess, random
from diffusers import FlaxStableDiffusionPipeline

from IPython.display import clear_output
clear_output()

In [None]:
from huggingface_hub import notebook_login
!git config --global credential.helper store
notebook_login()

In [9]:
pipe, params = FlaxStableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jax.numpy.bfloat16, safety_checker=None)
params = replicate(params)
clear_output()

In [None]:
from PIL import Image
def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

token = '' #@param {type: 'string'}
channel_id = 0 #@param {type: 'integer'}
header = {"authorization": f"Bot {token}"}
by = 'camenduru' #@param {type: 'string'}

root_folder = '/content/gdrive/MyDrive/AI/StableDiffusion' #@param {type: 'string'}
image_folder = '000' #@param {type: 'string'}
if os.path.exists(f"{root_folder}/{image_folder}") == False:
  os.mkdir(f"{root_folder}/{image_folder}")
name = max([int(f[:f.index('.')]) for f in os.listdir(f"{root_folder}/{image_folder}")], default=0)

from PIL.PngImagePlugin import PngInfo
metadata = PngInfo()
 
height = 448 #@param {type: 'integer'}
width = 832 #@param {type: 'integer'}

def generate(prompt, name):
  metadata.add_text("Prompt", f"{prompt}")
  metadata.add_text("by", f"{by}")
  gc.collect()
  real_seed = random.randint(0, 2147483647)
  prng_seed = jax.random.PRNGKey(real_seed)
  num_samples = jax.device_count()
  prompt_n = num_samples * [prompt]
  prompt_ids = pipe.prepare_inputs(prompt_n)
  prng_seed = jax.random.split(prng_seed, jax.device_count())
  prompt_ids = shard(prompt_ids)
  images = pipe(prompt_ids, params, prng_seed, num_inference_steps=50, height=height, width=width, guidance_scale=7.5, jit=True).images
  images = pipe.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
  image = image_grid(images, 2, 4)
  image.save(f"{root_folder}/{image_folder}/{name:04}.png", pnginfo=metadata)
  files = {f"{image_folder}_{name:04}.png" : open(f"{root_folder}/{image_folder}/{name:04}.png", "rb").read()}
  payload = {"content":f"{prompt}"}
  r = requests.post(f"https://discord.com/api/v9/channels/{channel_id}/messages", data=payload, headers=header, files=files).text
  clear_output()

max_files = 100 #@param {type: 'integer'}
is_from_prompts_txt = False #@param {type: 'boolean'}
prompts_txt = 'prompts.txt' #@param {type: 'string'}
if(is_from_prompts_txt):
  while name < max_files:
    with open(f'{prompts_txt}', "r") as file:
      prompts = file.readlines()
    for prompt in prompts:
      name += 1
      generate(prompt, name)
else:
  while name < max_files:
    prompt = 'panda by Mike Winkelmann Beeple, ultra-detailed pen and ink illustration' #@param {type: 'string'}
    name += 1
    generate(prompt, name)
