# ⭐ Lora Trainer by Hollowstrawberry

Here is a [guide to using this colab](https://civitai.com/models/22530). It will help you make a dataset quickly and using it to train a Lora.

This is based on the work of [Kohya-ss and Linaqruf](https://colab.research.google.com/github/Linaqruf/kohya-trainer/blob/main/kohya-LoRA-dreambooth.ipynb). Thank you!


|Colab|English|Spanish|
|:--|:-:|:-:|
| 📊 **Dataset Maker** | <a target="_blank" href="https://colab.research.google.com/github/hollowstrawberry/kohya-colab/blob/main/Dataset_Maker.ipynb"> <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> </a> | [🇪🇸](https://colab.research.google.com/github/hollowstrawberry/kohya-colab/blob/main/Spanish_Dataset_Maker.ipynb) |
| ⭐ **Lora Trainer** | <a target="_blank" href="https://colab.research.google.com/github/hollowstrawberry/kohya-colab/blob/main/Lora_Trainer.ipynb"> <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> </a> | [🇪🇸](https://colab.research.google.com/github/hollowstrawberry/kohya-colab/blob/main/Spanish_Lora_Trainer.ipynb) |

In [None]:
#@title ## 🚩 Start Here
import os
import re
import toml
import shutil
import zipfile
from google.colab import drive
from google.colab import output as console
from IPython.display import Markdown, display

if "model_url" in globals():
  old_model_url = model_url
else:
  old_model_url = None
if "dependencies_installed" not in globals():
  dependencies_installed = False
if "model_file" not in globals():
  model_file = None

#@markdown ### ▶️ Setup
#@markdown Your project name will be the same as the folder containing your images. Spaces aren't allowed.
project_name = "" #@param {type:"string"}
#@markdown Decide the model that will be downloaded and used for training. The base models produce the cleanest and most consistent results. You can instead specify a custom model if you're really sure.
training_model = "Anime (animefull-final-pruned-fp16.safetensors)" #@param ["Anime (animefull-final-pruned-fp16.safetensors)", "Photorealism (sd-v1-5-pruned-noema-fp16.safetensors)"]
custom_model_url = "" #@param {type:"string"}

if custom_model_url:
  model_url = custom_model_url
elif "Anime" in training_model:
  model_url = "https://huggingface.co/hollowstrawberry/stable-diffusion-guide/resolve/main/models/animefull-final-pruned-fp16.safetensors"
else:
  model_url = "https://huggingface.co/hollowstrawberry/stable-diffusion-guide/resolve/main/models/sd-v1-5-pruned-noema-fp16.safetensors"

#@markdown ### ▶️ Files <p>
#@markdown If you used [my dataset maker](https://colab.research.google.com/github/hollowstrawberry/kohya-colab/blob/main/Dataset_Maker.ipynb#scrollTo=-rdgF2AWLS2h), you're ready to go. Otherwise, create a folder in your Google Drive like this: `lora_training/datasets/project_name` and fill it with your images and their descriptions. You may use the Extras at the bottom to extract a zip file. <p>
#@markdown Resolution of 512 is standard for Stable Diffusion 1.5. Images will be automatically scaled while training, you don't need to crop or resize anything.
resolution = 512 #@param {type:"slider", min:512, max:1024, step:256}
#@markdown This option will train your images both normally and flipped, for no extra cost, to learn more from them. Turn it on specially if you have less than 20 images. <p> 
#@markdown **Turn it off if you care about asymmetrical elements in your Lora**.
flip_aug = True #@param {type:"boolean"}
#@markdown If you have an activation tag at the start of every text file, increase `keep_tokens` to 1.
keep_tokens = 0 #@param {type:"slider", min:0, max:5, step:1}
#@markdown If your text files use anime tags keep this active.
shuffle_caption = True #@param {type:"boolean"}
caption_extension = ".txt"

#@markdown ### ▶️ Steps <p>
#@markdown Your images will repeat this number of times during training. I recommend that your images multiplied by their repeats is between 200 and 400.
num_repeats = 10 #@param {type:"number"}
#@markdown One epoch is a number of training steps equal to: your number of images multiplied by their repeats, divided by batch size. <p>
#@markdown More epochs will give your Lora more time to learn and more options for you to test. If you followed the rest of the instructions, I recommend 10 to 30 epochs which would result in 1000 to 6000 total training steps. You'll see your total steps when the training begins.
max_train_epochs = 15 #@param {type:"number"}
save_every_n_epochs = 1 #@param {type:"number"}
if save_every_n_epochs < 1:
  save_every_n_epochs = max_train_epochs
#@markdown Increasing the batch size may help for lots of images, but you can leave it as is.
train_batch_size = 2 #@param {type:"slider", min:1, max:8, step:1}

#@markdown ### ▶️ Training
#@markdown The learning rate is the most important for your results. If you want to train slower with lots of images, or if your dim and alpha are high, move the unet to 1e-4 or lower. <p>
#@markdown The text encoder helps your Lora learn concepts slightly better. It is recommended to make it half or a fifth of the unet. If you're training a style you may set it to 0.
unet_lr = 5e-4 #@param {type:"number"}
text_encoder_lr = 1e-4 #@param {type:"number"}
#@markdown The scheduler is the algorithm that guides the learning rate. If you're not sure, pick `constant` and ignore the number. I personally recommend `cosine_with_restarts` with 3 restarts.
lr_scheduler = "cosine_with_restarts" #@param ["constant", "cosine", "cosine_with_restarts", "constant_with_warmup", "linear", "polynomial"]
lr_scheduler_number = 3 #@param {type:"number"}
lr_scheduler_num_cycles = lr_scheduler_number if lr_scheduler == "cosine_with_restarts" else 0
lr_scheduler_power = lr_scheduler_number if lr_scheduler == "polynomial" else 0
#@markdown Steps spent "warming up" the learning rate during training for efficiency. I recommend leaving it at 5%.
lr_warmup_ratio = 0.05 #@param {type:"slider", min:0.0, max:0.5, step:0.01}
lr_warmup_steps = 0
#@markdown More dim means larger Lora, it can hold more information but can also hold more garbage. A dim between 8-32 is recommended. The standard used to be 128, but it's completely overkill, as long as you increase the learning rate to preserve the detail. <p>
#@markdown Alpha is recommended to be equal or half the dim, or 1. <p>
#@markdown You can leave both as is.
network_dim = 16 #@param {type:"slider", min:1, max:128, step:1}
network_alpha = 8 #@param {type:"slider", min:1, max:128, step:1}

#@markdown ### ▶️ Lora Type
#@markdown LoCon and LoHa are new types of LoRA with different/expanded capacity for learning. If you want to experiment with them go ahead, otherwise **don't worry about it!** <p>
#@markdown LoCons are said to be great with artstyles. If you select LoRA then the following 2 values do nothing. More info [here](https://github.com/KohakuBlueleaf/Lycoris). You'll need [this extension](https://github.com/KohakuBlueleaf/a1111-sd-webui-locon) to use them in webui. <p>
#@markdown Recommended values (from lycoris repo):

#@markdown | type | network_dim | network_alpha | conv_dim | conv_alpha |
#@markdown | :---: | :---: | :---: | :---: | :---: |
#@markdown | LoRA | 32 | 16 | - | - |
#@markdown | LoCon | 16 | 8 | 8 | 1 |
#@markdown | LoHa | 8 | 4 | 4 | 1 |

lora_type = "LoRA" #@param ["LoRA", "LoCon Kohya", "LoCon Lycoris", "LoHa Lycoris"]
conv_dim = 8 #@param {type:"slider", min:1, max:64, step:1}
conv_alpha = 1 #@param {type:"slider", min:1, max:64, step:1}
conv_compression = False #@param {type:"boolean"}

network_module = "lycoris.kohya" if "Lycoris" in lora_type else "networks.lora"
network_args = None if lora_type == "LoRA" else [
  f"conv_dim={conv_dim}",
  f"conv_alpha={conv_alpha}",
]
if "Lycoris" in lora_type:
  network_args.append(f"algo={'loha' if 'LoHa' in network_args else 'lora'}")
  network_args.append(f"disable_conv_cp={str(not conv_compression)}")

#@markdown ### ▶️ Ready
#@markdown You can now run this cell to cook your Lora. Good luck! <p>
#markdown Save additional data equaling ~500 MB allowing you to resume training later.
save_state = False #param {type:"boolean"}
#markdown Resume training if a save state is found.
resume = False #param {type:"boolean"}


# 👩‍💻 Cool code goes here

root_dir = "/content"
deps_dir = os.path.join(root_dir,"deps")
repo_dir = os.path.join(root_dir,"kohya-trainer")

main_dir = os.path.join(root_dir,"drive/MyDrive/lora_training")
config_dir = os.path.join(main_dir,"config")
datasets_dir = os.path.join(main_dir,"datasets")
output_dir = os.path.join(main_dir,"output")
logging_dir = os.path.join(main_dir,"log")

images_folder = os.path.join(datasets_dir, project_name)
output_folder = os.path.join(output_dir, project_name)
config_folder = os.path.join(config_dir, project_name)

accelerate_config = os.path.join(repo_dir, "accelerate_config/config.yaml")
tools_dir = os.path.join(repo_dir,"tools")
finetune_dir = os.path.join(repo_dir,"finetune")

images = None
dataset_config_file = None
config_file = None
  
def clone_repo():
  os.chdir(root_dir)
  !git clone https://github.com/Linaqruf/kohya-trainer {repo_dir}
  os.chdir(repo_dir)
  !git reset --hard 86de685a8c37e60a610d08cbece3da6b3a553bc0

def install_ubuntu_deps(url, name, dst):
    os.chdir(repo_dir)
    !wget -q --show-progress {url}
    with zipfile.ZipFile(name, "r") as deps:
        deps.extractall(dst)
    !dpkg -i {dst}/*
    os.remove(name)
    shutil.rmtree(dst)

def install_dependencies():
  os.chdir(repo_dir)
  !pip -q install --upgrade -r requirements.txt
  !pip install -q xformers=="0.0.17.dev476"
  !pip install -q triton=="2.0.0.post1"

  # patch kohya for minor stuff
  !sed -i 's/from PIL import Image/from PIL import Image, ImageFile\nImageFile.LOAD_TRUNCATED_IMAGES=True/g' library/train_util.py # fix truncated jpegs error
  !sed -i 's/{:06d}/{:02d}/g' library/train_util.py # make epoch names shorter
  !sed -i 's/model_name + "."/model_name + "-{:02d}.".format(num_train_epochs)/g' train_network.py # name of the last epoch will match the rest

  from accelerate.utils import write_basic_config
  if not os.path.exists(accelerate_config):
    write_basic_config(save_location=accelerate_config)

def validate_dataset():
  global images, lr_warmup_steps, lr_warmup_ratio
  supported_types = (".png", ".jpg", ".jpeg")

  print("\n💿 Checking dataset...")
  if not project_name.strip():
    print("💥 Error: Please choose a project name.")
    return

  # this is starting to spaghetti
  if "override_dataset_config_file" in globals() and override_dataset_config_file:
    try:
      datconf = toml.load(override_dataset_config_file)
      datasets = {d["image_dir"]: d["num_repeats"] for d in datconf["datasets"][0]["subsets"]}
    except:
      print(f"💥 Error: Your custom dataset config file is invalid!")
      return
    folders = datasets.keys()
    files = [f for folder in folders for f in os.listdir(folder)]
    images_repeats = {folder: (len([f for f in os.listdir(folder) if f.endswith(supported_types)]), datasets[folder]) for folder in folders}
  else:
    folders = [images_folder]
    files = os.listdir(images_folder)
    images_repeats = {images_folder: (len([f for f in files if f.endswith(supported_types)]), num_repeats)}

  for folder in folders:
    if not os.path.exists(folder):
      print(f"💥 Error: The folder {folder.replace('/content/drive/', '')} doesn't exist.")
      return
  for folder, (img, rep) in images_repeats.items():
    if not img:
      print(f"💥 Error: Your {folder.replace('/content/drive/', '')} folder is empty.")
      return
  if not [txt for txt in files if txt.endswith(".txt")]:
    print("💥 Error: You don't have text files with captions next to your images. If you want to proceed, add a caption to at least 1 of your images.")
    return
  for f in files:
    if not f.endswith(".txt") and not f.endswith(supported_types):
      print(f"💥 Error: Invalid file in dataset: \"{f}\". Aborting.")
      return

  pre_steps_per_epoch = sum(img*rep for (img, rep) in images_repeats.values())
  steps_per_epoch = pre_steps_per_epoch/train_batch_size
  total_steps = int(max_train_epochs*steps_per_epoch)
  lr_warmup_steps = int(total_steps * lr_warmup_ratio)

  for folder, (img, rep) in images_repeats.items():
    print("📁"+folder.replace("/content/drive/", ""))
    print(f"📈 Found {img} images with {rep} repeats, equaling {img*rep} steps")
  print(f"📉 Divide {pre_steps_per_epoch} steps by {train_batch_size} batch size to get {steps_per_epoch} steps per epoch.")
  print(f"🔮 There will be {max_train_epochs} epochs, for around {total_steps} total training steps.")

  if total_steps > 10000:
    print("💥 Error: Your total steps are too high. You probably made a mistake. Aborting...") 
    return
  return True

def create_config():
  global dataset_config_file, config_file, model_file

  os.makedirs(os.path.join(config_dir, project_name), exist_ok=True)
  dataset_config_file = os.path.join(config_dir, project_name, "dataset_config.toml")
  config_file = os.path.join(config_dir, project_name, "training_config.toml")

  if resume:
    resume_points = [f.path for f in os.scandir(output_folder) if f.is_dir()]
    resume_points.sort()
    last_resume_point = resume_points[-1] if resume_points else None
  else:
    last_resume_point = None

  if "override_config_file" in globals() and override_config_file:
    config_file = override_config_file
    print(f"⭕ Using custom config file {config_file}")
  else:
    config_dict = {
      "additional_network_arguments": {
        "unet_lr": unet_lr,
        "text_encoder_lr": text_encoder_lr,
        "network_dim": network_dim,
        "network_alpha": network_alpha,
        "network_module": network_module,
        "network_args": network_args,
        "network_train_unet_only": True if text_encoder_lr == 0 else None,
      },
      "optimizer_arguments": {
        "learning_rate": unet_lr,
        "lr_scheduler": lr_scheduler,
        "lr_scheduler_num_cycles": lr_scheduler_num_cycles if lr_scheduler == "cosine_with_restarts" else None,
        "lr_scheduler_power": lr_scheduler_power if lr_scheduler == "polynomial" else None,
        "lr_warmup_steps": lr_warmup_steps,
        "optimizer_type": "AdamW8bit",
      },
      "training_arguments": {
        "max_train_epochs": max_train_epochs,
        "save_every_n_epochs": save_every_n_epochs,
        "train_batch_size": train_batch_size,
        "noise_offset": None,
        "clip_skip": 2,
        "seed": 42,
        "max_token_length": 225,
        "xformers": True,
        "lowram": True,
        "max_data_loader_n_workers": 8,
        "persistent_data_loader_workers": True,
        "save_precision": "fp16",
        "mixed_precision": "fp16",
        "output_dir": output_folder,
        "logging_dir": logging_dir,
        "output_name": project_name,
        "log_prefix": project_name,
        "save_state": save_state,
        "save_last_n_epochs_state": 1 if save_state else None,
        "resume": last_resume_point
      },
      "model_arguments": {
        "pretrained_model_name_or_path": model_file,
      },
      "saving_arguments": {
        "save_model_as": "safetensors",
      },
      "dreambooth_arguments": {
        "prior_loss_weight": 1.0,
      },
      "dataset_arguments": {
        "cache_latents": True,
      },
    }

    for key in config_dict:
      if isinstance(config_dict[key], dict):
        config_dict[key] = {k: v for k, v in config_dict[key].items() if v is not None}

    with open(config_file, "w") as f:
      f.write(toml.dumps(config_dict))
    print(f"📄 Config saved to {config_file}")

  if "override_dataset_config_file" in globals() and override_dataset_config_file:
    dataset_config_file = override_dataset_config_file
    print(f"⭕ Using custom dataset config file {dataset_config_file}")
  else:
    dataset_config_dict = {
      "general": {
        "resolution": resolution,
        "shuffle_caption": shuffle_caption,
        "flip_aug": flip_aug,        
        "caption_extension": caption_extension,
        "enable_bucket": True,
        "bucket_reso_steps": 64,
        "bucket_no_upscale": False,
        "min_bucket_reso": 320 if resolution > 640 else 256,
        "max_bucket_reso": 1280 if resolution > 640 else 1024,   
      },
      "datasets": [
        {
          "subsets": [
            {
              "num_repeats": num_repeats,
              "keep_tokens": keep_tokens,
              "image_dir": images_folder
            }
          ]
        }
      ]
    }

    for key in dataset_config_dict:
      if isinstance(dataset_config_dict[key], dict):
        dataset_config_dict[key] = {k: v for k, v in dataset_config_dict[key].items() if v is not None}

    with open(dataset_config_file, "w") as f:
      f.write(toml.dumps(dataset_config_dict))
    print(f"📄 Dataset config saved to {dataset_config_file}")

def download_model():
  global old_model_url, model_url, model_file
  real_model_url = model_url.strip()

  if not real_model_url:
    real_model_url = "https://huggingface.co/hollowstrawberry/animemodel/resolve/main/model.safetensors"
  
  if real_model_url.endswith((".ckpt", ".safetensors")):
    model_file = f"/content{real_model_url[real_model_url.rfind('/'):]}"
  else:
    model_file = "/content/downloaded_model.safetensors"

  if m := re.search(r"(?:https?://)?(?:www\.)?huggingface\.co/[^/]+/[^/]+/blob", model_url):
    real_model_url = real_model_url.replace("blob", "resolve")
  elif m := re.search(r"(?:https?://)?(?:www\.)?civitai\.com/models/([0-9]+)", model_url):
    real_model_url = f"https://civitai.com/api/download/models/{m.group(1)}"

  !wget "{real_model_url}" -O "{model_file}"

  if model_file.endswith(".safetensors"):
    from safetensors.torch import load_file as load_safetensors
    try:
      test = load_safetensors(model_file)
      del test
    except Exception as e:
      #if "HeaderTooLarge" in str(e):
      new_model_file = os.path.splitext(model_file)[0]+".ckpt"
      !mv "{model_file}" "{new_model_file}"
      model_file = new_model_file
      print(f"Renamed model to {os.path.splitext(model_file)[0]}.ckpt")

  if model_file.endswith(".ckpt"):
    from torch import load as load_ckpt
    try:
      test = load_ckpt(model_file)
      del test
    except Exception as e:
      return False
      
  return True

def main():
  global dependencies_installed

  if not os.path.exists('/content/drive'):
    print("📂 Connecting to Google Drive...")
    drive.mount('/content/drive')
  
  for dir in [deps_dir, repo_dir, main_dir, config_dir, datasets_dir, output_dir, logging_dir, images_folder, output_folder, config_folder]:
    os.makedirs(dir, exist_ok=True)

  if not validate_dataset():
    return
  
  if not dependencies_installed:
    print("\n🏭 Installing dependencies...\n")
    clone_repo()
    !apt -y update -qq
    install_ubuntu_deps("https://huggingface.co/Linaqruf/fast-repo/resolve/main/ram_patch.zip", "ram_patch.zip", deps_dir)
    %env LD_PRELOAD=libtcmalloc.so
    install_ubuntu_deps("https://huggingface.co/Linaqruf/fast-repo/resolve/main/deb-libs.zip", "deb-libs.zip", deps_dir)
    install_dependencies()
    console.clear()
    print("✅ Installation finished.")
    dependencies_installed = True
  else:
    print("\n✅ Dependencies already installed.")

  if old_model_url != model_url or not model_file or not os.path.exists(model_file):
    print("\n🔄 Downloading model... Don't interrupt this at any cost...\n")
    if not download_model():
      print("\n💥 Error: The model you selected is invalid or corrupted. You can use a civitai or huggingface link, or any direct download link.")
      return
  else:
    print("\n🔄 Model already downloaded.\n")

  create_config()
  
  print("\n⭐ Starting trainer...\n")
  os.chdir(repo_dir)
  
  !accelerate launch --config_file={accelerate_config} --num_cpu_threads_per_process=1 train_network.py --dataset_config={dataset_config_file} --config_file={config_file}

  if not get_ipython().__dict__['user_ns']['_exit_code']:
    display(Markdown("### ✅ Done! [Go download your Lora(s) from Google Drive](https://drive.google.com/drive/my-drive)"))

main()


## *️⃣ Extras

In [None]:
#@markdown ### 📂 Unzip dataset
#@markdown It's much slower to upload individual files to your Drive, so you may want to upload a zip if you have your dataset in your computer.
zip = "/content/drive/MyDrive/lora_training/datasets/example.zip" #@param {type:"string"}
extract_to = "/content/drive/MyDrive/lora_training/datasets/example/" #@param {type:"string"}

import os, zipfile

if not os.path.exists('/content/drive'):
  from google.colab import drive
  print("📂 Connecting to Google Drive...")
  drive.mount('/content/drive')

with zipfile.ZipFile(zip, 'r') as f:
  f.extractall(extract_to)

print("✅ Done")




In [None]:
#@markdown ### 🔢 Count datasets
#@markdown Google Drive makes it impossible to count the files in a folder, so this will show you the file counts in all folders and subfolders.
folder = "/content/drive/MyDrive/lora_training/datasets" #@param {type:"string"}

import os
from google.colab import drive

if not os.path.exists('/content/drive'):
    print("📂 Connecting to Google Drive...\n")
    drive.mount('/content/drive')

tree = {}
for i, (root, dirs, files) in enumerate(os.walk(folder)):
  images = len([f for f in files if f.endswith((".png", ".jpg", ".jpeg"))])
  captions = len([f for f in files if f.endswith(".txt")])
  others = len(files) - images - captions
  path = root[folder.rfind("/")+1:]
  tree[path] = None if not images and not captions and not others \
                    else f"{images:>4} images | {captions:>4} text files | {others:>4} other files"
pad = max(len(k) for k in tree)
print("\n".join(f"📁{k.ljust(pad)} | {v}" for k, v in tree.items() if v))


In [None]:
#@markdown ### 📚 Load custom config file
#@markdown **WARNING!** Advanced users only. <p>
#@markdown Changing the `dataset_config_file` will allow you to use multiple datasets at once, with different repeats, keep_tokens, etc. <p>
#@markdown Changing the `config_file` will override ALL settings including project name and model name, but will allow you to use any obscure settings you want about the kohya trainer. <p>
#@markdown The configs are in `.toml` format, which this trainer saves into the `lora_training/config` folder before training.
#@markdown Example dataset config: <p>
#@markdown ```toml
#@markdown [[datasets]]
#@markdown 
#@markdown [[datasets.subsets]]
#@markdown image_dir = "/content/drive/MyDrive/lora_training/datasets/mylora_good"
#@markdown num_repeats = 5
#@markdown keep_tokens = 1
#@markdown 
#@markdown [[datasets.subsets]]
#@markdown image_dir = "/content/drive/MyDrive/lora_training/datasets/mylora_bad"
#@markdown num_repeats = 1
#@markdown keep_tokens = 0
#@markdown #is_reg = true # Add this to use as regularization images
#@markdown 
#@markdown [general]
#@markdown resolution = 512
#@markdown shuffle_caption = true
#@markdown flip_aug = true
#@markdown caption_extension = ".txt"
#@markdown enable_bucket = true
#@markdown bucket_reso_steps = 64
#@markdown bucket_no_upscale = false
#@markdown min_bucket_reso = 256
#@markdown max_bucket_reso = 1024
#@markdown ```
#@markdown Run this with empty values to disable custom configs.
override_dataset_config_file = "/content/drive/MyDrive/lora_training/datasets/mylora_dataset_config.toml" #@param {type:"string"}
override_config_file = "" #@param {type:"string"}
