# ⭐ Lora 訓練器 by Hollowstrawberry

基於 [Kohya-ss](https://github.com/kohya-ss/sd-scripts) 與 [Linaqruf](https://colab.research.google.com/github/Linaqruf/kohya-trainer/blob/main/kohya-LoRA-dreambooth.ipynb) 來完成此工具，感謝。


### ⭕ 免則聲明
本文件是用於研究機器學習等前端技術為目的。
請閱讀以下文件說明 [Google Colab guidelines](https://research.google.com/colaboratory/faq.html) 與 [Terms of Service](https://research.google.com/colaboratory/tos_v3.html).

| |GitHub|🇬🇧 English|🇪🇸 Spanish|🇹🇼 繁體中文|
|:--|:-:|:-:|:-:|:-:|
| 🏠 **首頁** | [![GitHub](https://raw.githubusercontent.com/hollowstrawberry/kohya-colab/main/assets/github.svg)](https://github.com/hollowstrawberry/kohya-colab) | | |
| 📊 **資料集製作器** | [![GitHub](https://raw.githubusercontent.com/hollowstrawberry/kohya-colab/main/assets/github.svg)](https://github.com/hollowstrawberry/kohya-colab/blob/main/Dataset_Maker.ipynb) | [![Open in Colab](https://raw.githubusercontent.com/hollowstrawberry/kohya-colab/main/assets/colab-badge.svg)](https://colab.research.google.com/github/hollowstrawberry/kohya-colab/blob/main/Dataset_Maker.ipynb) | [![Abrir en Colab](https://raw.githubusercontent.com/hollowstrawberry/kohya-colab/main/assets/colab-badge-spanish.svg)](https://colab.research.google.com/github/hollowstrawberry/kohya-colab/blob/main/Spanish_Dataset_Maker.ipynb) | |
| ⭐ **Lora 訓練器** | [![GitHub](https://raw.githubusercontent.com/hollowstrawberry/kohya-colab/main/assets/github.svg)](https://github.com/hollowstrawberry/kohya-colab/blob/main/Lora_Trainer.ipynb) | [![Open in Colab](https://raw.githubusercontent.com/hollowstrawberry/kohya-colab/main/assets/colab-badge.svg)](https://colab.research.google.com/github/hollowstrawberry/kohya-colab/blob/main/Lora_Trainer.ipynb) | [![Abrir en Colab](https://raw.githubusercontent.com/hollowstrawberry/kohya-colab/main/assets/colab-badge-spanish.svg)](https://colab.research.google.com/github/hollowstrawberry/kohya-colab/blob/main/Spanish_Lora_Trainer.ipynb) | [![在 Colab 開啟](https://raw.githubusercontent.com/hinablue/kohya-colab/main/assets/colab-badge-spanish.svg)](https://colab.research.google.com/github/hinablue/kohya-colab/blob/main/Traditional_Chinese_Dataset_Maker.ipynb) |

In [None]:
import os
import re
import toml
import shutil
import zipfile
from time import time
from IPython.display import Markdown, display

# These carry information from past executions
if "model_url" in globals():
  old_model_url = model_url
else:
  old_model_url = None
if "dependencies_installed" not in globals():
  dependencies_installed = False
if "model_file" not in globals():
  model_file = None

# These may be set by other cells, some are legacy
if "custom_dataset" not in globals():
  custom_dataset = None
if "override_dataset_config_file" not in globals():
  override_dataset_config_file = None
if "override_config_file" not in globals():
  override_config_file = None
if "optimizer" not in globals():
  optimizer = "AdamW8bit"
if "optimizer_args" not in globals():
  optimizer_args = None
if "continue_from_lora" not in globals():
  continue_from_lora = ""
if "weighted_captions" not in globals():
  weighted_captions = False
if "adjust_tags" not in globals():
  adjust_tags = False
if "keep_tokens_weight" not in globals():
  keep_tokens_weight = 1.0

COLAB = True # low ram
COMMIT = "v0.6.3"
BETTER_EPOCH_NAMES = True
LOAD_TRUNCATED_IMAGES = True

#@title ## 🚩 Start Here

#@markdown ### ▶️ 設定
#@markdown 你的專案名稱必須和包含圖片的資料夾名稱相同。不允許使用空格。
project_name = "" #@param {type:"string"}
#@markdown 資料夾結構不重要，只是為了方便。請確保每次都選擇相同的結構。我傾向使用以專案模式的方式。
folder_structure = "專案模式 (MyDrive/Loras/project_name/dataset)" #@param ["分類模式 (MyDrive/lora_training/datasets/project_name)", "專案模式 (MyDrive/Loras/project_name/dataset)"]
#@markdown 選擇並下載訓練所需要使用的模型。這些選項應該會產生乾淨且一致的結果。你也可以選擇自己的模型，只要貼上下載連結即可。
training_model = "Anime (animefull-final-pruned-fp16.safetensors)" #@param ["Anime (animefull-final-pruned-fp16.safetensors)", "AnyLora (AnyLoRA_noVae_fp16-pruned.ckpt)", "Stable Diffusion (sd-v1-5-pruned-noema-fp16.safetensors)"]
#@markdown 自訂模型的下載連結。如果你沒有自訂模型，請留空。
optional_custom_training_model_url = "" #@param {type:"string"}
#@markdown 如果你的自訂模型是基於 Stable Diffusion 2.0，請勾選此選項。
custom_model_is_based_on_sd2 = False #@param {type:"boolean"}

if optional_custom_training_model_url:
  model_url = optional_custom_training_model_url
elif "AnyLora" in training_model:
  model_url = "https://huggingface.co/Lykon/AnyLoRA/resolve/main/AnyLoRA_noVae_fp16-pruned.ckpt"
elif "Anime" in training_model:
  model_url = "https://huggingface.co/hollowstrawberry/stable-diffusion-guide/resolve/main/models/animefull-final-pruned-fp16.safetensors"
else:
  model_url = "https://huggingface.co/hollowstrawberry/stable-diffusion-guide/resolve/main/models/sd-v1-5-pruned-noema-fp16.safetensors"

#@markdown ### ▶️ 加工處理
#@markdown 512 的解析度是 Stable Diffusion 1.5 的標準。更高的解析度會讓訓練速度變慢，但是可以產生更好的細節。
#@markdown 圖片會在訓練時自動縮放，所以你不需要自己裁切或縮放圖片。
resolution = 512 #@param {type:"slider", min:512, max:1024, step:16}
#@markdown 這個選項會在訓練時自動翻轉圖片，不會增加訓練時間，但是可以讓模型學習更多。如果你的圖片少於 20 張，請務必開啟此選項。
#@markdown **如果你在意圖片的對稱性，請關閉此選項。**
flip_aug = False #@param {type:"boolean"}
#markdown 提示詞檔案的副檔名，若無提示詞請留下空白。
caption_extension = ".txt" #param {type:"string"}
#@markdown 針對動畫標籤進行洗牌可以提高學習和提示的效果。文字檔案開頭的啟動標籤，不會被洗牌。
shuffle_tags = True #@param {type:"boolean"}
shuffle_caption = shuffle_tags
#@markdown 啟動標籤數量。如果你的圖片沒有啟動標籤，請設定為 0。請務必確認你的啟動標籤放在文字檔案的最前面。
activation_tags = "1" #@param [0,1,2,3]
keep_tokens = int(activation_tags)

#@markdown ### ▶️ 步數
#@markdown 你的圖片會在訓練時重複這個次數。我建議你的圖片乘以重複次數介於 200 到 400 之間。
num_repeats = 10 #@param {type:"number"}
#@markdown Choose how long you want to train for. A good starting point is around 10 epochs or around 2000 steps.
#@markdown 選擇你想要訓練的時間。一個好的起點是 10 個輪次（Epoch） 或 2000 個步數（Steps)。
#@markdown 每一個輪次（Epoch）的步數等於：你的圖片數量乘以重複次數，除以批次大小。
preferred_unit = "Epochs" #@param ["Epochs", "Steps"]
how_many = 10 #@param {type:"number"}
max_train_epochs = how_many if preferred_unit == "Epochs" else None
max_train_steps = how_many if preferred_unit == "Steps" else None
#@markdown Saving more epochs will let you compare your Lora's progress better.
#@markdown 儲存每次的輪次（Epochs）可以讓你更好的比較你的 Lora 的進度。
save_every_n_epochs = 1 #@param {type:"number"}
keep_only_last_n_epochs = 10 #@param {type:"number"}
if not save_every_n_epochs:
  save_every_n_epochs = max_train_epochs
if not keep_only_last_n_epochs:
  keep_only_last_n_epochs = max_train_epochs
#@markdown 增加批次大小可以讓訓練更快，但是可能會讓學習變差。建議 2 或 3。
train_batch_size = 2 #@param {type:"slider", min:1, max:8, step:1}

#@markdown ### ▶️ 學習設定
#@markdown 學習率是你的結果最重要的因素。如果你想要訓練更慢，或是你的圖片數量很多，或是你的 dim 和 alpha 很高，請把 UNet 的學習率調到 2e-4 或更低。
#@markdown 文字編碼器可以讓你的 Lora 學習概念更好。建議你把文字編碼器的學習率設定為 UNet 的一半或五分之一。如果你在訓練風格，你甚至可以把它設定為 0。
unet_lr = 2e-4 #@param {type:"number"}
text_encoder_lr = 1e-4 #@param {type:"number"}
#@markdown The scheduler is the algorithm that guides the learning rate. If you're not sure, pick `constant` and ignore the number. I personally recommend `cosine_with_restarts` with 3 restarts.
#@markdown 調度器是指導學習率的演算法。如果你不確定，請選擇 `constant` 並忽略重啟數字（`lr_scheduler_number`）。若使用 `cosine_with_restarts` ，我個人建議使用 3 次重啟數字。
lr_scheduler = "cosine_with_restarts" #@param ["constant", "cosine", "cosine_with_restarts", "constant_with_warmup", "linear", "polynomial"]
lr_scheduler_number = 3 #@param {type:"number"}
lr_scheduler_num_cycles = lr_scheduler_number if lr_scheduler == "cosine_with_restarts" else 0
lr_scheduler_power = lr_scheduler_number if lr_scheduler == "polynomial" else 0
#@markdown Steps spent "warming up" the learning rate during training for efficiency. I recommend leaving it at 5%.
#@markdown 在訓練時，用來「熱身」學習率的步數（僅適用於 `constant_with_warmup`）。我建議你把它設定為 5%。
lr_warmup_ratio = 0.05 #@param {type:"slider", min:0.0, max:0.5, step:0.01}
lr_warmup_steps = 0
#@markdown New feature that adjusts loss over time, makes learning much more efficient, and training can be done with about half as many epochs. Uses a value of 5.0 as recommended by [the paper](https://arxiv.org/abs/2303.09556).
#@markdown 新功能，可以隨著時間調整損失函數，讓學習更有效率，並且可以用大約一半的輪次（Epochs）完成訓練。參考[論文](https://arxiv.org/abs/2303.09556)建議，使用 5% 的數值。
min_snr_gamma = True #@param {type:"boolean"}
min_snr_gamma_value = 5.0 if min_snr_gamma else None

#@markdown ### ▶️ 模型結構
#@markdown LoRA 是經典的類型，而 LoCon/LoHa 則是適合風格類型的訓練。在 WebUI 中使用 LyCORIS 需要[這個擴充功能](https://github.com/KohakuBlueleaf/a1111-sd-webui-lycoris)。若想要更多資訊，請參考[這裡](https://github.com/KohakuBlueleaf/Lycoris)。
lora_type = "LoRA" #@param ["LoRA", "LoCon Lycoris", "LoHa Lycoris"]

#@markdown 以下是一些推薦的設定值：

#@markdown | 類型 | network_dim | network_alpha | conv_dim | conv_alpha |
#@markdown | :---: | :---: | :---: | :---: | :---: |
#@markdown | LoRA | 32 | 16 |   |   |
#@markdown | LoCon | 16 | 8 | 8 | 1 |
#@markdown | LoHa | 8 | 4 | 4 | 1 |

#@markdown 更大的 dim 代表更大的 Lora，它可以儲存更多資訊，但是並不是越大越好。建議的 dim 為 8-32，alpha 則是 dim 的一半。
network_dim = 16 #@param {type:"slider", min:1, max:128, step:1}
network_alpha = 8 #@param {type:"slider", min:1, max:128, step:1}
#@markdown 以下的數值不會影響 LoRA。它們的作用類似於 dim/alpha，但僅適用於 LyCORIS 的額外學習層。
conv_dim = 8 #@param {type:"slider", min:1, max:64, step:1}
conv_alpha = 1 #@param {type:"slider", min:1, max:64, step:1}
conv_compression = False #@param {type:"boolean"}

network_module = "lycoris.kohya" if "Lycoris" in lora_type else "networks.lora"
network_args = None if lora_type == "LoRA" else [
  f"conv_dim={conv_dim}",
  f"conv_alpha={conv_alpha}",
]
if "Lycoris" in lora_type:
  network_args.append(f"algo={'loha' if 'LoHa' in lora_type else 'lora'}")
  network_args.append(f"disable_conv_cp={str(not conv_compression)}")

#@markdown ### ▶️ 更多設定
#@markdown 建議使用 0.15 ~ 0.4 之間的數值，數值越大，訓練生成的結果越接近正規化圖片。
prior_loss_weight = 0.1 #@param {type:"number"}
#@markdown 噪訊偏移可以改善亮度/暗度的處理結果。
noise_offset = 0.1 #@param {type:"number"}
#@markdown 黑魔法，建議值使用 2，若訓練真實人物等模型，可以使用 1。
clip_skip = 2 #@param {type:"slider", min:0, max:10, step:1}
#@markdown 最大文字標籤長度，預設為 225。
max_token_length = 225 #@param [75,125,225]
#@markdown Enable bucket no upscale, set to False if you want to using min/max bucket resolution.
#@markdown 啟用不放大批次解析度，如果你想要使用 最大/最小 批次解析度，請設定為 `False`。
bucket_no_upscale = True #@param {type:"boolean"}

#@markdown **以下設定僅適用於不放大批次解析度為 `False` 時。**

#@markdown 若解析度小於 512，建議使用 256，否則建議使用 320。
min_bucket_reso = 256 #@param {type:"number"}
#@markdown 若解析度小於 512，建議使用 1024，否則建議使用 1280。
max_bucket_reso = 1024 #@param {type:"number"}

#@markdown 顏色增強可以改善結果的顏色。如果啟用，將會強制關閉快取潛在變數快取（`Latent Cache = False`）。
color_aug = False #@param {type:"boolean"}
#@markdown 依據臉部中心尺寸擷取圖片，再依此上限、下限決定擷取的範圍。可優化臉部訓練。若訓練包含背景風格可將此數值上限、下限加大。
face_crop_aug_range = [1.0, 3.0] #@param {type:"raw"}
#@markdown Random crop the image, recommend using False.
#@markdown 隨機擷取圖片區域，建議使用 `False`，若訓練風格類型則建議開啟。
random_crop = False #@param {type:"boolean"}
#@markdown 增加梯度累積步驟以節省 GPU 記憶體。但會降低訓練速度，並需要更高的學習率。
gradient_accumulation_steps = 1 #@param {type:"number"}

#@markdown ### ▶️ 實驗性功能
#@markdown 儲存額外資料，約 1 GB，可以讓你稍後繼續訓練。
save_state = False #@param {type:"boolean"}
#@markdown 如果有儲存的額外資料，則繼續訓練。
resume = False #@param {type:"boolean"}

#@markdown ### ▶️ 準備好了
#@markdown 你現在可以執行此儲存格來開始訓練。祝你好運！

# 👩‍💻 Cool code goes here

if optimizer == "DAdaptation":
  optimizer_args = ["decouple=True","weight_decay=0.02"]
  unet_lr = 0.5
  text_encoder_lr = 0.5
  lr_scheduler = "constant_with_warmup"
  network_alpha = network_dim

root_dir = "/content" if COLAB else "~/Loras"
deps_dir = os.path.join(root_dir, "deps")
repo_dir = os.path.join(root_dir, "kohya-trainer")

if "/Loras" in folder_structure:
  main_dir      = os.path.join(root_dir, "drive/MyDrive/Loras") if COLAB else root_dir
  log_folder    = os.path.join(main_dir, "_logs")
  config_folder = os.path.join(main_dir, project_name)
  images_folder = os.path.join(main_dir, project_name, "dataset")
  output_folder = os.path.join(main_dir, project_name, "output")
else:
  main_dir      = os.path.join(root_dir, "drive/MyDrive/lora_training") if COLAB else root_dir
  images_folder = os.path.join(main_dir, "datasets", project_name)
  output_folder = os.path.join(main_dir, "output", project_name)
  config_folder = os.path.join(main_dir, "config", project_name)
  log_folder    = os.path.join(main_dir, "log")

config_file = os.path.join(config_folder, "training_config.toml")
dataset_config_file = os.path.join(config_folder, "dataset_config.toml")
accelerate_config_file = os.path.join(repo_dir, "accelerate_config/config.yaml")

def clone_repo():
  os.chdir(root_dir)
  !git clone https://github.com/kohya-ss/sd-scripts {repo_dir}
  os.chdir(repo_dir)
  if COMMIT:
    !git reset --hard {COMMIT}
  !wget https://raw.githubusercontent.com/hollowstrawberry/kohya-colab/main/requirements.txt -q -O requirements.txt

def install_dependencies():
  clone_repo()
  !apt -y update -qq
  !apt -y install aria2
  !pip -q install --upgrade -r requirements.txt

  # patch kohya for minor stuff
  if COLAB:
    !sed -i "s@cpu@cuda@" library/model_util.py # low ram
  if LOAD_TRUNCATED_IMAGES:
    !sed -i 's/from PIL import Image/from PIL import Image, ImageFile\nImageFile.LOAD_TRUNCATED_IMAGES=True/g' library/train_util.py # fix truncated jpegs error
  if BETTER_EPOCH_NAMES:
    !sed -i 's/{:06d}/{:02d}/g' library/train_util.py # make epoch names shorter
    !sed -i 's/model_name + "."/model_name + "-{:02d}.".format(num_train_epochs)/g' train_network.py # name of the last epoch will match the rest

  from accelerate.utils import write_basic_config
  if not os.path.exists(accelerate_config_file):
    write_basic_config(save_location=accelerate_config_file)

  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
  os.environ["BITSANDBYTES_NOWELCOME"] = "1"  
  os.environ["SAFETENSORS_FAST_GPU"] = "1"

def validate_dataset():
  global lr_warmup_steps, lr_warmup_ratio, caption_extension, keep_tokens, keep_tokens_weight, weighted_captions, adjust_tags
  supported_types = (".png", ".jpg", ".jpeg")

  print("\n💿 檢查資料集...")
  if not project_name.strip() or any(c in project_name for c in " .()\"'\\/"):
    print("💥 錯誤：請選擇正確的專案名稱。")
    return

  if custom_dataset:
    try:
      datconf = toml.loads(custom_dataset)
      datasets = {d["image_dir"]: d["num_repeats"] for d in datconf["datasets"][0]["subsets"]}
    except:
      print(f"💥 錯誤：你的自訂資料集結構錯誤，請確認資料集結構。")
      return
    folders = datasets.keys()
    files = [f for folder in folders for f in os.listdir(folder)]
    images_repeats = {folder: (len([f for f in os.listdir(folder) if f.lower().endswith(supported_types)]), datasets[folder]) for folder in folders}
  else:
    folders = [images_folder]
    files = os.listdir(images_folder)
    images_repeats = {images_folder: (len([f for f in files if f.lower().endswith(supported_types)]), num_repeats)}

  for folder in folders:
    if not os.path.exists(folder):
      print(f"💥 錯誤：資料夾 {folder.replace('/content/drive/', '')} 不存在。")
      return
  for folder, (img, rep) in images_repeats.items():
    if not img:
      print(f"💥 錯誤：你的 {folder.replace('/content/drive/', '')} 資料夾沒有資料。")
      return
  for f in files:
    if not f.lower().endswith(".txt") and not f.lower().endswith(supported_types):
      print(f"💥 錯誤：錯誤的資料集結構 \"{f}\"，中斷執行。")
      return
    
  if not [txt for txt in files if txt.lower().endswith(".txt")]:
    caption_extension = ""
  if continue_from_lora and not (continue_from_lora.endswith(".safetensors") and os.path.exists(continue_from_lora)):
    print(f"💥 錯誤：錯誤的路徑或 Lora 不存在。範例：/content/drive/MyDrive/Loras/example.safetensors")
    return

  pre_steps_per_epoch = sum(img*rep for (img, rep) in images_repeats.values())
  steps_per_epoch = pre_steps_per_epoch/train_batch_size
  total_steps = max_train_steps or int(max_train_epochs*steps_per_epoch)
  estimated_epochs = int(total_steps/steps_per_epoch)
  lr_warmup_steps = int(total_steps*lr_warmup_ratio)

  for folder, (img, rep) in images_repeats.items():
    print("📁"+folder.replace("/content/drive/", ""))
    print(f"📈 讀取 {img} 圖片，重複 {rep} 次，共有 {img*rep} 個步驟。")
  print(f"📉 訓練批次 {train_batch_size} 除以每一輪次步驟 {pre_steps_per_epoch} 得到每一輪次處理 {steps_per_epoch} 個步驟。")
  if max_train_epochs:
    print(f"🔮 最大輪次 {max_train_epochs}，大約會有 {total_steps} 個總訓練步驟。")
  else:
    print(f"🔮 總訓練步驟 {total_steps}，並除以 {estimated_epochs} 個輪次。")

  if total_steps > 18000:
    print("💥 錯誤：你的總訓練步驟過高，可能會造成錯誤。中斷訓練。") 
    return

  if adjust_tags:
    print(f"\n📎 設定權重標籤：{'開啟' if weighted_captions else '關閉'}")
    if weighted_captions:
      print(f"📎 將對 {keep_tokens} 個啟動標籤設定 {keep_tokens_weight} 權重。")
    print("📎 調整標籤...")
    adjust_weighted_tags(folders, keep_tokens, keep_tokens_weight, weighted_captions)
  
  return True

def adjust_weighted_tags(folders, keep_tokens: int, keep_tokens_weight: float, weighted_captions: bool):
  weighted_tag = re.compile(r"\((.+?):[.\d]+\)(,|$)")
  for folder in folders:
    for txt in [f for f in os.listdir(folder) if f.lower().endswith(".txt")]:
      with open(os.path.join(folder, txt), 'r') as f:
        content = f.read()
      # reset previous changes
      content = content.replace('\\', '')
      content = weighted_tag.sub(r'\1\2', content)
      if weighted_captions:
        # re-apply changes
        content = content.replace(r'(', r'\(').replace(r')', r'\)').replace(r':', r'\:')
        if keep_tokens_weight > 1:
          tags = [s.strip() for s in content.split(",")]
          for i in range(min(keep_tokens, len(tags))):
            tags[i] = f'({tags[i]}:{keep_tokens_weight})'
          content = ", ".join(tags)
      with open(os.path.join(folder, txt), 'w') as f:
        f.write(content)

def create_config():
  global dataset_config_file, config_file, model_file

  if resume:
    resume_points = [f.path for f in os.scandir(output_folder) if f.is_dir()]
    resume_points.sort()
    last_resume_point = resume_points[-1] if resume_points else None
  else:
    last_resume_point = None

  if override_config_file:
    config_file = override_config_file
    print(f"\n⭕ 使用自訂設定檔 {config_file}")
  else:
    config_dict = {
      "additional_network_arguments": {
        "unet_lr": unet_lr,
        "text_encoder_lr": text_encoder_lr,
        "network_dim": network_dim,
        "network_alpha": network_alpha,
        "network_module": network_module,
        "network_args": network_args,
        "network_train_unet_only": True if text_encoder_lr == 0 else None,
        "network_weights": continue_from_lora if continue_from_lora else None
      },
      "optimizer_arguments": {
        "learning_rate": unet_lr,
        "lr_scheduler": lr_scheduler,
        "lr_scheduler_num_cycles": lr_scheduler_num_cycles if lr_scheduler == "cosine_with_restarts" else None,
        "lr_scheduler_power": lr_scheduler_power if lr_scheduler == "polynomial" else None,
        "lr_warmup_steps": lr_warmup_steps if lr_scheduler != "constant" else None,
        "optimizer_type": optimizer,
        "optimizer_args": optimizer_args if optimizer_args else None,
      },
      "training_arguments": {
        "max_train_steps": max_train_steps,
        "max_train_epochs": max_train_epochs,
        "save_every_n_epochs": save_every_n_epochs,
        "save_last_n_epochs": keep_only_last_n_epochs,
        "train_batch_size": train_batch_size,
        "noise_offset": noise_offset if noise_offset > 0 else None,
        "clip_skip": clip_skip if clip_skip > 0 else None,
        "min_snr_gamma": min_snr_gamma_value,
        "weighted_captions": weighted_captions,
        "seed": 42,
        "max_token_length": max_token_length,
        "xformers": True,
        "lowram": COLAB,
        "max_data_loader_n_workers": 8,
        "persistent_data_loader_workers": True,
        "save_precision": "fp16",
        "mixed_precision": "fp16",
        "output_dir": output_folder,
        "logging_dir": log_folder,
        "output_name": project_name,
        "log_prefix": project_name,
        "save_state": save_state,
        "save_last_n_epochs_state": 1 if save_state else None,
        "resume": last_resume_point,
        "gradient_accumulation_steps": gradient_accumulation_steps,
      },
      "model_arguments": {
        "pretrained_model_name_or_path": model_file,
        "v2": custom_model_is_based_on_sd2,
        "v_parameterization": True if custom_model_is_based_on_sd2 else None,
      },
      "saving_arguments": {
        "save_model_as": "safetensors",
      },
      "dreambooth_arguments": {
        "prior_loss_weight": prior_loss_weight,
      },
      "dataset_arguments": {
        "cache_latents": True if color_aug == True else False,
      },
    }

    for key in config_dict:
      if isinstance(config_dict[key], dict):
        config_dict[key] = {k: v for k, v in config_dict[key].items() if v is not None}

    with open(config_file, "w") as f:
      f.write(toml.dumps(config_dict))
    print(f"\n📄 設定檔已儲存到 {config_file}")

  if override_dataset_config_file:
    dataset_config_file = override_dataset_config_file
    print(f"⭕ 使用自訂資料集設定檔 {dataset_config_file}")
  else:
    dataset_config_dict = {
      "general": {
        "resolution": resolution,
        "shuffle_caption": shuffle_caption,
        "keep_tokens": keep_tokens,
        "flip_aug": flip_aug,
        "caption_extension": caption_extension,
        "enable_bucket": True,
        "bucket_reso_steps": 64,
        "bucket_no_upscale": bucket_no_upscale,
        "min_bucket_reso": min_bucket_reso,
        "max_bucket_reso": max_bucket_reso,
        "color_aug": color_aug,
        "face_crop_aug_range": face_crop_aug_range,
        "random_crop": random_crop,
      },
      "datasets": toml.loads(custom_dataset)["datasets"] if custom_dataset else [
        {
          "subsets": [
            {
              "num_repeats": num_repeats,
              "image_dir": images_folder,
              "class_tokens": None if caption_extension else project_name
            }
          ]
        }
      ]
    }

    for key in dataset_config_dict:
      if isinstance(dataset_config_dict[key], dict):
        dataset_config_dict[key] = {k: v for k, v in dataset_config_dict[key].items() if v is not None}

    with open(dataset_config_file, "w") as f:
      f.write(toml.dumps(dataset_config_dict))
    print(f"📄 資料集設定檔已儲存到 {dataset_config_file}")

def download_model():
  global old_model_url, model_url, model_file
  real_model_url = model_url.strip()
  
  if real_model_url.lower().endswith((".ckpt", ".safetensors")):
    model_file = f"/content{real_model_url[real_model_url.rfind('/'):]}"
  else:
    model_file = "/content/downloaded_model.safetensors"
    if os.path.exists(model_file):
      !rm "{model_file}"

  if m := re.search(r"(?:https?://)?(?:www\.)?huggingface\.co/[^/]+/[^/]+/blob", model_url):
    real_model_url = real_model_url.replace("blob", "resolve")
  elif m := re.search(r"(?:https?://)?(?:www\.)?civitai\.com/models/([0-9]+)", model_url):
    real_model_url = f"https://civitai.com/api/download/models/{m.group(1)}"

  !aria2c "{real_model_url}" --console-log-level=warn -c -s 16 -x 16 -k 10M -d / -o "{model_file}"

  if model_file.lower().endswith(".safetensors"):
    from safetensors.torch import load_file as load_safetensors
    try:
      test = load_safetensors(model_file)
      del test
    except Exception as e:
      #if "HeaderTooLarge" in str(e):
      new_model_file = os.path.splitext(model_file)[0]+".ckpt"
      !mv "{model_file}" "{new_model_file}"
      model_file = new_model_file
      print(f"Renamed model to {os.path.splitext(model_file)[0]}.ckpt")

  if model_file.lower().endswith(".ckpt"):
    from torch import load as load_ckpt
    try:
      test = load_ckpt(model_file)
      del test
    except Exception as e:
      return False
      
  return True

def main():
  global dependencies_installed

  if COLAB and not os.path.exists('/content/drive'):
    from google.colab import drive
    print("📂 連接到 Google Drive...")
    drive.mount('/content/drive')
  
  for dir in (main_dir, deps_dir, repo_dir, log_folder, images_folder, output_folder, config_folder):
    os.makedirs(dir, exist_ok=True)

  if not validate_dataset():
    return
  
  if not dependencies_installed:
    print("\n🏭 安裝相依套件...\n")
    t0 = time()
    install_dependencies()
    t1 = time()
    dependencies_installed = True
    print(f"\n✅ 安裝完成，使用 {int(t1-t0)} 秒。")
  else:
    print("\n✅ 相依套件已安裝完成。")

  if old_model_url != model_url or not model_file or not os.path.exists(model_file):
    print("\n🔄 下載模型中...")
    if not download_model():
      print("\n💥 錯誤：你的模型錯誤或無法下載，你可以使用 Civitai 或 Huggingface 連結，或任何可以直接下載的連結。")
      return
    print()
  else:
    print("\n🔄 模型已經下載。\n")

  create_config()
  
  print("\n⭐ GPU 資訊...\n")
  !nvidia-smi -L

  print("\n⭐ 開始訓練...\n")
  os.chdir(repo_dir)
  
  !accelerate launch --config_file={accelerate_config_file} --num_cpu_threads_per_process=2 train_network.py --dataset_config={dataset_config_file} --config_file={config_file}

  if not get_ipython().__dict__['user_ns']['_exit_code']:
    display(Markdown("### ✅ 完成！[前往你的 Google Drive 下載](https://drive.google.com/drive/my-drive)"))

main()


## *️⃣ 擴充功能

在開始訓練之前，你可以執行以下的功能。

### 📚 複數資料夾的資料集
以下的樣版允許你在資料集中，定義複數的資料夾。你需要包含每個資料集的檔案路徑，且指定每個資料集的重複次數。你可以直接複製 `[[datasets.subsets]]` 區塊，簡單的增加你的資料集。

當你使用這個設定，在原本訓練中的重複次數的設定將會被忽略，且依據專案的資料集設定也會被忽略。

你可以加入 `ìs_reg = true` 將某一個資料集設定為正規化（regularization）資料。
你也可以設定各種不同的參數，例如 `keep_tokens`, `flip_aug` 等等。

In [None]:
custom_dataset = """
[[datasets]]

[[datasets.subsets]]
image_dir = "/content/drive/MyDrive/Loras/example/dataset/good_images"
num_repeats = 3
is_reg = false

[[datasets.subsets]]
image_dir = "/content/drive/MyDrive/Loras/example/dataset/normal_images"
num_repeats = 1
is_reg = false

[[datasets.subsets]]
image_dir = "/content/drive/MyDrive/Loras/example/dataset/reg_images"
num_repeats = 1
is_reg = true

"""

In [None]:
custom_dataset = None

In [None]:
#@markdown ### 🔮 進階設定
#@markdown 改變訓練時使用的優化器，`AdamW8bit` 是推薦的預設值。
#@markdown 選擇 DAdaptation 優化器（會使用自動管理學習率）會覆蓋以下設定：
#@markdown `learning_rate=0.5`, `lr_scheduler="constant_with_warmup"`, `optimizer_args=decouple=True,weight_decay=0.02`, `network_alpha=network_dim`
optimizer = "AdamW8bit" #@param ["AdamW8bit", "AdamW", "Lion", "DAdaptation", "SGDNesterov", "SGDNesterov8bit", "AdaFactor"]
optimizer_args = "weight_decay=0.1" #@param {type:"string"}
optimizer_args = [a.strip() for a in optimizer_args.split(",") if a]

#@markdown 權重標籤是一個新功能，允許你使用（括號）來給予資料集中某些標籤更多的權重，就像在網頁界面提示中一樣。
#@markdown 正常括號在你的標籤中，例如 `(series names)`，需要像 `\(series names\)` 這樣跳脫。
weighted_captions = False #@param {type:"boolean"}

#markdown By enabling `adjust_tags`, you will let this colab modify your tags before running to automatically adjust to `weighted_captions` being on or off. 
#@markdown 若啟用 `adjust_tags`，你將允許此 colab 在執行前修改你的標籤，以自動調整 `weighted_captions` 的開關。
adjust_tags = False #@param {type:"boolean"}
activation_tag_weight = "1.0" #@param ["1.0","1.1","1.2"]
keep_tokens_weight = float(activation_tag_weight)

#@markdown 你可以在這裡寫下你的 Google Drive 中的路徑，以載入現有的 Lora 檔案，以繼續訓練。
#@markdown **警告** 這不是一個長時間的訓練階段。每個 epoch 都是從頭開始，並且可能會有更差的結果。
continue_from_lora = "" #@param {type:"string"}
if continue_from_lora and not continue_from_lora.startswith("/content/drive/MyDrive"):
  import os
  continue_from_lora = os.path.join("/content/drive/MyDrive", continue_from_lora)


In [None]:
#@markdown ### 📂 解壓縮資料集
#@markdown 上傳資料集檔案是一個很慢的過程，如果你有一個 zip 檔案，你可能想要上傳一個 zip 檔案到你的雲端硬碟。
zip = "/content/drive/MyDrive/my_dataset.zip" #@param {type:"string"}
extract_to = "/content/drive/MyDrive/Loras/example/dataset" #@param {type:"string"}

import os, zipfile

if not os.path.exists('/content/drive'):
  from google.colab import drive
  print("📂 Connecting to Google Drive...")
  drive.mount('/content/drive')

os.makedirs(extract_to, exist_ok=True)

with zipfile.ZipFile(zip, 'r') as f:
  f.extractall(extract_to)

print("✅ Done")


In [None]:
#@markdown ### 🔢 計算資料集
#@markdown Google Drive 無法計算資料夾中的檔案數，因此這將顯示所有資料夾和子資料夾中的檔案數。
folder = "/content/drive/MyDrive/Loras" #@param {type:"string"}

import os
from google.colab import drive

if not os.path.exists('/content/drive'):
    print("📂 Connecting to Google Drive...\n")
    drive.mount('/content/drive')

tree = {}
exclude = ("_logs", "/output")
for i, (root, dirs, files) in enumerate(os.walk(folder, topdown=True)):
  dirs[:] = [d for d in dirs if all(ex not in d for ex in exclude)]
  images = len([f for f in files if f.lower().endswith((".png", ".jpg", ".jpeg"))])
  captions = len([f for f in files if f.lower().endswith(".txt")])
  others = len(files) - images - captions
  path = root[folder.rfind("/")+1:]
  tree[path] = None if not images else f"{images:>4} images | {captions:>4} captions |"
  if tree[path] and others:
    tree[path] += f" {others:>4} other files"

pad = max(len(k) for k in tree)
print("\n".join(f"📁{k.ljust(pad)} | {v}" for k, v in tree.items() if v))
