[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/notebooks/blob/main/camenduru's_stable_diffusion_tile.ipynb)

In [1]:
!pip install -q torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 torchtext==0.14.1 torchdata==0.5.1 --extra-index-url https://download.pytorch.org/whl/cu116 -U
!pip install git+https://github.com/camenduru/diffusers.git@padding
!pip install --upgrade jax jaxlib
!pip install piexif fold-to-ascii

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu('tpu_driver_20221011')

!pip install flax transformers ftfy
jax.devices()

import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard

import os, gc, requests, subprocess, random
from diffusers import FlaxStableDiffusionPipeline

from IPython.display import clear_output
import csv, piexif
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from fold_to_ascii import fold
clear_output()

In [None]:
from huggingface_hub import notebook_login
!git config --global credential.helper store
notebook_login()

In [3]:
pipe, params = FlaxStableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jax.numpy.bfloat16, safety_checker=None)
#pipe, params = FlaxStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jax.numpy.bfloat16, safety_checker=None)
params = replicate(params)
name = 0
clear_output()

In [17]:
name = 0

In [None]:
def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

token = '' #@param {type: 'string'}
channel_id = 0 #@param {type: 'integer'}
header = {"authorization": f"Bot {token}"}
by = 'camenduru' #@param {type: 'string'}
is_png = False #@param {type: 'boolean'}

root_folder = '/content/art' #@param {type: 'string'}
image_folder = 'tile' #@param {type: 'string'}
if os.path.exists(f"{root_folder}") == False:
  os.mkdir(f"{root_folder}")
  if os.path.exists(f"{root_folder}/{image_folder}") == False:
    os.mkdir(f"{root_folder}/{image_folder}")

if os.path.exists(f"/content/png/") == False:
  os.mkdir(f"/content/png/")
  if os.path.exists(f"/content/png/{image_folder}") == False:
    os.mkdir(f"/content/png/{image_folder}")

height = 640 #@param {type: 'integer'}
width = 1024 #@param {type: 'integer'}
prompts_csv = 'fav.txt' #@param {type: 'string'}

metadata = PngInfo()

def generate(prompt, name, category, artist, score):
  if os.path.exists(f"{root_folder}/{image_folder}/{category}") == False:
    os.mkdir(f"{root_folder}/{image_folder}/{category}")
  metadata.add_text("prompt", f"{prompt}")
  metadata.add_text("by", f"{by}")
  metadata.add_text("category", f"{category}")
  metadata.add_text("artist", f"{artist}")
  metadata.add_text("score", f"{score}")
  gc.collect()
  real_seed = random.randint(0, 2147483647)
  prng_seed = jax.random.PRNGKey(real_seed)
  num_samples = jax.device_count()
  prompt_n = num_samples * [prompt]
  prompt_ids = pipe.prepare_inputs(prompt_n)
  prng_seed = jax.random.split(prng_seed, jax.device_count())
  prompt_ids = shard(prompt_ids)
  images = pipe(prompt_ids, params, prng_seed, num_inference_steps=50, height=height, width=width, guidance_scale=7.5, jit=True).images
  images = pipe.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
  # image = image_grid(images, 2, 4)
  image_number = 0
  for image in images:
    if is_png:
      image.save(f"{root_folder}/{image_folder}/{category}/{category}_{artist}_{score}_{name:04}_{image_number}.png", pnginfo=metadata)
    else:
      zeroth_ifd = {piexif.ImageIFD.Artist: f"{fold(artist)}",
            piexif.ImageIFD.DocumentName: f"{fold(category)}",
            piexif.ImageIFD.ImageDescription: f"{fold(prompt)}",
            piexif.ImageIFD.Make: f"{fold(by)}",
            piexif.ImageIFD.Model: f"runwayml/stable-diffusion-v1-5",
            piexif.ImageIFD.Copyright: f"Attribution 4.0 International (CC BY 4.0)",
            piexif.ImageIFD.Software: f"{fold(score)}"}
      exif_dict = {"0th": zeroth_ifd}
      exif_bytes = piexif.dump(exif_dict)
      image.save(f"{root_folder}/{image_folder}/{category}/{category}_{artist}_{score}_{name:04}_{image_number}.jpg", "JPEG", quality=70, exif=exif_bytes)
    if os.path.exists(f"/content/png/{image_folder}/{category}") == False:
      os.mkdir(f"/content/png/{image_folder}/{category}")
    image.save(f"/content/png/{image_folder}/{category}/{category}_{artist}_{score}_{name:04}_{image_number}.png", pnginfo=metadata)
    files = {f"{category}_{artist}_{score}_{name:04}_{image_number}.png" : open(f"/content/png/{image_folder}/{category}/{category}_{artist}_{score}_{name:04}_{image_number}.png", "rb").read()}
    payload = {"content":f"{category} : {artist} : {score} : {prompt}"}
    r = requests.post(f"https://discord.com/api/v9/channels/{channel_id}/messages", data=payload, headers=header, files=files).text
    image_number += 1
    clear_output()

max_files = 100 #@param {type: 'integer'}
is_prompts_from_csv = False #@param {type: 'boolean'}
is_prompts_from_txt = False #@param {type: 'boolean'}
prompts_txt = 'prompts.txt' #@param {type: 'string'}
category = 'test' #@param {type: 'string'}
artist = 'artist' #@param {type: 'string'}
score = 'score' #@param {type: 'string'}
if(is_prompts_from_csv):
  with open(f"{prompts_csv}", 'r') as file:
      csv_file = csv.DictReader(file)
      for row in csv_file:
          name += 1
          prompt_csv_prefix = '' #@param {type: 'string'}
          prompt_csv_suffix = '' #@param {type: 'string'}
          prompt = f"{prompt_csv_prefix} {row['artist']} {prompt_csv_suffix}"
          generate(prompt, name, row['category'], row['artist'], row['score'])
elif(is_prompts_from_txt):
  while name < max_files:
    with open(f'{prompts_txt}', "r") as file:
      prompts = file.readlines()
    for prompt in prompts:
      name += 1
      generate(prompt, name, category, artist, score)
else:
  while name < max_files:
    prompt = 'sci-fi landscape' #@param {type: 'string'}
    name += 1
    generate(prompt, name, category, artist, score)

In [None]:
from PIL import Image, ExifTags
img = Image.open(f"/content/artiststostudy2/dog/cartoon/test.jpg")
exif = { ExifTags.TAGS[k]: v for k, v in img._getexif().items() if k in ExifTags.TAGS }
print(exif)

In [None]:
!git config --global user.name "camenduru"
!git config --global user.email "camenduru@gmail.com"

In [None]:
%cd /content/artiststostudy3
!git init
!git checkout -b  main
!git remote add origin https://camenduru:token@gitlab.com/camenduru/artiststostudy3.git
!git add .
!git commit -m "Initial commit"
!git push -u origin main
%cd /content

In [None]:
%cd /content/artiststostudy3
!git pull https://camenduru:token@gitlab.com/camenduru/artiststostudy3.git
%cd /content

In [None]:
%cd /content/artiststostudy3
!git add .
!git commit -m "fix images"
!git push -u origin main
%cd /content

In [None]:
!git clone https://camenduru:token@gitlab.com/camenduru/artiststostudy3.git

In [None]:
import os
from PIL import Image

def crop_center(pil_img, crop_width, crop_height):
    img_width, img_height = pil_img.size
    return pil_img.crop(((img_width - crop_width) // 2,
                         (img_height - crop_height) // 2,
                         (img_width + crop_width) // 2,
                         (img_height + crop_height) // 2))
    
def crop_max_square(pil_img):
    return crop_center(pil_img, min(pil_img.size), min(pil_img.size))

root_path = 'artiststostudy2'
jpg_paths= ['anime', 'black-white', 'c', 'cartoon', 'digipa-high-impact', 'digipa-low-impact', 'digipa-med-impact', 'fareast', 'fineart', 'n', 'nudity', 'scribbles', 'special', 'ukioe', 'weird']
prompt_paths= ['spaceship']

if os.path.exists(f"/content/{root_path}/thumbnail") == False:
  os.mkdir(f"/content/{root_path}/thumbnail")

for prompt_path in prompt_paths:
  if os.path.exists(f"/content/{root_path}/thumbnail/{prompt_path}") == False:
    os.mkdir(f"/content/{root_path}/thumbnail/{prompt_path}")

  for jpg_path in jpg_paths:
    for jpg_image in os.listdir(f"/content/{root_path}/{prompt_path}/{jpg_path}/"):
      try:
          print(f"/content/{root_path}/{prompt_path}/{jpg_path}/{jpg_image}")
          image = Image.open(f"/content/{root_path}/{prompt_path}/{jpg_path}/{jpg_image}")
          #exif = image._getexif()
          image = crop_max_square(image)
          image.thumbnail((300, 300))
          if os.path.exists(f"/content/{root_path}/thumbnail/{prompt_path}/{jpg_path}") == False:
            os.mkdir(f"/content/{root_path}/thumbnail/{prompt_path}/{jpg_path}")
          image.save(f"/content/{root_path}/thumbnail/{prompt_path}/{jpg_path}/{jpg_image}", "JPEG", quality=50)
      except Exception as e:
        print(f"/content/{root_path}/{prompt_path}/{jpg_path}/{jpg_image}", e)