In [None]:
# Test if running on colab
import sys
COLAB = 'google.colab' in sys.modules


if COLAB:
  !git clone https://github.com/kk-digital/kcg-ml-sd1p4
  kcg = '/content/kcg-ml-sd1p4'
else:
  kcg = os.getcwd()

In [None]:
# The variable pool™ (I know, I'm terrible at structuring notebooks..)

# Directories:
project_name = ""
root_dir = "/content/Loras"
deps_dir = os.path.join(root_dir, "deps")
repo_dir = os.path.join(root_dir, "kohya-trainer")

main_dir      = os.path.join(root_dir, "lora_training")
images_folder = os.path.join(main_dir, "dataset", project_name)
output_folder = os.path.join(main_dir, "output", project_name)
config_folder = os.path.join(main_dir, "config", project_name)
log_folder    = os.path.join(main_dir, "log")

# LoRa parameters:

## Optimizer and token config
optimizer = "AdamW8bit"
optimizer_args = None
continue_from_lora = ""
weighted_captions = False
adjust_tags = False
keep_tokens_weight = 1.0

## Stable Diffusion model to use for training
model_filename = "sd-v1-5-pruned-noema-fp16.safetensors"
model_file = os.path.join(kcg, "inputs/models", model_filename)

!aria2c "https://huggingface.co/hollowstrawberry/stable-diffusion-guide/resolve/main/models/sd-v1-5-pruned-noema-fp16.safetensors" -d /-o {model_file}
custom_model_is_based_on_sd2 = False

## Dataset parameters
resolution = 512
flip_aug = False
caption_extension = ".txt"
activation_tags = "1"
keep_tokens = int(activation_tags)

## Step parameters (iterations to train for)
num_repeats = 10
max_train_epochs = 10
max_train_steps = None
save_every_n_epochs = 1
keep_only_last_n_epochs = 10

## Learning parameters
train_batch_size = 3
unet_lr = 5e-4
text_encoder_lr = 1e-4
lr_scheduler = "cosine_with_restarts"
lr_scheduler_num_cycles = 3
lr_warmup_ratio = 0.05
lr_warmup_steps = 0
min_snr_gamma_value = 5.0

## Network parameters
lora_type = "LoRA"
network_dim = 32
network_alpha = round(network_dim/2)
network_module = "networks.lora"
network_args = None

In [None]:
# Setup directories
import os

for dir in (main_dir, deps_dir, repo_dir, log_folder, output_folder, config_folder):
  if not os.path.exists(dir):
    os.makedirs(dir)

config_file = os.path.join(config_folder, "training_config.toml")
dataset_config_file = os.path.join(config_folder, "dataset_config.toml")
accelerate_config_file = os.path.join(repo_dir, "accelerate_config/config.yaml")

In [None]:
COMMIT = "5050971ac687dca70ba0486a583d283e8ae324e2"

os.chdir(root_dir)
!git clone https://github.com/kohya-ss/sd-scripts {repo_dir}
os.chdir(repo_dir)
if COMMIT:
  !git reset --hard {COMMIT}

if COLAB:
  !wget https://raw.githubusercontent.com/hollowstrawberry/kohya-colab/main/requirements.txt -q -O requirements.txt
  !apt -y update -qq
  !pip -q install --upgrade -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118
else:
  !mv /content/kcg-ml-sd1p4/train/requirements.txt ./
  !pip -q install --upgrade -r requirements.txt 

!apt install -y aria2 -qq

# patch kohya for minor stuff
if COLAB:
  !sed -i "s@cpu@cuda@" library/model_util.py # low ram

!sed -i 's/from PIL import Image/from PIL import Image, ImageFile\nImageFile.LOAD_TRUNCATED_IMAGES=True/g' library/train_util.py # fix truncated jpegs error
!sed -i 's/{:06d}/{:02d}/g' library/train_util.py # make epoch names shorter
!sed -i 's/model_name + "."/model_name + "-{:02d}.".format(num_train_epochs)/g' train_network.py # name of the last epoch will match the rest

from accelerate.utils import write_basic_config
if not os.path.exists(accelerate_config_file):
  write_basic_config(save_location=accelerate_config_file)

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["BITSANDBYTES_NOWELCOME"] = "1"  
os.environ["SAFETENSORS_FAST_GPU"] = "1"

fatal: destination path '/content/Loras/kohya-trainer' already exists and is not an empty directory.
HEAD is now at 5050971 Merge pull request #388 from kohya-ss/dev
4 packages can be upgraded. Run 'apt list --upgradable' to see them.
  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for library (setup.py) ... [?25l[?25hdone


In [None]:
import os, sys
import re
import toml
import shutil
import zipfile
from time import time
sys.path.append(os.path.join(kcg, "scripts"))

BETTER_EPOCH_NAMES = True
LOAD_TRUNCATED_IMAGES = True

if not save_every_n_epochs:
  save_every_n_epochs = max_train_epochs
if not keep_only_last_n_epochs:
  keep_only_last_n_epochs = max_train_epochs

# Serious code goes here

print("Done")

def create_config():
  global dataset_config_file, config_file, model_file

  config_dict = {
    "additional_network_arguments": {
      "unet_lr": unet_lr,
      "text_encoder_lr": text_encoder_lr,
      "network_dim": network_dim,
      "network_alpha": network_alpha,
      "network_module": network_module,
      "network_args": network_args,
      "network_train_unet_only": True if text_encoder_lr == 0 else None,
      "network_weights": continue_from_lora if continue_from_lora else None
    },
    "optimizer_arguments": {
      "learning_rate": unet_lr,
      "lr_scheduler": lr_scheduler,
      "lr_scheduler_num_cycles": lr_scheduler_num_cycles if lr_scheduler == "cosine_with_restarts" else None,
      "lr_scheduler_power": None,
      "lr_warmup_steps": lr_warmup_steps if lr_scheduler != "constant" else None,
      "optimizer_type": optimizer,
      "optimizer_args": optimizer_args if optimizer_args else None,
    },
    "training_arguments": {
      "max_train_steps": max_train_steps,
      "max_train_epochs": max_train_epochs,
      "save_every_n_epochs": save_every_n_epochs,
      "save_last_n_epochs": keep_only_last_n_epochs,
      "train_batch_size": train_batch_size,
      "noise_offset": None,
      "clip_skip": 2,
      "min_snr_gamma": min_snr_gamma_value,
      "weighted_captions": weighted_captions,
      "seed": 42,
      "max_token_length": 225,
      "xformers": True,
      "lowram": COLAB,
      "max_data_loader_n_workers": 8,
      "persistent_data_loader_workers": True,
      "save_precision": "fp16",
      "mixed_precision": "fp16",
      "output_dir": output_folder,
      "logging_dir": log_folder,
      "output_name": project_name,
      "log_prefix": project_name,
      "save_state": False,
      "save_last_n_epochs_state": None,
      "resume": None
    },
    "model_arguments": {
      "pretrained_model_name_or_path": model_file,
      "v2": custom_model_is_based_on_sd2,
      "v_parameterization": True if custom_model_is_based_on_sd2 else None,
    },
    "saving_arguments": {
      "save_model_as": "safetensors",
    },
    "dreambooth_arguments": {
      "prior_loss_weight": 1.0,
    },
    "dataset_arguments": {
      "cache_latents": True,
    },
  }

  for key in config_dict:
    if isinstance(config_dict[key], dict):
      config_dict[key] = {k: v for k, v in config_dict[key].items() if v is not None}

  with open(config_file, "w") as f:
    f.write(toml.dumps(config_dict))
  print(f"\n Config saved to {config_file}")

  dataset_config_dict = {
    "general": {
      "resolution": resolution,
      "keep_tokens": keep_tokens,
      "flip_aug": flip_aug,
      "caption_extension": caption_extension,
      "enable_bucket": True,
      "bucket_reso_steps": 64,
      "bucket_no_upscale": False,
      "min_bucket_reso": 320 if resolution > 640 else 256,
      "max_bucket_reso": 1280 if resolution > 640 else 1024,
    },
    "datasets": [
      {
        "subsets": [
          {
            "num_repeats": num_repeats,
            "image_dir": images_folder,
            "class_tokens": None if caption_extension else project_name
          }
        ]
      }
    ]
  }

  for key in dataset_config_dict:
    if isinstance(dataset_config_dict[key], dict):
      dataset_config_dict[key] = {k: v for k, v in dataset_config_dict[key].items() if v is not None}

  with open(dataset_config_file, "w") as f:
    f.write(toml.dumps(dataset_config_dict))
  print(f"Dataset config saved to {dataset_config_file}")

def main():

  create_config()
  
  print("\nStarting trainer...\n")
  os.chdir(repo_dir)
  
  !accelerate launch --config_file={accelerate_config_file} --num_cpu_threads_per_process=1 train_network.py --dataset_config={dataset_config_file} --config_file={config_file}


main()


In [None]:
folder = "/content/drive/MyDrive/Loras"

import os

tree = {}
exclude = ("_logs", "/output")
for i, (root, dirs, files) in enumerate(os.walk(folder, topdown=True)):
  dirs[:] = [d for d in dirs if all(ex not in d for ex in exclude)]
  images = len([f for f in files if f.lower().endswith((".png", ".jpg", ".jpeg"))])
  captions = len([f for f in files if f.lower().endswith(".txt")])
  others = len(files) - images - captions
  path = root[folder.rfind("/")+1:]
  tree[path] = None if not images else f"{images:>4} images | {captions:>4} captions |"
  if tree[path] and others:
    tree[path] += f" {others:>4} other files"

pad = max(len(k) for k in tree)
print("\n".join(f"📁{k.ljust(pad)} | {v}" for k, v in tree.items() if v))
