In [None]:
# 1. Install the required packages
# On Windows, you just need to execute this cell for once.
try:
    import google.colab
    # IN_COLAB = True
except ImportError:
    # IN_COLAB = False
    %pip install -q git+https://github.com/huggingface/transformers
    %pip install -q git+https://github.com/huggingface/accelerate

%pip install -q git+https://github.com/huggingface/diffusers
%pip install -q gradio ftfy tensorboard
%pip install -q bitsandbytes
#%pip install -U git+https://github.com/TimDettmers/bitsandbytes.git
%pip install -q xformers --index-url https://download.pytorch.org/whl/cu124
#%pip install -U git+https://github.com/facebookresearch/xformers.git@main
print("Package installation finished.")

In [None]:
# 2. Create folders and download training scripts
import os, shutil

dataset_dir = "./dataset"
output_dir = "./output"
logging_dir = "./log"

# Create the directories if they don't exist
os.makedirs(dataset_dir, exist_ok=True)
# Delete the 'output' folder and its contents
shutil.rmtree(output_dir, ignore_errors=True)
os.makedirs(output_dir, exist_ok=True)
# Delete the 'log' folder and its contents
shutil.rmtree(logging_dir, ignore_errors=True)
os.makedirs(logging_dir, exist_ok=True)
# Delete the 'dataset' folder and its contents
# shutil.rmtree(dataset_dir, ignore_errors=True)
# os.makedirs(dataset_dir, exist_ok=True)

# fetch train_dreambooth.py if it doesn't exist
if not os.path.exists("train_dreambooth.py"):
    !wget https://raw.githubusercontent.com/jomo0825/MrFuGenerativeAI/main/Dreambooth/train_dreambooth.py
else:
    print("train_dreambooth.py already exists, skipping download.")

# fetch convertosdv2.py if it doesn't exist
if not os.path.exists("convert_diffusers_to_original_stable_diffusion.py"):
    !wget https://raw.githubusercontent.com/jomo0825/MrFuGenerativeAI/main/Dreambooth/convert_diffusers_to_original_stable_diffusion.py
else:
    print("convert_diffusers_to_original_stable_diffusion.py already exists, skipping download.")

ipynb_checkpoints = os.path.join( dataset_dir, ".ipynb_checkpoints")
shutil.rmtree(".gradio", ignore_errors=True)
shutil.rmtree(".config", ignore_errors=True)
shutil.rmtree(ipynb_checkpoints, ignore_errors=True)


In [None]:
# 3. Upload dataset images
def load_dataset():
  try:
    from google.colab import files
    import os

    # Upload files from local machine
    uploaded = files.upload()

    # Move the uploaded files to the target directory
    for filename in uploaded.keys():
        # Get source and destination paths
        source_path = filename
        destination_path = os.path.join(dataset_dir, filename)

        # Move the file
        !mv "{source_path}" "{destination_path}"
        print(f"Moved {filename} to {dataset_dir}")

  except ImportError:
    import tkinter as tk
    from tkinter import filedialog
    import os

    # Initialize tkinter
    root = tk.Tk()
    root.withdraw()  # Hide the root window
    root.attributes('-topmost',True)

    # Open a file dialog and allow multiple file selection
    file_paths = filedialog.askopenfilenames(
        title='Select Dataset Images',
        filetypes=[('Image Files', '*.png;*.jpg;*.jpeg;*.bmp;*.gif')]
    )
    root.destroy()

    # Define the target directory
    target_directory = dataset_dir

    # Copy or move files to the target directory
    for file_path in file_paths:
        filename = os.path.basename(file_path)
        destination = os.path.join(target_directory, filename)
        os.rename(file_path, destination)  # or use shutil.copy for copying
        print(f'File {filename} saved to {target_directory}')

load_dataset()

In [None]:
# 4. Create a WebUI for training
# It will download the SD v1.5 for the 1st time training
# A very good reference:
# https://www.reddit.com/r/StableDiffusion/comments/ybxv7h/good_dreambooth_formula/

import gradio as gr
import sys
import threading
#from textual_inversion import main as train_textual_inversion  # Assuming main function is the entry point in textual_inversion.py
# from train_dreambooth import main as train_dreambooth
import time, os, logging
from os import path
import subprocess
import shlex
import queue
from PIL import Image

def parse_lr_schedule(lr_schedule_str):
    schedule = []
    segments = lr_schedule_str.split(',')
    for segment in segments:
        if ':' in segment:
            lr, steps = segment.split(':')
            schedule.append((float(lr), int(steps)))
        else:
            schedule.append((float(segment), None))  # Final constant learning rate
    return schedule

def get_learning_rate_at_step(lr_schedule, step):
    current_step = 0
    for lr, segment_steps in lr_schedule:
        if segment_steps is None or step < current_step + segment_steps:
            return lr
        current_step += segment_steps
    return lr_schedule[-1][0]  # Return the last LR if beyond defined steps

# Callback to update the preview image in the UI
def preview_callback(image, step):
    global current_preview, current_status
    current_preview = image
    current_status = f"Preview updated at step {step}"

def run_training(dataset_path, prompt, placeholder_token, initializer_token, num_training_steps,
                 learning_rate, batch_size, preview_save_steps, preview_seed):
    global current_preview, current_status
    current_preview = None  # Reset the preview
    current_status = "Training started..."  # Initial status

    # Define DreamBooth training parameters as a list of command-line arguments.
    # Adjust the paths, prompts, and hyperparameters to match your experiment.
    command = [
        "accelerate", "launch", "train_dreambooth.py",
        "--pretrained_model_name_or_path", "stable-diffusion-v1-5/stable-diffusion-v1-5",  # or your chosen model
        "--instance_data_dir", dataset_path,  # folder with your subject images
        "--instance_prompt", placeholder_token,  # prompt identifier for your subject
        "--output_dir", output_dir,          # where to save your DreamBooth model
        "--train_batch_size", str(batch_size),
        "--resolution", "512",
        "--lr_scheduler", "linear",
        "--learning_rate", str(learning_rate),
        "--lr_warmup_steps", "0",
        "--gradient_accumulation_steps", "1",
        "--num_validation_images", "1",
        "--validation_prompt", prompt,
        "--validation_steps", str(preview_save_steps),
        "--max_train_steps", str(num_training_steps),
        "--mixed_precision", "fp16",
        "--use_8bit_adam",
        "--gradient_checkpointing",
        "--enable_xformers_memory_efficient_attention",
        "--set_grads_to_none",
        "--logging_dir", logging_dir,
        "--seed", str(preview_seed),
        "--checkpointing_steps", str(num_training_steps+1),
        # "--train_only_unet",


        # "--class_data_dir", "./class_images",        # folder with class images (for prior preservation)
        # "--class_prompt", "a photo of a person",        # prompt for class images
        # "--with_prior_preservation",           # enable prior preservation if you have class images
        # "--num_class_images", "100",           # adjust based on your available class images
        # Add other DreamBooth parameters as needed
    ]

    # Parse the arguments using the DreamBooth parser
    #args = train_dreambooth.parse_args(args_list)

    # Now, call the main function to start training
    #train_dreambooth.main(args)


    # Print the command for debugging
    print("Command:", " ".join(command))
    # temp = " ".join(command)
    yield None, gr.update(value="Training...")

    # Disable logging
    logging.getLogger("accelerate").disabled = True

    # Run the command in a separate process
    global process
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    # Wait for the process to complete
    stdout, stderr = process.communicate()

    # Print the output and errors (for debugging)
    print("Output:", stdout.decode())
    print("Errors:", stderr.decode())

    # Update status when training completes
    current_status = "Converting model..."
    yield gr.update(value=current_preview), gr.update(value=current_status)

    global pipeline
    pipeline=None
    !python convert_diffusers_to_original_stable_diffusion.py --model_path {output_dir} --checkpoint_path model.safetensors --use_safetensors
    current_status = "Training completed!"
    process.kill()
    yield gr.update(value=current_preview), gr.update(value=current_status)


def ui():
    with gr.Blocks() as demo:
        gr.Markdown("# Stable Diffusion Dreambooth WebUI")
        gr.Markdown("Train Stable Diffusion model using preloaded weights.")

        with gr.Row():
            dataset_path = gr.Textbox(label="Dataset Path", value="dataset", interactive=True)

        with gr.Row():
            with gr.Column(scale=1, min_width=300):
                placeholder_token = gr.Textbox(label="Placeholder Token", placeholder="Enter placeholder token here", interactive=True)
                class_token = gr.Textbox(label="Class Token (not implemented yet)", placeholder="Not implemented yet", interactive=False)
                prompt = gr.Textbox(label="Preview Prompt", placeholder="Enter your prompt here", interactive=True)
                num_training_steps = gr.Number(label="Number of Training Steps", value=500, interactive=True)
                learning_rate = gr.Number(label="Learning Rate", value=0.000005, interactive=True)
                batch_size = gr.Number(label="Batch Size", value=4, interactive=True)
                preview_save_steps = gr.Number(label="Preview Steps", value=25, interactive=True)
                preview_seed = gr.Number(label="Preview Seed", value=1, interactive=True)
            with gr.Column(scale=1, min_width=300):
                output_image = gr.Image(label="Generated Image")

        generate_button = gr.Button("Start Training")

        generate_status = gr.Textbox(value="Status messages will appear here.", label="Status", interactive=False)

        generate_button.click(
            fn=run_training,
            inputs=[dataset_path, prompt, placeholder_token, class_token, num_training_steps,
                    learning_rate, batch_size, preview_save_steps, preview_seed],
            outputs=[output_image, generate_status],
            show_progress=True,
            queue=True
        )

    return demo

demo = ui()
demo.launch()


In [None]:
# 5. Loads the logs in TensorBoard
# If you are using Windows, open http://localhost:8888 in browser
# Enable auto update for each 30 seconds, look into Image tab and wait for update.
%load_ext tensorboard
%tensorboard --logdir=log/ --host localhost --port 8888

In [None]:
# 6. Test the Dreambooth model
from diffusers import StableDiffusionPipeline
import torch

preview_prompt = "wpg, a illustration of wpg. cyan, reflective, flower, 8k, lineart, extremly detailed eyes, digital painting. masterpiece, best quality."
negative_prompt = "vibrant color, grain, pattern, disfigured, kitsch, ugly, oversaturated, grain, low-res, Deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, poorly drawn hands, missing limb, blurry, floating limbs, disconnected limbs, malformed hands, blur, out of focus, long neck, long body, disgusting, poorly drawn, childish, mutilated, mangled, old, surreal"

if pipeline is None:
  pipeline = StableDiffusionPipeline.from_single_file(
      "model.safetensors",
      negative_prompt=negative_prompt,
      torch_dtype=torch.float16,
  ).to("cuda")

output = pipeline(
    "wpg",
    num_inference_steps=30,
    guidance_scale=7,
)

display(output.images[0])

In [None]:
# 7. Download your model.savetensors
# If you are using Colab, you can mount Google Drive and upload your model.safetensors
try:
    from google.colab import drive
    drive.mount('/content/drive')

    # Create a directory in Google Drive if it doesn't exist
    import os
    target_dir = "/content/drive/MyDrive/Dreambooth"
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)
        print(f"Created directory: {target_dir}")
    else:
        print(f"Directory already exists: {target_dir}")

    # Copy your file to Drive
    !cp /content/model.safetensors {target_dir}/model.safetensors
    print(f"Your Dreambooth model has been uploaded to your Google Drive folder {target_dir}")
except:
  pass

In [None]:
# Force terminate the WebUI and training process
demo.close()
process.kill()