<a href="https://colab.research.google.com/github/Linaqruf/kohya-trainer/blob/main/kohya-trainer-XL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[![visitor][visitor-badge]][visitor-stats]
[![ko-fi][ko-fi-badge]][ko-fi-link]

# **Kohya Trainer XL**
A Colab Notebook For Native Training

[visitor-badge]: https://api.visitorbadge.io/api/visitors?path=Kohya%20Trainer%20XL&label=Visitors&labelColor=%2334495E&countColor=%231ABC9C&style=flat&labelStyle=none
[visitor-stats]: https://visitorbadge.io/status?path=Kohya%20Trainer%20XL
[ko-fi-badge]: https://img.shields.io/badge/Support%20me%20on%20Ko--fi-F16061?logo=ko-fi&logoColor=white&style=flat
[ko-fi-link]: https://ko-fi.com/linaqruf


| Notebook Name | Description | Link |
| --- | --- | --- |
| [Kohya LoRA Trainer XL](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-LoRA-trainer-XL.ipynb) | LoRA Training | [![](https://img.shields.io/static/v1?message=Open%20in%20Colab&logo=googlecolab&labelColor=5c5c5c&color=0f80c1&label=%20&style=flat)](https://colab.research.google.com/github/Linaqruf/kohya-trainer/blob/main/kohya-LoRA-trainer-XL.ipynb) |
| [Kohya Trainer XL](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-trainer-XL.ipynb) | Native Training | [![](https://img.shields.io/static/v1?message=Open%20in%20Colab&logo=googlecolab&labelColor=5c5c5c&color=0f80c1&label=%20&style=flat)](https://colab.research.google.com/github/Linaqruf/kohya-trainer/blob/main/kohya-trainer-XL.ipynb) |


# **I. Prepare Environment**

In [None]:
# Make sure the GPU is A100
!nvidia-smi

In [None]:
# @title ## **1.1. Install Kohya Trainer**
import os
import zipfile
import shutil
import time
import requests
import torch
from subprocess import getoutput
from IPython.utils import capture
from google.colab import drive

%store -r

# root_dir
root_dir          = "/content"
drive_dir         = os.path.join(root_dir, "drive/MyDrive")
deps_dir          = os.path.join(root_dir, "deps")
repo_dir          = os.path.join(root_dir, "kohya-trainer")
training_dir      = os.path.join(root_dir, "fine_tune")
pretrained_model  = os.path.join(root_dir, "pretrained_model")
vae_dir           = os.path.join(root_dir, "vae")
lora_dir          = os.path.join(root_dir, "network_weight")
repositories_dir  = os.path.join(root_dir, "repositories")
config_dir        = os.path.join(training_dir, "config")
tools_dir         = os.path.join(repo_dir, "tools")
finetune_dir      = os.path.join(repo_dir, "finetune")
accelerate_config = os.path.join(repo_dir, "accelerate_config/config.yaml")

for store in ["root_dir", "repo_dir", "training_dir", "pretrained_model", "vae_dir", "repositories_dir", "accelerate_config", "tools_dir", "finetune_dir", "config_dir"]:
    with capture.capture_output() as cap:
        %store {store}
        del cap

repo_dict = {
    "qaneel/kohya-trainer (forked repo, stable, optimized for colab use)" : "https://github.com/qaneel/kohya-trainer",
    "kohya-ss/sd-scripts (original repo, latest update)"                    : "https://github.com/kohya-ss/sd-scripts",
}

repository        = "qaneel/kohya-trainer (forked repo, stable, optimized for colab use)" #@param ["qaneel/kohya-trainer (forked repo, stable, optimized for colab use)", "kohya-ss/sd-scripts (original repo, latest update)"] {allow-input: true}
repo_url          = repo_dict[repository]
branch            = "main"  # @param {type: "string"}
output_to_drive   = True  # @param {type: "boolean"}

def clone_repo(url, dir, branch):
    if not os.path.exists(dir):
       !git clone -b {branch} {url} {dir}

def mount_drive(dir):
    output_dir      = os.path.join(training_dir, "output")

    if output_to_drive:
        if not os.path.exists(drive_dir):
            drive.mount(os.path.dirname(drive_dir))
        output_dir  = os.path.join(drive_dir, "kohya-trainer/output")

    return output_dir

def setup_directories():
    global output_dir

    output_dir      = mount_drive(drive_dir)

    for dir in [training_dir, config_dir, pretrained_model, vae_dir, repositories_dir, output_dir]:
        os.makedirs(dir, exist_ok=True)

def pastebin_reader(id):
    if "pastebin.com" in id:
        url = id
        if 'raw' not in url:
                url = url.replace('pastebin.com', 'pastebin.com/raw')
    else:
        url = "https://pastebin.com/raw/" + id
    response = requests.get(url)
    response.raise_for_status()
    lines = response.text.split('\n')
    return lines

def install_repository():
    global infinite_image_browser_dir, voldy, discordia_archivum_dir

    _, voldy = pastebin_reader("kq6ZmHFU")[:2]

    infinite_image_browser_url  = f"https://github.com/zanllp/{voldy}-infinite-image-browsing.git"
    infinite_image_browser_dir  = os.path.join(repositories_dir, f"infinite-image-browsing")
    infinite_image_browser_deps = os.path.join(infinite_image_browser_dir, "requirements.txt")

    discordia_archivum_url = "https://github.com/Linaqruf/discordia-archivum"
    discordia_archivum_dir = os.path.join(repositories_dir, "discordia-archivum")
    discordia_archivum_deps = os.path.join(discordia_archivum_dir, "requirements.txt")

    clone_repo(infinite_image_browser_url, infinite_image_browser_dir, "main")
    clone_repo(discordia_archivum_url, discordia_archivum_dir, "main")

    !pip install -q --upgrade -r {infinite_image_browser_deps}
    !pip install python-dotenv
    !pip install -q --upgrade -r {discordia_archivum_deps}

def install_dependencies():
    requirements_file = os.path.join(repo_dir, "requirements.txt")
    model_util        = os.path.join(repo_dir, "library/model_util.py")
    gpu_info          = getoutput('nvidia-smi')
    t4_xformers_wheel = "https://github.com/Linaqruf/colab-xformers/releases/download/0.0.20/xformers-0.0.20+1d635e1.d20230519-cp310-cp310-linux_x86_64.whl"

    !apt install aria2 -yqq
    !pip install -q --upgrade -r {requirements_file}

    !pip install -q xformers==0.0.22.post7

    from accelerate.utils import write_basic_config

    if not os.path.exists(accelerate_config):
        write_basic_config(save_location=accelerate_config)

def prepare_environment():
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["SAFETENSORS_FAST_GPU"] = "1"
    os.environ["PYTHONWARNINGS"] = "ignore"

def main():
    os.chdir(root_dir)
    clone_repo(repo_url, repo_dir, branch)
    os.chdir(repo_dir)
    setup_directories()
    install_repository()
    install_dependencies()
    prepare_environment()

main()

In [None]:
# @title ## **1.2. Download SDXL**
import os
import re
import json
import glob
import gdown
import requests
import subprocess
from IPython.utils import capture
from urllib.parse import urlparse, unquote
from pathlib import Path

%store -r

os.chdir(root_dir)

# @markdown Place your Huggingface [Read Token](https://huggingface.co/settings/tokens) Here. Get your SDXL access [here](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9).

HUGGINGFACE_TOKEN = ""#@param {type: "string"}
SDXL_MODEL_URL = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9/resolve/main/sd_xl_base_0.9.safetensors" #@param {type: "string"}
SDXL_VAE_URL = "https://huggingface.co/stabilityai/sdxl-vae/resolve/main/sdxl_vae.safetensors" #@param {type: "string"}

def get_supported_extensions():
    return tuple([".ckpt", ".safetensors", ".pt", ".pth"])

def get_filename(url, quiet=True):
    extensions = get_supported_extensions()

    if url.startswith(drive_dir) or url.endswith(tuple(extensions)):
        filename = os.path.basename(url)
    else:
        response = requests.get(url, stream=True)
        response.raise_for_status()

        if 'content-disposition' in response.headers:
            content_disposition = response.headers['content-disposition']
            filename = re.findall('filename="?([^"]+)"?', content_disposition)[0]
        else:
            url_path = urlparse(url).path
            filename = unquote(os.path.basename(url_path))

    if filename.endswith(tuple(get_supported_extensions())):
        return filename
    else:
        return None

def get_most_recent_file(directory):
    files = glob.glob(os.path.join(directory, "*"))
    if not files:
        return None
    most_recent_file = max(files, key=os.path.getmtime)
    basename = os.path.basename(most_recent_file)

    return most_recent_file

def parse_args(config):
    args = []

    for k, v in config.items():
        if k.startswith("_"):
            args.append(f"{v}")
        elif isinstance(v, str) and v is not None:
            args.append(f'--{k}={v}')
        elif isinstance(v, bool) and v:
            args.append(f"--{k}")
        elif isinstance(v, float) and not isinstance(v, bool):
            args.append(f"--{k}={v}")
        elif isinstance(v, int) and not isinstance(v, bool):
            args.append(f"--{k}={v}")

    return args

def aria2_download(dir, filename, url):
    # hf_token    = "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" if not HUGGINGFACE_TOKEN else HUGGINGFACE_TOKEN
    user_header = f"Authorization: Bearer {HUGGINGFACE_TOKEN}"

    aria2_config = {
        "console-log-level"         : "error",
        "summary-interval"          : 10,
        "header"                    : user_header if "huggingface.co" in url else None,
        "continue"                  : True,
        "max-connection-per-server" : 16,
        "min-split-size"            : "1M",
        "split"                     : 16,
        "dir"                       : dir,
        "out"                       : filename,
        "_url"                      : url,
    }
    aria2_args = parse_args(aria2_config)
    subprocess.run(["aria2c", *aria2_args])

def gdown_download(url, dst, filepath):
    if "/uc?id/" in url:
        return gdown.download(url, filepath, quiet=False)
    elif "/file/d/" in url:
        return gdown.download(url=url, output=filepath, quiet=False, fuzzy=True)
    elif "/drive/folders/" in url:
        os.chdir(dst)
        return gdown.download_folder(url, quiet=True, use_cookies=False)

def download(url, dst):
    filename = get_filename(url, quiet=False)
    filepath = os.path.join(dst, filename)

    if "drive.google.com" in url:
        gdown = gdown_download(url, dst, filepath)
    elif url.startswith("/content/drive/MyDrive/"):
        # Path(filepath).write_bytes(Path(url).read_bytes())
        return url
    else:
        if "huggingface.co" in url:
            if "/blob/" in url:
                url = url.replace("/blob/", "/resolve/")
        aria2_download(dst, filename, url)

def get_filepath(url, dst):
    extensions = get_supported_extensions()
    filename = get_filename(url)

    if not filename.endswith(extensions):
        most_recent_file = get_most_recent_file(dst)
        filename = os.path.basename(most_recent_file)

    filepath = os.path.join(dst, filename)

    return filepath

def main():
    global model_path, vae_path

    model_path = vae_path = None

    download_targets = {
        "model" : (SDXL_MODEL_URL, pretrained_model),
        "vae"   : (SDXL_VAE_URL, vae_dir),
    }
    selected_files = {}

    for target, (url, dst) in download_targets.items():
        if url and f"PASTE {target.upper()} URL OR GDRIVE PATH HERE" not in url:
            downloader = download(url, dst)
            selected_files[target] = get_filepath(url, dst)

            if target == "model":
                model_path = selected_files["model"] if not downloader else downloader
            elif target == "vae":
                vae_path = selected_files["vae"] if not downloader else downloader

    for category, path in {
        "model": model_path,
        "vae": vae_path,
    }.items():
        if path is not None and os.path.exists(path):
            print(f"Selected {category}: {path}")

main()

In [None]:
# @title ## **1.3. Directory Config**
# @markdown Specify the location of your training data in the following cell. A folder with the same name as your input will be created.
import os

%store -r

train_data_dir = "/content/fine_tune/train_data"  # @param {'type' : 'string'}
%store train_data_dir

os.makedirs(train_data_dir, exist_ok=True)
print(f"Your train data directory : {train_data_dir}")

In [None]:
# @title ## **1.4. Image Browser**
import os
import json
import random
import portpicker
from IPython.utils import capture
from IPython.display import clear_output
from threading import Thread
from imjoy_elfinder.app import main
from google.colab.output import serve_kernel_port_as_iframe, serve_kernel_port_as_window

%store -r

# @markdown This cell allows you to view and manage your images in real-time. You can use it to:
# @markdown - Prepare your dataset before training
# @markdown - Monitor the sample outputs during training.

root_dir      = "/content"
browser_type  = "sd-webui-infinite-image-browsing" #@param ["imjoy-elfinder", "sd-webui-infinite-image-browsing"]
window_height = 550 #@param {type:"slider", min:0, max:1000, step:1}

main_app          = os.path.join(infinite_image_browser_dir, "app.py")
config_file       = os.path.join(infinite_image_browser_dir, "config.json")
port              = portpicker.pick_unused_port()

config = {
    "outdir_txt2img_samples": train_data_dir,
}

def write_file(filename, config):
    with open(filename, 'w',) as f:
        json.dump(config, f, indent=4)

def run_app():
    !python {main_app} --port={port} --sd_webui_config={config_file} > /dev/null 2>&1

def launch():
    os.chdir(root_dir)

    thread = Thread(target=main, args=[[f"--root-dir={root_dir}",
                                        f"--port={port}",
                                        f"--thumbnail"]])

    if browser_type == "sd-webui-infinite-image-browsing":
        os.chdir(train_data_dir)
        write_file(config_file, config)

        thread = Thread(target=run_app)

    thread.start()

    serve_kernel_port_as_iframe(port, width='100%', height=window_height, cache_in_notebook=False)

    clear_output(wait=True)

launch()

# **II. Data Gathering**

You have three options for collecting your dataset:

1. Upload it to Colab's local files.
2. Use the `Simple Booru Scraper` to download images in bulk from Danbooru.
3. Locate your dataset in Google Drive.


In [None]:
# @title ## **2.1. Unzip Dataset**

import os
import shutil
from pathlib import Path

#@title ## Unzip Dataset
# @markdown If your dataset is in a `zip` file and has been uploaded to a location, use this section to extract it. The dataset will be downloaded and automatically extracted to `train_data_dir` if `unzip_to` is empty.
zipfile_url  = "https://huggingface.co/datasets/Linaqruf/hitokomoru-lora-dataset/resolve/main/hitokomoru_dataset.zip" #@param {type:"string"}
zipfile_name = "zipfile.zip"
unzip_to     = "" #@param {type:"string"}

hf_token     = "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"
user_header  = f'"Authorization: Bearer {hf_token}"'

if unzip_to:
    os.makedirs(unzip_to, exist_ok=True)
else:
    unzip_to = train_data_dir

def download_dataset(url):
    if url.startswith("/content"):
        return url
    elif "drive.google.com" in url:
        os.chdir(root_dir)
        !gdown --fuzzy {url}
        return f"{root_dir}/{zipfile_name}"
    elif "huggingface.co" in url:
        if "/blob/" in url:
            url = url.replace("/blob/", "/resolve/")
        !aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 -d {root_dir} -o {zipfile_name} {url}
        return f"{root_dir}/{zipfile_name}"
    else:
        !aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {root_dir} -o {zipfile_name} {url}
        return f"{root_dir}/{zipfile_name}"

def extract_dataset(zip_file, output_path):
    !unzip -j -o {zip_file} -d "{output_path}"

def remove_files(train_dir, files_to_move):
    for filename in os.listdir(train_dir):
        file_path = os.path.join(train_dir, filename)
        if filename in files_to_move:
            if not os.path.exists(file_path):
                shutil.move(file_path, training_dir)
            else:
                os.remove(file_path)

zip_file = download_dataset(zipfile_url)
extract_dataset(zip_file, unzip_to)
os.remove(zip_file)

files_to_move = (
    "meta_cap.json",
    "meta_cap_dd.json",
    "meta_lat.json",
    "meta_clean.json",
)

remove_files(train_data_dir, files_to_move)

In [None]:
#@title ## **2.2. Scrape Dataset**
import os
import html
from IPython.utils import capture
%store -r

os.chdir(root_dir)
#@markdown Use `gallery-dl` to scrape images from an imageboard site. To specify `prompt(s)`, separate them with commas (e.g., `hito_komoru, touhou`).
booru = "Danbooru" #@param ["Danbooru", "Gelbooru", "Safebooru"]
prompt = "hitokomoru" #@param {type: "string"}

#@markdown Alternatively, you can provide a `custom_url` instead of using a predefined site.
custom_url = "" #@param {type: "string"}

#@markdown Use the `sub_folder` option to organize the downloaded images into separate folders based on their concept or category.
sub_folder = "" #@param {type: "string"}

user_agent = "gdl/1.24.5"

#@markdown You can limit the number of images to download by using the `--range` option followed by the desired range (e.g., `1-200`).
range = "1-200" #@param {type: "string"}

write_tags = False #@param {type: "boolean"}

additional_arguments = "--filename /O --no-part"

tags = prompt.split(',')
tags = '+'.join(tags)

replacement_dict = {" ": "", "(": "%28", ")": "%29", ":": "%3a"}
tags = ''.join(replacement_dict.get(c, c) for c in tags)

if sub_folder == "":
    image_dir = train_data_dir
elif sub_folder.startswith("/content"):
    image_dir = sub_folder
else:
    image_dir = os.path.join(train_data_dir, sub_folder)
    os.makedirs(image_dir, exist_ok=True)

if booru == "Danbooru":
    url = "https://danbooru.donmai.us/posts?tags={}".format(tags)
elif booru == "Gelbooru":
    url = "https://gelbooru.com/index.php?page=post&s=list&tags={}".format(tags)
else:
    url = "https://safebooru.org/index.php?page=post&s=list&tags={}".format(tags)

valid_url = custom_url if custom_url else url

def scrape(config):
    args = ""
    for k, v in config.items():
        if k.startswith("_"):
            args += f'"{v}" '
        elif isinstance(v, str):
            args += f'--{k}="{v}" '
        elif isinstance(v, bool) and v:
            args += f"--{k} "
        elif isinstance(v, float) and not isinstance(v, bool):
            args += f"--{k}={v} "
        elif isinstance(v, int) and not isinstance(v, bool):
            args += f"--{k}={v} "

    return args

def pre_process_tags(directory):
    for item in os.listdir(directory):
        item_path = os.path.join(directory, item)
        if os.path.isfile(item_path) and item.endswith(".txt"):
            old_path = item_path
            new_file_name = os.path.splitext(os.path.splitext(item)[0])[0] + ".txt"
            new_path = os.path.join(directory, new_file_name)

            os.rename(old_path, new_path)

            with open(new_path, "r") as f:
                contents = f.read()

            contents = html.unescape(contents)
            contents = contents.replace("_", " ")
            contents = ", ".join(contents.split("\n"))

            with open(new_path, "w") as f:
                f.write(contents)

        elif os.path.isdir(item_path):
            pre_process_tags(item_path)

get_url_config = {
    "_valid_url" : valid_url,
    "get-urls" : True,
    "range" : range if range else None,
    "user-agent" : user_agent
}

scrape_config = {
    "_valid_url" : valid_url,
    "directory" : image_dir,
    "write-tags" : write_tags,
    "range" : range if range else None,
    "user-agent" : user_agent
}

get_url_args = scrape(get_url_config)
scrape_args = scrape(scrape_config)
scraper_text = os.path.join(root_dir, "scrape_this.txt")

if write_tags:
    !gallery-dl {scrape_args} {additional_arguments}
    pre_process_tags(train_data_dir)
else:
    with capture.capture_output() as cap:
        !gallery-dl {get_url_args} {additional_arguments}
    with open(scraper_text, "w") as f:
        f.write(cap.stdout)

    os.chdir(image_dir)
    !aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -i {scraper_text}


In [None]:
#@title ## **2.3. Journey Scraper**
import os
%store -r

# @markdown Enter your Discord token below.
token = "" #@param {type: "string"}
channel_id = "1022054094476673085" #@param {type: "string"}
# @markdown Which bot do you want to scrape? This code is optimized to only scrape non-grid images from the Journey bot, so don't worry about cropping.
bot = "niji" #@param ["niji", "mid"]
# @markdown Set the limit of messages to scrape here. (This does not limit the number of messages to download.)
limit = 10000 #@param {type: "number"}
# @markdown To specify the `include_word` and `undesired_word`, separate them with commas (e.g., hito_komoru, touhou). By default, it scrapes the newest Niji model.
include_word = "" #@param {type:"string"}
undesired_word = "--style, --niji 4" #@param {type:"string"}
download_attachments = "single"

def scrape(config):
    args = ""
    for k, v in config.items():
        if k.startswith("_"):
            args += f'"{v}" '
        elif isinstance(v, str):
            args += f'--{k}="{v}" '
        elif isinstance(v, bool) and v:
            args += f"--{k} "
        elif isinstance(v, float) and not isinstance(v, bool):
            args += f"--{k}={v} "
        elif isinstance(v, int) and not isinstance(v, bool):
            args += f"--{k}={v} "

    return args

scrape_config = {
    "token": token,
    "channel_id": channel_id,
    "nijijourney": True if bot == "niji" else False,
    "midjourney": True if bot == "mid" else False,
    "limit": limit if limit else None,
    "prompt": include_word,
    "single": True,
    "undesired_word": undesired_word,
    "download_attachments": True,
    "output_folder": train_data_dir,

}

scrape_args = scrape(scrape_config)

os.chdir(discordia_archivum_dir)
!python main.py {scrape_args}


# **III. Data Preprocessing**

In [None]:
# @title ## **3.1. Data Cleaning**
import os
import random
import concurrent.futures
from tqdm import tqdm
from PIL import Image

%store -r

os.chdir(root_dir)

test = os.listdir(train_data_dir)
#@markdown This section removes unsupported media types such as `.mp4`, `.webm`, and `.gif`, as well as any unnecessary files.
#@markdown To convert a transparent dataset with an alpha channel (RGBA) to RGB and give it a white background, set the `convert` parameter to `True`.
convert = False  # @param {type:"boolean"}
#@markdown Alternatively, you can give the background a `random_color` instead of white by checking the corresponding option.
random_color = False  # @param {type:"boolean"}
recursive = False

batch_size = 32
supported_types = [
    ".png",
    ".jpg",
    ".jpeg",
    ".webp",
    ".bmp",
    ".caption",
    ".npz",
    ".txt",
    ".json",
]

background_colors = [
    (255, 255, 255),
    (0, 0, 0),
    (255, 0, 0),
    (0, 255, 0),
    (0, 0, 255),
    (255, 255, 0),
    (255, 0, 255),
    (0, 255, 255),
]

def clean_directory(directory):
    for item in os.listdir(directory):
        file_path = os.path.join(directory, item)
        if os.path.isfile(file_path):
            file_ext = os.path.splitext(item)[1]
            if file_ext not in supported_types:
                print(f"Deleting file {item} from {directory}")
                os.remove(file_path)
        elif os.path.isdir(file_path) and recursive:
            clean_directory(file_path)

def process_image(image_path):
    img = Image.open(image_path)
    img_dir, image_name = os.path.split(image_path)

    if img.mode in ("RGBA", "LA"):
        if random_color:
            background_color = random.choice(background_colors)
        else:
            background_color = (255, 255, 255)
        bg = Image.new("RGB", img.size, background_color)
        bg.paste(img, mask=img.split()[-1])

        if image_name.endswith(".webp"):
            bg = bg.convert("RGB")
            new_image_path = os.path.join(img_dir, image_name.replace(".webp", ".jpg"))
            bg.save(new_image_path, "JPEG")
            os.remove(image_path)
            print(f" Converted image: {image_name} to {os.path.basename(new_image_path)}")
        else:
            bg.save(image_path, "PNG")
            print(f" Converted image: {image_name}")
    else:
        if image_name.endswith(".webp"):
            new_image_path = os.path.join(img_dir, image_name.replace(".webp", ".jpg"))
            img.save(new_image_path, "JPEG")
            os.remove(image_path)
            print(f" Converted image: {image_name} to {os.path.basename(new_image_path)}")
        else:
            img.save(image_path, "PNG")

def find_images(directory):
    images = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(".png") or file.endswith(".webp"):
                images.append(os.path.join(root, file))
    return images

clean_directory(train_data_dir)
images = find_images(train_data_dir)
num_batches = len(images) // batch_size + 1

if convert:
    with concurrent.futures.ThreadPoolExecutor() as executor:
        for i in tqdm(range(num_batches)):
            start = i * batch_size
            end = start + batch_size
            batch = images[start:end]
            executor.map(process_image, batch)

    print("All images have been converted")

## **3.2. Data Captioning**

- For general images, use BLIP captioning.
- For anime and manga-style images, use Waifu Diffusion 1.4 Tagger V2.

In [None]:
#@title ### **3.2.1. BLIP Captioning**
#@markdown BLIP is a pre-training framework for unified vision-language understanding and generation, which achieves state-of-the-art results on a wide range of vision-language tasks. It can be used as a tool for image captioning, for example, `astronaut riding a horse in space`.
import os

os.chdir(finetune_dir)

beam_search = True #@param {type:'boolean'}
min_length = 5 #@param {type:"slider", min:0, max:100, step:5.0}
max_length = 75 #@param {type:"slider", min:0, max:100, step:5.0}

config = {
    "_train_data_dir"   : train_data_dir,
    "batch_size"        : 8,
    "beam_search"       : beam_search,
    "min_length"        : min_length,
    "max_length"        : max_length,
    "debug"             : True,
    "caption_extension" : ".caption",
    "max_data_loader_n_workers" : 2,
    "recursive"         : True
}

args = ""
for k, v in config.items():
    if k.startswith("_"):
        args += f'"{v}" '
    elif isinstance(v, str):
        args += f'--{k}="{v}" '
    elif isinstance(v, bool) and v:
        args += f"--{k} "
    elif isinstance(v, float) and not isinstance(v, bool):
        args += f"--{k}={v} "
    elif isinstance(v, int) and not isinstance(v, bool):
        args += f"--{k}={v} "

final_args = f"python make_captions.py {args}"

os.chdir(finetune_dir)
!{final_args}

In [None]:
#@title ### **3.2.2. Waifu Diffusion 1.4 Tagger V2**
import os
%store -r

os.chdir(finetune_dir)

#@markdown [Waifu Diffusion 1.4 Tagger V2](https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags) is a Danbooru-styled image classification model developed by SmilingWolf. It can also be useful for general image tagging, for example, `1girl, solo, looking_at_viewer, short_hair, bangs, simple_background`.
model = "SmilingWolf/wd-v1-4-moat-tagger-v2" #@param ["SmilingWolf/wd-v1-4-moat-tagger-v2", "SmilingWolf/wd-v1-4-convnextv2-tagger-v2", "SmilingWolf/wd-v1-4-swinv2-tagger-v2", "SmilingWolf/wd-v1-4-convnext-tagger-v2", "SmilingWolf/wd-v1-4-vit-tagger-v2"]
#@markdown Separate `undesired_tags` with comma `(,)` if you want to remove multiple tags, e.g. `1girl,solo,smile`.
undesired_tags = "" #@param {type:'string'}
#@markdown Adjust `general_threshold` for pruning tags (less tags, less flexible). `character_threshold` is useful if you want to train with character tags, e.g. `hakurei reimu`.
general_threshold = 0.55 #@param {type:"slider", min:0, max:1, step:0.05}
character_threshold = 0.35 #@param {type:"slider", min:0, max:1, step:0.05}

config = {
    "_train_data_dir"           : train_data_dir,
    "batch_size"                : 8,
    "repo_id"                   : model,
    "recursive"                 : True,
    "remove_underscore"         : True,
    "general_threshold"         : general_threshold,
    "character_threshold"       : character_threshold,
    "caption_extension"         : ".txt",
    "max_data_loader_n_workers" : 2,
    "debug"                     : True,
    "undesired_tags"            : undesired_tags
}

args = ""
for k, v in config.items():
    if k.startswith("_"):
        args += f'"{v}" '
    elif isinstance(v, str):
        args += f'--{k}="{v}" '
    elif isinstance(v, bool) and v:
        args += f"--{k} "
    elif isinstance(v, float) and not isinstance(v, bool):
        args += f"--{k}={v} "
    elif isinstance(v, int) and not isinstance(v, bool):
        args += f"--{k}={v} "

final_args = f"python tag_images_by_wd14_tagger.py {args}"

os.chdir(finetune_dir)
!{final_args}

In [None]:
# @title ### **3.2.3. Custom Caption/Tag**
import os

%store -r

os.chdir(root_dir)

# @markdown Add or remove custom tags here.
extension   = ".txt"  # @param [".txt", ".caption"]
custom_tag  = "anime"  # @param {type:"string"}
# @markdown Use `sub_folder` option to specify a subfolder for multi-concept training.
# @markdown > Specify `--all` to process all subfolders/`recursive`
sub_folder  = "" #@param {type: "string"}
# @markdown Enable this to append custom tags at the end of lines.
append      = False  # @param {type:"boolean"}
# @markdown Enable this if you want to remove captions/tags instead.
remove_tag  = False  # @param {type:"boolean"}
recursive   = False

if sub_folder == "":
    image_dir = train_data_dir
elif sub_folder == "--all":
    image_dir = train_data_dir
    recursive = True
elif sub_folder.startswith("/content"):
    image_dir = sub_folder
else:
    image_dir = os.path.join(train_data_dir, sub_folder)
    os.makedirs(image_dir, exist_ok=True)

def read_file(filename):
    with open(filename, "r") as f:
        contents = f.read()
    return contents

def write_file(filename, contents):
    with open(filename, "w") as f:
        f.write(contents)

def process_tags(filename, custom_tag, append, remove_tag):
    contents = read_file(filename)
    tags = [tag.strip() for tag in contents.split(',')]
    custom_tags = [tag.strip() for tag in custom_tag.split(',')]

    for custom_tag in custom_tags:
        custom_tag = custom_tag.replace("_", " ")
        if remove_tag:
            while custom_tag in tags:
                tags.remove(custom_tag)
        else:
            if custom_tag not in tags:
                if append:
                    tags.append(custom_tag)
                else:
                    tags.insert(0, custom_tag)

    contents = ', '.join(tags)
    write_file(filename, contents)

def process_directory(image_dir, tag, append, remove_tag, recursive):
    for filename in os.listdir(image_dir):
        file_path = os.path.join(image_dir, filename)

        if os.path.isdir(file_path) and recursive:
            process_directory(file_path, tag, append, remove_tag, recursive)
        elif filename.endswith(extension):
            process_tags(file_path, tag, append, remove_tag)

tag = custom_tag

if not any(
    [filename.endswith(extension) for filename in os.listdir(image_dir)]
):
    for filename in os.listdir(image_dir):
        if filename.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp")):
            open(
                os.path.join(image_dir, filename.split(".")[0] + extension),
                "w",
            ).close()

if custom_tag:
    process_directory(image_dir, tag, append, remove_tag, recursive)

In [None]:
# @title ## **3.4. Bucketing and Latents Caching**
%store -r

# @markdown This code will create buckets based on the `bucket_resolution` provided for multi-aspect ratio training, and then convert all images within the `train_data_dir` to latents.
bucketing_json    = os.path.join(training_dir, "meta_lat.json")
metadata_json     = os.path.join(training_dir, "meta_clean.json")
bucket_resolution = 1024  # @param {type:"slider", min:512, max:1024, step:128}
mixed_precision   = "no"  # @param ["no", "fp16", "bf16"] {allow-input: false}
flip_aug          = False  # @param{type:"boolean"}
# @markdown Use `clean_caption` option to clean such as duplicate tags, `women` to `girl`, etc
clean_caption     = True #@param {type:"boolean"}
#@markdown Use the `recursive` option to process subfolders as well
recursive         = True #@param {type:"boolean"}

metadata_config = {
    "_train_data_dir": train_data_dir,
    "_out_json": metadata_json,
    "recursive": recursive,
    "full_path": recursive,
    "clean_caption": clean_caption
}

bucketing_config = {
    "_train_data_dir": train_data_dir,
    "_in_json": metadata_json,
    "_out_json": bucketing_json,
    "_model_name_or_path": model_path,
    "recursive": recursive,
    "full_path": recursive,
    "flip_aug": flip_aug,
    "batch_size": 4,
    "max_data_loader_n_workers": 2,
    "max_resolution": f"{bucket_resolution}, {bucket_resolution}",
    "mixed_precision": mixed_precision,
}

def generate_args(config):
    args = ""
    for k, v in config.items():
        if k.startswith("_"):
            args += f'"{v}" '
        elif isinstance(v, str):
            args += f'--{k}="{v}" '
        elif isinstance(v, bool) and v:
            args += f"--{k} "
        elif isinstance(v, float) and not isinstance(v, bool):
            args += f"--{k}={v} "
        elif isinstance(v, int) and not isinstance(v, bool):
            args += f"--{k}={v} "
    return args.strip()

merge_metadata_args = generate_args(metadata_config)
prepare_buckets_args = generate_args(bucketing_config)

merge_metadata_command = f"python merge_all_to_metadata.py {merge_metadata_args}"
prepare_buckets_command = f"python prepare_buckets_latents.py {prepare_buckets_args}"

os.chdir(finetune_dir)
!{merge_metadata_command}
time.sleep(1)
!{prepare_buckets_command}


# **IV. Training**



In [None]:
# @title ## **4.1. Optimizer Config**
import toml
import ast

# @markdown Use `Adafactor` optimizer. `RMSprop 8bit` or `Adagrad 8bit` may work. `AdamW 8bit` doesn't seem to work.
optimizer_type = "AdaFactor"  # @param ["AdamW", "AdamW8bit", "Lion8bit", "Lion", "SGDNesterov", "SGDNesterov8bit", "DAdaptation(DAdaptAdamPreprint)", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptAdanIP", "DAdaptLion", "DAdaptSGD", "AdaFactor"]
# @markdown Specify `optimizer_args` to add `additional` args for optimizer, e.g: `["weight_decay=0.6"]`
optimizer_args = "[ \"scale_parameter=False\", \"relative_step=False\", \"warmup_init=False\" ]"  # @param {'type':'string'}
# @markdown ### **Learning Rate Config**
# @markdown Different `optimizer_type` and `network_category` for some condition requires different learning rate. It's recommended to set `text_encoder_lr = 1/2 * unet_lr`
learning_rate = 4e-7  # @param {'type':'number'}
train_text_encoder = False  # @param {type:"boolean"}
# train_text_encoder = False  # @param {'type':'boolean'}
# @markdown ### **LR Scheduler Config**
# @markdown `lr_scheduler` provides several methods to adjust the learning rate based on the number of epochs.
lr_scheduler = "constant_with_warmup"  # @param ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "adafactor"] {allow-input: false}
lr_warmup_steps = 100  # @param {'type':'number'}
# @markdown Specify `lr_scheduler_num` with `num_cycles` value for `cosine_with_restarts` or `power` value for `polynomial`
lr_scheduler_num = 0  # @param {'type':'number'}

if isinstance(optimizer_args, str):
    optimizer_args = optimizer_args.strip()
    if optimizer_args.startswith('[') and optimizer_args.endswith(']'):
        try:
            optimizer_args = ast.literal_eval(optimizer_args)
        except (SyntaxError, ValueError) as e:
            print(f"Error parsing optimizer_args: {e}\n")
            optimizer_args = []
    elif len(optimizer_args) > 0:
        print(f"WARNING! '{optimizer_args}' is not a valid list! Put args like this: [\"args=1\", \"args=2\"]\n")
        optimizer_args = []
    else:
        optimizer_args = []
else:
    optimizer_args = []

optimizer_config = {
    "optimizer_arguments": {
        "optimizer_type"          : optimizer_type,
        "learning_rate"           : learning_rate,
        "train_text_encoder"      : train_text_encoder,
        "max_grad_norm"           : 0,
        "optimizer_args"          : optimizer_args,
        "lr_scheduler"            : lr_scheduler,
        "lr_warmup_steps"         : lr_warmup_steps,
        "lr_scheduler_num_cycles" : lr_scheduler_num if lr_scheduler == "cosine_with_restarts" else None,
        "lr_scheduler_power"      : lr_scheduler_num if lr_scheduler == "polynomial" else None,
        "lr_scheduler_type"       : None,
        "lr_scheduler_args"       : None,
    },
}

print(toml.dumps(optimizer_config))


In [None]:
# @title ## **4.3. Advanced Training Config** (Optional)
import toml


# @markdown ### **Resume With Optimizer State**
optimizer_state_path      = "" #@param {type:"string"}
# @markdown ### **Noise Control**
noise_control_type        = "none" #@param ["none", "noise_offset", "multires_noise"]
# @markdown #### **a. Noise Offset**
# @markdown Control and easily generating darker or light images by offset the noise when fine-tuning the model. Recommended value: `0.1`. Read [Diffusion With Offset Noise](https://www.crosslabs.org//blog/diffusion-with-offset-noise)
noise_offset_num          = 0.1  # @param {type:"number"}
# @markdown **[Experimental]**
# @markdown Automatically adjusts the noise offset based on the absolute mean values of each channel in the latents when used with `--noise_offset`. Specify a value around 1/10 to the same magnitude as the `--noise_offset` for best results. Set `0` to disable.
adaptive_noise_scale      = 0.01 # @param {type:"number"}
# @markdown #### **b. Multires Noise**
# @markdown enable multires noise with this number of iterations (if enabled, around 6-10 is recommended)
multires_noise_iterations = 6 #@param {type:"slider", min:1, max:10, step:1}
multires_noise_discount = 0.3 #@param {type:"slider", min:0.1, max:1, step:0.1}
# @markdown ### **Custom Train Function**
# @markdown Gamma for reducing the weight of high-loss timesteps. Lower numbers have a stronger effect. The paper recommends `5`. Read the paper [here](https://arxiv.org/abs/2303.09556).
min_snr_gamma             = -1 #@param {type:"number"}

advanced_training_config = {
    "advanced_training_config": {
        "resume"                    : optimizer_state_path,
        "noise_offset"              : noise_offset_num if noise_control_type == "noise_offset" else None,
        "adaptive_noise_scale"      : adaptive_noise_scale if adaptive_noise_scale and noise_control_type == "noise_offset" else None,
        "multires_noise_iterations" : multires_noise_iterations if noise_control_type =="multires_noise" else None,
        "multires_noise_discount"   : multires_noise_discount if noise_control_type =="multires_noise" else None,
        "min_snr_gamma"             : min_snr_gamma if not min_snr_gamma == -1 else None,
    }
}

print(toml.dumps(advanced_training_config))

In [None]:
# @title ## **4.3. Training Config**
import toml
import os
import random
from subprocess import getoutput

%store -r

# @markdown ### **Project Config**
project_name            = "sdxl_finetune"  # @param {type:"string"}
# @markdown Get your `wandb_api_key` [here](https://wandb.ai/settings) to logs with wandb.
wandb_api_key           = "" # @param {type:"string"}
in_json                 = "/content/fine_tune/meta_lat.json"  # @param {type:"string"}
# @markdown ### **SDXL Config**
gradient_checkpointing      = True  # @param {type:"boolean"}
no_half_vae             = True  # @param {type:"boolean"}
#@markdown Recommended parameter for SDXL training but if you enable it, `shuffle_caption` won't work
cache_text_encoder_outputs = True  # @param {type:"boolean"}
#@markdown These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
min_timestep = 0 # @param {type:"number"}
max_timestep = 1000 # @param {type:"number"}
# @markdown ### **Dataset Config**
num_repeats             = 1  # @param {type:"number"}
# @markdown Please refer to `3.2.3. Custom Caption/Tag (Optional)` if you want to append `activation_word` to captions/tags
resolution              = 1024  # @param {type:"slider", min:512, max:1024, step:128}
keep_tokens             = 0  # @param {type:"number"}
# @markdown ### **General Config**
max_train_steps         = 2500  # @param {type:"number"}
train_batch_size        = 4  # @param {type:"number"}
mixed_precision         = "fp16"  # @param ["no","fp16","bf16"] {allow-input: false}
seed                    = -1  # @param {type:"number"}
# @markdown ### **Save Output Config**
save_precision          = "fp16"  # @param ["float", "fp16", "bf16"] {allow-input: false}
save_every_n_steps      = 1000  # @param {type:"number"}
save_optimizer_state    = False  # @param {type:"boolean"}
save_model_as           = "safetensors" #@param ["ckpt", "safetensors", "diffusers", "diffusers_safetensors"]
# @markdown ### **Sample Prompt Config**
enable_sample               = True  # @param {type:"boolean"}
sampler                     = "euler_a"  # @param ["ddim", "pndm", "lms", "euler", "euler_a", "heun", "dpm_2", "dpm_2_a", "dpmsolver","dpmsolver++", "dpmsingle", "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"]
positive_prompt             = ""
negative_prompt             = ""
quality_prompt              = "None"  # @param ["None", "Waifu Diffusion 1.5", "NovelAI", "AbyssOrangeMix", "Stable Diffusion XL"] {allow-input: false}
if quality_prompt          == "Waifu Diffusion 1.5":
    positive_prompt         = "(exceptional, best aesthetic, new, newest, best quality, masterpiece, extremely detailed, anime, waifu:1.2), "
    negative_prompt         = "lowres, ((bad anatomy)), ((bad hands)), missing finger, extra digits, fewer digits, blurry, ((mutated hands and fingers)), (poorly drawn face), ((mutation)), ((deformed face)), (ugly), ((bad proportions)), ((extra limbs)), extra face, (double head), (extra head), ((extra feet)), monster, logo, cropped, worst quality, jpeg, humpbacked, long body, long neck, ((jpeg artifacts)), deleted, old, oldest, ((censored)), ((bad aesthetic)), (mosaic censoring, bar censor, blur censor), "
if quality_prompt          == "NovelAI":
    positive_prompt         = "masterpiece, best quality, "
    negative_prompt         = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, "
if quality_prompt         == "AbyssOrangeMix":
    positive_prompt         = "masterpiece, best quality, "
    negative_prompt         = "(worst quality, low quality:1.4), "
if quality_prompt          == "Stable Diffusion XL":
    negative_prompt         = "3d render, smooth, plastic, blurry, grainy, low-resolution, deep-fried, oversaturated"
custom_prompt               = "1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt" # @param {type:"string"}
# @markdown Specify `prompt_from_caption` if you want to use caption as prompt instead. Will be chosen randomly.
prompt_from_caption         = "none"  # @param ["none", ".txt", ".caption"]
if prompt_from_caption != "none":
    custom_prompt           = ""
num_prompt                  = 2  # @param {type:"number"}
sample_interval             = 100  # @param {type:"number"}
logging_dir             = "/content/fine_tune/logs"

os.chdir(repo_dir)

prompt_config = {
    "prompt": {
        "negative_prompt" : negative_prompt,
        "width"           : resolution,
        "height"          : resolution,
        "scale"           : 7,
        "sample_steps"    : 28,
        "subset"          : [],
    }
}

train_config = {
    "sdxl_arguments": {
        "cache_text_encoder_outputs" : cache_text_encoder_outputs,
        # "enable_bucket"              : True,
        "no_half_vae"                : no_half_vae,
        # "cache_latents"              : True,
        # "cache_latents_to_disk"      : True,
        # "vae_batch_size"             : 4,
        "min_timestep"               : min_timestep,
        "max_timestep"               : max_timestep,
        "shuffle_caption"            : True if not cache_text_encoder_outputs else False,
    },
    "model_arguments": {
        "pretrained_model_name_or_path" : model_path,
        "vae"                           : vae_path,
    },
    "dataset_arguments": {
        "debug_dataset"                 : False,
        "in_json"                       : in_json,
        "train_data_dir"                : train_data_dir,
        "dataset_repeats"               : num_repeats,
        "keep_tokens"                   : keep_tokens,
        "resolution"                    : str(resolution) + ',' + str(resolution),
        "caption_dropout_rate"          : 0,
        "caption_tag_dropout_rate"      : 0,
        "caption_dropout_every_n_epochs": 0,
        "color_aug"                     : False,
        "face_crop_aug_range"           : None,
        "token_warmup_min"              : 1,
        "token_warmup_step"             : 0,
    },
    "training_arguments": {
        "output_dir"                    : output_dir,
        "output_name"                   : project_name if project_name else "last",
        "save_precision"                : save_precision,
        "save_every_n_steps"            : save_every_n_steps,
        "save_n_epoch_ratio"            : None,
        "save_last_n_epochs"            : None,
        "save_state"                    : None,
        "save_last_n_epochs_state"      : None,
        "resume"                        : None,
        "train_batch_size"              : train_batch_size,
        "max_token_length"              : 225,
        "mem_eff_attn"                  : False,
        "xformers"                      : True,
        "max_train_steps"               : max_train_steps,
        "max_data_loader_n_workers"     : 8,
        "persistent_data_loader_workers": True,
        "seed"                          : seed if seed > 0 else None,
        "gradient_checkpointing"        : gradient_checkpointing,
        "gradient_accumulation_steps"   : 1,
        "mixed_precision"               : mixed_precision,
    },
    "logging_arguments": {
        "log_with"          : "wandb" if wandb_api_key else "tensorboard",
        "log_tracker_name"  : project_name if wandb_api_key and not project_name == "last" else None,
        "logging_dir"       : logging_dir,
        "log_prefix"        : project_name if not wandb_api_key else None,
    },
    "sample_prompt_arguments": {
        "sample_every_n_steps"    : sample_interval,
        "sample_every_n_epochs"   : None,
        "sample_sampler"          : sampler,
    },
    "saving_arguments": {
        "save_model_as": "safetensors"
    },
}

def write_file(filename, contents):
    with open(filename, "w") as f:
        f.write(contents)

def prompt_convert(enable_sample, num_prompt, train_data_dir, prompt_config, custom_prompt):
    if enable_sample:
        search_pattern = os.path.join(train_data_dir, '**/*' + prompt_from_caption)
        caption_files = glob.glob(search_pattern, recursive=True)

        if not caption_files:
            if not custom_prompt:
                custom_prompt = "masterpiece, best quality, 1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt"
            new_prompt_config = prompt_config.copy()
            new_prompt_config['prompt']['subset'] = [
                {"prompt": positive_prompt + custom_prompt if positive_prompt else custom_prompt}
            ]
        else:
            selected_files = random.sample(caption_files, min(num_prompt, len(caption_files)))

            prompts = []
            for file in selected_files:
                with open(file, 'r') as f:
                    prompts.append(f.read().strip())

            new_prompt_config = prompt_config.copy()
            new_prompt_config['prompt']['subset'] = []

            for prompt in prompts:
                new_prompt = {
                    "prompt": positive_prompt + prompt if positive_prompt else prompt,
                }
                new_prompt_config['prompt']['subset'].append(new_prompt)

        return new_prompt_config
    else:
        return prompt_config

def eliminate_none_variable(config):
    for key in config:
        if isinstance(config[key], dict):
            for sub_key in config[key]:
                if config[key][sub_key] == "":
                    config[key][sub_key] = None
        elif config[key] == "":
            config[key] = None

    return config

try:
    train_config.update(optimizer_config)
except NameError:
    raise NameError("'optimizer_config' dictionary is missing. Please run  '4.1. Optimizer Config' cell.")

advanced_training_warning = False
try:
    train_config.update(advanced_training_config)
except NameError:
    advanced_training_warning = True
    pass

prompt_config       = prompt_convert(enable_sample, num_prompt, train_data_dir, prompt_config, custom_prompt)

config_path         = os.path.join(config_dir, "config_file.toml")
prompt_path         = os.path.join(config_dir, "sample_prompt.toml")

config_str          = toml.dumps(eliminate_none_variable(train_config))
prompt_str          = toml.dumps(eliminate_none_variable(prompt_config))

write_file(config_path, config_str)
write_file(prompt_path, prompt_str)

print(config_str)

if advanced_training_warning:
    import textwrap
    error_message = "WARNING: This is not an error message, but the [advanced_training_config] dictionary is missing. Please run the '4.2. Advanced Training Config' cell if you intend to use it, or continue to the next step."
    wrapped_message = textwrap.fill(error_message, width=80)
    print('\033[38;2;204;102;102m' + wrapped_message + '\033[0m\n')
    pass

print(prompt_str)

In [None]:
#@title ## **4.4. Start Training**
import os
import toml

#@markdown Check your config here if you want to edit something:
#@markdown - `sample_prompt` : /content/fine_tune/config/sample_prompt.toml
#@markdown - `config_file` : /content/fine_tune/config/config_file.toml

#@markdown You can import config from another session if you want.

sample_prompt   = "/content/fine_tune/config/sample_prompt.toml" #@param {type:'string'}
config_file     = "/content/fine_tune/config/config_file.toml" #@param {type:'string'}

def read_file(filename):
    with open(filename, "r") as f:
        contents = f.read()
    return contents

def train(config):
    args = ""
    for k, v in config.items():
        if k.startswith("_"):
            args += f'"{v}" '
        elif isinstance(v, str):
            args += f'--{k}="{v}" '
        elif isinstance(v, bool) and v:
            args += f"--{k} "
        elif isinstance(v, float) and not isinstance(v, bool):
            args += f"--{k}={v} "
        elif isinstance(v, int) and not isinstance(v, bool):
            args += f"--{k}={v} "

    return args

accelerate_conf = {
    "config_file" : accelerate_config,
    "num_cpu_threads_per_process" : 1,
}

train_conf = {
    "sample_prompts"  : sample_prompt if os.path.exists(sample_prompt) else None,
    "config_file"     : config_file,
    "wandb_api_key"   : wandb_api_key if wandb_api_key else None,
}

accelerate_args = train(accelerate_conf)
train_args = train(train_conf)

final_args = f"accelerate launch {accelerate_args} sdxl_train.py {train_args}"

os.chdir(repo_dir)
!{final_args}

In [None]:
#@title ## **5.1. Inference**

import os
%store -r

# @markdown ### Model Config
ckpt_path = "/content/fine_tune/output/sdxl_finetune.safetensors" #@param {type:'string'}
# @markdown ### Prompt Config
prompt = "1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt" #@param {type:'string'}
negative_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" #@param {type:'string'}
output_path = "/content/tmp/" #@param {type:'string'}
resolution = "1024,1024" # @param {type: "string"}
optimization = "scaled dot-product attention" # @param ["xformers", "scaled dot-product attention"]
conditional_resolution = "1024,1024" # @param {type: "string"}
steps = 28 # @param {type: "number"}
sampler = "euler_a"  # @param ["ddim", "pndm", "lms", "euler", "euler_a", "heun", "dpm_2", "dpm_2_a", "dpmsolver","dpmsolver++", "dpmsingle", "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"]
scale = 7 # @param {type: "number"}
seed = -1 # @param {type: "number"}
images_per_prompt = 1 # @param {type: "number"}
batch_size = 1 # @param {type: "number"}
clip_skip = 2 # @param {type: "number"}

os.makedirs(output_path, exist_ok=True)

separators = ["*", "x", ","]

for separator in separators:
    if separator in resolution:
        width, height = [value.strip() for value in resolution.split(separator)]
        original_width, original_height = [value.strip() for value in conditional_resolution.split(separator)]
        break

config = {
    "prompt": prompt + " --n " + negative_prompt,
    "images_per_prompt": images_per_prompt,
    "outdir": output_path,
    "W": width,
    "H": height,
    "original_width": original_width,
    "original_height": original_height,
    "batch_size": batch_size,
    "vae_batch_size": 1,
    "no_half_vae": True,
    "steps": steps,
    "sampler": sampler,
    "scale": scale,
    "ckpt": ckpt_path,
    "vae": vae_path,
    "seed": seed if seed > 0 else None,
    "fp16": True,
    "sdpa": True if optimization == "scaled dot-product attention" else False,
    "xformers": True if optimization == "xformers" else False,
    "opt_channels_last": True,
    "clip_skip": clip_skip,
    "max_embeddings_multiples": 3,
}

args = ""
for k, v in config.items():
    if k.startswith("_"):
        args += f'"{v}" '
    elif isinstance(v, str):
        args += f'--{k}="{v}" '
    elif isinstance(v, bool) and v:
        args += f"--{k} "
    elif isinstance(v, float) and not isinstance(v, bool):
        args += f"--{k}={v} "
    elif isinstance(v, int) and not isinstance(v, bool):
        args += f"--{k}={v} "

final_args = f"python sdxl_gen_img.py {args}"

os.chdir(repo_dir)
!{final_args}

# **V. Deployment**

In [None]:
# @title ## **5.1. Huggingface Hub config**
from huggingface_hub import login
from huggingface_hub import HfApi
from huggingface_hub.utils import validate_repo_id, HfHubHTTPError

# @markdown Login to Huggingface Hub
# @markdown > Get **your** huggingface `WRITE` token [here](https://huggingface.co/settings/tokens)
write_token = ""  # @param {type:"string"}
# @markdown Fill this if you want to upload to your organization, or just leave it empty.
orgs_name = ""  # @param{type:"string"}
# @markdown If your model/dataset repo does not exist, it will automatically create it.
model_name = ""  # @param{type:"string"}
dataset_name = ""  # @param{type:"string"}
make_private = False  # @param{type:"boolean"}

def authenticate(write_token):
    login(write_token, add_to_git_credential=True)
    api = HfApi()
    return api.whoami(write_token), api


def create_repo(api, user, orgs_name, repo_name, repo_type, make_private=False):
    global model_repo
    global datasets_repo

    if orgs_name == "":
        repo_id = user["name"] + "/" + repo_name.strip()
    else:
        repo_id = orgs_name + "/" + repo_name.strip()

    try:
        validate_repo_id(repo_id)
        api.create_repo(repo_id=repo_id, repo_type=repo_type, private=make_private)
        print(f"{repo_type.capitalize()} repo '{repo_id}' didn't exist, creating repo")
    except HfHubHTTPError as e:
        print(f"{repo_type.capitalize()} repo '{repo_id}' exists, skipping create repo")

    if repo_type == "model":
        model_repo = repo_id
        print(f"{repo_type.capitalize()} repo '{repo_id}' link: https://huggingface.co/{repo_id}\n")
    else:
        datasets_repo = repo_id
        print(f"{repo_type.capitalize()} repo '{repo_id}' link: https://huggingface.co/datasets/{repo_id}\n")

user, api = authenticate(write_token)

if model_name:
    create_repo(api, user, orgs_name, model_name, "model", make_private)
if dataset_name:
    create_repo(api, user, orgs_name, dataset_name, "dataset", make_private)


In [None]:
# @title ## **5.2. Upload Checkpoint to Huggingface**
from huggingface_hub import HfApi
from pathlib import Path

api = HfApi()

# @markdown This will be uploaded to model repo
model_path = "/content/fine_tune/output/sdxl_finetune.safetensors"  # @param {type :"string"}
path_in_repo = ""  # @param {type :"string"}
# @markdown Now you can save your config file for future use
config_path = "/content/fine_tune/config"  # @param {type :"string"}
# @markdown Other Information
commit_message = ""  # @param {type :"string"}

if not commit_message:
    commit_message = "feat: upload " + project_name + " checkpoint"

if os.path.exists(model_path):
    vae_exists = os.path.exists(os.path.join(model_path, "vae"))
    unet_exists = os.path.exists(os.path.join(model_path, "unet"))
    text_encoder_exists = os.path.exists(os.path.join(model_path, "text_encoder"))


def upload_model(model_paths, is_folder: bool, is_config: bool):
    path_obj = Path(model_paths)
    trained_model = path_obj.parts[-1]

    if path_in_repo:
        trained_model = path_in_repo

    if is_config:
        if path_in_repo:
            trained_model = f"{path_in_repo}_config"
        else:
            trained_model = f"{project_name}_config"

    if is_folder == True:
        print(f"Uploading {trained_model} to https://huggingface.co/" + model_repo)
        print(f"Please wait...")

        if vae_exists and unet_exists and text_encoder_exists:
            api.upload_folder(
                folder_path=model_paths,
                repo_id=model_repo,
                commit_message=commit_message,
                ignore_patterns=".ipynb_checkpoints",
            )
        else:
            api.upload_folder(
                folder_path=model_paths,
                path_in_repo=trained_model,
                repo_id=model_repo,
                commit_message=commit_message,
                ignore_patterns=".ipynb_checkpoints",
            )
        print(
            f"Upload success, located at https://huggingface.co/"
            + model_repo
            + "/tree/main\n"
        )
    else:
        print(f"Uploading {trained_model} to https://huggingface.co/" + model_repo)
        print(f"Please wait...")

        api.upload_file(
            path_or_fileobj=model_paths,
            path_in_repo=trained_model,
            repo_id=model_repo,
            commit_message=commit_message,
        )

        print(
            f"Upload success, located at https://huggingface.co/"
            + model_repo
            + "/blob/main/"
            + trained_model
            + "\n"
        )


def upload():
    if model_path.endswith((".ckpt", ".safetensors", ".pt")):
        upload_model(model_path, False, False)
    else:
        upload_model(model_path, True, False)

    if config_path:
        upload_model(config_path, True, True)


upload()

In [None]:
# @title ## **5.3. Upload Dataset to Huggingface** (optional)
from huggingface_hub import HfApi
from pathlib import Path
import shutil
import zipfile
import os

api = HfApi()

# @markdown This will be compressed to zip and  uploaded to datasets repo, leave it empty if not necessary
train_data_path = "/content/fine_tune/train_data"  # @param {type :"string"}
meta_lat_path = "/content/fine_tune/meta_lat.json"  # @param {type :"string"}
last_state_path = "/content/fine_tune/output/last-state"  # @param {type :"string"}
# @markdown `Nerd stuff, only if you want to save training logs`
logs_path = "/content/fine_tune/logs"  # @param {type :"string"}

if project_name:
    tmp_dataset = "/content/fine_tune/" + project_name + "_dataset"
    tmp_last_state = "/content/fine_tune/" + project_name + "_last_state"

else:
    tmp_dataset = "/content/fine_tune/tmp_dataset"
    tmp_last_state = "/content/fine_tune/tmp_last_state"

tmp_train_data = tmp_dataset + "/train_data"
dataset_zip = tmp_dataset + ".zip"
last_state_zip = tmp_last_state + ".zip"

# @markdown  Other Information
commit_message = ""  # @param {type :"string"}

if not commit_message:
    commit_message = "feat: upload " + project_name + " dataset and logs"

tmp_folder = ["tmp_dataset", "tmp_last_state", "tmp_train_data"]


def makedirs(tmp_folders):
    os.makedirs(tmp_folders, exist_ok=True)


for folder in tmp_folder:
    makedirs(folder)


def upload_dataset(dataset_paths, is_zip: bool):
    path_obj = Path(dataset_paths)
    dataset_name = path_obj.parts[-1]

    if is_zip:
        print(
            f"Uploading {dataset_name} to https://huggingface.co/datasets/"
            + datasets_repo
        )
        print(f"Please wait...")

        api.upload_file(
            path_or_fileobj=dataset_paths,
            path_in_repo=dataset_name,
            repo_id=datasets_repo,
            repo_type="dataset",
            commit_message=commit_message,
        )
        print(
            f"Upload success, located at https://huggingface.co/datasets/"
            + datasets_repo
            + "/blob/main/"
            + dataset_name
            + "\n"
        )
    else:
        print(
            f"Uploading {dataset_name} to https://huggingface.co/datasets/"
            + datasets_repo
        )
        print(f"Please wait...")

        api.upload_folder(
            folder_path=dataset_paths,
            path_in_repo=dataset_name,
            repo_id=datasets_repo,
            repo_type="dataset",
            commit_message=commit_message,
            ignore_patterns=".ipynb_checkpoints",
        )
        print(
            f"Upload success, located at https://huggingface.co/datasets/"
            + datasets_repo
            + "/tree/main/"
            + dataset_name
            + "\n"
        )


def zip_file(tmp_folders):
    zipfiles = tmp_folders + ".zip"
    with zipfile.ZipFile(zipfiles, "w") as zip:
        for tmp_folders, dirs, files in os.walk(tmp_folders):
            for file in files:
                zip.write(os.path.join(tmp_folders, file))


def move(src_path, dst_path, is_metadata: bool):
    files_to_move = [
        "meta_cap.json",
        "meta_cap_dd.json",
        "meta_lat.json",
        "meta_clean.json",
        "meta_final.json",
    ]

    if os.path.exists(src_path):
        shutil.move(src_path, dst_path)

    if is_metadata:
        parent_meta_path = os.path.dirname(src_path)

        for filename in os.listdir(parent_meta_path):
            file_path = os.path.join(parent_meta_path, filename)
            if filename in files_to_move:
                shutil.move(file_path, dst_path)


def upload():
    if train_data_path and meta_lat_path:
        move(train_data_path, tmp_train_data, False)
        move(meta_lat_path, tmp_dataset, True)
        zip_file(tmp_dataset)
        upload_dataset(dataset_zip, True)
        os.remove(dataset_zip)

    if last_state_path:
        if os.path.exists(last_state_path):
            move(last_state_path, tmp_last_state, False)
            zip_file(tmp_last_state)
            upload_dataset(last_state_zip, True)
            os.remove(last_state_zip)

    if logs_path:
        upload_dataset(logs_path, False)


upload()