# KOHYA TRAINER XL 

## Install Kohya Trainer

In [None]:
!nvidia-smi

In [None]:
import os, shutil
from pathlib import Path

root_dir = Path("/workspace")
repo_dir = root_dir / "kohya-trainer"
training_dir = root_dir / "fine_tune"
model_dir = root_dir / "model"
vae_dir = root_dir / "vae"
lora_dir = root_dir / "network_weight"
config_dir = training_dir / "config"
output_dir = training_dir / "outputs"
tools_dir = repo_dir / "tools"
finetune_dir = repo_dir / "finetune"
accelerate_config = repo_dir / "accelerate_config" / "config.yaml"

# repo_url = "https://github.com/qaneel/kohya-trainer"
repo_url = "https://github.com/kohya-ss/sd-scripts"

HUGGINGFACE_TOKEN = ""

def clone_repo(url, dir, branch="main"):
    dir = Path(dir)
    if not dir.exists():
        !git clone -b {branch} {url} {dir}

def install_dependencies():
    !apt update -yqq
    !apt install aria2 -yqq
    !pip install -q --upgrade xformers==0.0.21 accelerate==0.23.0 transformers==4.30.2 diffusers[torch]==0.21.2 ftfy==6.1.1 opencv-python==4.7.0.68 einops==0.6.0 pytorch-lightning==1.9.0 safetensors==0.3.1 toml==0.10.2 voluptuous==0.13.1 huggingface-hub==0.15.1 wandb==0.15.7 invisible-watermark==0.2.0 open-clip-torch==2.20.0 tensorflow==2.10.1 -e .
    # !pip install -q --upgrade -r requirements.txt
    
    !rm $accelerate_config
    from accelerate.utils import write_basic_config

    write_basic_config(save_location=accelerate_config)

    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["PYTHONWARNINGS"] = "ignore"

# initialize function

def get_filename(url, bearer_token):
    headers = {"Authorization": f"Bearer {bearer_token}"}
    with requests.get(url, headers=headers, stream=True) as response:
        response.raise_for_status()

        if 'content-disposition' in response.headers:
            content_disposition = response.headers['content-disposition']
            filename = re.findall('filename="?([^"]+)"?', content_disposition)[0]
        else:
            url_path = urlparse(url).path
            filename = unquote(Path(url_path).name)

    return filename

def parse_args(config, aria=False):
    args = []

    for k, v in config.items():
        if k.startswith("_"):
            args.append(f"{v}")
        elif isinstance(v, str) and v is not None:
            if aria:
                args.append(f"--{k}={v}")
            else:
                args.append(f"--{k}='{v}'")
        elif isinstance(v, bool) and v:
            args.append(f"--{k}")
        elif isinstance(v, float) and not isinstance(v, bool):
            args.append(f"--{k}={v}")
        elif isinstance(v, int) and not isinstance(v, bool):
            args.append(f"--{k}={v}")

    return args

def aria2_download(dir, filename, url, token):
    user_header = f"Authorization: Bearer {token}"

    aria2_config = {
        "console-log-level"         : "error",
        "summary-interval"          : 10,
        "header"                    : user_header if "huggingface.co" in url else None,
        "continue"                  : True,
        "max-connection-per-server" : 16,
        "min-split-size"            : "1M",
        "split"                     : 16,
        "dir"                       : str(dir),
        "out"                       : filename,
        "_url"                      : url,
    }
    aria2_args = parse_args(aria2_config, aria=True)
    subprocess.run(["aria2c", *aria2_args])
    
def download(url, dst, token):
    filename = get_filename(url, token)
    dst = Path(dst)
    filepath = dst / filename

    if url.startswith("/workspace"):
        return url
    else:
        if "/blob/" in url:
            url = url.replace("/blob/", "/resolve/")
                
        aria2_download(dst, filename, url, token)

    return filepath
    
def main():
    os.chdir(root_dir)
    clone_repo(repo_url, repo_dir)
    os.chdir(repo_dir)
    for dir in [training_dir, config_dir, model_dir, vae_dir, output_dir]:
        dir.mkdir(parents=True, exist_ok=True)
    install_dependencies()

main()

## Download SDXL

In [None]:
import os, re, requests, subprocess
from urllib.parse import urlparse, unquote
from pathlib import Path

model_path = Path()
vae_path = Path()

SDXL_MODEL_URL    = "/workspace/fine_tune/outputs/animagine-xl-3.0/animagine-xl-3.0-step00062000.safetensors"
SDXL_VAE_URL      = "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl_vae.safetensors"

def main():
    global model_path, vae_path
    
    os.chdir(root_dir)

    download_targets = {
        "model": (SDXL_MODEL_URL, model_dir),
        "vae": (SDXL_VAE_URL, vae_dir),
    }
    selected_files = {}

    for target, (url, dst) in download_targets.items():
        if url.startswith("/workspace"):
            selected_files[target] = Path(url)
        else:
            selected_files[target] = download(url, dst, HUGGINGFACE_TOKEN)

    model_path = selected_files.get("model", model_path)
    vae_path = selected_files.get("vae", vae_path)
    
    for category, path in {"model": model_path, "vae": vae_path}.items():
        if path and path.exists():
            print(f"Selected {category}: {path}")

main()
    

## Directory Config

In [None]:
import os

project_name = "animagine-xl-3.0"
train_data_dir = root_dir / "train_data" / "animagine-xl-3.0"

train_data_dir.mkdir(parents=True, exist_ok=True)
print(f"Your train data directory : {train_data_dir}")

## Data Gathering

## Unzip Dataset

In [None]:
import os, zipfile, shutil
from pathlib import Path

src_url = "https://huggingface.co/datasets/Linaqruf/sdxl-dataset/resolve/main/aesthetic-beta-raw.zip"
dst_dir = ""

if not dst_dir:
    dst_dir = train_data_dir

dst_dir = Path(dst_dir)
dst_dir.mkdir(parents=True, exist_ok=True)

def extract_zipfile(zip_file, output_path):
    with zipfile.ZipFile(zip_file, "r") as zip_ref:
        zip_ref.extractall(output_path)

def main():
    zip_file = download(src_url, root_dir, HUGGINGFACE_TOKEN)
    extract_zipfile(zip_file, dst_dir)
    os.remove(zip_file)

main()

## WD Tagger

In [None]:
import os

os.chdir(finetune_dir)

models = ["moat", "convnextv2", "swinv2", "convnext", "vit"]
model = models[1]

tagger_config = {
    "_train_data_dir" : train_data_dir,
    "batch_size" : 24,
    "repo_id" : f"SmilingWolf/wd-v1-4-{model}-tagger-v2",
    "recursive" : True,
    "remove_underscore" : True,
    "general_threshold" : 0.35,
    "character_threshold" : 1,
    "caption_extension" : ".txt",
    "max_data_loader_n_workers" : 8, 
    "force_download" : True, 
    "undesired_tags" : ""
}

tagger_args = ' '.join(parse_args(tagger_config))
final_args = f"python tag_images_by_wd14_tagger.py {tagger_args}"

os.chdir(finetune_dir)
! {final_args}

## Aspect Ratio Bucketing and Caching latents

In [None]:
import os

raw_metadata = training_dir / f"{project_name}_clean.json"
processed_metadata = training_dir / f"{project_name}_lat.json"
recursive = True
resolution = 1024

metadata_config = {
    "_train_data_dir": train_data_dir,
    "_out_json": raw_metadata,
    "recursive": recursive,
    "full_path": recursive,
}

bucketing_config = {
    "_train_data_dir": train_data_dir,
    "_in_json": raw_metadata,
    "_out_json": processed_metadata,
    "_model_name_or_path": vae_path if vae_path else model_path,
    "recursive": recursive,
    "full_path": recursive,
    "flip_aug": False,
    "max_bucket_reso" : int(resolution * 2),
    "min_bucket_reso" : int(resolution / 2),
    "bucket_no_upscale" : False, 
    "bucket_reso_steps" : 64, 
    "batch_size": 8,
    "skip_existing": True,
    "max_data_loader_n_workers": 1,
    "max_resolution": ", ".join([str(resolution)] * 2),
    "mixed_precision": "fp16",
}

merge_metadata_args = ' '.join(parse_args(metadata_config))
prepare_buckets_args = ' '.join(parse_args(bucketing_config))

merge_metadata_command = f"python merge_all_to_metadata.py {merge_metadata_args}"
prepare_buckets_command = f"python prepare_buckets_latents.py {prepare_buckets_args}"

os.chdir(finetune_dir)
if not Path("merge_all_to_metadata.py").exists():
    !wget https://raw.githubusercontent.com/qaneel/kohya-trainer/main/finetune/merge_all_to_metadata.py
# !{merge_metadata_command}
!{prepare_buckets_command}

# 20:23 18696
# 20:28 21064 
# 20:33 23168
# 20:38 25576 
# 20:43 28032  

## Optimizer Config

In [None]:
import toml

learning_rate = 7.5e-6

optimizer_config = {
    "optimizer_arguments": {
        "optimizer_type" : "AdaFactor",
        "learning_rate" : learning_rate,
        "train_text_encoder" : True,
        "learning_rate_te1" : learning_rate / 2,
        "learning_rate_te2" : learning_rate / 2,
        "optimizer_args" : ['scale_parameter=False', 'relative_step=False', 'warmup_init=False'],
        "lr_scheduler" : "constant_with_warmup",
        "lr_warmup_steps" : 100,
        "lr_scheduler_num_cycles" : None, # cosine_with_restarts
        "lr_scheduler_power" : None, # polynomial
        "lr_scheduler_type" : None,
        "lr_scheduler_args" : None,
        "max_grad_norm" : 0
    },
}

print(toml.dumps(optimizer_config))

## Advanced Training Config
1. Specify `optimizer_state_path` to resume training with Optimizer State
2. You can't use both `noise_offset` and `multires_noise` at the same time


In [None]:
import toml

optimizer_state_path      = "/workspace/fine_tune/outputs/animagine-xl-3.0/animagine-xl-3.0-step00062000-state" 

advanced_training_config = {
    "advanced_training_config": {
        "resume" : optimizer_state_path,
        "resume_from_huggingface": False,
        # "noise_offset" : 0.0357,
        # "adaptive_noise_scale" : 0.00357,
        # "multires_noise_iterations" : 6, 
        # "multires_noise_discount" : 0.3, 
        # "min_snr_gamma" : 5
    }
}

print(toml.dumps(advanced_training_config))

## Deployment Config

In [None]:
import toml
from datetime import datetime
WRITE_TOKEN = ""

current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")

deployment_config = {
    "save_to_hub_config": {
        "huggingface_repo_id" : "",
        "huggingface_repo_type" : "model", 
        "huggingface_path_in_repo" : f"{project_name}_{current_datetime}",
        # "resume_from_huggingface"
        "huggingface_token" : WRITE_TOKEN,
        "async_upload" : True,
        "save_state_to_huggingface" : True,
        "huggingface_repo_visibility" : "private",
    }
}
print(toml.dumps(deployment_config))

if WRITE_TOKEN == "":
    del deployment_config

# Training Config 
1. Get your `wandb_api_key` here: https://wandb.ai/settings


In [None]:
import toml

wandb_api_key = "" 

resolution = 1024

prompt_config = {
    "prompt": {
        "negative_prompt" : "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
        "scale"           : 10,
        "sample_steps"    : 28,
        "subset"          : [
            {
                "prompt" : "1girl, hoshimachi suisei, hololive, looking at viewer, upper body, outdoors, night, masterpiece, best quality",
                "width"  : 896,
                "height" : 1152,                
            },
        ],
    }
}

train_config = {
    "sdxl_arguments": {
        "cache_text_encoder_outputs" : False,
        "no_half_vae" : False,
        "min_timestep" : 0,
        "max_timestep" : 1000,
    },
    "model_arguments": {
        "pretrained_model_name_or_path" : str(model_path),
        "vae" : str(vae_path),
    },
    "dataset_arguments": {
        "shuffle_caption" : True,
        "debug_dataset" : False,
        "in_json" : str(training_dir / f"{project_name}_lat.json"),
        "train_data_dir" : str(train_data_dir),
        "dataset_repeats" : 1,
        "keep_tokens" : None,
        "keep_tokens_separator" : "|||",
        "resolution" : ", ".join([str(resolution)] * 2),
        "caption_dropout_rate" : 0,
        "caption_tag_dropout_rate" : 0,
        "caption_dropout_every_n_epochs": 0,
        "token_warmup_min" : 1,
        "token_warmup_step" : 0,
    },
    "training_arguments": {
        "output_dir" : str(output_dir / project_name),
        "output_name" : project_name,
        "save_precision" : "fp16",
        # "save_every_n_epochs" : 1,
        "save_every_n_steps" : 500,
        "save_n_epoch_ratio" : None,
        # "save_last_n_epochs" : None,
        "save_last_n_steps" : True,
        "save_state" : True,
        # "save_last_n_epochs_state" : True,
        "save_last_n_steps_state" : True,
        "train_batch_size" : 48,
        "max_token_length" : 225,
        "mem_eff_attn" : False,
        "xformers" : True,
        "sdpa" : False, 
        # "max_train_epochs" : 10,
        "max_train_steps": 132590 - 41259 - 62000,
        "max_data_loader_n_workers" : 8,
        "persistent_data_loader_workers": True,
        "seed" : None,
        "gradient_checkpointing" : True,
        "gradient_accumulation_steps" : 1,
        "mixed_precision" : "fp16",
    },
    "logging_arguments": {
        "log_with" : "wandb",
        "log_tracker_name" : project_name,
        "logging_dir" : str(training_dir / "logs"),
    },
    "sample_prompt_arguments": {
        "sample_every_n_steps" : 100,
        "sample_every_n_epochs" : None,
        "sample_sampler" : "euler_a",
    },
    "saving_arguments": {
        "save_model_as": "safetensors"
    },
}

def write_file(filename, contents):
    with open(filename, "w") as f:
        f.write(contents)

def eliminate_none_variable(config):
    for key in config:
        if isinstance(config[key], dict):
            for sub_key in config[key]:
                if config[key][sub_key] == "":
                    config[key][sub_key] = None
        elif config[key] == "":
            config[key] = None

    return config

try:
    train_config.update(optimizer_config)
except NameError:
    raise NameError("'optimizer_config' dictionary is missing. Please run  'Optimizer Config' cell.")

advanced_training_warning = False
try:
    train_config.update(advanced_training_config)
except NameError:
    advanced_training_warning = True
    pass

deployment_config_warning = False
try:
    train_config.update(deployment_config)
except NameError:
    deployment_config_warning = True
    pass

config_path         = config_dir / f"{project_name}_config_file.toml"
prompt_path         = config_dir / f"{project_name}_sample_prompt.toml"

config_str          = toml.dumps(eliminate_none_variable(train_config))
prompt_str          = toml.dumps(eliminate_none_variable(prompt_config))

write_file(config_path, config_str)
write_file(prompt_path, prompt_str)

print(config_str)

if advanced_training_warning:
    import textwrap
    error_message = "WARNING: This is not an error message, but the [advanced_training_config] dictionary is missing. Please run the 'Advanced Training Config' cell if you intend to use it, or continue to the next step."
    wrapped_message = textwrap.fill(error_message, width=80)
    print('\033[38;2;204;102;102m' + wrapped_message + '\033[0m\n')
    pass
    
if deployment_config_warning:
    import textwrap
    error_message = "WARNING: This is not an error message, but the [deployment_config] dictionary is missing. Please run the 'Deployment Training Config' cell if you intend to use it, or continue to the next step."
    wrapped_message = textwrap.fill(error_message, width=80)
    print('\033[38;2;204;102;102m' + wrapped_message + '\033[0m\n')
    pass

print(prompt_str)

## Start Training

In [None]:
import os
import toml

sample_prompt   = f"/workspace/fine_tune/config/{project_name}_sample_prompt.toml"
config_file     = f"/workspace/fine_tune/config/{project_name}_config_file.toml"

script_names = ["sdxl_train.py", "sdxl_train_network.py"]
script_name = script_names[0]

accelerate_conf = {
    "config_file" : str(accelerate_config),
    "num_cpu_threads_per_process" : 1,
    # "num_processes" : 2, 
    # "multi_gpu" : True,
    # "num_machines" : 1, 
    # "gpu_ids" : "0,1"
}

train_conf = {
    "sample_prompts"  : sample_prompt if os.path.exists(sample_prompt) else None,
    "config_file"     : config_file,
    "wandb_api_key"   : wandb_api_key if wandb_api_key else None,
}

train_args = ' '.join(parse_args(train_conf))
accelerate_args = ' '.join(parse_args(accelerate_conf))

final_args = f"accelerate launch {accelerate_args} {script_name} {train_args}"

os.chdir(repo_dir)
!{final_args}
print(final_args)

In [None]:
import time

def sleep_with_countdown(interval):
    for remaining in range(interval, 0, -1):
        mins, secs = divmod(remaining, 60)
        timer = '{:02d}:{:02d}'.format(mins, secs)
        print(timer, end="\r")
        time.sleep(1)

# Set the interval for 10 minutes (600 seconds)
interval = 600

# Display the countdown while the program is paused
print("Program is paused. Time remaining:")
sleep_with_countdown(interval)

print("Program resumes after countdown.")

RUNPOD_POD_ID = "kg58rpik9lk8pi"
!runpodctl remove pod $RUNPOD_POD_ID


In [None]:
!pip install huggingface_hub