<a href="https://colab.research.google.com/github/dhenypatungka/StableDiffusionNotebook/blob/main/kohya-LoRA-trainer-XL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @markdown ## ♻ **Install SDXL**
#@markdown - Please use T4 for the hardware accelerator at runtime.
#@markdown - Prepare to upload photos with a maximum resolution of 2000 pixels.
#@markdown - For perfect training results, upload a minimum of 30 photos with an aspect ratio of 1:1 Closeup | 8:10 Half Body | 4:6 Full Body

import os
import zipfile
import shutil
import time
import requests
import torch
from subprocess import getoutput
from IPython.utils import capture
from google.colab import drive

%store -r

# root_dir
root_dir          = "/content"
drive_dir         = os.path.join(root_dir, "drive/MyDrive")
deps_dir          = os.path.join(root_dir, "deps")
repo_dir          = os.path.join(root_dir, "kohya-trainer")
training_dir      = os.path.join(root_dir, "LoRA")
pretrained_model  = os.path.join(root_dir, "pretrained_model")
vae_dir           = os.path.join(root_dir, "vae")
lora_dir          = os.path.join(root_dir, "network_weight")
repositories_dir  = os.path.join(root_dir, "repositories")
config_dir        = os.path.join(training_dir, "config")
tools_dir         = os.path.join(repo_dir, "tools")
finetune_dir      = os.path.join(repo_dir, "finetune")
accelerate_config = os.path.join(repo_dir, "accelerate_config/config.yaml")

for store in ["root_dir", "repo_dir", "training_dir", "pretrained_model", "vae_dir", "repositories_dir", "accelerate_config", "tools_dir", "finetune_dir", "config_dir"]:
    with capture.capture_output() as cap:
        %store {store}
        del cap

repo_dict = {
    "qaneel/kohya-trainer (forked repo, stable, optimized for colab use)" : "https://github.com/qaneel/kohya-trainer",
    "kohya-ss/sd-scripts (original repo, latest update)"                    : "https://github.com/kohya-ss/sd-scripts",
}

repository        = "qaneel/kohya-trainer (forked repo, stable, optimized for colab use)"
repo_url          = repo_dict[repository]
branch            = "main"
output_to_drive   = False  # @param {type: "boolean"}

def clone_repo(url, dir, branch):
    if not os.path.exists(dir):
       !git clone -b {branch} {url} {dir}

def mount_drive(dir):
    output_dir      = os.path.join(training_dir, "output")

    if output_to_drive:
        if not os.path.exists(drive_dir):
            drive.mount(os.path.dirname(drive_dir))
        output_dir  = os.path.join(drive_dir, "kohya-trainer/output")

    return output_dir

def setup_directories():
    global output_dir

    output_dir      = mount_drive(drive_dir)

    for dir in [training_dir, config_dir, pretrained_model, vae_dir, repositories_dir, output_dir]:
        os.makedirs(dir, exist_ok=True)

def pastebin_reader(id):
    if "pastebin.com" in id:
        url = id
        if 'raw' not in url:
                url = url.replace('pastebin.com', 'pastebin.com/raw')
    else:
        url = "https://pastebin.com/raw/" + id
    response = requests.get(url)
    response.raise_for_status()
    lines = response.text.split('\n')
    return lines

def install_repository():
    global infinite_image_browser_dir, voldy, discordia_archivum_dir

    _, voldy = pastebin_reader("kq6ZmHFU")[:2]

    infinite_image_browser_url  = f"https://github.com/zanllp/{voldy}-infinite-image-browsing.git"
    infinite_image_browser_dir  = os.path.join(repositories_dir, f"infinite-image-browsing")
    infinite_image_browser_deps = os.path.join(infinite_image_browser_dir, "requirements.txt")

    discordia_archivum_url = "https://github.com/Linaqruf/discordia-archivum"
    discordia_archivum_dir = os.path.join(repositories_dir, "discordia-archivum")
    discordia_archivum_deps = os.path.join(discordia_archivum_dir, "requirements.txt")

    clone_repo(infinite_image_browser_url, infinite_image_browser_dir, "main")
    clone_repo(discordia_archivum_url, discordia_archivum_dir, "main")

    !pip install -q --upgrade -r {infinite_image_browser_deps}
    !pip install python-dotenv
    !pip install -q --upgrade -r {discordia_archivum_deps}

def install_dependencies():
    requirements_file = os.path.join(repo_dir, "requirements.txt")
    model_util        = os.path.join(repo_dir, "library/model_util.py")
    gpu_info          = getoutput('nvidia-smi')
    t4_xformers_wheel = "https://github.com/Linaqruf/colab-xformers/releases/download/0.0.20/xformers-0.0.20+1d635e1.d20230519-cp310-cp310-linux_x86_64.whl"

    !apt install aria2 lz4
    !wget https://github.com/camenduru/gperftools/releases/download/v1.0/libtcmalloc_minimal.so.4 -O /content/libtcmalloc_minimal.so.4
    !pip install -q --upgrade -r {requirements_file}

    !pip install -q xformers==0.0.22.post7

    from accelerate.utils import write_basic_config

    if not os.path.exists(accelerate_config):
        write_basic_config(save_location=accelerate_config)

def prepare_environment():
    os.environ["LD_PRELOAD"] = "/content/libtcmalloc_minimal.so.4"
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["SAFETENSORS_FAST_GPU"] = "1"
    os.environ["PYTHONWARNINGS"] = "ignore"

def main():
    os.chdir(root_dir)
    clone_repo(repo_url, repo_dir, branch)
    os.chdir(repo_dir)
    setup_directories()
    install_repository()
    install_dependencies()
    prepare_environment()

main()

In [None]:
# @markdown ## ♻ **Download SDXL**
import os
import re
import json
import glob
import gdown
import requests
import subprocess
from IPython.utils import capture
from urllib.parse import urlparse, unquote
from pathlib import Path
from huggingface_hub import HfFileSystem
from huggingface_hub.utils import validate_repo_id, HfHubHTTPError

%store -r

os.chdir(root_dir)


HUGGINGFACE_TOKEN     = "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"
LOAD_DIFFUSERS_MODEL  = True
SDXL_MODEL_URL        = "stabilityai/stable-diffusion-xl-base-1.0" # @param ["stabilityai/stable-diffusion-xl-base-1.0", "PASTE MODEL URL OR GDRIVE PATH HERE"] {allow-input: true}
SDXL_VAE_URL          = "Original VAE" # @param ["None", "Original VAE", "FP16 VAE", "PASTE VAE URL OR GDRIVE PATH HERE"] {allow-input: true}

MODEL_URLS = {
    "stabilityai/stable-diffusion-xl-base-1.0" : "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors",
}
VAE_URLS = {
    "None"                    : "",
    "Original VAE"           : "https://huggingface.co/stabilityai/sdxl-vae/resolve/main/sdxl_vae.safetensors",
    "FP16 VAE"           : "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl_vae.safetensors"
}

SDXL_MODEL_URL = MODEL_URLS.get(SDXL_MODEL_URL, SDXL_MODEL_URL)
SDXL_VAE_URL = VAE_URLS.get(SDXL_VAE_URL, SDXL_VAE_URL)

def get_filename(url):
    if any(url.endswith(ext) for ext in [".ckpt", ".safetensors", ".pt", ".pth"]):
        return os.path.basename(url)

    response = requests.get(url, stream=True)
    response.raise_for_status()

    if 'content-disposition' in response.headers:
        filename = re.findall('filename="?([^"]+)"?', response.headers['content-disposition'])[0]
    else:
        filename = unquote(os.path.basename(urlparse(url).path))

    return filename

def aria2_download(dir, filename, url):
    user_header = f"Authorization: Bearer {HUGGINGFACE_TOKEN}"
    aria2_args = [
        "aria2c",
        "--console-log-level=error",
        "--summary-interval=10",
        f"--header={user_header}" if "huggingface.co" in url else "",
        "--continue=true",
        "--max-connection-per-server=16",
        "--min-split-size=1M",
        "--split=16",
        f"--dir={dir}",
        f"--out={filename}",
        url
    ]
    subprocess.run(aria2_args)

def download(url, dst):
    print(f"Starting downloading from {url}")
    filename = get_filename(url)
    filepath = os.path.join(dst, filename)

    if "drive.google.com" in url:
        gdown.download(url, filepath, quiet=False)
    else:
        if "huggingface.co" in url and "/blob/" in url:
            url = url.replace("/blob/", "/resolve/")
        aria2_download(dst, filename, url)

    print(f"Download finished: {filepath}")
    return filepath

def all_folders_present(base_model_url, sub_folders):
    fs = HfFileSystem()
    existing_folders = set(fs.ls(base_model_url, detail=False))

    for folder in sub_folders:
        full_folder_path = f"{base_model_url}/{folder}"
        if full_folder_path not in existing_folders:
            return False
    return True

def get_total_ram_gb():
    with open('/proc/meminfo', 'r') as f:
        for line in f.readlines():
            if "MemTotal" in line:
                return int(line.split()[1]) / (1024**2)  # Convert to GB

def get_gpu_name():
    try:
        return subprocess.check_output("nvidia-smi --query-gpu=name --format=csv,noheader,nounits", shell=True).decode('ascii').strip()
    except:
        return None

def main():
    global model_path, vae_path, LOAD_DIFFUSERS_MODEL

    model_path, vae_path = None, None

    required_sub_folders = [
        'scheduler',
        'text_encoder',
        'text_encoder_2',
        'tokenizer',
        'tokenizer_2',
        'unet',
        'vae',
    ]

    download_targets = {
        "model": (SDXL_MODEL_URL, pretrained_model),
        "vae": (SDXL_VAE_URL, vae_dir),
    }

    total_ram = get_total_ram_gb()
    gpu_name = get_gpu_name()

    # Check hardware constraints
    if total_ram < 13 and gpu_name in ["Tesla T4", "Tesla V100"]:
        print("Attempt to load diffusers model instead due to hardware constraints.")
        if not LOAD_DIFFUSERS_MODEL:
            LOAD_DIFFUSERS_MODEL = True

    for target, (url, dst) in download_targets.items():
        if url and not url.startswith(f"PASTE {target.upper()} URL OR GDRIVE PATH HERE"):
            if target == "model" and LOAD_DIFFUSERS_MODEL:
                # Code for checking and handling diffusers model
                if 'huggingface.co' in url:
                    match = re.search(r'huggingface\.co/([^/]+)/([^/]+)', SDXL_MODEL_URL)
                    if match:
                        username = match.group(1)
                        model_name = match.group(2)
                        url = f"{username}/{model_name}"
                if all_folders_present(url, required_sub_folders):
                    print(f"Diffusers model is loaded : {url}")
                    model_path = url
                else:
                    print("Repository doesn't exist or no diffusers model detected.")
                    filepath = download(url, dst)  # Continue with the regular download
                    model_path = filepath
            else:
                filepath = download(url, dst)

                if target == "model":
                    model_path = filepath
                elif target == "vae":
                    vae_path = filepath

            print()

    if model_path:
        print(f"Selected model: {model_path}")

    if vae_path:
        print(f"Selected VAE: {vae_path}")

main()

In [None]:
# @markdown ## ♻ **Directory Config**
# @markdown Specify the location of your training data in the following cell. A folder with the same name as your input will be created.
import os

%store -r

train_data_dir = "/content/LoRA/train_data"  # @param {'type' : 'string'}
%store train_data_dir

os.makedirs(train_data_dir, exist_ok=True)
print(f"Your train data directory : {train_data_dir}")

In [None]:
# @markdown ## ♻ **Unzip Dataset**

import os
import subprocess
import re
from urllib.parse import unquote
import requests
from zipfile import ZipFile

# @markdown If your dataset is in a `zip` file and has been uploaded to a location, use this section to extract it.
# @markdown The dataset will be downloaded and automatically extracted to `train_data_dir` if `unzip_to` is empty.

zipfile_url = ""  # @param {type:"string"}
unzip_to = ""  # @param {type:"string"}
hf_token = "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"

use_aria2c = True # @param {type:"boolean"}
preserve_folders = True # @param {type:"boolean"}
remove_after_unzipping = False # @param {type:"boolean"}

if "huggingface.co" in zipfile_url and "blob" in zipfile_url:
    zipfile_url = zipfile_url.replace("blob", "resolve")

if not unzip_to:
    unzip_to = train_data_dir

def get_filename_from_url(url):
    if "huggingface.co" or "/content/" in url:
        return os.path.basename(url)

    response = requests.head(url, allow_redirects=True)
    cd = response.headers.get('content-disposition')
    if cd:
        fname = re.findall('filename=(.+)', cd)
        if len(fname) == 0:
            return "zipfile.zip"
        return unquote(fname[0])

    return "zipfile.zip"

def download_with_requests(url, output_path):
    print(f"Downloading {url} with requests...")
    response = requests.get(url, stream=True)
    with open(output_path, 'wb') as file:
        for chunk in response.iter_content(chunk_size=8192):
            file.write(chunk)
    print(f"Downloaded to {output_path}")
    return output_path

def download_with_aria2c(url, output_path):
    print(f"Downloading {url} with aria2c...")
    aria_args = {
        'console-log-level': 'error',
        'summary-interval': '10',
        'continue': 'true',
        'max-connection-per-server': '16',
        'min-split-size': '1M',
        'split': '16',
        'dir': os.path.dirname(output_path),
        'out': os.path.basename(output_path),
    }

    if "huggingface.co" in url:
        aria_args['header'] = f"Authorization: Bearer {hf_token}"

    cmd = ['aria2c'] + [f'--{k}={v}' for k, v in aria_args.items()] + [url]
    subprocess.run(cmd)
    print(f"Downloaded to {output_path}")
    return output_path

def move_files(train_dir):
    for filename in os.listdir(train_dir):
        file_path = os.path.join(train_dir, filename)
        if filename.startswith("meta_") and filename.endswith(".json"):
            if not os.path.exists(file_path):
                shutil.move(file_path, training_dir)
            else:
                os.remove(file_path)

def remove_empty_dirs(path):
    for dirpath, dirnames, files in os.walk(path, topdown=False):  # start from leaf folders
        for dirname in dirnames:
            full_dir_path = os.path.join(dirpath, dirname)
            if not os.listdir(full_dir_path):  # Check if directory is empty
                os.rmdir(full_dir_path)
                print(f"Removed empty directory: {full_dir_path}")

def extract_dataset(zip_file, output_path):
    with ZipFile(zip_file, 'r') as zip_ref:
        print(f"Extracting {zip_file} to {output_path}...")

        if not preserve_folders:  # If we do not want to preserve folder structure
            for member in zip_ref.namelist():
                # Extract only the file name, discard directory structure
                filename = os.path.basename(member)
                if filename:  # Check if file name is not empty (this skips directories)
                    zip_ref.extract(member, output_path)
                    source_path = os.path.join(output_path, member)
                    target_path = os.path.join(output_path, filename)
                    os.rename(source_path, target_path)

            remove_empty_dirs(output_path)

        else:
            zip_ref.extractall(output_path)

        print("Extraction completed!")

def download_dataset(url, output_path):
    if url.startswith("/content"):
        print(f"Using file at {url}")
        return url

    elif "drive.google.com" in url:
        print("Downloading from Google Drive...")
        cmd = ['gdown', '--id', url.split('/')[-2], '-O', output_path]
        subprocess.run(cmd)
        return output_path

    elif use_aria2c:
        return download_with_aria2c(url, output_path)

    else:
        return download_with_requests(url, output_path)

def main():
    zipfile_name = get_filename_from_url(zipfile_url)
    output_path = os.path.join(root_dir, zipfile_name)

    zip_file = download_dataset(zipfile_url, output_path)

    extract_dataset(zip_file, unzip_to)

    move_files(unzip_to)

    if remove_after_unzipping and "/content/drive" not in zip_file:
        os.remove(zip_file)
        print(f"Removed {zip_file}")

main()


In [None]:
# @markdown ## ♻ **Data Cleaning**
import os
import random
import concurrent.futures
from tqdm import tqdm
from PIL import Image

%store -r

os.chdir(root_dir)

test = os.listdir(train_data_dir)
#@markdown This section removes unsupported media types such as `.mp4`, `.webm`, and `.gif`, as well as any unnecessary files.
#@markdown To convert a transparent dataset with an alpha channel (RGBA) to RGB and give it a white background, set the `convert` parameter to `True`.
convert = False  # @param {type:"boolean"}
#@markdown Alternatively, you can give the background a `random_color` instead of white by checking the corresponding option.
random_color = False  # @param {type:"boolean"}
recursive = False

batch_size = 32
supported_types = [
    ".png",
    ".jpg",
    ".jpeg",
    ".webp",
    ".bmp",
    ".caption",
    ".npz",
    ".txt",
    ".json",
]

background_colors = [
    (255, 255, 255),
    (0, 0, 0),
    (255, 0, 0),
    (0, 255, 0),
    (0, 0, 255),
    (255, 255, 0),
    (255, 0, 255),
    (0, 255, 255),
]

def clean_directory(directory):
    for item in os.listdir(directory):
        file_path = os.path.join(directory, item)
        if os.path.isfile(file_path):
            file_ext = os.path.splitext(item)[1]
            if file_ext not in supported_types:
                print(f"Deleting file {item} from {directory}")
                os.remove(file_path)
        elif os.path.isdir(file_path) and recursive:
            clean_directory(file_path)

def process_image(image_path):
    img = Image.open(image_path)
    img_dir, image_name = os.path.split(image_path)

    if img.mode in ("RGBA", "LA"):
        if random_color:
            background_color = random.choice(background_colors)
        else:
            background_color = (255, 255, 255)
        bg = Image.new("RGB", img.size, background_color)
        bg.paste(img, mask=img.split()[-1])

        if image_name.endswith(".webp"):
            bg = bg.convert("RGB")
            new_image_path = os.path.join(img_dir, image_name.replace(".webp", ".jpg"))
            bg.save(new_image_path, "JPEG")
            os.remove(image_path)
            print(f" Converted image: {image_name} to {os.path.basename(new_image_path)}")
        else:
            bg.save(image_path, "PNG")
            print(f" Converted image: {image_name}")
    else:
        if image_name.endswith(".webp"):
            new_image_path = os.path.join(img_dir, image_name.replace(".webp", ".jpg"))
            img.save(new_image_path, "JPEG")
            os.remove(image_path)
            print(f" Converted image: {image_name} to {os.path.basename(new_image_path)}")
        else:
            img.save(image_path, "PNG")

def find_images(directory):
    images = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(".png") or file.endswith(".webp"):
                images.append(os.path.join(root, file))
    return images

clean_directory(train_data_dir)
images = find_images(train_data_dir)
num_batches = len(images) // batch_size + 1

if convert:
    with concurrent.futures.ThreadPoolExecutor() as executor:
        for i in tqdm(range(num_batches)):
            start = i * batch_size
            end = start + batch_size
            batch = images[start:end]
            executor.map(process_image, batch)

    print("All images have been converted")

In [None]:
#@markdown ### ♻ **BLIP Captioning**
#@markdown BLIP is a pre-training framework for unified vision-language understanding and generation, which achieves state-of-the-art results on a wide range of vision-language tasks. It can be used as a tool for image captioning, for example, `astronaut riding a horse in space`.
import os

os.chdir(finetune_dir)

beam_search = True #@param {type:'boolean'}
min_length = 5 #@param {type:"slider", min:0, max:100, step:5.0}
max_length = 75 #@param {type:"slider", min:0, max:100, step:5.0}

config = {
    "_train_data_dir"   : train_data_dir,
    "batch_size"        : 8,
    "beam_search"       : beam_search,
    "min_length"        : min_length,
    "max_length"        : max_length,
    "debug"             : True,
    "caption_extension" : ".caption",
    "max_data_loader_n_workers" : 2,
    "recursive"         : True
}

args = ""
for k, v in config.items():
    if k.startswith("_"):
        args += f'"{v}" '
    elif isinstance(v, str):
        args += f'--{k}="{v}" '
    elif isinstance(v, bool) and v:
        args += f"--{k} "
    elif isinstance(v, float) and not isinstance(v, bool):
        args += f"--{k}={v} "
    elif isinstance(v, int) and not isinstance(v, bool):
        args += f"--{k}={v} "

final_args = f"python make_captions.py {args}"

os.chdir(finetune_dir)
!{final_args}

In [None]:
#@markdown ### ♻ **Waifu Diffusion 1.4 Tagger V2**
import os
%store -r

os.chdir(finetune_dir)

model = "SmilingWolf/wd-v1-4-moat-tagger-v2"
undesired_tags = "" #@param {type:'string'}
general_threshold = 0.35 #@param {type:"slider", min:0, max:1, step:0.05}
character_threshold = 0.85 #@param {type:"slider", min:0, max:1, step:0.05}

config = {
    "_train_data_dir"           : train_data_dir,
    "batch_size"                : 8,
    "repo_id"                   : model,
    "recursive"                 : True,
    "remove_underscore"         : True,
    "general_threshold"         : general_threshold,
    "character_threshold"       : character_threshold,
    "caption_extension"         : ".txt",
    "max_data_loader_n_workers" : 2,
    "debug"                     : True,
    "undesired_tags"            : undesired_tags
}

args = ""
for k, v in config.items():
    if k.startswith("_"):
        args += f'"{v}" '
    elif isinstance(v, str):
        args += f'--{k}="{v}" '
    elif isinstance(v, bool) and v:
        args += f"--{k} "
    elif isinstance(v, float) and not isinstance(v, bool):
        args += f"--{k}={v} "
    elif isinstance(v, int) and not isinstance(v, bool):
        args += f"--{k}={v} "

final_args = f"python tag_images_by_wd14_tagger.py {args}"

os.chdir(finetune_dir)
!{final_args}

In [None]:
# @markdown ### ♻ **Custom Caption/Tag**
import os

%store -r

os.chdir(root_dir)

# @markdown Add or remove custom tags here.
extension   = ".txt"  # @param [".txt", ".caption"]
custom_tag  = "white background, grey background, simple background"  # @param {type:"string"}
# @markdown Use `sub_folder` option to specify a subfolder for multi-concept training.
# @markdown > Specify `--all` to process all subfolders/`recursive`
sub_folder  = "" #@param {type: "string"}
# @markdown Enable this to append custom tags at the end of lines.
append      = False  # @param {type:"boolean"}
# @markdown Enable this if you want to remove captions/tags instead.
remove_tag  = True  # @param {type:"boolean"}
recursive   = False

if sub_folder == "":
    image_dir = train_data_dir
elif sub_folder == "--all":
    image_dir = train_data_dir
    recursive = True
elif sub_folder.startswith("/content"):
    image_dir = sub_folder
else:
    image_dir = os.path.join(train_data_dir, sub_folder)
    os.makedirs(image_dir, exist_ok=True)

def read_file(filename):
    with open(filename, "r") as f:
        contents = f.read()
    return contents

def write_file(filename, contents):
    with open(filename, "w") as f:
        f.write(contents)

def process_tags(filename, custom_tag, append, remove_tag):
    contents = read_file(filename)
    tags = [tag.strip() for tag in contents.split(',')]
    custom_tags = [tag.strip() for tag in custom_tag.split(',')]

    for custom_tag in custom_tags:
        custom_tag = custom_tag.replace("_", " ")
        if remove_tag:
            while custom_tag in tags:
                tags.remove(custom_tag)
        else:
            if custom_tag not in tags:
                if append:
                    tags.append(custom_tag)
                else:
                    tags.insert(0, custom_tag)

    contents = ', '.join(tags)
    write_file(filename, contents)

def process_directory(image_dir, tag, append, remove_tag, recursive):
    for filename in os.listdir(image_dir):
        file_path = os.path.join(image_dir, filename)

        if os.path.isdir(file_path) and recursive:
            process_directory(file_path, tag, append, remove_tag, recursive)
        elif filename.endswith(extension):
            process_tags(file_path, tag, append, remove_tag)

tag = custom_tag

if not any(
    [filename.endswith(extension) for filename in os.listdir(image_dir)]
):
    for filename in os.listdir(image_dir):
        if filename.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp")):
            open(
                os.path.join(image_dir, filename.split(".")[0] + extension),
                "w",
            ).close()

if custom_tag:
    process_directory(image_dir, tag, append, remove_tag, recursive)

In [None]:
# @markdown ## ♻ **Bucketing and Latents Caching**
%store -r

# @markdown This code will create buckets based on the `bucket_resolution` provided for multi-aspect ratio training, and then convert all images within the `train_data_dir` to latents.
bucketing_json    = os.path.join(training_dir, "meta_lat.json")
metadata_json     = os.path.join(training_dir, "meta_clean.json")
bucket_resolution = 1024  # @param {type:"slider", min:512, max:1024, step:128}
mixed_precision   = "no"  # @param ["no", "fp16", "bf16"] {allow-input: false}
skip_existing     = False  # @param{type:"boolean"}
flip_aug          = False  # @param{type:"boolean"}
# @markdown Use `clean_caption` option to clean such as duplicate tags, `women` to `girl`, etc
clean_caption     = False #@param {type:"boolean"}
#@markdown Use the `recursive` option to process subfolders as well
recursive         = True #@param {type:"boolean"}

metadata_config = {
    "_train_data_dir": train_data_dir,
    "_out_json": metadata_json,
    "recursive": recursive,
    "full_path": recursive,
    "clean_caption": clean_caption
}

bucketing_config = {
    "_train_data_dir": train_data_dir,
    "_in_json": metadata_json,
    "_out_json": bucketing_json,
    "_model_name_or_path": vae_path if vae_path else model_path,
    "recursive": recursive,
    "full_path": recursive,
    "flip_aug": flip_aug,
    "skip_existing": skip_existing,
    "batch_size": 4,
    "max_data_loader_n_workers": 2,
    "max_resolution": f"{bucket_resolution}, {bucket_resolution}",
    "mixed_precision": mixed_precision,
}

def generate_args(config):
    args = ""
    for k, v in config.items():
        if k.startswith("_"):
            args += f'"{v}" '
        elif isinstance(v, str):
            args += f'--{k}="{v}" '
        elif isinstance(v, bool) and v:
            args += f"--{k} "
        elif isinstance(v, float) and not isinstance(v, bool):
            args += f"--{k}={v} "
        elif isinstance(v, int) and not isinstance(v, bool):
            args += f"--{k}={v} "
    return args.strip()

merge_metadata_args = generate_args(metadata_config)
prepare_buckets_args = generate_args(bucketing_config)

merge_metadata_command = f"python merge_all_to_metadata.py {merge_metadata_args}"
prepare_buckets_command = f"python prepare_buckets_latents.py {prepare_buckets_args}"

os.chdir(finetune_dir)
!{merge_metadata_command}
time.sleep(1)
!{prepare_buckets_command}


In [None]:
import toml

# @markdown ## ♻ **LoRa: Low-Rank Adaptation Config**
network_category = "LoRA_LierLa"
network_args    = ""
network_dim     = 32  # @param {'type':'number'}
network_alpha   = 16  # @param {'type':'number'}
conv_dim        = 32  # @param {'type':'number'}
conv_alpha      = 16  # @param {'type':'number'}
unit = 4  # @param {'type':'number'}

if isinstance(network_args, str):
    network_args = network_args.strip()
    if network_args.startswith('[') and network_args.endswith(']'):
        try:
            network_args = ast.literal_eval(network_args)
        except (SyntaxError, ValueError) as e:
            print(f"Error parsing network_args: {e}\n")
            network_args = []
    elif len(network_args) > 0:
        print(f"WARNING! '{network_args}' is not a valid list! Put args like this: [\"args=1\", \"args=2\"]\n")
        network_args = []
    else:
        network_args = []
else:
    network_args = []

network_config = {
    "LoRA_LierLa": {
        "module": "networks.lora",
        "args"  : []
    },
    "LoRA_C3Lier": {
        "module": "networks.lora",
        "args"  : [
            f"conv_dim={conv_dim}",
            f"conv_alpha={conv_alpha}"
        ]
    },
    "DyLoRA_LierLa": {
        "module": "networks.dylora",
        "args"  : [
            f"unit={unit}"
        ]
    },
    "DyLoRA_C3Lier": {
        "module": "networks.dylora",
        "args"  : [
            f"conv_dim={conv_dim}",
            f"conv_alpha={conv_alpha}",
            f"unit={unit}"
        ]
    },
    "LoCon": {
        "module": "lycoris.kohya",
        "args"  : [
            f"algo=locon",
            f"conv_dim={conv_dim}",
            f"conv_alpha={conv_alpha}"
        ]
    },
    "LoHa": {
        "module": "lycoris.kohya",
        "args"  : [
            f"algo=loha",
            f"conv_dim={conv_dim}",
            f"conv_alpha={conv_alpha}"
        ]
    },
    "IA3": {
        "module": "lycoris.kohya",
        "args"  : [
            f"algo=ia3",
            f"conv_dim={conv_dim}",
            f"conv_alpha={conv_alpha}"
        ]
    },
    "LoKR": {
        "module": "lycoris.kohya",
        "args"  : [
            f"algo=lokr",
            f"conv_dim={conv_dim}",
            f"conv_alpha={conv_alpha}"
        ]
    },
    "DyLoRA_Lycoris": {
        "module": "lycoris.kohya",
        "args"  : [
            f"algo=dylora",
            f"conv_dim={conv_dim}",
            f"conv_alpha={conv_alpha}"
        ]
    }
}

network_module = network_config[network_category]["module"]
network_args.extend(network_config[network_category]["args"])

lora_config = {
    "additional_network_arguments": {
        "no_metadata"                     : False,
        "network_module"                  : network_module,
        "network_dim"                     : network_dim,
        "network_alpha"                   : network_alpha,
        "network_args"                    : network_args,
        "network_train_unet_only"         : True,
        "training_comment"                : None,
    },
}

print(toml.dumps(lora_config))

In [None]:
import toml
import ast

# @markdown ## ♻ **Optimizer Config**
optimizer_type = "AdaFactor"  # @param ["AdamW", "AdamW8bit", "Lion8bit", "Lion", "SGDNesterov", "SGDNesterov8bit", "DAdaptation(DAdaptAdamPreprint)", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptAdanIP", "DAdaptLion", "DAdaptSGD", "AdaFactor"]
optimizer_args = "[ \"scale_parameter=False\", \"relative_step=False\", \"warmup_init=False\" ]"
learning_rate = 1e-4  # @param {'type':'number'}
lr_scheduler = "constant_with_warmup"  # @param ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "adafactor"] {allow-input: false}
lr_warmup_steps = 100  # @param {'type':'number'}
lr_scheduler_num = 0  # @param {'type':'number'}

if isinstance(optimizer_args, str):
    optimizer_args = optimizer_args.strip()
    if optimizer_args.startswith('[') and optimizer_args.endswith(']'):
        try:
            optimizer_args = ast.literal_eval(optimizer_args)
        except (SyntaxError, ValueError) as e:
            print(f"Error parsing optimizer_args: {e}\n")
            optimizer_args = []
    elif len(optimizer_args) > 0:
        print(f"WARNING! '{optimizer_args}' is not a valid list! Put args like this: [\"args=1\", \"args=2\"]\n")
        optimizer_args = []
    else:
        optimizer_args = []
else:
    optimizer_args = []

optimizer_config = {
    "optimizer_arguments": {
        "optimizer_type"          : optimizer_type,
        "learning_rate"           : learning_rate,
        "max_grad_norm"           : 0,
        "optimizer_args"          : optimizer_args,
        "lr_scheduler"            : lr_scheduler,
        "lr_warmup_steps"         : lr_warmup_steps,
        "lr_scheduler_num_cycles" : lr_scheduler_num if lr_scheduler == "cosine_with_restarts" else None,
        "lr_scheduler_power"      : lr_scheduler_num if lr_scheduler == "polynomial" else None,
        "lr_scheduler_type"       : None,
        "lr_scheduler_args"       : None,
    },
}

print(toml.dumps(optimizer_config))


In [None]:
# @markdown ## ♻ **Training Config**
import toml
import os
from subprocess import getoutput

%store -r

# @markdown ### Project Config
project_name                = "sdxl_lora"  # @param {type:"string"}
wandb_api_key               = ""
in_json                     = "/content/LoRA/meta_lat.json"  # @param {type:"string"}

gradient_checkpointing      = True
no_half_vae                 = True

cache_text_encoder_outputs  = False  # @param {type:"boolean"}
min_timestep                = 0 # @param {type:"number"}
max_timestep                = 1000 # @param {type:"number"}

num_repeats                 = 1  # @param {type:"number"}
resolution                  = 1024  # @param {type:"slider", min:512, max:1024, step:128}
keep_tokens                 = 0  # @param {type:"number"}

num_epochs                  = 10  # @param {type:"number"}
train_batch_size            = 4  # @param {type:"number"}
mixed_precision             = "fp16"  # @param ["no","fp16","bf16"] {allow-input: false}
seed                        = -1  # @param {type:"number"}
optimization                = "scaled dot-product attention"

save_precision              = "fp16"  # @param ["float", "fp16", "bf16"] {allow-input: false}
save_every_n_epochs         = 1  # @param {type:"number"}
enable_sample               = False
sampler                     = "euler_a"
positive_prompt             = ""
negative_prompt             = ""
quality_prompt              = "NovelAI"
if quality_prompt          == "NovelAI":
    positive_prompt         = "masterpiece, best quality, "
    negative_prompt         = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, "
if quality_prompt          == "AbyssOrangeMix":
    positive_prompt         = "masterpiece, best quality, "
    negative_prompt         = "(worst quality, low quality:1.4), "
if quality_prompt          == "Stable Diffusion XL":
    negative_prompt         = "3d render, smooth, plastic, blurry, grainy, low-resolution, deep-fried, oversaturated"
custom_prompt               = "face focus, cute, 1girl, green hair, sweater, looking at viewer, upper body, beanie, outdoors, night, turtleneck"
prompt_from_caption         = "none"
if prompt_from_caption     != "none":
    custom_prompt           = ""
num_prompt                  = 2
logging_dir                 = os.path.join(training_dir, "logs")
lowram                      = int(next(line.split()[1] for line in open('/proc/meminfo') if "MemTotal" in line)) / (1024**2) < 15

os.chdir(repo_dir)

prompt_config = {
    "prompt": {
        "negative_prompt" : negative_prompt,
        "width"           : resolution,
        "height"          : resolution,
        "scale"           : 12,
        "sample_steps"    : 28,
        "subset"          : [],
    }
}

train_config = {
    "sdxl_arguments": {
        "cache_text_encoder_outputs" : cache_text_encoder_outputs,
        "no_half_vae"                : True,
        "min_timestep"               : min_timestep,
        "max_timestep"               : max_timestep,
        "shuffle_caption"            : True if not cache_text_encoder_outputs else False,
        "lowram"                     : lowram
    },
    "model_arguments": {
        "pretrained_model_name_or_path" : model_path,
        "vae"                           : vae_path,
    },
    "dataset_arguments": {
        "debug_dataset"                 : False,
        "in_json"                       : in_json,
        "train_data_dir"                : train_data_dir,
        "dataset_repeats"               : num_repeats,
        "keep_tokens"                   : keep_tokens,
        "resolution"                    : str(resolution) + ',' + str(resolution),
        "color_aug"                     : False,
        "face_crop_aug_range"           : None,
        "token_warmup_min"              : 1,
        "token_warmup_step"             : 0,
    },
    "training_arguments": {
        "output_dir"                    : os.path.join(output_dir, project_name),
        "output_name"                   : project_name if project_name else "last",
        "save_precision"                : save_precision,
        "save_every_n_epochs"           : save_every_n_epochs,
        "save_n_epoch_ratio"            : None,
        "save_last_n_epochs"            : None,
        "resume"                        : None,
        "train_batch_size"              : train_batch_size,
        "max_token_length"              : 225,
        "mem_eff_attn"                  : False,
        "sdpa"                          : True if optimization == "scaled dot-product attention" else False,
        "xformers"                      : True if optimization == "xformers" else False,
        "max_train_epochs"              : num_epochs,
        "max_data_loader_n_workers"     : 8,
        "persistent_data_loader_workers": True,
        "seed"                          : seed if seed > 0 else None,
        "gradient_checkpointing"        : gradient_checkpointing,
        "gradient_accumulation_steps"   : 1,
        "mixed_precision"               : mixed_precision,
    },
    "logging_arguments": {
        "log_with"          : "wandb" if wandb_api_key else "tensorboard",
        "log_tracker_name"  : project_name if wandb_api_key and not project_name == "last" else None,
        "logging_dir"       : logging_dir,
        "log_prefix"        : project_name if not wandb_api_key else None,
    },
    "sample_prompt_arguments": {
        "sample_every_n_steps"    : None,
        "sample_every_n_epochs"   : save_every_n_epochs if enable_sample else None,
        "sample_sampler"          : sampler,
    },
    "saving_arguments": {
        "save_model_as": "safetensors"
    },
}

def write_file(filename, contents):
    with open(filename, "w") as f:
        f.write(contents)

def prompt_convert(enable_sample, num_prompt, train_data_dir, prompt_config, custom_prompt):
    if enable_sample:
        search_pattern = os.path.join(train_data_dir, '**/*' + prompt_from_caption)
        caption_files = glob.glob(search_pattern, recursive=True)

        if not caption_files:
            if not custom_prompt:
                custom_prompt = "masterpiece, best quality, 1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt"
            new_prompt_config = prompt_config.copy()
            new_prompt_config['prompt']['subset'] = [
                {"prompt": positive_prompt + custom_prompt if positive_prompt else custom_prompt}
            ]
        else:
            selected_files = random.sample(caption_files, min(num_prompt, len(caption_files)))

            prompts = []
            for file in selected_files:
                with open(file, 'r') as f:
                    prompts.append(f.read().strip())

            new_prompt_config = prompt_config.copy()
            new_prompt_config['prompt']['subset'] = []

            for prompt in prompts:
                new_prompt = {
                    "prompt": positive_prompt + prompt if positive_prompt else prompt,
                }
                new_prompt_config['prompt']['subset'].append(new_prompt)

        return new_prompt_config
    else:
        return prompt_config

def eliminate_none_variable(config):
    for key in config:
        if isinstance(config[key], dict):
            for sub_key in config[key]:
                if config[key][sub_key] == "":
                    config[key][sub_key] = None
        elif config[key] == "":
            config[key] = None

    return config

try:
    train_config.update(optimizer_config)
except NameError:
    raise NameError("'optimizer_config' dictionary is missing. Please run  '4.1. Optimizer Config' cell.")

try:
    train_config.update(lora_config)
except NameError:
    raise NameError("'lora_config' dictionary is missing. Please run  '4.1. LoRa: Low-Rank Adaptation Config' cell.")

advanced_training_warning = False
try:
    train_config.update(advanced_training_config)
except NameError:
    advanced_training_warning = True
    pass

prompt_config = prompt_convert(enable_sample, num_prompt, train_data_dir, prompt_config, custom_prompt)

config_path         = os.path.join(config_dir, "config_file.toml")
prompt_path         = os.path.join(config_dir, "sample_prompt.toml")

config_str          = toml.dumps(eliminate_none_variable(train_config))
prompt_str          = toml.dumps(eliminate_none_variable(prompt_config))

write_file(config_path, config_str)
write_file(prompt_path, prompt_str)

print(config_str)

if advanced_training_warning:
    import textwrap
    error_message = "WARNING: This is not an error message, but the [advanced_training_config] dictionary is missing. Please run the '4.2. Advanced Training Config' cell if you intend to use it, or continue to the next step."
    wrapped_message = textwrap.fill(error_message, width=80)
    print('\033[38;2;204;102;102m' + wrapped_message + '\033[0m\n')
    pass

print(prompt_str)


In [None]:
#@markdown ## ♻ **Start Training**
#@markdown You can import config from another session if you want.

import os
import toml

sample_prompt   = "/content/LoRA/config/sample_prompt.toml" #@param {type:'string'}
config_file     = "/content/LoRA/config/config_file.toml" #@param {type:'string'}

def read_file(filename):
    with open(filename, "r") as f:
        contents = f.read()
    return contents

def train(config):
    args = ""
    for k, v in config.items():
        if k.startswith("_"):
            args += f'"{v}" '
        elif isinstance(v, str):
            args += f'--{k}="{v}" '
        elif isinstance(v, bool) and v:
            args += f"--{k} "
        elif isinstance(v, float) and not isinstance(v, bool):
            args += f"--{k}={v} "
        elif isinstance(v, int) and not isinstance(v, bool):
            args += f"--{k}={v} "

    return args

accelerate_conf = {
    "config_file" : "/content/kohya-trainer/accelerate_config/config.yaml",
    "num_cpu_threads_per_process" : 1,
}

train_conf = {
    "sample_prompts"  : sample_prompt if os.path.exists(sample_prompt) else None,
    "config_file"     : config_file,
    "wandb_api_key"   : wandb_api_key if wandb_api_key else None
}

accelerate_args = train(accelerate_conf)
train_args = train(train_conf)

final_args = f"accelerate launch {accelerate_args} sdxl_train_network.py {train_args}"

os.chdir(repo_dir)
!{final_args}