[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/trainer/blob/main/realistic.ipynb)

In [None]:
%cd /content

!apt -y update -qq
!wget https://github.com/camenduru/gperftools/releases/download/v1.0/libtcmalloc_minimal.so.4 -O /content/libtcmalloc_minimal.so.4
%env LD_PRELOAD=/content/libtcmalloc_minimal.so.4
%env TF_CPP_MIN_LOG_LEVEL=1

!apt -y install -qq aria2
!pip install -q torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 torchtext==0.15.2 torchdata==0.6.1 --extra-index-url https://download.pytorch.org/whl/cu118 -U
!pip install -q xformers==0.0.20 triton==2.0.0 diffusers==0.19.0 datasets==2.14.0 gradio==3.38.0 wandb==0.15.7 transformers==4.26.0 accelerate==0.16.0 bitsandbytes==0.41.0 -U

!git clone https://github.com/camenduru/trainer

diffusers_version = "v0.19.0"
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M  https://raw.githubusercontent.com/huggingface/diffusers/{diffusers_version}/scripts/convert_diffusers_to_original_stable_diffusion.py -d /content/trainer/diffusers/dreambooth -o convert_diffusers_to_original_stable_diffusion.py
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M  https://raw.githubusercontent.com/huggingface/diffusers/{diffusers_version}/examples/dreambooth/train_dreambooth.py -d /content/trainer/diffusers/dreambooth -o train_dreambooth.py
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M  https://raw.githubusercontent.com/huggingface/diffusers/{diffusers_version}/examples/dreambooth/train_dreambooth_lora.py -d /content/trainer/diffusers/lora -o train_dreambooth_lora.py

BaseModelUrl = "https://huggingface.co/uf/cyberrealistic_v3.2"
BaseModelDir = "/content/model"
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/model_index.json -d {BaseModelDir} -o model_index.json
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/resolve/main/vae/diffusion_pytorch_model.bin -d {BaseModelDir}/vae -o diffusion_pytorch_model.bin
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/vae/config.json -d {BaseModelDir}/vae -o config.json
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/resolve/main/unet/diffusion_pytorch_model.bin -d {BaseModelDir}/unet -o diffusion_pytorch_model.bin
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/unet/config.json -d {BaseModelDir}/unet -o config.json
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/tokenizer/vocab.json -d {BaseModelDir}/tokenizer -o vocab.json
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/tokenizer/tokenizer_config.json -d {BaseModelDir}/tokenizer -o tokenizer_config.json
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/tokenizer/special_tokens_map.json -d {BaseModelDir}/tokenizer -o special_tokens_map.json
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/tokenizer/merges.txt -d {BaseModelDir}/tokenizer -o merges.txt
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/resolve/main/text_encoder/pytorch_model.bin -d {BaseModelDir}/text_encoder -o pytorch_model.bin
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/text_encoder/config.json -d {BaseModelDir}/text_encoder -o config.json
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/scheduler/scheduler_config.json -d {BaseModelDir}/scheduler -o scheduler_config.json
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/resolve/main/safety_checker/pytorch_model.bin -d {BaseModelDir}/safety_checker -o pytorch_model.bin
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/safety_checker/config.json -d {BaseModelDir}/safety_checker -o config.json
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/feature_extractor/preprocessor_config.json -d {BaseModelDir}/feature_extractor -o preprocessor_config.json

%cd /content/trainer
from IPython.display import clear_output 
clear_output()

import os, shutil
import gradio as gr
from gradio import strings
from shared import Shared

trainer = gr.Blocks(title="Trainer")

train_lora_command = f"""python -u /content/trainer/diffusers/lora/train_dreambooth_lora.py \\
                            --pretrained_model_name_or_path="/content/model"  \\
                            --instance_data_dir="/content/images" \\
                            --output_dir="/content/lora" \\
                            --learning_rate=5e-6 \\
                            --max_train_steps=1250 \\
                            --instance_prompt="Required" \\
                            --resolution=512 \\
                            --center_crop \\
                            --train_batch_size=1 \\
                            --gradient_accumulation_steps=1 \\
                            --max_grad_norm=1.0 \\
                            --mixed_precision="fp16" \\
                            --gradient_checkpointing \\
                            --enable_xformers_memory_efficient_attention \\
                            --use_8bit_adam \\
                            --train_text_encoder"""

def upload_file(files):
    !rm -rf /content/images
    file_paths = [file.name for file in files]
    if not os.path.exists('/content/images'):
        os.mkdir('/content/images')
    for file_path in file_paths:
        shutil.copy(file_path, '/content/images/')
    return file_paths

def update_instance_prompt(learning_rate, max_train_steps, instance_prompt):
    train_lora_command = f"""python -u /content/trainer/diffusers/lora/train_dreambooth_lora.py \\
                            --pretrained_model_name_or_path="/content/model"  \\
                            --instance_data_dir="/content/images" \\
                            --output_dir="/content/lora" \\
                            --learning_rate={learning_rate} \\
                            --max_train_steps={max_train_steps} \\
                            --instance_prompt="{instance_prompt}" \\
                            --resolution=512 \\
                            --center_crop \\
                            --train_batch_size=1 \\
                            --gradient_accumulation_steps=1 \\
                            --max_grad_norm=1.0 \\
                            --mixed_precision="fp16" \\
                            --gradient_checkpointing \\
                            --enable_xformers_memory_efficient_attention \\
                            --use_8bit_adam \\
                            --train_text_encoder"""
    return train_lora_command

def launch():
    strings.en["SHARE_LINK_MESSAGE"] = f"😊"
    with trainer:
        with gr.Group():
          with gr.Row():
              with gr.Box():
                files = gr.Files(label="Upload Images", file_types=["image"], file_count="multiple")
                files.upload(fn=upload_file, inputs=files)
              with gr.Box():
                  learning_rate = gr.Textbox(label="learning_rate", value=1250)
                  max_train_steps = gr.Textbox(label="max_train_steps", value=5e-6)
                  instance_prompt = gr.Textbox(label="instance_prompt", value="Required")
                  update_command = gr.Button(value="Update train command")
                  lora_command = gr.Textbox(show_label=False, lines=16, value=train_lora_command)
                  update_command.click(fn=update_instance_prompt, inputs=[learning_rate, max_train_steps, instance_prompt], outputs=lora_command)
                  train_lora_out_text = gr.Textbox(show_label=False)
                  btn_train_lora_run_live = gr.Button("Train Lora")
                  btn_train_lora_run_live.click(Shared.run_live, inputs=lora_command, outputs=train_lora_out_text, show_progress=False)
        with gr.Group():
          with gr.Row():
              with gr.Box():
                  image = gr.Image(show_label=False)
              with gr.Box():
                  model_dir = gr.Textbox(label="Enter your output dir", show_label=False, max_lines=1, value="/content/model")
                  output_dir = gr.Textbox(label="Enter your output dir", show_label=False, max_lines=1, value="/content/lora")
                  prompt = gr.Textbox(label="prompt", show_label=False, max_lines=1, placeholder="Enter your prompt")
                  negative_prompt = gr.Textbox(label="negative prompt", show_label=False, max_lines=1, placeholder="Enter your negative prompt")
                  steps = gr.Slider(label="Steps", minimum=5, maximum=50, value=25, step=1)
                  scale = gr.Slider(label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1)
                  checkbox = gr.Checkbox(label="Load Model", value=True)
                  btn_test_lora = gr.Button("Generate image")
                  btn_test_lora.click(Shared.test_lora, inputs=[model_dir, checkbox, output_dir, prompt, negative_prompt, steps, scale], outputs=image)
    trainer.queue().launch(debug=True, share=True, inline=False)

if __name__ == "__main__":
    launch()