# KOHYA TRAINER XL 

# Install Kohya Trainer

In [None]:
import os
import shutil

# root_dir
root_dir          = "/workspace"
drive_dir         = os.path.join(root_dir, "drive", "MyDrive")
repo_dir          = os.path.join(root_dir, "kohya-trainer")
training_dir      = os.path.join(root_dir, "fine_tune")
pretrained_model  = os.path.join(root_dir, "pretrained_model")
vae_dir           = os.path.join(root_dir, "vae")
lora_dir          = os.path.join(root_dir, "network_weight")
config_dir        = os.path.join(training_dir, "config")
output_dir        = os.path.join(training_dir, "outputs")
tools_dir         = os.path.join(repo_dir, "tools")
finetune_dir      = os.path.join(repo_dir, "finetune")
accelerate_config = os.path.join(repo_dir, "accelerate_config", "config.yaml")

repo_url          = "https://github.com/qaneel/kohya-trainer"
branch            = "main" 

def clone_repo(url, dir, branch):
    if not os.path.exists(dir):
       !git clone -b {branch} {url} {dir}

def install_dependencies():
    !apt update -yqq
    !apt install aria2 -yqq
    !pip install -q --upgrade -r requirements.txt
    !pip install xformers

    from accelerate.utils import write_basic_config

    if not os.path.exists(accelerate_config):
        write_basic_config(save_location=accelerate_config)

def prepare_environment():
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["SAFETENSORS_FAST_GPU"] = "1"
    os.environ["PYTHONWARNINGS"] = "ignore"

def main():
    os.chdir(root_dir)
    clone_repo(repo_url, repo_dir, branch)
    os.chdir(repo_dir)
    for dir in [training_dir, config_dir, pretrained_model, vae_dir, output_dir]:
        os.makedirs(dir, exist_ok=True)
    install_dependencies()
    prepare_environment()

main()

# Download SDXL

In [None]:
import os
import re
import requests
import subprocess
from urllib.parse import urlparse, unquote
from pathlib import Path

os.chdir(root_dir)

HUGGINGFACE_TOKEN = "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"
SDXL_MODEL_URL    = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9/resolve/main/sd_xl_base_0.9.safetensors"
SDXL_VAE_URL      = "https://huggingface.co/stabilityai/sdxl-vae/resolve/main/sdxl_vae.safetensors"

def get_supported_extensions():
    return tuple([".ckpt", ".safetensors", ".pt", ".pth"])

def get_filename(url, bearer_token, quiet=True):
    headers = {"Authorization": f"Bearer {bearer_token}"}
    response = requests.get(url, headers=headers, stream=True)
    response.raise_for_status()

    if 'content-disposition' in response.headers:
        content_disposition = response.headers['content-disposition']
        filename = re.findall('filename="?([^"]+)"?', content_disposition)[0]
    else:
        url_path = urlparse(url).path
        filename = unquote(os.path.basename(url_path))

    return filename

def parse_args(config):
    args = []

    for k, v in config.items():
        if k.startswith("_"):
            args.append(f"{v}")
        elif isinstance(v, str) and v is not None:
            args.append(f'--{k}={v}')
        elif isinstance(v, bool) and v:
            args.append(f"--{k}")
        elif isinstance(v, float) and not isinstance(v, bool):
            args.append(f"--{k}={v}")
        elif isinstance(v, int) and not isinstance(v, bool):
            args.append(f"--{k}={v}")

    return args

def aria2_download(dir, filename, url, token):
    user_header = f"Authorization: Bearer {token}"

    aria2_config = {
        "console-log-level"         : "error",
        "summary-interval"          : 10,
        "header"                    : user_header if "huggingface.co" in url else None,
        "continue"                  : True,
        "max-connection-per-server" : 16,
        "min-split-size"            : "1M",
        "split"                     : 16,
        "dir"                       : dir,
        "out"                       : filename,
        "_url"                      : url,
    }
    aria2_args = parse_args(aria2_config)
    subprocess.run(["aria2c", *aria2_args])

def download(url, dst, token):
    filename = get_filename(url, token, quiet=False)
    filepath = os.path.join(dst, filename)

    if url.startswith("/workspace"):
        return url
    elif "huggingface.co" in url:
        if "/blob/" in url:
            url = url.replace("/blob/", "/resolve/")
                
        aria2_download(dst, filename, url, token)

    return filepath

def main():
    global model_path, vae_path

    model_path = vae_path = None

    download_targets = {
        "model" : (SDXL_MODEL_URL, pretrained_model),
        "vae"   : (SDXL_VAE_URL, vae_dir),
    }
    selected_files = {}

    for target, (url, dst) in download_targets.items():
        if url:
            downloader = download(url, dst, HUGGINGFACE_TOKEN)
            selected_files[target] = downloader

            if target == "model":
                model_path = selected_files["model"] if not downloader else downloader
            elif target == "vae":
                vae_path = selected_files["vae"] if not downloader else downloader

    for category, path in {
        "model": model_path,
        "vae": vae_path,
    }.items():
        if path is not None and os.path.exists(path):
            print(f"Selected {category}: {path}")

main()

# Directory Config

In [None]:
import os

train_data_dir = "/workspace/fine_tune/train_data"

os.makedirs(train_data_dir, exist_ok=True)
print(f"Your train data directory : {train_data_dir}")

# Data Gathering

## Unzip Dataset
If your dataset is in a `zip` file and has been uploaded to a location, use this section to extract it. The dataset will be downloaded and automatically extracted to `train_data_dir` if `unzip_to` is empty.

In [None]:
import os
import zipfile
import shutil
from pathlib import Path

zipfile_url  = "https://huggingface.co/datasets/Linaqruf/hitokomoru-lora-dataset/resolve/main/hitokomoru_dataset.zip"
unzip_to     = ""

if unzip_to:
    os.makedirs(unzip_to, exist_ok=True)
else:
    unzip_to = train_data_dir

def extract_dataset(zip_file, output_path):
    with zipfile.ZipFile(zip_file, "r") as zip_ref:
        zip_ref.extractall(output_path)
        
def remove_files(train_dir, files_to_move):
    for filename in os.listdir(train_dir):
        file_path = os.path.join(train_dir, filename)
        if filename in files_to_move:
            if not os.path.exists(file_path):
                shutil.move(file_path, training_dir)
            else:
                os.remove(file_path)

zip_file = download(zipfile_url, root_dir, HUGGINGFACE_TOKEN)
extract_dataset(zip_file, unzip_to)
os.remove(zip_file)

files_to_move = (
    "meta_cap.json",
    "meta_cap_dd.json",
    "meta_lat.json",
    "meta_clean.json",
)

remove_files(train_data_dir, files_to_move)

# Bucketing and Latents Caching
This code will create buckets based on the `bucket_resolution` provided for multi-aspect ratio training, and then convert all images within the `train_data_dir` to latents.

In [None]:
# @title ## **3.4. Bucketing and Latents Caching**
import time

bucketing_json    = os.path.join(training_dir, "meta_lat.json")
metadata_json     = os.path.join(training_dir, "meta_clean.json")
bucket_resolution = 1024
mixed_precision   = "no" # choose between ["no", "fp16", "bf16"]
flip_aug          = False 

# Use `clean_caption` option to clean such as duplicate tags, `women` to `girl`, etc
clean_caption     = True 
# Use the `recursive` option to process subfolders as well
recursive         = True

metadata_config = {
    "_train_data_dir": train_data_dir,
    "_out_json": metadata_json,
    "recursive": recursive,
    "full_path": recursive,
    "clean_caption": clean_caption
}

bucketing_config = {
    "_train_data_dir": train_data_dir,
    "_in_json": metadata_json,
    "_out_json": bucketing_json,
    "_model_name_or_path": model_path,
    "recursive": recursive,
    "full_path": recursive,
    "flip_aug": flip_aug,
    "batch_size": 4,
    "max_data_loader_n_workers": 2,
    "max_resolution": f"{bucket_resolution}, {bucket_resolution}",
    "mixed_precision": mixed_precision,
}

def generate_args(config):
    args = ""
    for k, v in config.items():
        if k.startswith("_"):
            args += f'"{v}" '
        elif isinstance(v, str):
            args += f'--{k}="{v}" '
        elif isinstance(v, bool) and v:
            args += f"--{k} "
        elif isinstance(v, float) and not isinstance(v, bool):
            args += f"--{k}={v} "
        elif isinstance(v, int) and not isinstance(v, bool):
            args += f"--{k}={v} "
    return args.strip()

merge_metadata_args = generate_args(metadata_config)
prepare_buckets_args = generate_args(bucketing_config)

merge_metadata_command = f"python merge_all_to_metadata.py {merge_metadata_args}"
prepare_buckets_command = f"python prepare_buckets_latents.py {prepare_buckets_args}"

os.chdir(finetune_dir)
!{merge_metadata_command}
time.sleep(1)
!{prepare_buckets_command}


# Optimizer Config

1. For `optimizer_type`, use `Adafactor` optimizer. `RMSprop 8bit` or `Adagrad 8bit` may work. `AdamW 8bit` doesn't seem to work.
2. Choose between ["AdamW", "AdamW8bit", "Lion8bit", "Lion", "SGDNesterov", "SGDNesterov8bit", "DAdaptation(DAdaptAdamPreprint)", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptAdanIP", "DAdaptLion", "DAdaptSGD", "AdaFactor"]
3. Specify `optimizer_args` to add `additional` args for optimizer, e.g: `["weight_decay=0.6"]`
4. It's not recommended to Train Text Encoder for SDXL
5. `lr_scheduler` provides several methods to adjust the learning rate based on the number of epochs.
6. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "adafactor"]
7. Specify `lr_scheduler_num` with `num_cycles` value for `cosine_with_restarts` or `power` value for `polynomial`

In [None]:
import toml
import ast

optimizer_type = "AdaFactor"  
optimizer_args = "[ \"scale_parameter=False\", \"relative_step=False\", \"warmup_init=False\" ]"
learning_rate = 4e-7
train_text_encoder = False
lr_scheduler = "constant_with_warmup" 
lr_warmup_steps = 100
lr_scheduler_num = 0

if isinstance(optimizer_args, str):
    optimizer_args = optimizer_args.strip()
    if optimizer_args.startswith('[') and optimizer_args.endswith(']'):
        try:
            optimizer_args = ast.literal_eval(optimizer_args)
        except (SyntaxError, ValueError) as e:
            print(f"Error parsing optimizer_args: {e}\n")
            optimizer_args = []
    elif len(optimizer_args) > 0:
        print(f"WARNING! '{optimizer_args}' is not a valid list! Put args like this: [\"args=1\", \"args=2\"]\n")
        optimizer_args = []
    else:
        optimizer_args = []
else:
    optimizer_args = []

optimizer_config = {
    "optimizer_arguments": {
        "optimizer_type"          : optimizer_type,
        "learning_rate"           : learning_rate,
        "train_text_encoder"      : train_text_encoder,
        "max_grad_norm"           : 1.0,
        "optimizer_args"          : optimizer_args,
        "lr_scheduler"            : lr_scheduler,
        "lr_warmup_steps"         : lr_warmup_steps,
        "lr_scheduler_num_cycles" : lr_scheduler_num if lr_scheduler == "cosine_with_restarts" else None,
        "lr_scheduler_power"      : lr_scheduler_num if lr_scheduler == "polynomial" else None,
        "lr_scheduler_type"       : None,
        "lr_scheduler_args"       : None,
    },
}

print(toml.dumps(optimizer_config))


# Advanced Training Config
1. Specify `optimizer_state_path` to resume training with Optimizer State
2. You can't use both `noise_offset` and `multires_noise` at the same time
3. Uncomment if necessary


In [None]:
import toml

optimizer_state_path      = "" 
# noise_offset          = 0.1 
# multires_noise_iterations = 6 
# multires_noise_discount = 0.3
min_snr_gamma             = -1 

advanced_training_config = {
    "advanced_training_config": {
        "resume"                    : optimizer_state_path,
        # "noise_offset"              : noise_offset, 
        # "multires_noise_iterations" : multires_noise_iterations, 
        # "multires_noise_discount"   : multires_noise_discount, 
        "min_snr_gamma"             : min_snr_gamma if not min_snr_gamma == -1 else None,

    }
}

print(toml.dumps(advanced_training_config))

# Deployment Config

In [None]:
import toml

huggingface_repo_id = "sdxl_finetune"
huggingface_write_token = ""
huggingface_path_in_repo = ""
huggingface_repo_visibility = "private" # private or public
async_upload = True

deployment_config = {
    "save_to_hub_config": {
        "huggingface_repo_id"         : huggingface_repo_id,
        "huggingface_repo_type"       : "model", 
        "huggingface_path_in_repo"    : huggingface_path_in_repo, 
        "huggingface_token"           : huggingface_write_token,
        "async_upload"                : async_upload, 
        "huggingface_repo_visibility" : huggingface_repo_visibility,
    }
}
print(toml.dumps(deployment_config))

if huggingface_write_token == "":
    del deployment_config

# Training Config 
1. Get your `wandb_api_key` here: https://wandb.ai/settings
2. `cache_text_encoder_outputs` is the recommended parameter for SDXL training but if you enable it, `shuffle_caption` won't work
3. `min_timestep` and `max_timestep` can be used to train U-Net with different timesteps. The default values are 0 and 1000.
4. Sampler List: ["ddim", "pndm", "lms", "euler", "euler_a", "heun", "dpm_2", "dpm_2_a", "dpmsolver","dpmsolver++", "dpmsingle", "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"]

In [None]:
import toml
import os
import random
from subprocess import getoutput

# PROJECT CONFIG
project_name            = "sdxl_finetune"
wandb_api_key           = "" 
in_json                 = "/workspace/fine_tune/meta_lat.json"

# SDXL CONFIG
grad_checkpointing      = True
no_half_vae             = True 
cache_text_encoder_outputs = True
min_timestep            = 0 
max_timestep            = 1000

# DATASET CONFIG
num_repeats             = 1
resolution              = 1024
keep_tokens             = 0

# GENERAL CONFIG
max_train_steps         = 2500
train_batch_size        = 4
mixed_precision         = "fp16"
seed                    = -1

# SAVE OUTPUT AS
save_precision          = "fp16" 
save_every_n_steps      = 1000
save_optimizer_state    = False 
save_model_as           = "safetensors"

# SAMPLE PROMPT
prompt                  = "masterpiece, best quality, 1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt"
custom_negative_prompt  = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"
sample_interval         = 100 
sampler                 = "euler_a" 
logging_dir             = os.path.join(training_dir, "logs")

os.chdir(repo_dir)

prompt_config = {
    "prompt": {
        "negative_prompt" : negative_prompt if not custom_negative_prompt else custom_negative_prompt,
        "width"           : resolution,
        "height"          : resolution,
        "scale"           : 7,
        "sample_steps"    : 28,
        "subset"          : [
            {
                "prompt" : prompt,
            }
        ],
    }
}

train_config = {
    "sdxl_arguments": {
        "cache_text_encoder_outputs" : cache_text_encoder_outputs,
        "no_half_vae"                : no_half_vae,
        "min_timestep"               : min_timestep,
        "max_timestep"               : max_timestep,
        "shuffle_caption"            : True if not cache_text_encoder_outputs else False,
    },
    "model_arguments": {
        "pretrained_model_name_or_path" : model_path,
        "vae"                           : vae_path,
    },
    "dataset_arguments": {
        "debug_dataset"                 : False,
        "in_json"                       : in_json,
        "train_data_dir"                : train_data_dir,
        "dataset_repeats"               : num_repeats,
        "keep_tokens"                   : keep_tokens,
        "resolution"                    : str(resolution) + ',' + str(resolution),
        "caption_dropout_rate"          : 0,
        "caption_tag_dropout_rate"      : 0,
        "caption_dropout_every_n_epochs": 0,
        "color_aug"                     : False,
        "face_crop_aug_range"           : None,
        "token_warmup_min"              : 1,
        "token_warmup_step"             : 0,
    },
    "training_arguments": {
        "output_dir"                    : output_dir,
        "output_name"                   : project_name if project_name else "last",
        "save_precision"                : save_precision,
        "save_every_n_steps"            : save_every_n_steps,
        "save_n_epoch_ratio"            : None,
        "save_last_n_epochs"            : None,
        "save_state"                    : None,
        "save_last_n_epochs_state"      : None,
        "resume"                        : None,
        "train_batch_size"              : train_batch_size,
        "max_token_length"              : 225,
        "mem_eff_attn"                  : False,
        "xformers"                      : True,
        "max_train_steps"               : max_train_steps,
        "max_data_loader_n_workers"     : 8,
        "persistent_data_loader_workers": True,
        "seed"                          : seed if seed > 0 else None,
        "gradient_checkpointing"        : grad_checkpointing,
        "gradient_accumulation_steps"   : 1,
        "mixed_precision"               : mixed_precision,
    },
    "logging_arguments": {
        "log_with"          : "wandb" if wandb_api_key else "tensorboard",
        "log_tracker_name"  : project_name if wandb_api_key and not project_name == "last" else None,
        "logging_dir"       : logging_dir,
        "log_prefix"        : project_name if not wandb_api_key else None,
    },
    "sample_prompt_arguments": {
        "sample_every_n_steps"    : sample_interval,
        "sample_every_n_epochs"   : None,
        "sample_sampler"          : sampler,
    },
    "saving_arguments": {
        "save_model_as": "safetensors"
    },
}

def write_file(filename, contents):
    with open(filename, "w") as f:
        f.write(contents)

def eliminate_none_variable(config):
    for key in config:
        if isinstance(config[key], dict):
            for sub_key in config[key]:
                if config[key][sub_key] == "":
                    config[key][sub_key] = None
        elif config[key] == "":
            config[key] = None

    return config

try:
    train_config.update(optimizer_config)
except NameError:
    raise NameError("'optimizer_config' dictionary is missing. Please run  'Optimizer Config' cell.")

advanced_training_warning = False
try:
    train_config.update(advanced_training_config)
except NameError:
    advanced_training_warning = True
    pass

deployment_config_warning = False
try:
    train_config.update(deployment_config)
except NameError:
    deployment_config_warning = True
    pass

config_path         = os.path.join(config_dir, "config_file.toml")
prompt_path         = os.path.join(config_dir, "sample_prompt.toml")

config_str          = toml.dumps(eliminate_none_variable(train_config))
prompt_str          = toml.dumps(eliminate_none_variable(prompt_config))

write_file(config_path, config_str)
write_file(prompt_path, prompt_str)

print(config_str)

if advanced_training_warning:
    import textwrap
    error_message = "WARNING: This is not an error message, but the [advanced_training_config] dictionary is missing. Please run the 'Advanced Training Config' cell if you intend to use it, or continue to the next step."
    wrapped_message = textwrap.fill(error_message, width=80)
    print('\033[38;2;204;102;102m' + wrapped_message + '\033[0m\n')
    pass
    
if deployment_config_warning:
    import textwrap
    error_message = "WARNING: This is not an error message, but the [deployment_config] dictionary is missing. Please run the 'Deployment Training Config' cell if you intend to use it, or continue to the next step."
    wrapped_message = textwrap.fill(error_message, width=80)
    print('\033[38;2;204;102;102m' + wrapped_message + '\033[0m\n')
    pass

print(prompt_str)

# Start Training

In [None]:
import os
import toml

sample_prompt   = "/workspace/fine_tune/config/sample_prompt.toml"
config_file     = "/workspace/fine_tune/config/config_file.toml"

def read_file(filename):
    with open(filename, "r") as f:
        contents = f.read()
    return contents

def train(config):
    args = ""
    for k, v in config.items():
        if k.startswith("_"):
            args += f'"{v}" '
        elif isinstance(v, str):
            args += f'--{k}="{v}" '
        elif isinstance(v, bool) and v:
            args += f"--{k} "
        elif isinstance(v, float) and not isinstance(v, bool):
            args += f"--{k}={v} "
        elif isinstance(v, int) and not isinstance(v, bool):
            args += f"--{k}={v} "

    return args

accelerate_conf = {
    "config_file" : accelerate_config,
    "num_cpu_threads_per_process" : 1,
}

train_conf = {
    "sample_prompts"  : sample_prompt if os.path.exists(sample_prompt) else None,
    "config_file"     : config_file,
    "wandb_api_key"   : wandb_api_key if wandb_api_key else None,
}

accelerate_args = train(accelerate_conf)
train_args = train(train_conf)

final_args = f"accelerate launch {accelerate_args} sdxl_train.py {train_args}"

os.chdir(repo_dir)
!{final_args}

# Inference

In [None]:
import os
import math
from PIL import Image, ImageOps
from IPython.display import display

ckpt_path = "/workspace/pretrained_model/sd_xl_base_0.9.safetensors" 
prompt = "1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt" 
negative_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" 
output_path = "/workspace/tmp/" 
resolution = "1024,1024"
optimization = "scaled dot-product attention" # ["xformers", "scaled dot-product attention"]
conditional_resolution = "1024,1024"
steps = 28 
sampler = "euler_a" 
scale = 7 
seed = -1
images_per_prompt = 1
batch_size = 1 
clip_skip = 2

os.makedirs(output_path, exist_ok=True)

separators = ["*", "x", ","]

for separator in separators:
    if separator in resolution:
        width, height = [value.strip() for value in resolution.split(separator)]
        original_width, original_height = [value.strip() for value in conditional_resolution.split(separator)]
        break

config = {
    "prompt": prompt + " --n " + negative_prompt,
    "images_per_prompt": images_per_prompt,
    "outdir": output_path,
    "W": width,
    "H": height,
    "original_width": original_width,
    "original_height": original_height,
    "batch_size": batch_size,
    "vae_batch_size": 1,
    "no_half_vae": True,
    "steps": steps,
    "sampler": sampler,
    "scale": scale,
    "ckpt": ckpt_path,
    "vae": vae_path,
    "seed": seed if seed > 0 else None,
    "fp16": True,
    "sdpa": True if optimization == "scaled dot-product attention" else False,
    "xformers": True if optimization == "xformers" else False,
    "opt_channels_last": True,
    "clip_skip": clip_skip,
    "max_embeddings_multiples": 3,
}

def display_results(count):
    samples = os.listdir(output_path)
    samples.sort(reverse=True)
    samples = samples[:count]

    for sample in samples:
        if sample.endswith((".png", ".jpg")):
            image_path = os.path.join(output_path, sample)

            if os.path.exists(image_path):
                img = Image.open(image_path)
                img = img.resize((512, 512))  # Resize the image to 512x512 pixels
                display(img)
                
args = ""
for k, v in config.items():
    if k.startswith("_"):
        args += f'"{v}" '
    elif isinstance(v, str):
        args += f'--{k}="{v}" '
    elif isinstance(v, bool) and v:
        args += f"--{k} "
    elif isinstance(v, float) and not isinstance(v, bool):
        args += f"--{k}={v} "
    elif isinstance(v, int) and not isinstance(v, bool):
        args += f"--{k}={v} "

final_args = f"python sdxl_gen_img.py {args}"

os.chdir(repo_dir)
!{final_args}

display_results(batch_size)