# XenoBooth
Basado en el código de https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth, siguiendo el tutorial de https://medium.com/@pajunenpyry/easy-realistic-avatars-with-stable-diffusion-dreambooth-no-programming-step-by-step-seo-guide-no-711b70c91f69

v1.05

In [None]:
#@title ## 1 Conectar Google Drive
from google.colab import drive
drive.mount('/content/drive')


In [None]:
#@title ## 2 Comprobar GPU
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader

In [None]:
#@title ## 3 Librerías y módulos

%pip install -q torch==2.0.0+cu118 torchvision==0.15.1+cu118 torchaudio==2.0.1+cu118 torchtext==0.15.1 torchdata==0.6.0 --extra-index-url https://download.pytorch.org/whl/cu118 -U
%pip install -q xformers==0.0.19 triton==2.0.0 -U

!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/examples/dreambooth/train_dreambooth.py
!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py
%pip install -qq git+https://github.com/ShivamShrirao/diffusers

%pip install -q accelerate transformers ftfy bitsandbytes==0.35.0 gradio==3.48.0 natsort safetensors

import gradio as gr
import json
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import os
import time
import torch
from datetime import datetime
from diffusers import StableDiffusionPipeline, DDIMScheduler
from glob import glob
from IPython.display import display, HTML
from natsort import natsorted
from torch import autocast

def prettyprint(texto):
  html = f"<p style='font-size: large;'>{texto}</p>"
  display(HTML(html))
  #print(texto)

In [None]:
#@title ## 4 El nombre de la cosa

#@markdown Introduce el nombre -en inglés- de la clase (woman, man, child, boy, girl, person...)
CLASS_NAME = "woman" #@param {type:"string"}

#@markdown Introduce el nombre de la instancia
INSTANCE_NAME = "brplz" #@param {type:"string"}

BASE_DIR = "/content/drive/MyDrive/"
MODEL_NAME = "runwayml/stable-diffusion-v1-5"

OUTPUT_DIR = f"{BASE_DIR}xenobooth/{INSTANCE_NAME}/pesos"
INSTANCE_DATA_DIR = f"{BASE_DIR}xenobooth/{INSTANCE_NAME}/selfies"
CLASS_DATA_DIR = f"{BASE_DIR}xenobooth/data/{CLASS_NAME}"
GENERATION_DIR = f"{BASE_DIR}xenobooth/{INSTANCE_NAME}/generadas"

os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(GENERATION_DIR, exist_ok=True)

concepts_list = [
    {
        "instance_prompt":      "photo of " + INSTANCE_NAME + " " + CLASS_NAME,
        "class_prompt":         "photo of a " + CLASS_NAME,
        "instance_data_dir":    INSTANCE_DATA_DIR,
        "class_data_dir":       CLASS_DATA_DIR
    }
]

# print(f"[*] Los pesos se almacenarán en la carpeta {OUTPUT_DIR}")

for c in concepts_list:
    os.makedirs(c["instance_data_dir"], exist_ok=True)
    os.makedirs(c["class_data_dir"], exist_ok=True)

with open("concepts_list.json", "w") as f:
    json.dump(concepts_list, f, indent=4)

CONCEPT = concepts_list[0]["instance_prompt"]

prettyprint(f"Copia los <b>ejemplos</b> en la carpeta <code>xenobooth/data/{CLASS_NAME}</code>")
prettyprint(f"Copia tus <b>fotos</b> en la carpeta <code>xenobooth/{INSTANCE_NAME}/selfies</code>")


In [None]:
#@title ## 5 Una vez copiadas, comprobamos que todo está OK

NUM_IMAGES = 0
for path in os.listdir(INSTANCE_DATA_DIR):
    if os.path.isfile(os.path.join(INSTANCE_DATA_DIR, path)):
        NUM_IMAGES += 1
prettyprint(f'Has subido {NUM_IMAGES} imágenes')

In [None]:
#@title ## 6 Arranca el entrenamiento
os.environ['CURL_CA_BUNDLE'] = ''

num_class_images = NUM_IMAGES * 12
max_train_steps = NUM_IMAGES * 80
lr_warmup_steps = max_train_steps // 10

start = time.time()

!python3 train_dreambooth.py \
  --num_class_images=$num_class_images \
  --max_train_steps=$max_train_steps \
  --lr_warmup_steps=$lr_warmup_steps \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --save_sample_prompt="$CONCEPT" \
  --pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse" \
  --output_dir=$OUTPUT_DIR \
  --revision="fp16" \
  --with_prior_preservation \
  --prior_loss_weight=1.0 \
  --seed=1337 \
  --resolution=512 \
  --train_batch_size=1 \
  --train_text_encoder \
  --mixed_precision="fp16" \
  --use_8bit_adam \
  --gradient_accumulation_steps=1 \
  --learning_rate=1e-6 \
  --lr_scheduler="constant" \
  --sample_batch_size=4 \
  --save_interval=10000 \
  --concepts_list="concepts_list.json"

end = time.time()
print(f"{end - start:.0f} segundos")

WEIGHTS_DIR = natsorted(glob(OUTPUT_DIR + os.sep + "*"))[-1]
prettyprint(f"### WEIGHTS_DIR={WEIGHTS_DIR}")

In [None]:
#@title ## 7 Primeros resultados

weights_folder = OUTPUT_DIR
folders = sorted([f for f in os.listdir(weights_folder) if f != "0"], key=lambda x: int(x))

row = len(folders)
col = len(os.listdir(os.path.join(weights_folder, folders[0], "samples")))
scale = 4
fig, axes = plt.subplots(row, col, figsize=(col*scale, row*scale), gridspec_kw={'hspace': 0, 'wspace': 0})

for i, folder in enumerate(folders):
    folder_path = os.path.join(weights_folder, folder)
    image_folder = os.path.join(folder_path, "samples")
    images = [f for f in os.listdir(image_folder)]
    for j, image in enumerate(images):
        if row == 1:
            currAxes = axes[j]
        else:
            currAxes = axes[i, j]
        if i == 0:
            currAxes.set_title(f"Image {j}")
        if j == 0:
            currAxes.text(-0.1, 0.5, folder, rotation=0, va='center', ha='center', transform=currAxes.transAxes)
        image_path = os.path.join(image_folder, image)
        img = mpimg.imread(image_path)
        currAxes.imshow(img, cmap='gray')
        currAxes.axis('off')

plt.tight_layout()
plt.savefig('grid.png', dpi=72)

In [None]:
#@title ## 8 Elegir modelo
#@markdown  Será una ruta parecida a esta: /content/drive/MyDrive/xenobooth/INSTANCE_NAME/pesos/1200

model_path = "/content/drive/MyDrive/xenobooth/brplz/pesos/1200" #@param {type:"string"}

if model_path == "":
  try:
    WEIGHTS_DIR
  except NameError:
    prettyprint("### ERROR: model_path no está definido")
  else:
    if WEIGHTS_DIR is None or WEIGHTS_DIR == "":
      prettyprint("### ERROR: model_path no está definido")
    else:
      final_model_path = WEIGHTS_DIR
      print(f"model_path = {final_model_path}")
else:
    final_model_path = model_path
    print(f"model_path = {final_model_path}")
if not os.path.isdir(final_model_path):
    prettyprint(f"## No encuentro la carpeta {final_model_path}")


In [None]:
#@title ## 9 Generar imágenes
pipe = StableDiffusionPipeline.from_pretrained(final_model_path, torch_dtype=torch.float16).to("cuda")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_xformers_memory_efficient_attention()
g_cuda = torch.Generator(device='cuda')

def inference(prompt, negative_prompt, num_samples, random_seed, num_inference_steps=30, guidance_scale=9):
    with torch.autocast("cuda"), torch.inference_mode():
        g_cuda.manual_seed(random_seed)
        images = pipe(
                prompt, height=int(512), width=int(512),
                negative_prompt=negative_prompt,
                num_images_per_prompt=int(num_samples),
                num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale,
                generator=g_cuda
            ).images
        ahora = datetime.now().strftime('%H%M%S')
        for index, image in enumerate(images):
          image.save(f"{GENERATION_DIR}/xeno_{ahora}_{index+1}.jpg", quality=100, subsampling=0)
        return images

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", value=f"a portrait of a {INSTANCE_NAME} {CLASS_NAME}, oil painting, by sorolla")
            negative_prompt = gr.Textbox(label="Negative Prompt", value="ugly, additional arms, additional legs, additional head, two heads, blurry, pixelated, extra hands, extra arms, collage, grainy, low, poor, monochrome")
            run = gr.Button(value="Generate")
            random_seed = gr.Slider(label="Seed", value=2345678, maximum=9999999, step=1)
            num_inference_steps = gr.Slider(label="Steps", value=30)
            with gr.Row():
                num_samples = gr.Number(label="Number of Samples", value=1)
                guidance_scale = gr.Number(label="Guidance Scale", value=9)
        with gr.Column():
            gallery = gr.Gallery().style(preview=True)

    run.click(inference, inputs=[prompt, negative_prompt, num_samples, random_seed, num_inference_steps, guidance_scale], outputs=gallery)

demo.launch(debug=True)