[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/trainer/blob/main/trainer.ipynb)

In [None]:
%cd /content

!apt -y update -qq
!wget https://github.com/camenduru/gperftools/releases/download/v1.0/libtcmalloc_minimal.so.4 -O /content/libtcmalloc_minimal.so.4
%env LD_PRELOAD=/content/libtcmalloc_minimal.so.4
%env TF_CPP_MIN_LOG_LEVEL=1

!apt -y install -qq aria2
!pip install -q torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 torchtext==0.15.2 torchdata==0.6.1 --extra-index-url https://download.pytorch.org/whl/cu118 -U
!pip install -q xformers==0.0.20 triton==2.0.0 diffusers==0.19.0 datasets==2.14.0 gradio==3.38.0 wandb==0.15.7 transformers==4.26.0 accelerate==0.16.0 bitsandbytes==0.41.0 -U

!git clone https://github.com/camenduru/trainer

diffusers_version = "v0.19.0"
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M  https://raw.githubusercontent.com/huggingface/diffusers/{diffusers_version}/scripts/convert_diffusers_to_original_stable_diffusion.py -d /content/trainer/diffusers/dreambooth -o convert_diffusers_to_original_stable_diffusion.py
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M  https://raw.githubusercontent.com/huggingface/diffusers/{diffusers_version}/examples/dreambooth/train_dreambooth.py -d /content/trainer/diffusers/dreambooth -o train_dreambooth.py
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M  https://raw.githubusercontent.com/huggingface/diffusers/{diffusers_version}/examples/dreambooth/train_dreambooth_lora.py -d /content/trainer/diffusers/lora -o train_dreambooth_lora.py

BaseModelUrl = "https://huggingface.co/uf/cyberrealistic_v3.2"
BaseModelDir = "/content/model"
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/model_index.json -d {BaseModelDir} -o model_index.json
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/resolve/main/vae/diffusion_pytorch_model.bin -d {BaseModelDir}/vae -o diffusion_pytorch_model.bin
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/vae/config.json -d {BaseModelDir}/vae -o config.json
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/resolve/main/unet/diffusion_pytorch_model.bin -d {BaseModelDir}/unet -o diffusion_pytorch_model.bin
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/unet/config.json -d {BaseModelDir}/unet -o config.json
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/tokenizer/vocab.json -d {BaseModelDir}/tokenizer -o vocab.json
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/tokenizer/tokenizer_config.json -d {BaseModelDir}/tokenizer -o tokenizer_config.json
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/tokenizer/special_tokens_map.json -d {BaseModelDir}/tokenizer -o special_tokens_map.json
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/tokenizer/merges.txt -d {BaseModelDir}/tokenizer -o merges.txt
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/resolve/main/text_encoder/pytorch_model.bin -d {BaseModelDir}/text_encoder -o pytorch_model.bin
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/text_encoder/config.json -d {BaseModelDir}/text_encoder -o config.json
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/scheduler/scheduler_config.json -d {BaseModelDir}/scheduler -o scheduler_config.json
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/resolve/main/safety_checker/pytorch_model.bin -d {BaseModelDir}/safety_checker -o pytorch_model.bin
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/safety_checker/config.json -d {BaseModelDir}/safety_checker -o config.json
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {BaseModelUrl}/raw/main/feature_extractor/preprocessor_config.json -d {BaseModelDir}/feature_extractor -o preprocessor_config.json

%cd /content/trainer

import os, shutil
import gradio as gr
from gradio import strings

from dreambooth import Dreambooth
from lora_d import LoraD

trainer = gr.Blocks(title="Trainer")

def upload_file(files):
    file_paths = [file.name for file in files]
    if not os.path.exists('/content/images'):
        os.mkdir('/content/images')
    for file_path in file_paths:
        shutil.copy(file_path, '/content/images/')
    return file_paths

def launch():
    strings.en["SHARE_LINK_MESSAGE"] = f"😊"
    with trainer:
        with gr.Tab("Upload Images"):
            file_output = gr.File()
            upload_button = gr.UploadButton("Upload Images", file_types=["image"], file_count="multiple")
            upload_button.upload(upload_file, upload_button, file_output)
        Dreambooth.tab()
        LoraD.tab()
    trainer.queue().launch(debug=True, share=True, inline=False)

if __name__ == "__main__":
    launch()