<a href="https://colab.research.google.com/github/karaage0703/ai_zoo_keeper/blob/main/ai_zoo_animal_creator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AI Zoo animal creator

Create AI Zoo animal

In [None]:
#@title **Setup**

# Setup rembg Reference:
# https://stackoverflow.com/questions/71738218/module-pil-has-not-attribute-resampling

!pip -qq install diffusers[torch]==0.11.1 transformers
!pip -qq install --upgrade --pre triton

# install xformers
# https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/issues/102
# !pip -qq install https://github.com/metrolobo/xformers_wheels/releases/download/1d31a3ac_various_6/xformers-0.0.14.dev0-cp37-cp37m-linux_x86_64.whl
!pip -qq install https://github.com/metrolobo/xformers_wheels/releases/download/4c06c79_various6/xformers-0.0.15.dev0_4c06c79.d20221201-cp38-cp38-linux_x86_64.whl

!pip install -qq tqdm
!pip install -qq rembg
!pip install -qq pillow==9.2.0

import PIL.Image
if not hasattr(PIL.Image, 'Resampling'):  # Pillow<9.0
    PIL.Image.Resampling = PIL.Image
import torch
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
import os
import cv2
from rembg.bg import remove

fig = plt.figure(figsize=(10,10))

device = "cuda"
model_id = "stabilityai/stable-diffusion-2"

if model_id == "stabilityai/stable-diffusion-2":
    pipe = StableDiffusionPipeline.from_pretrained(
        model_id, 
        scheduler=EulerDiscreteScheduler.from_pretrained(
        model_id, 
        subfolder="scheduler"
        ), 
        torch_dtype=torch.float16,
        revision="fp16"
    ).to("cuda")
    pipe.enable_attention_slicing()

In [None]:
#@title **Generate Image**
#@markdown　Enter Parameter  (Attention: Seed=-1 is random)

name = 'karaage' #@param {type:"string"}
animal = 'kawaii panda' #@param {type:"string"}
seed_number = -1 #@param

num_inference_steps  = 20
guidance_scale_value = 7.5
width_image = 512
height_image = 512

def infer(prompt, seed_number, num_inference_steps, guidance_scale_value, width_image, height_image):
    generator = torch.Generator(device=device)
    latents = None

    # Get a new random seed, store it and use it as the generator state
    if seed_number < 0:
        seed = generator.seed()
    else:
        seed = seed_number

    generator = generator.manual_seed(seed)

    image_latent = torch.randn(
        (1, pipe.unet.in_channels, height_image // 8, width_image // 8),
        generator = generator,
        device = device
    )

    with torch.autocast('cuda'):
        image = pipe(
            [prompt],
            width=width_image,
            height=height_image,
            guidance_scale=guidance_scale_value,
            num_inference_steps=num_inference_steps,
            latents = image_latent
        ).images[0]

    return image, image_latent

def draw_image_from_latents(prompt, num_inference_steps, guidance_scale_value, width_image, height_image, image_latent):
    with torch.autocast('cuda'):
        image = pipe(
            [prompt],
            width=width_image,
            height=height_image,
            guidance_scale=guidance_scale_value,
            num_inference_steps=num_inference_steps,
            latents = image_latent
        ).images[0]

    return image

def draw_image(image):
    fig = plt.figure(figsize=(10,10))
    plt.imshow(image)
    plt.axis('off')
    plt.show()

image, latents = infer(animal, seed_number, num_inference_steps, guidance_scale_value, width_image, height_image)

draw_image(image)

In [None]:
#@title **Making variations and save images**
#@markdown　Execute for generating images

number_frames = 4
max_distance = 0.1 #@param {type:"slider", min:0.01, max:0.5, step:0.01}
random_walk = np.random.default_rng()

# random walk in latent space
image_cv = []

for n in tqdm(range(number_frames)):
    for i in range(latents.size()[1]):
        for j in range(latents.size()[2]):
            for k in range(latents.size()[3]):
                latents[0][i][j][k] += random_walk.uniform(-max_distance, max_distance)


    image = draw_image_from_latents(animal, num_inference_steps, guidance_scale_value, width_image, height_image, latents)
    print('below image is number ' + str(n))
    image_np = np.array(image)

    image_np = cv2.resize(image_np, dsize=(256, 256))
    image_np = remove(image_np)

    file_path = os.path.join(f"{name}_{n:001}.png")
    cv2.imwrite(file_path, image_np)

In [None]:
#@title **Check images**
#@markdown　Execute for checking images

from IPython.display import Image as IPImage
from IPython.display import display_png

for n in range(4):
  file_path = os.path.join(f"{name}_{n:001}.png")
  display_png(IPImage(file_path))

### download

In [None]:
#@title **Download images**
#@markdown　Execute for downloading images

!zip animal.zip *.png

from google.colab import files
files.download('animal.zip')

## Submit

Submit issue with your animal. 

https://github.com/karaage0703/ai_zoo_keeper/issues

## Reference
