<a href="https://colab.research.google.com/github/britkenko/kohya-trainer/blob/main/240418_sdxl_LoRA_trainer_XL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[![visitor][visitor-badge]][visitor-stats]
[![ko-fi][ko-fi-badge]][ko-fi-link]

# **Kohya LoRA Trainer XL**
A Colab Notebook For SDXL LoRA Training (Fine-tuning Method)
*Johnsons Fork*

[visitor-badge]: https://api.visitorbadge.io/api/visitors?path=Kohya%20LoRA%20Trainer%20XL&label=Visitors&labelColor=%2334495E&countColor=%231ABC9C&style=flat&labelStyle=none
[visitor-stats]: https://visitorbadge.io/status?path=Kohya%20LoRA%20Trainer%20XL
[ko-fi-badge]: https://img.shields.io/badge/Support%20me%20on%20Ko--fi-F16061?logo=ko-fi&logoColor=white&style=flat
[ko-fi-link]: https://ko-fi.com/linaqruf

| Notebook Name | Description | Link |
| --- | --- | --- |
| [Kohya LoRA Trainer XL](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-LoRA-trainer-XL.ipynb) | LoRA Training | [![](https://img.shields.io/static/v1?message=Open%20in%20Colab&logo=googlecolab&labelColor=5c5c5c&color=0f80c1&label=%20&style=flat)](https://colab.research.google.com/github/Linaqruf/kohya-trainer/blob/main/kohya-LoRA-trainer-XL.ipynb) |
| [Kohya Trainer XL](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-trainer-XL.ipynb) | Native Training | [![](https://img.shields.io/static/v1?message=Open%20in%20Colab&logo=googlecolab&labelColor=5c5c5c&color=0f80c1&label=%20&style=flat)](https://colab.research.google.com/github/Linaqruf/kohya-trainer/blob/main/kohya-trainer-XL.ipynb) |


# **I. Prepare Environment**

In [None]:
# You can run this notebook in Tesla T4 but make sure you're using High RAM setting. (Colab Pro gated)
!nvidia-smi

Thu Apr 18 00:16:00 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   31C    P0              41W / 350W |      2MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
!pip install --upgrade pip
!pip3 uninstall --yes torch torchaudio torchvision torchtext torchdata
!pip3 install torch torchaudio torchvision torchtext torchdata
!pip install xformers
!pip install "jax[cuda12_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install scvi-tools

import os
import zipfile
import shutil
import time
import requests
import torch
from subprocess import getoutput
from IPython.utils import capture
from google.colab import drive

%store -r

# root_dir
root_dir          = "/content"
drive_dir         = os.path.join(root_dir, "drive/MyDrive")
deps_dir          = os.path.join(root_dir, "deps")
repo_dir          = os.path.join(root_dir, "kohya-trainer")
training_dir      = os.path.join(root_dir, "LoRA")
pretrained_model  = os.path.join(root_dir, "pretrained_model")
vae_dir           = os.path.join(root_dir, "vae")
lora_dir          = os.path.join(root_dir, "network_weight")
repositories_dir  = os.path.join(root_dir, "repositories")
config_dir        = os.path.join(training_dir, "config")
tools_dir         = os.path.join(repo_dir, "tools")
finetune_dir      = os.path.join(repo_dir, "finetune")
accelerate_config = os.path.join(repo_dir, "accelerate_config/config.yaml")

for store in ["root_dir", "repo_dir", "training_dir", "pretrained_model", "vae_dir", "repositories_dir", "accelerate_config", "tools_dir", "finetune_dir", "config_dir"]:
    with capture.capture_output() as cap:
        %store {store}
        del cap

repo_dict = {
    "qaneel/kohya-trainer (forked repo, stable, optimized for colab use)" : "https://github.com/qaneel/kohya-trainer",
    "kohya-ss/sd-scripts (original repo, latest update)"                    : "https://github.com/kohya-ss/sd-scripts",
}

repository        = "qaneel/kohya-trainer (forked repo, stable, optimized for colab use)" #@param ["qaneel/kohya-trainer (forked repo, stable, optimized for colab use)", "kohya-ss/sd-scripts (original repo, latest update)"] {allow-input: true}
repo_url          = repo_dict[repository]
branch            = "main"  # @param {type: "string"}
output_to_drive   = True  # @param {type: "boolean"}

def clone_repo(url, dir, branch):
    if not os.path.exists(dir):
       !git clone -b {branch} {url} {dir}

def mount_drive(dir):
    output_dir      = os.path.join(training_dir, "output")

    if output_to_drive:
        if not os.path.exists(drive_dir):
            drive.mount(os.path.dirname(drive_dir))
        output_dir  = os.path.join(drive_dir, "kohya-trainer/output")

    return output_dir

def setup_directories():
    global output_dir

    output_dir      = mount_drive(drive_dir)

    for dir in [training_dir, config_dir, pretrained_model, vae_dir, repositories_dir, output_dir]:
        os.makedirs(dir, exist_ok=True)

def pastebin_reader(id):
    if "pastebin.com" in id:
        url = id
        if 'raw' not in url:
                url = url.replace('pastebin.com', 'pastebin.com/raw')
    else:
        url = "https://pastebin.com/raw/" + id
    response = requests.get(url)
    response.raise_for_status()
    lines = response.text.split('\n')
    return lines

def install_repository():
    global infinite_image_browser_dir, voldy, discordia_archivum_dir

    _, voldy = pastebin_reader("kq6ZmHFU")[:2]

    infinite_image_browser_url  = f"https://github.com/zanllp/{voldy}-infinite-image-browsing.git"
    infinite_image_browser_dir  = os.path.join(repositories_dir, f"infinite-image-browsing")
    infinite_image_browser_deps = os.path.join(infinite_image_browser_dir, "requirements.txt")

    discordia_archivum_url = "https://github.com/Linaqruf/discordia-archivum"
    discordia_archivum_dir = os.path.join(repositories_dir, "discordia-archivum")
    discordia_archivum_deps = os.path.join(discordia_archivum_dir, "requirements.txt")

    clone_repo(infinite_image_browser_url, infinite_image_browser_dir, "main")
    clone_repo(discordia_archivum_url, discordia_archivum_dir, "main")

    !pip install -q --upgrade -r {infinite_image_browser_deps}
    !pip install python-dotenv
    !pip install -q --upgrade -r {discordia_archivum_deps}

def install_dependencies():
    requirements_file = os.path.join(repo_dir, "requirements.txt")
    model_util        = os.path.join(repo_dir, "library/model_util.py")
    gpu_info          = getoutput('nvidia-smi')
    t4_xformers_wheel = "https://github.com/Linaqruf/colab-xformers/releases/download/0.0.20/xformers-0.0.20+1d635e1.d20230519-cp310-cp310-linux_x86_64.whl"

    !apt install aria2 lz4
    !pip install -q --upgrade -r {requirements_file}

    if '2.0.1+cu118' in torch.__version__:
        if 'T4' in gpu_info:
            !pip install -q {t4_xformers_wheel}
        else:
         !pip3 install -q --upgrade xformers==0.0.20 triton==2.0.0 --force -reinstall
         !pip3 install diffusers==0.11.1
         !pip3 install transformers scipy ftfy accelerate
         !pip3 install bitsandbytes>=0.43.0
    else:

        from accelerate.utils import write_basic_config

        if not os.path.exists(accelerate_config):
          write_basic_config(save_location=accelerate_config)

def prepare_environment():
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["SAFETENSORS_FAST_GPU"] = "1"
    os.environ["PYTHONWARNINGS"] = "ignore"

def main():
    os.chdir(root_dir)
    clone_repo(repo_url, repo_dir, branch)
    os.chdir(repo_dir)
    setup_directories()
    install_repository()
    install_dependencies()
    prepare_environment()

main()

[0mFound existing installation: torch 2.2.2
Uninstalling torch-2.2.2:
  Successfully uninstalled torch-2.2.2
Found existing installation: torchaudio 2.2.2
Uninstalling torchaudio-2.2.2:
  Successfully uninstalled torchaudio-2.2.2
Found existing installation: torchvision 0.17.2
Uninstalling torchvision-0.17.2:
  Successfully uninstalled torchvision-0.17.2
Found existing installation: torchtext 0.17.2
Uninstalling torchtext-0.17.2:
  Successfully uninstalled torchtext-0.17.2
Found existing installation: torchdata 0.7.1
Uninstalling torchdata-0.7.1:
  Successfully uninstalled torchdata-0.7.1
[0mCollecting torch
  Using cached torch-2.2.2-cp310-cp310-manylinux1_x86_64.whl.metadata (26 kB)
Collecting torchaudio
  Using cached torchaudio-2.2.2-cp310-cp310-manylinux1_x86_64.whl.metadata (6.4 kB)
Collecting torchvision
  Using cached torchvision-0.17.2-cp310-cp310-manylinux1_x86_64.whl.metadata (6.6 kB)
Collecting torchtext
  Using cached torchtext-0.17.2-cp310-cp310-manylinux1_x86_64.whl.me

In [None]:
# @title ## **1.2. Download SDXL**
import os
import re
import json
import glob
import gdown
import requests
import subprocess
from IPython.utils import capture
from urllib.parse import urlparse, unquote
from pathlib import Path

%store -r

os.chdir(root_dir)

# @markdown Place your Huggingface [Read Token](https://huggingface.co/settings/tokens) Here. Get your SDXL access [here](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9).

HUGGINGFACE_TOKEN = "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"#@param {type: "string"}
SDXL_MODEL_URL = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors" #@param {type: "string"}
SDXL_VAE_URL = "https://huggingface.co/stabilityai/sdxl-vae/resolve/main/sdxl_vae.safetensors" #@param {type: "string"}

def get_supported_extensions():
    return tuple([".ckpt", ".safetensors", ".pt", ".pth"])

def get_filename(url, quiet=True):
    extensions = get_supported_extensions()

    if url.startswith(drive_dir) or url.endswith(tuple(extensions)):
        filename = os.path.basename(url)
    else:
        response = requests.get(url, stream=True)
        response.raise_for_status()

        if 'content-disposition' in response.headers:
            content_disposition = response.headers['content-disposition']
            filename = re.findall('filename="?([^"]+)"?', content_disposition)[0]
        else:
            url_path = urlparse(url).path
            filename = unquote(os.path.basename(url_path))

    if filename.endswith(tuple(get_supported_extensions())):
        return filename
    else:
        return None

def get_most_recent_file(directory):
    files = glob.glob(os.path.join(directory, "*"))
    if not files:
        return None
    most_recent_file = max(files, key=os.path.getmtime)
    basename = os.path.basename(most_recent_file)

    return most_recent_file

def parse_args(config):
    args = []

    for k, v in config.items():
        if k.startswith("_"):
            args.append(f"{v}")
        elif isinstance(v, str) and v is not None:
            args.append(f'--{k}={v}')
        elif isinstance(v, bool) and v:
            args.append(f"--{k}")
        elif isinstance(v, float) and not isinstance(v, bool):
            args.append(f"--{k}={v}")
        elif isinstance(v, int) and not isinstance(v, bool):
            args.append(f"--{k}={v}")

    return args

def aria2_download(dir, filename, url):
    # hf_token    = "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" if not HUGGINGFACE_TOKEN else HUGGINGFACE_TOKEN
    user_header = f"Authorization: Bearer {HUGGINGFACE_TOKEN}"

    aria2_config = {
        "console-log-level"         : "error",
        "summary-interval"          : 10,
        "header"                    : user_header if "huggingface.co" in url else None,
        "continue"                  : True,
        "max-connection-per-server" : 16,
        "min-split-size"            : "1M",
        "split"                     : 16,
        "dir"                       : dir,
        "out"                       : filename,
        "_url"                      : url,
    }
    aria2_args = parse_args(aria2_config)
    subprocess.run(["aria2c", *aria2_args])

def gdown_download(url, dst, filepath):
    if "/uc?id/" in url:
        return gdown.download(url, filepath, quiet=False)
    elif "/file/d/" in url:
        return gdown.download(url=url, output=filepath, quiet=False, fuzzy=True)
    elif "/drive/folders/" in url:
        os.chdir(dst)
        return gdown.download_folder(url, quiet=True, use_cookies=False)

def download(url, dst):
    filename = get_filename(url, quiet=False)
    filepath = os.path.join(dst, filename)

    if "drive.google.com" in url:
        gdown = gdown_download(url, dst, filepath)
    elif url.startswith("/content/drive/MyDrive/"):
        # Path(filepath).write_bytes(Path(url).read_bytes())
        return url
    else:
        if "huggingface.co" in url:
            if "/blob/" in url:
                url = url.replace("/blob/", "/resolve/")
        aria2_download(dst, filename, url)

def get_filepath(url, dst):
    extensions = get_supported_extensions()
    filename = get_filename(url)

    if not filename.endswith(extensions):
        most_recent_file = get_most_recent_file(dst)
        filename = os.path.basename(most_recent_file)

    filepath = os.path.join(dst, filename)

    return filepath

def main():
    global model_path, vae_path

    model_path = vae_path = None

    download_targets = {
        "model" : (SDXL_MODEL_URL, pretrained_model),
        "vae"   : (SDXL_VAE_URL, vae_dir),
    }
    selected_files = {}

    for target, (url, dst) in download_targets.items():
        if url and f"PASTE {target.upper()} URL OR GDRIVE PATH HERE" not in url:
            downloader = download(url, dst)
            selected_files[target] = get_filepath(url, dst)

            if target == "model":
                model_path = selected_files["model"] if not downloader else downloader
            elif target == "vae":
                vae_path = selected_files["vae"] if not downloader else downloader

    for category, path in {
        "model": model_path,
        "vae": vae_path,
    }.items():
        if path is not None and os.path.exists(path):
            print(f"Selected {category}: {path}")

main()

Selected model: /content/pretrained_model/sd_xl_base_1.0.safetensors
Selected vae: /content/vae/sdxl_vae.safetensors


In [None]:
# @title ## **1.3. Directory Config**
# @markdown Specify the location of your training data in the following cell. A folder with the same name as your input will be created.
import os

%store -r

train_data_dir = "/content/drive/MyDrive/train/js"  # @param {'type' : 'string'}
%store train_data_dir

os.makedirs(train_data_dir, exist_ok=True)
print(f"Your train data directory : {train_data_dir}")

Stored 'train_data_dir' (str)
Your train data directory : /content/drive/MyDrive/train/js


In [None]:
!pip show jax


Name: jax
Version: 0.4.26
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /usr/local/lib/python3.10/dist-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, dopamine-rl, flax, numpyro, optax, orbax-checkpoint, scvi-tools


In [None]:
# @title ## **3.4. Bucketing and Latents Caching**
%store -r

# @markdown This code will create buckets based on the `bucket_resolution` provided for multi-aspect ratio training, and then convert all images within the `train_data_dir` to latents.
bucketing_json    = os.path.join(training_dir, "meta_lat.json")
metadata_json     = os.path.join(training_dir, "meta_clean.json")
bucket_resolution = 1024  # @param {type:"slider", min:512, max:1024, step:128}
mixed_precision   = "no"  # @param ["no", "fp16", "bf16"] {allow-input: false}
flip_aug          = False  # @param{type:"boolean"}
# @markdown Use `clean_caption` option to clean such as duplicate tags, `women` to `girl`, etc
clean_caption     = True #@param {type:"boolean"}
#@markdown Use the `recursive` option to process subfolders as well
recursive         = True #@param {type:"boolean"}

metadata_config = {
    "_train_data_dir": train_data_dir,
    "_out_json": metadata_json,
    "recursive": recursive,
    "full_path": recursive,
    "clean_caption": clean_caption
}

bucketing_config = {
    "_train_data_dir": train_data_dir,
    "_in_json": metadata_json,
    "_out_json": bucketing_json,
    "_model_name_or_path": model_path,
    "recursive": recursive,
    "full_path": recursive,
    "flip_aug": flip_aug,
    "batch_size": 4,
    "max_data_loader_n_workers": 2,
    "max_resolution": f"{bucket_resolution}, {bucket_resolution}",
    "mixed_precision": mixed_precision,
}

def generate_args(config):
    args = ""
    for k, v in config.items():
        if k.startswith("_"):
            args += f'"{v}" '
        elif isinstance(v, str):
            args += f'--{k}="{v}" '
        elif isinstance(v, bool) and v:
            args += f"--{k} "
        elif isinstance(v, float) and not isinstance(v, bool):
            args += f"--{k}={v} "
        elif isinstance(v, int) and not isinstance(v, bool):
            args += f"--{k}={v} "
    return args.strip()

merge_metadata_args = generate_args(metadata_config)
prepare_buckets_args = generate_args(bucketing_config)

merge_metadata_command = f"python merge_all_to_metadata.py {merge_metadata_args}"
prepare_buckets_command = f"python prepare_buckets_latents.py {prepare_buckets_args}"

os.chdir(finetune_dir)
!{merge_metadata_command}
time.sleep(1)
!{prepare_buckets_command}


Found 58 images.
Creating a new metadata file
Merging tags and captions into metadata json.
  0% 0/58 [00:00<?, ?it/s]

# **IV. Training**



In [None]:
import toml

# @title ## **4.1. LoRa: Low-Rank Adaptation Config**
# @markdown Kohya's `LoRA` renamed to `LoRA-LierLa` and Kohya's `LoCon` renamed to `LoRA-C3Lier`, read [official announcement](https://github.com/kohya-ss/sd-scripts/blob/849bc24d205a35fbe1b2a4063edd7172533c1c01/README.md#naming-of-lora).
network_category = "LoRA_LierLa"  # @param ["LoRA_LierLa", "LoRA_C3Lier", "DyLoRA_LierLa", "DyLoRA_C3Lier", "LoCon", "LoHa", "IA3", "LoKR", "DyLoRA_Lycoris"]

# @markdown | network_category | network_dim | network_alpha | conv_dim | conv_alpha | unit |
# @markdown | :---: | :---: | :---: | :---: | :---: | :---: |
# @markdown | LoRA-LierLa | 32 | 1 | - | - | - |
# @markdown | LoCon/LoRA-C3Lier | 16 | 8 | 8 | 1 | - |
# @markdown | LoHa | 8 | 4 | 4 | 1 | - |
# @markdown | Other Category | ? | ? | ? | ? | - |

# @markdown Specify `network_args` to add `optional` training args, like for specifying each 25 block weight, read [this](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E9%9A%8E%E5%B1%A4%E5%88%A5%E5%AD%A6%E7%BF%92%E7%8E%87)
network_args    = ""  # @param {'type':'string'}

# @markdown ### **Linear Layer Config**
# @markdown Used by all `network_category`. When in doubt, set `network_dim = network_alpha`
network_dim     = 32  # @param {'type':'number'}
network_alpha   = 1  # @param {'type':'number'}

# @markdown ### **Convolutional Layer Config**
# @markdown Only required if `network_category` is not `LoRA_LierLa`, as it involves training convolutional layers in addition to linear layers.
conv_dim        = 8  # @param {'type':'number'}
conv_alpha      = 1  # @param {'type':'number'}

# @markdown ### **DyLoRA Config**
# @markdown Only required if `network_category` is `DyLoRA_LierLa` and `DyLoRA_C3Lier`
unit = 4  # @param {'type':'number'}

if isinstance(network_args, str):
    network_args = network_args.strip()
    if network_args.startswith('[') and network_args.endswith(']'):
        try:
            network_args = ast.literal_eval(network_args)
        except (SyntaxError, ValueError) as e:
            print(f"Error parsing network_args: {e}\n")
            network_args = []
    elif len(network_args) > 0:
        print(f"WARNING! '{network_args}' is not a valid list! Put args like this: [\"args=1\", \"args=2\"]\n")
        network_args = []
    else:
        network_args = []
else:
    network_args = []

network_config = {
    "LoRA_LierLa": {
        "module": "networks.lora",
        "args"  : []
    },
    "LoRA_C3Lier": {
        "module": "networks.lora",
        "args"  : [
            f"conv_dim={conv_dim}",
            f"conv_alpha={conv_alpha}"
        ]
    },
    "DyLoRA_LierLa": {
        "module": "networks.dylora",
        "args"  : [
            f"unit={unit}"
        ]
    },
    "DyLoRA_C3Lier": {
        "module": "networks.dylora",
        "args"  : [
            f"conv_dim={conv_dim}",
            f"conv_alpha={conv_alpha}",
            f"unit={unit}"
        ]
    },
    "LoCon": {
        "module": "lycoris.kohya",
        "args"  : [
            f"algo=locon",
            f"conv_dim={conv_dim}",
            f"conv_alpha={conv_alpha}"
        ]
    },
    "LoHa": {
        "module": "lycoris.kohya",
        "args"  : [
            f"algo=loha",
            f"conv_dim={conv_dim}",
            f"conv_alpha={conv_alpha}"
        ]
    },
    "IA3": {
        "module": "lycoris.kohya",
        "args"  : [
            f"algo=ia3",
            f"conv_dim={conv_dim}",
            f"conv_alpha={conv_alpha}"
        ]
    },
    "LoKR": {
        "module": "lycoris.kohya",
        "args"  : [
            f"algo=lokr",
            f"conv_dim={conv_dim}",
            f"conv_alpha={conv_alpha}"
        ]
    },
    "DyLoRA_Lycoris": {
        "module": "lycoris.kohya",
        "args"  : [
            f"algo=dylora",
            f"conv_dim={conv_dim}",
            f"conv_alpha={conv_alpha}"
        ]
    }
}

network_module = network_config[network_category]["module"]
network_args.extend(network_config[network_category]["args"])

lora_config = {
    "additional_network_arguments": {
        "no_metadata"                     : False,
        "network_module"                  : network_module,
        "network_dim"                     : network_dim,
        "network_alpha"                   : network_alpha,
        "network_args"                    : network_args,
        "network_train_unet_only"         : True,
        "training_comment"                : None,
    },
}

print(toml.dumps(lora_config))

[additional_network_arguments]
no_metadata = false
network_module = "networks.lora"
network_dim = 32
network_alpha = 1
network_args = []
network_train_unet_only = true



In [None]:
import toml
import ast

# @title ## **4.2. Optimizer Config**
# @markdown Use `Adafactor` optimizer. `RMSprop 8bit` or `Adagrad 8bit` may work. `AdamW 8bit` doesn't seem to work.
optimizer_type = "AdaFactor"  # @param ["AdamW", "AdamW8bit", "Lion8bit", "Lion", "SGDNesterov", "SGDNesterov8bit", "DAdaptation(DAdaptAdamPreprint)", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptAdanIP", "DAdaptLion", "DAdaptSGD", "AdaFactor"]
# @markdown Specify `optimizer_args` to add `additional` args for optimizer, e.g: `["weight_decay=0.6"]`
optimizer_args = "[ \"scale_parameter=False\", \"relative_step=False\", \"warmup_init=False\" ]"  # @param {'type':'string'}
# @markdown ### **Learning Rate Config**
# @markdown Different `optimizer_type` and `network_category` for some condition requires different learning rate. It's recommended to set `text_encoder_lr = 1/2 * unet_lr`
learning_rate = 3e-3  # @param {'type':'number'}
# @markdown ### **LR Scheduler Config**
# @markdown `lr_scheduler` provides several methods to adjust the learning rate based on the number of epochs.
lr_scheduler = "adafactor"  # @param ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "adafactor"] {allow-input: false}
lr_warmup_steps = None  # @param {'type':'number'}
# @markdown Specify `lr_scheduler_num` with `num_cycles` value for `cosine_with_restarts` or `power` value for `polynomial`
lr_scheduler_num = 1  # @param {'type':'number'}

if isinstance(optimizer_args, str):
    optimizer_args = optimizer_args.strip()
    if optimizer_args.startswith('[') and optimizer_args.endswith(']'):
        try:
            optimizer_args = ast.literal_eval(optimizer_args)
        except (SyntaxError, ValueError) as e:
            print(f"Error parsing optimizer_args: {e}\n")
            optimizer_args = []
    elif len(optimizer_args) > 0:
        print(f"WARNING! '{optimizer_args}' is not a valid list! Put args like this: [\"args=1\", \"args=2\"]\n")
        optimizer_args = []
    else:
        optimizer_args = []
else:
    optimizer_args = []

optimizer_config = {
    "optimizer_arguments": {
        "optimizer_type"          : optimizer_type,
        "learning_rate"           : learning_rate,
        "max_grad_norm"           : 0,
        "optimizer_args"          : optimizer_args,
        "lr_scheduler"            : lr_scheduler,
        "lr_warmup_steps"         : lr_warmup_steps,
        "lr_scheduler_num_cycles" : lr_scheduler_num if lr_scheduler == "cosine_with_restarts" else None,
        "lr_scheduler_power"      : lr_scheduler_num if lr_scheduler == "polynomial" else None,
        "lr_scheduler_type"       : None,
        "lr_scheduler_args"       : None,
    },
}

print(toml.dumps(optimizer_config))


[optimizer_arguments]
optimizer_type = "AdaFactor"
learning_rate = 0.003
max_grad_norm = 0
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False",]
lr_scheduler = "adafactor"



In [None]:
# @title ## **4.3. Advanced Training Config** (Optional)
import toml

# @markdown ### **Noise Control**
noise_control_type        = "multires_noise" #@param ["none", "noise_offset", "multires_noise"]
# @markdown #### **a. Noise Offset**
# @markdown Control and easily generating darker or light images by offset the noise when fine-tuning the model. Recommended value: `0.1`. Read [Diffusion With Offset Noise](https://www.crosslabs.org//blog/diffusion-with-offset-noise)
noise_offset_num          = 0.1  # @param {type:"number"}
# @markdown **[Experimental]**
# @markdown Automatically adjusts the noise offset based on the absolute mean values of each channel in the latents when used with `--noise_offset`. Specify a value around 1/10 to the same magnitude as the `--noise_offset` for best results. Set `0` to disable.
adaptive_noise_scale      = 0.01 # @param {type:"number"}
# @markdown #### **b. Multires Noise**
# @markdown enable multires noise with this number of iterations (if enabled, around 6-10 is recommended)
multires_noise_iterations = 8 #@param {type:"slider", min:1, max:10, step:1}
multires_noise_discount = 0.6 #@param {type:"slider", min:0.1, max:1, step:0.1}
# @markdown ### **Custom Train Function**
# @markdown Gamma for reducing the weight of high-loss timesteps. Lower numbers have a stronger effect. The paper recommends `5`. Read the paper [here](https://arxiv.org/abs/2303.09556).
min_snr_gamma             = 3 #@param {type:"number"}

advanced_training_config = {
    "advanced_training_config": {
        "noise_offset"              : noise_offset_num if noise_control_type == "noise_offset" else None,
        "adaptive_noise_scale"      : adaptive_noise_scale if adaptive_noise_scale and noise_control_type == "noise_offset" else None,
        "multires_noise_iterations" : multires_noise_iterations if noise_control_type =="multires_noise" else None,
        "multires_noise_discount"   : multires_noise_discount if noise_control_type =="multires_noise" else None,
        "min_snr_gamma"             : min_snr_gamma if not min_snr_gamma == -1 else None,
    }
}

print(toml.dumps(advanced_training_config))

[advanced_training_config]
multires_noise_iterations = 8
multires_noise_discount = 0.6
min_snr_gamma = 3



In [None]:
# @title ## **4.4. Training Config**
import toml
import os
from subprocess import getoutput

%store -r

# @markdown ### **Project Config**
project_name                = "js"  # @param {type:"string"}
# @markdown Get your `wandb_api_key` [here](https://wandb.ai/settings) to logs with wandb.
wandb_api_key               = "855a9e4f2b1aa8b2e2e27e9c1c1373071fb23c42" # @param {type:"string"}
in_json                     = "/content/LoRA/meta_lat.json"  # @param {type:"string"}
# @markdown ### **SDXL Config**
gradient_checkpointing      = True  # @param {type:"boolean"}
no_half_vae                 = True  # @param {type:"boolean"}
#@markdown Recommended parameter for SDXL training but if you enable it, `shuffle_caption` won't work
cache_text_encoder_outputs  = True  # @param {type:"boolean"}
#@markdown These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
min_timestep                = 0 # @param {type:"number"}
max_timestep                = 1000 # @param {type:"number"}
# @markdown ### **Dataset Config**
num_repeats                 = 10  # @param {type:"number"}
resolution                  = 1024  # @param {type:"slider", min:512, max:1024, step:128}
keep_tokens                 = 0  # @param {type:"number"}
# @markdown ### **General Config**
num_epochs                  = 6  # @param {type:"number"}
train_batch_size            = 4  # @param {type:"number"}
mixed_precision             = "bf16"  # @param ["no","fp16","bf16"] {allow-input: false}
seed                        = -1  # @param {type:"number"}
optimization                = "xformers" # @param ["xformers", "scaled dot-product attention"]
# @markdown ### **Save Output Config**
save_precision              = "bf16"  # @param ["float", "fp16", "bf16"] {allow-input: false}
save_every_n_epochs         = 1  # @param {type:"number"}
# @markdown ### **Sample Prompt Config**
enable_sample               = True  # @param {type:"boolean"}
sampler                     = "euler_a"  # @param ["ddim", "pndm", "lms", "euler", "euler_a", "heun", "dpm_2", "dpm_2_a", "dpmsolver","dpmsolver++", "dpmsingle", "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"]
positive_prompt             = ""
negative_prompt             = ""
quality_prompt              = "None"  # @param ["None", "Waifu Diffusion 1.5", "NovelAI", "AbyssOrangeMix", "Stable Diffusion XL"] {allow-input: false}
if quality_prompt          == "Waifu Diffusion 1.5":
    positive_prompt         = "(exceptional, best aesthetic, new, newest, best quality, masterpiece, extremely detailed, anime, waifu:1.2), "
    negative_prompt         = "lowres, ((bad anatomy)), ((bad hands)), missing finger, extra digits, fewer digits, blurry, ((mutated hands and fingers)), (poorly drawn face), ((mutation)), ((deformed face)), (ugly), ((bad proportions)), ((extra limbs)), extra face, (double head), (extra head), ((extra feet)), monster, logo, cropped, worst quality, jpeg, humpbacked, long body, long neck, ((jpeg artifacts)), deleted, old, oldest, ((censored)), ((bad aesthetic)), (mosaic censoring, bar censor, blur censor), "
if quality_prompt          == "NovelAI":
    positive_prompt         = "masterpiece, best quality, "
    negative_prompt         = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, "
if quality_prompt         == "AbyssOrangeMix":
    positive_prompt         = "masterpiece, best quality, "
    negative_prompt         = "(worst quality, low quality:1.4), "
if quality_prompt          == "Stable Diffusion XL":
    negative_prompt         = "3d render, smooth, plastic, blurry, grainy, low-resolution, deep-fried, oversaturated"
custom_prompt               = "a Korean ink wash painting by Jeongseon, cityscape of Rio, " # @param {type:"string"}
# @markdown Specify `prompt_from_caption` if you want to use caption as prompt instead. Will be chosen randomly.
prompt_from_caption         = "none"  # @param ["none", ".txt", ".caption"]
if prompt_from_caption != "none":
    custom_prompt           = ""
num_prompt                  = 2  # @param {type:"number"}
logging_dir                 = os.path.join(training_dir, "logs")

os.chdir(repo_dir)

prompt_config = {
    "prompt": {
        "negative_prompt" : negative_prompt,
        "width"           : resolution,
        "height"          : resolution,
        "scale"           : 7,
        "sample_steps"    : 28,
        "subset"          : [],
    }
}

train_config = {
    "sdxl_arguments": {
        "cache_text_encoder_outputs" : cache_text_encoder_outputs,
        # "enable_bucket"              : True,
        "no_half_vae"                : True,
        # "cache_latents"              : True,
        # "cache_latents_to_disk"      : True,
        # "vae_batch_size"             : 4,
        "min_timestep"               : min_timestep,
        "max_timestep"               : max_timestep,
        "shuffle_caption"            : True if not cache_text_encoder_outputs else False,
    },
    "model_arguments": {
        "pretrained_model_name_or_path" : model_path,
        "vae"                           : vae_path,
    },
    "dataset_arguments": {
        "debug_dataset"                 : False,
        "in_json"                       : in_json,
        "train_data_dir"                : train_data_dir,
        "dataset_repeats"               : num_repeats,
        "keep_tokens"                   : keep_tokens,
        "resolution"                    : str(resolution) + ',' + str(resolution),
        "color_aug"                     : False,
        "face_crop_aug_range"           : None,
        "token_warmup_min"              : 1,
        "token_warmup_step"             : 0,
    },
    "training_arguments": {
        "output_dir"                    : output_dir,
        "output_name"                   : project_name if project_name else "last",
        "save_precision"                : save_precision,
        "save_every_n_epochs"           : save_every_n_epochs,
        "save_n_epoch_ratio"            : None,
        "save_last_n_epochs"            : None,
        "save_state"                    : None,
        "save_last_n_epochs_state"      : None,
        "resume"                        : None,
        "train_batch_size"              : train_batch_size,
        "max_token_length"              : 225,
        "mem_eff_attn"                  : False,
        "sdpa"                          : True if optimization == "scaled dot-product attention" else False,
        "xformers"                      : True if optimization == "xformers" else False,
        "max_train_epochs"              : num_epochs,
        "max_data_loader_n_workers"     : 8,
        "persistent_data_loader_workers": True,
        "seed"                          : seed if seed > 0 else None,
        "gradient_checkpointing"        : gradient_checkpointing,
        "gradient_accumulation_steps"   : 1,
        "mixed_precision"               : mixed_precision,
    },
    "logging_arguments": {
        "log_with"          : "wandb" if wandb_api_key else "tensorboard",
        "log_tracker_name"  : project_name if wandb_api_key and not project_name == "last" else None,
        "logging_dir"       : logging_dir,
        "log_prefix"        : project_name if not wandb_api_key else None,
    },
    "sample_prompt_arguments": {
        "sample_every_n_steps"    : None,
        "sample_every_n_epochs"   : save_every_n_epochs if enable_sample else None,
        "sample_sampler"          : sampler,
    },
    "saving_arguments": {
        "save_model_as": "safetensors"
    },
}

def write_file(filename, contents):
    with open(filename, "w") as f:
        f.write(contents)

def prompt_convert(enable_sample, num_prompt, train_data_dir, prompt_config, custom_prompt):
    if enable_sample:
        search_pattern = os.path.join(train_data_dir, '**/*' + prompt_from_caption)
        caption_files = glob.glob(search_pattern, recursive=True)

        if not caption_files:
            if not custom_prompt:
                custom_prompt = "masterpiece, best quality, 1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt"
            new_prompt_config = prompt_config.copy()
            new_prompt_config['prompt']['subset'] = [
                {"prompt": positive_prompt + custom_prompt if positive_prompt else custom_prompt}
            ]
        else:
            selected_files = random.sample(caption_files, min(num_prompt, len(caption_files)))

            prompts = []
            for file in selected_files:
                with open(file, 'r') as f:
                    prompts.append(f.read().strip())

            new_prompt_config = prompt_config.copy()
            new_prompt_config['prompt']['subset'] = []

            for prompt in prompts:
                new_prompt = {
                    "prompt": positive_prompt + prompt if positive_prompt else prompt,
                }
                new_prompt_config['prompt']['subset'].append(new_prompt)

        return new_prompt_config
    else:
        return prompt_config

def eliminate_none_variable(config):
    for key in config:
        if isinstance(config[key], dict):
            for sub_key in config[key]:
                if config[key][sub_key] == "":
                    config[key][sub_key] = None
        elif config[key] == "":
            config[key] = None

    return config

try:
    train_config.update(optimizer_config)
except NameError:
    raise NameError("'optimizer_config' dictionary is missing. Please run  '4.1. Optimizer Config' cell.")

try:
    train_config.update(lora_config)
except NameError:
    raise NameError("'lora_config' dictionary is missing. Please run  '4.1. LoRa: Low-Rank Adaptation Config' cell.")

advanced_training_warning = False
try:
    train_config.update(advanced_training_config)
except NameError:
    advanced_training_warning = True
    pass

prompt_config = prompt_convert(enable_sample, num_prompt, train_data_dir, prompt_config, custom_prompt)

config_path         = os.path.join(config_dir, "config_file.toml")
prompt_path         = os.path.join(config_dir, "sample_prompt.toml")

config_str          = toml.dumps(eliminate_none_variable(train_config))
prompt_str          = toml.dumps(eliminate_none_variable(prompt_config))

write_file(config_path, config_str)
write_file(prompt_path, prompt_str)

print(config_str)

if advanced_training_warning:
    import textwrap
    error_message = "WARNING: This is not an error message, but the [advanced_training_config] dictionary is missing. Please run the '4.2. Advanced Training Config' cell if you intend to use it, or continue to the next step."
    wrapped_message = textwrap.fill(error_message, width=80)
    print('\033[38;2;204;102;102m' + wrapped_message + '\033[0m\n')
    pass

print(prompt_str)


[sdxl_arguments]
cache_text_encoder_outputs = true
no_half_vae = true
min_timestep = 0
max_timestep = 1000
shuffle_caption = false

[model_arguments]
pretrained_model_name_or_path = "/content/pretrained_model/sd_xl_base_1.0.safetensors"
vae = "/content/vae/sdxl_vae.safetensors"

[dataset_arguments]
debug_dataset = false
in_json = "/content/LoRA/meta_lat.json"
train_data_dir = "/content/drive/MyDrive/train/js"
dataset_repeats = 10
keep_tokens = 0
resolution = "1024,1024"
color_aug = false
token_warmup_min = 1
token_warmup_step = 0

[training_arguments]
output_dir = "/content/drive/MyDrive/kohya-trainer/output"
output_name = "js"
save_precision = "bf16"
save_every_n_epochs = 1
train_batch_size = 4
max_token_length = 225
mem_eff_attn = false
sdpa = false
xformers = true
max_train_epochs = 6
max_data_loader_n_workers = 8
persistent_data_loader_workers = true
gradient_checkpointing = true
gradient_accumulation_steps = 1
mixed_precision = "bf16"

[logging_arguments]
log_with = "wandb"
log_tr

In [None]:
!pip install dtype

Collecting dtype
  Downloading dtype-0.0.11-py3-none-any.whl.metadata (211 bytes)
Downloading dtype-0.0.11-py3-none-any.whl (1.2 kB)
Installing collected packages: dtype
Successfully installed dtype-0.0.11
[0m

In [None]:
#@title ## **4.5. Start Training**
import os
import toml

#@markdown Check your config here if you want to edit something:
#@markdown - `sample_prompt` : /content/LoRA/config/sample_prompt.toml
#@markdown - `config_file` : /content/LoRA/config/config_file.toml


#@markdown You can import config from another session if you want.

sample_prompt   = "/content/LoRA/config/sample_prompt.toml" #@param {type:'string'}
config_file     = "/content/LoRA/config/config_file.toml" #@param {type:'string'}

def read_file(filename):
    with open(filename, "r") as f:
        contents = f.read()
    return contents

def train(config):
    args = ""
    for k, v in config.items():
        if k.startswith("_"):
            args += f'"{v}" '
        elif isinstance(v, str):
            args += f'--{k}="{v}" '
        elif isinstance(v, bool) and v:
            args += f"--{k} "
        elif isinstance(v, float) and not isinstance(v, bool):
            args += f"--{k}={v} "
        elif isinstance(v, int) and not isinstance(v, bool):
            args += f"--{k}={v} "

    return args

accelerate_conf = {
    "config_file" : "/content/kohya-trainer/accelerate_config/config.yaml",
    "num_cpu_threads_per_process" : 1,
}

train_conf = {
    "sample_prompts"  : sample_prompt if os.path.exists(sample_prompt) else None,
    "config_file"     : config_file,
}

accelerate_args = train(accelerate_conf)
train_args = train(train_conf)

final_args = f"accelerate launch {accelerate_args} sdxl_train_network.py {train_args}"

os.chdir(repo_dir)
!{final_args}

    PyTorch 2.0.0+cu118 with CUDA 1108 (you have 2.2.2+cu121)
    Python  3.10.11 (you have 3.10.12)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details
Loading settings from /content/LoRA/config/config_file.toml...
/content/LoRA/config/config_file
prepare tokenizers
update token length: 225
Training with captions.
loading existing metadata: /content/LoRA/meta_lat.json
metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします
using bucket info in metadata / メタデータ内のbucket情報を使います
[Dataset 0]
  batch_size: 4
  resolution: (1024, 1024)
  enable_bucket: True
  min_bucket_reso: None
  max_bucket_reso: None
  bucket_reso_steps: None
  bucket_no_upscale: None

  [Subset 0 of Dataset 0]
    image_dir: "/content/drive/MyDrive/train/js"
    image_count: 58
    num_repeats: 10
    shuffle_caption: False
    keep_