# Simplified Lora Workshop by Citron Legacy 🍋

Help fuel my passion!
 [![ko-fi](https://img.shields.io/badge/Support%20me%20on%20Ko--fi-F16061?logo=ko-fi&logoColor=white&style=flat)](https://ko-fi.com/citronlegacy)

Any amount would be awesome!

![](https://i.imgur.com/sjXiQwT.png)


### Project Description

This project is for simplying the training of Loras for Stable Diffusion. There are 2 steps
1. Make a Dataset
2. Make a Lora

There are a lot of great Lora training tools with nice features but this one is intended to hide advanced settings and make the simplest trainer possible.




# Links
| Project |GitHub| Colab | | Other content | Link|
|:--|:-:|:-:|:-:|:--|:--|
| 🏠 **Homepage** | [![GitHub](https://raw.githubusercontent.com/citronlegacy/kohya-colab/main/assets/github.svg)](https://github.com/citronlegacy/kohya-colab) | | | ☕ **Ko-fi** | [![Ko-Fi](https://img.shields.io/badge/Ko--Fi-Support-orange.svg)](https://ko-fi.com/citronlegacy) |
| 🛠️ **Citron Lora Workshop (Dataset & Training)** | [![GitHub](https://raw.githubusercontent.com/citronlegacy/kohya-colab/main/assets/github.svg)](https://github.com/citronlegacy/kohya-colab/blob/main/Citron_Lora_Workshop.ipynb) | [![Open in Colab](https://raw.githubusercontent.com/citronlegacy/kohya-colab/main/assets/colab-badge.svg)](https://colab.research.google.com/github/citronlegacy/kohya-colab/blob/main/Citron_Lora_Workshop.ipynb) | |🤖 **CivitAI** | [![CivitAI](https://img.shields.io/badge/CivitAI-Models-blue.svg)](https://civitai.com/user/CitronLegacy/models) |
| 💪 **Citron Lora Trainer** | [![GitHub](https://raw.githubusercontent.com/citronlegacy/kohya-colab/main/assets/github.svg)](https://github.com/citronlegacy/kohya-colab/blob/main/Citron_Lora_Trainer.ipynb) | [![Open in Colab](https://raw.githubusercontent.com/citronlegacy/kohya-colab/main/assets/colab-badge.svg)](https://colab.research.google.com/github/citronlegacy/kohya-colab/blob/main/Citron_Lora_Trainer.ipynb) | | 🎨 **Pixiv** | [![Pixiv](https://img.shields.io/badge/Pixiv-Profile-purple.svg)](https://www.pixiv.net/en/users/95364318) |
| ⭐ **Coming soon!! CALM - Citron Auto Lora Maker** |  |  | | 🎬 **Youtube**  | [![YouTube](https://img.shields.io/badge/YouTube-Subscribe-red.svg)](https://www.youtube.com/@FujiwaraNoMokou11) |
| 📊 **Citron Dataset Maker** | [![GitHub](https://raw.githubusercontent.com/citronlegacy/kohya-colab/main/assets/github.svg)](https://github.com/citronlegacy/kohya-colab/blob/main/Citron_Dataset_Maker.ipynb) | [![Open in Colab](https://raw.githubusercontent.com/citronlegacy/kohya-colab/main/assets/colab-badge.svg)](https://colab.research.google.com/github/citronlegacy/kohya-colab/blob/main/Citron_Dataset_Maker.ipynb) | | | |


---
### Project Disclaimer
This is forked from the work of [Hollowstrawberry 🍓](https://github.com/hollowstrawberry/kohya-colab) which is based on the work of [Kohya-ss](https://github.com/kohya-ss/sd-scripts) and [Linaqruf](https://colab.research.google.com/github/Linaqruf/kohya-trainer/blob/main/kohya-LoRA-dreambooth.ipynb). Thank you!

Please read and follow the [Google Colab guidelines](https://research.google.com/colaboratory/faq.html) and its [Terms of Service](https://research.google.com/colaboratory/tos_v3.html).

---

In [None]:
#@title # Install
#@markdown ### 1️⃣ Setup Connect to Google Drive and Install Dependences
#@markdown Installation usually takes about 3 minutes

#@markdown ------------------------------------------------------

#@markdown Update 02/03/2024 - First release of my Citron Lora Workshop

import os
import time
from pathlib import Path
from google.colab.output import clear as clear_output
from google.colab import drive
import re
import toml
import shutil
import zipfile
from IPython.display import Markdown, display
from datetime import datetime, timedelta

print("📂 Connecting to Google Drive...")
drive.mount('/content/drive')



root_dir = "/content"
deps_dir = os.path.join(root_dir, "deps")
repo_dir = os.path.join(root_dir, "kohya-trainer")
COLAB = True # low ram
COMMIT = "e6ad3cbc66130fdc3bf9ecd1e0272969b1d613f7"
BETTER_EPOCH_NAMES = True
LOAD_TRUNCATED_IMAGES = True

def count_images_in_folder(folder_path):
    # Ensure the folder path is a valid directory
    folder_path = Path(folder_path)
    if not folder_path.is_dir():
        raise ValueError(f"The provided path '{folder_path}' does not exist. If it does exist but Colab can't find it try reconnecting to Google Drive")

    # Count the number of image files in the folder
    image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp']
    image_count = sum(1 for file in folder_path.iterdir() if file.suffix.lower() in image_extensions)

    return image_count


def clone_repo():
  os.chdir(root_dir)
  !git clone https://github.com/kohya-ss/sd-scripts {repo_dir}
  os.chdir(repo_dir)
  if COMMIT:
    !git reset --hard {COMMIT}
  !wget https://raw.githubusercontent.com/hollowstrawberry/kohya-colab/main/requirements.txt -q -O requirements.txt

def install_dependencies():
  clone_repo()
  !apt -y update -qq
  !apt -y install aria2 -qq
  !pip install --upgrade -r requirements.txt

  # patch kohya for minor stuff
  if COLAB:
    !sed -i "s@cpu@cuda@" library/model_util.py # low ram
  if LOAD_TRUNCATED_IMAGES:
    !sed -i 's/from PIL import Image/from PIL import Image, ImageFile\nImageFile.LOAD_TRUNCATED_IMAGES=True/g' library/train_util.py # fix truncated jpegs error
  if BETTER_EPOCH_NAMES:
    !sed -i 's/{:06d}/{:02d}/g' library/train_util.py # make epoch names shorter
    !sed -i 's/"." + args.save_model_as)/"-{:02d}.".format(num_train_epochs) + args.save_model_as)/g' train_network.py # name of the last epoch will match the rest

  from accelerate.utils import write_basic_config
  accelerate_config_file = os.path.join(repo_dir, "accelerate_config/config.yaml")
  if not os.path.exists(accelerate_config_file):
    write_basic_config(save_location=accelerate_config_file)

  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
  os.environ["BITSANDBYTES_NOWELCOME"] = "1"
  os.environ["SAFETENSORS_FAST_GPU"] = "1"

  global dependencies_installed
  dependencies_installed = True

aiInstalls_start_time = time.perf_counter()
install_dependencies()
aiInstalls_end_time = time.perf_counter()


#######################################################
##### Start - Define Citron's Library Fuctions
#######################################################

def writeToFile(filename, text):
  ! echo {text} >> {filename}
  #! cat {filename}

def clearFile(filename):
   ! echo "" > {filename}

def writeLineToFile(filename):
  ! echo "==============================" >> {filename}

#######################################################
##### End - Define Citron's Library Fuctions
#######################################################

#Import colabUtilities
!git clone https://github.com/citronlegacy/kohya-colab.git
# CD into project
%cd kohya-colab
# import modules
import colabUtilities
#return to original directory
%cd ..

print("AI Installation time: " + str(colabUtilities.get_time_hh_mm_ss(aiInstalls_end_time-aiInstalls_start_time)) + " minutes")

In [None]:
#@title # Dataset Maker

import os
import time
from IPython import get_ipython
from IPython.display import display, Markdown
import json
from urllib.request import urlopen, Request

start_time = time.perf_counter()

COLAB = True

if COLAB:
  from google.colab.output import clear as clear_output
else:
  from IPython.display import clear_output

#defining variables at the beginning to fix a bug with the log file - lol such programming
remove_tags = "NA - Tagging Skipped"
topTags = "NA - Tagging Skipped"
total_steps = 0 #Defining this variable here so that its a global varaible


#@markdown 💡 Tip from [Citron Legacy](https://civitai.com/user/CitronLegacy/models): Dataset creation is the most important part of Lora training. Take your time and have fun collecting images of something you like.
#@markdown This project has a lot of tips but feel free to ignore them! You are the creator, don't let anything restrict your creativity! 🎉
#@markdown ### 1️⃣ Setup
download_images = True #@param {type:"boolean"}
tag_images = True #@param {type:"boolean"}
create_logs = True #@param {type:"boolean"}

#@markdown Your project name can't contain spaces
project_name = "Hatsune_Miku" #@param {type:"string"}
project_name = project_name.strip()
#@markdown Folder Structure is Organized by project name: MyDrive/lora_training/datasets/project_name
folder_structure = "Organize by category (MyDrive/lora_training/datasets/project_name)"

if not project_name or any(c in project_name for c in " .()\"'\\") or project_name.count("/") > 1:
  print("Please write a valid project_name.")

project_base = project_name if "/" not in project_name else project_name[:project_name.rfind("/")]
project_subfolder = project_name if "/" not in project_name else project_name[project_name.rfind("/")+1:]


main_dir      = os.path.join(root_dir, "drive/MyDrive/lora_training") if COLAB else root_dir
config_folder = os.path.join(main_dir, "config", project_name)
images_folder = os.path.join(main_dir, "datasets", project_name)

for dir in [main_dir, deps_dir, images_folder, config_folder]:
  os.makedirs(dir, exist_ok=True)

print(f"✅ Project {project_name} is ready!")


#######################################################
##### STEP - Image Downloading
#######################################################
print("#######################################################")
print("##### STEP - Image Downloading")
print("#######################################################")

if (not download_images):
  print("skipping image download")
  gelbooruSearchQuery = "NA step is skipped" #setting this value because this step is skipped
if (download_images):

  #@markdown ### 2️⃣ Scrape images from Gelbooru

  #@markdown We will grab images from the popular anime gallery [Gelbooru](https://gelbooru.com/). Images are sorted by tags, including poses, scenes, character traits, character names, artists, etc. <p>
  #@markdown * If you instead want to download screencaps of anime episodes, try [this other colab by another person](https://colab.research.google.com/drive/1oBSntB40BKzNmKceXUlkXzujzdQw-Ci7). It's more complicated though.

  #@markdown Up to 1000 images may be downloaded by this step in just one minute. Don't abuse it. <p>

  #@markdown Note: putting a minus sign in front of a tag will exclude images with that tag from the search result. For example, `-pink_hair` will mean you won't get images with the `pink_hair` tag

  #@markdown 💡 Tip from [Citron Legacy](https://civitai.com/user/CitronLegacy/models): If you want to be really detailed you can run download step several times with different search queries.
  #@markdown For example, get a few images with the `from_behind` or `from_side` tag so that the Lora can learn those angles/concepts.

  tags = "Hatsune_Miku" #@param {type:"string"}
  #@markdown You don't have to use this but sometimes it nice to define tags that never change in this seperate input field. For example I always want search results sorted by score and I never want images with certain tags
  extra_tags = "sort:score solo -animated -crying -1boy -monochrome -duskyer -rakko_(r2) -slave -injury -scared -futanari -geebomb -bokuman -1340smile -osg_pk -2girls -among_us -rope -bdsm -feet " #@param {type:"string"}

  #@markdown the tag `rating:general` is basically the SFW tag on Gelbooru. You can use these these checkboxes to either limit results to SFW images or exclude SFW images from results.

  #@markdown 💡 Tip from [Citron Legacy](https://civitai.com/user/CitronLegacy/models): I recommend having at least 75% SFW images if you want to train an outfit. Too many NSFW images may result in a lack of data about clothes.
  apply_tag_rating_general = False #@param {type:"boolean"}
  apply_tag_minus_rating_general = False #@param {type:"boolean"}
  if (apply_tag_rating_general):
    tags = "rating:general " + tags
  if (apply_tag_minus_rating_general):
     tags = "-rating:general " + tags

  tags = tags + " " + extra_tags
  gelbooruSearchQuery = tags
  gelbooruSearchQuery = gelbooruSearchQuery.replace("(", "\(").replace(")", "\)")

  ##@markdown If an image is bigger than this resolution a smaller version will be downloaded instead.
  max_resolution = 3072 #param {type:"slider", min:1024, max:8196, step:1024}
  ##@markdown Posts with a parent post are often minor variations of the same image.
  include_posts_with_parent = True #param {type:"boolean"}

  tags = tags.replace(" ", "+")\
            .replace("(", "%28")\
            .replace(")", "%29")\
            .replace(":", "%3a")\
            .replace("&", "%26")\

  url = "https://gelbooru.com/index.php?page=dapi&json=1&s=post&q=index&limit=100&tags={}".format(tags)
  user_agent = "Mozilla/5.0 AppleWebKit/537.36 (KHTML, like Gecko; compatible; Googlebot/2.1; +http://www.google.com/bot.html) Chrome/93.0.4577.83 Safari/537.36"
  limit = 100 # hardcoded by gelbooru
  #@markdown Enter maximum number of images to download from Gelbooru (There is a bug where it sometimes downloads 1 more/less than the number entered)

  #@markdown 💡 Tip from [Citron Legacy](https://civitai.com/user/CitronLegacy/models): You can make a decent Lora with 50 to 100 images. 300+ images is great but can take a really long time to train.
  maxNumberOfImages = 50 #@param {type:"number"}
  total_limit = maxNumberOfImages
  # Testing setting the limit of images at this point in the code. If this works then the line where the url is set above can be deleted
  #urlwithLimit = "https://gelbooru.com/index.php?page=dapi&json=1&s=post&q=index&limit="+str(maxNumberOfImages)+"&tags={}".format(tags)
  #url = urlwithLimit.format(tags)
  #if (maxNumberOfImages < 100):
  #  url = "https://gelbooru.com/index.php?page=dapi&json=1&s=post&q=index&limit="+str(maxNumberOfImages)+"&tags={}".format(tags)

  #url = "https://gelbooru.com/index.php?page=dapi&json=1&s=post&q=index&limit="+str(maxNumberOfImages)+"&tags={}".format(tags)

  print(url)

  supported_types = (".png", ".jpg", ".jpeg")

  """
  def ubuntu_deps():
    print("🏭 Installing dependencies...\n")
    !apt -y install aria2
    return not get_ipython().__dict__['user_ns']['_exit_code']

  if "step2_installed_flag" not in globals():
    if ubuntu_deps():
      #clear_output()
      step2_installed_flag = True
    else:
      print("❌ Error installing dependencies, attempting to continue anyway...")

  """

  def get_json(url):
    print("get_json url = " + url)
    with urlopen(Request(url, headers={"User-Agent": user_agent})) as page:
      return json.load(page)

  def filter_images(data):
    return [p["file_url"] if p["width"]*p["height"] <= max_resolution**2 else p["sample_url"]
            for p in data["post"]
            if (p["parent_id"] == 0 or include_posts_with_parent)
            and p["file_url"].lower().endswith(supported_types)]

  def download_images():
    count = 0
    if(maxNumberOfImages < 100):
      smallerThan100DownloadURL = "https://gelbooru.com/index.php?page=dapi&json=1&s=post&q=index&limit="+str(maxNumberOfImages+1)+"&tags={}".format(tags)
      data = get_json(smallerThan100DownloadURL)
      count = data["@attributes"]["count"]
    else:
      data = get_json(url)
      count = data["@attributes"]["count"]

    if count == 0:
      print("📷 No results found")
      return

    print(f"🎯 Found {count} results")
    test_url = "https://gelbooru.com/index.php?page=post&s=list&tags={}".format(tags)
    display(Markdown(f"[Click here to open in browser!]({test_url})"))
    print (f"🔽 Will download to {images_folder.replace('/content/drive/', '')} (A confirmation box should appear below, otherwise run this cell again)")
    inp = 'yes' #input("❓ Enter the word 'yes' if you want to proceed with the download: ")

    if inp.lower().strip() != 'yes':
      print("❌ Download cancelled")
      return

    print("📩 Grabbing image list...")

    image_urls = set()
    image_urls = image_urls.union(filter_images(data))
    for i in range(total_limit // limit):
      print("inside loop")
      numberOfImagesDownloadLinksInFile = len(image_urls)
      numberOfDownloadsRemaining = maxNumberOfImages - numberOfImagesDownloadLinksInFile
      #Debugging log - can be deleted
      print (f"i = {i}; image_urls = {len(image_urls)}; total_limit = {total_limit}; limit = {limit}; numberOfImagesDownloadLinksInFile = {numberOfImagesDownloadLinksInFile}; numberOfDownloadsRemaining = {numberOfDownloadsRemaining} " )
      count -= limit
      if count <= 0:
        break
      time.sleep(0.1)

      # Reformat URLs to ensure only the correct amount of images is downloaded
      #Added if-block to ensure that the last set of images to download doesnt go over the maxNumberOfImages
      #Determine how many download links are remaining
      #If there are less than the limit (hardcoded to 100) then only add the right amount of urls to the downloads list
      numberOfDownloadsRemaining = maxNumberOfImages - numberOfImagesDownloadLinksInFile
      if (numberOfDownloadsRemaining < 100):
        filterImagesResult = filter_images(get_json(url+f"&pid={i+1}"))
        limitedFilterImagesResult = filterImagesResult[0:numberOfDownloadsRemaining]
        image_urls = image_urls.union(limitedFilterImagesResult)
      else:
        image_urls = image_urls.union(filter_images(get_json(url+f"&pid={i+1}")))

    scrape_file = os.path.join(config_folder, f"scrape_{project_subfolder}.txt")
    with open(scrape_file, "w") as f:
      f.write("\n".join(image_urls))

    print(f"🌐 Saved links to {scrape_file}\n\n🔁 Downloading images...\n")
    old_img_count = len([f for f in os.listdir(images_folder) if f.lower().endswith(supported_types)])

    os.chdir(images_folder)
    !aria2c --console-log-level=warn -c -x 16 -k 1M -s 16 -i {scrape_file}

    new_img_count = len([f for f in os.listdir(images_folder) if f.lower().endswith(supported_types)])
    print(f"\n✅ Downloaded {new_img_count - old_img_count} images.")
    print(f"\n number of images in image_urls: {len(image_urls)} ")

  download_images()
  #clear_output()

#######################################################
##### STEP - Tagging
#######################################################
print("#######################################################")
print("##### STEP - Tagging")
print("#######################################################")

#@markdown ### 3️⃣ Tag your images
#@markdown We will be using AI to automatically tag your images, specifically [Waifu Diffusion](https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2) in the case of anime and [BLIP](https://huggingface.co/spaces/Salesforce/BLIP) in the case of photos.
#@markdown Giving tags/captions to your images allows for much better training. This process should take a couple minutes. <p>

#@markdown ❗ Important: you can choose to not enter anything in this section if you want to train your lora without a trigger

trigger = "Hatsune_Miku" #@param {type:"string"}
global_activation_tag = trigger.strip()
if (not tag_images):
  print("skipping image tagging")
if (tag_images):
  start_time_tagging = time.perf_counter()
  method = "Anime tags"
  #@markdown Abosrb tags that represent your Lora. This could be details like eye color or concepts like `glowing`

  #@markdown 💡 Tip from [Citron Legacy](https://civitai.com/user/CitronLegacy/models): If you aren't sure what tags to absorb run this to see what tags are most common. You can check the logs and the rerun this with more absorbed tags

  #@markdown 💡 Tip from [Citron Legacy](https://civitai.com/user/CitronLegacy/models): For Character Loras, I recommend absorbing `1girl` or `1boy`, `solo`, eye color, hair color, and hair length/style (`short_hair`, `long_hair`, `twintails`, etc)

  #@markdown 💡 Tip from [Citron Legacy](https://civitai.com/user/CitronLegacy/models): For Character Loras, I recommend not absorbing details about an outfit. If you absorb the outfit into the trigger the Lora will not be flexible enough to change the outfit

  absorbed_these_tags_into_trigger = "1girl, solo, " #@param {type:"string"}
  #@markdown Change the tag threshold if you are not getting enough tags
  tag_threshold = 0.4 #@param {type:"number"}
  #@markdown These tags will be ignored and thus not added to the captions
  blacklist_tags = "bangs, breasts, multicolored hair, two-tone hair, gradient hair, virtual youtuber, official alternate costume, official alternate hairstyle, official alternate hair length, alternate costume, alternate hairstyle, alternate hair length, alternate hair color" #@param {type:"string"}
  blacklist_tags_2 = ""

  numberOfTopTagsToAbsorb = 0

  extraAbsorbedTags = absorbed_these_tags_into_trigger
  caption_min = 15
  caption_max = 75

  %env PYTHONPATH=/env/python
  os.chdir(root_dir)
  kohya = "/content/kohya-trainer"
  if not os.path.exists(kohya):
    !git clone https://github.com/kohya-ss/sd-scripts {kohya}
    os.chdir(kohya)
    !git reset --hard 5050971ac687dca70ba0486a583d283e8ae324e2
    os.chdir(root_dir)

  if "tags" in method:
    """
    if "step4a_installed_flag" not in globals():
      print("\n🏭 Installing dependencies...\n")
      #!pip -q install tensorflow==2.12.0 huggingface-hub==0.12.0 accelerate==0.15.0 transformers==4.26.0 diffusers[torch]==0.10.2 einops==0.6.0 safetensors==0.2.6 torchvision albumentations
      !pip -q install -U tensorflow huggingface-hub==0.12.0 accelerate==0.15.0 transformers==4.26.0 diffusers[torch]==0.10.2 einops==0.6.0 safetensors==0.2.6 torchvision albumentations
      if not get_ipython().__dict__['user_ns']['_exit_code']:
        clear_output()
        step4a_installed_flag = True
      else:
        print("❌ Error installing dependencies, trying to continue anyway...")
    """
    print("\n🚶‍♂️ Launching program...\n")

    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    %env PYTHONPATH={kohya}
    !python {kohya}/finetune/tag_images_by_wd14_tagger.py \
      {images_folder} \
      --repo_id=SmilingWolf/wd-v1-4-swinv2-tagger-v2 \
      --model_dir={root_dir} \
      --thresh={tag_threshold} \
      --batch_size=8 \
      --caption_extension=.txt \
      --force_download

    if not get_ipython().__dict__['user_ns']['_exit_code']:
      print("removing underscores and blacklist...")
      blacklisted_tags = [t.strip() for t in blacklist_tags.split(",")]
      print("Processing 2nd Blacklist...")
      blacklist_tags_2 = [t.strip() for t in blacklist_tags_2.split(",")]
      combined_Blacklist = blacklisted_tags + blacklist_tags_2
      print("Processing extraAbsorbedTags...")
      extraAbsorbedTags = [t.strip() for t in extraAbsorbedTags.split(",")]
      from collections import Counter
      top_tags = Counter()
      for txt in [f for f in os.listdir(images_folder) if f.lower().endswith(".txt")]:
        with open(os.path.join(images_folder, txt), 'r') as f:
          tags = [t.strip() for t in f.read().split(",")]
          tags = [t.replace("_", " ") if len(t) > 3 else t for t in tags]
          tags = [t for t in tags if t not in blacklisted_tags]
        top_tags.update(tags)
        with open(os.path.join(images_folder, txt), 'w') as f:
          f.write(", ".join(tags))


      %env PYTHONPATH=/env/python
      #clear_output()
      ### Original tagging output message
      #print(f"📊 Tagging complete. Here are the top 50 tags in your dataset:")
      #print("\n".join(f"{k}," for k, v in top_tags.most_common(50)))

      ### Top 10 Tags Code
      outputTags = [k for k, v in top_tags.most_common(50) if k not in blacklist_tags_2]
      top10Tags = outputTags[:numberOfTopTagsToAbsorb]
      tagsToAbsorb = top10Tags + extraAbsorbedTags
      ### Debugging line
      #print(f"-----\ntop10Tags = {top10Tags}")
      print(f"-----\ntagsToAbsorb = {tagsToAbsorb}")
      remove_tags = (" ".join(f"{y}," for y in tagsToAbsorb))
      print(f"-----\nremove_tags = {remove_tags}\n")

      ### New Tagging to only show non-blacklist
      print(f"Currated tagging complete. Here are the top 50 tags after purging items from blacklist 2:")
      print("\n".join(f"{k}," for k, v in top_tags.most_common(50) if k not in blacklist_tags_2))



  else: # Photos
    if "step4b_installed_flag" not in globals():
      print("\n🏭 Installing dependencies...\n")
      #!pip -q install timm==0.6.12 fairscale==0.4.13 transformers==4.26.0 requests==2.28.2 accelerate==0.15.0 diffusers[torch]==0.10.2 einops==0.6.0 safetensors==0.2.6
      !pip -q install -U timm==0.6.12 fairscale==0.4.13 transformers==4.26.0 requests==2.28.2 accelerate==0.15.0 diffusers[torch]==0.10.2 einops==0.6.0 safetensors==0.2.6
      if not get_ipython().__dict__['user_ns']['_exit_code']:
        clear_output()
        step4b_installed_flag = True
      else:
        print("❌ Error installing dependencies, trying to continue anyway...")

    print("\n🚶‍♂️ Launching program...\n")

    os.chdir(kohya)
    %env PYTHONPATH={kohya}
    !python {kohya}/finetune/make_captions.py \
      {images_folder} \
      --beam_search \
      --max_data_loader_n_workers=2 \
      --batch_size=8 \
      --min_length={caption_min} \
      --max_length={caption_max} \
      --caption_extension=.txt

    if not get_ipython().__dict__['user_ns']['_exit_code']:
      import random
      captions = [f for f in os.listdir(images_folder) if f.lower().endswith(".txt")]
      sample = []
      for txt in random.sample(captions, min(10, len(captions))):
        with open(os.path.join(images_folder, txt), 'r') as f:
          sample.append(f.read())

      os.chdir(root_dir)
      %env PYTHONPATH=/env/python
      clear_output()
      print(f"📊 Captioning complete. Here are {len(sample)} example captions from your dataset:")
      print("".join(sample))

    end_time_tagging = time.perf_counter()
    time_total_tagging = end_time_tagging-start_time_tagging
    print(f"Tagging took {(time_total_tagging/60):0.1f} minutes ({time_total_tagging:0.1f} seconds)")

def print_a_line():
  print("========================================================================================================================")

def print_important_log (logMessage):
  print("========================================================================================================================")
  print(logMessage)
  print("========================================================================================================================")


#######################################################
##### STEP - Curate tags!
#######################################################
print("#######################################################")
print("##### STEP - Curate tags!")
print("#######################################################")

search_tags = ""
replace_with = ""
search_mode = "OR"
new_becomes_activation_tag = False
sort_alphabetically = False
remove_duplicates = False

def split_tags(tagstr):
  return [s.strip() for s in tagstr.split(",") if s.strip()]

activation_tag_list = split_tags(global_activation_tag)
remove_tags_list = split_tags(remove_tags)
search_tags_list = split_tags(search_tags)
replace_with_list = split_tags(replace_with)
replace_new_list = [t for t in replace_with_list if t not in search_tags_list]

replace_with_list = [t for t in replace_with_list if t not in replace_new_list]
replace_new_list.reverse()
activation_tag_list.reverse()

remove_count = 0
replace_count = 0

for txt in [f for f in os.listdir(images_folder) if f.lower().endswith(".txt")]:

  with open(os.path.join(images_folder, txt), 'r') as f:
    tags = [s.strip() for s in f.read().split(",")]

  if remove_duplicates:
    tags = list(set(tags))
  if sort_alphabetically:
    tags.sort()

  for rem in remove_tags_list:
    if rem in tags:
      remove_count += 1
      tags.remove(rem)

  if "AND" in search_mode and all(r in tags for r in search_tags_list) \
      or "OR" in search_mode and any(r in tags for r in search_tags_list):
    replace_count += 1
    for rem in search_tags_list:
      if rem in tags:
        tags.remove(rem)
    for add in replace_with_list:
      if add not in tags:
        tags.append(add)
    for new in replace_new_list:
      if new_becomes_activation_tag:
        if new in tags:
          tags.remove(new)
        tags.insert(0, new)
      else:
        if new not in tags:
          tags.append(new)

  for act in activation_tag_list:
    if act in tags:
      tags.remove(act)
    tags.insert(0, act)

  with open(os.path.join(images_folder, txt), 'w') as f:
    f.write(", ".join(tags))

if global_activation_tag:
  print(f"\n📎 Applied new activation tag(s): {', '.join(activation_tag_list)}")
if remove_tags:
  print(f"\n🚮 Removed {remove_count} tags.")
if search_tags:
  print(f"\n💫 Replaced in {replace_count} files.")


end_time = time.perf_counter()
time_total = end_time-start_time
print(f"\n✅ Done! Process took {(time_total/60):0.1f} minutes ({time_total:0.1f} seconds)")


#######################################################
##### Create log file - Tagging
#######################################################


#Create log file
directory = main_dir +"/log"
dateTimeFormatedForFilename = colabUtilities.getDateTimeFormatedForFilename()
logFileName = directory + "/" + project_name + "_" + dateTimeFormatedForFilename + ".log"



#Write all the top tags to file
if (tag_images):

  #Write the top 50 tags found in the images to file
  linuxSafeFormattedTop50Tags = colabUtilities.reformatToSafeString(str(" ".join(f"{logTags_i}," for logTags_i, v in top_tags.most_common(50))))
  topTags = linuxSafeFormattedTop50Tags

fileName = project_name + "_" + dateTimeFormatedForFilename + ".log"
trigger =  str(" ".join(f"{trigger_i}," for trigger_i in activation_tag_list))
removedTags = remove_tags
taggingTime = time_total


if (create_logs):
  colabUtilities.writeLogHeaderToFile(directory, fileName, project_name)
  colabUtilities.writeLogForTagging(directory,
                          fileName,
                          trigger,
                          gelbooruSearchQuery,
                          removedTags,
                          topTags,
                          taggingTime)





In [None]:
#@title # Lora Trainer

#@markdown ###Define your dataset location
#@markdown Your project name will be the same as the folder containing your images. Spaces aren't allowed.
project_name = "Hatsune_Miku" #@param {type:"string"}
#@markdown Is your dataset in Google Drive or Gogole Colab? (This hasn't been fully tested with datasets in Colab)
environment_location = "Google Drive (/content/drive/MyDrive/)" #@param ["Google Drive (/content/drive/MyDrive/)", "Google Colab (/content/)"]


folder_structure = "Organize by category (MyDrive/lora_training/datasets/project_name)"

#@markdown Decide the model that will be downloaded and used for training. You can also choose your own by pasting its download link.

#@markdown 💡 Tip from [Citron Legacy](https://civitai.com/user/CitronLegacy/models): Use AnyLora if you want to train an art style.
training_model = "Anime (animefull-final-pruned-fp16.safetensors)" #@param ["Anime (animefull-final-pruned-fp16.safetensors)", "AnyLora (AnyLoRA_noVae_fp16-pruned.ckpt)", "Stable Diffusion (sd-v1-5-pruned-noema-fp16.safetensors)"]
optional_custom_training_model_url = "" #@param {type:"string"}
custom_model_is_based_on_sd2 = False #@param {type:"boolean"}

if "Drive" in environment_location:
    environment_path = "/content/drive/MyDrive/"

elif "Colab" in environment_location:
    environment_path = ""

main_dir      = os.path.join(environment_path, "lora_training")
images_folder = os.path.join(main_dir, "datasets", project_name)
output_folder = os.path.join(main_dir, "output", project_name)
config_folder = os.path.join(main_dir, "config", project_name)
log_folder    = os.path.join(main_dir, "log")

assert count_images_in_folder(images_folder) > 0, f"Error: No images found in the specified folder: {images_folder}."

global total_steps
total_steps = ""

if optional_custom_training_model_url:
  model_url = optional_custom_training_model_url
elif "AnyLora" in training_model:
  model_url = "https://huggingface.co/Lykon/AnyLoRA/resolve/main/AnyLoRA_noVae_fp16-pruned.ckpt"
elif "Anime" in training_model:
  model_url = "https://huggingface.co/hollowstrawberry/stable-diffusion-guide/resolve/main/models/animefull-final-pruned-fp16.safetensors"
else:
  model_url = "https://huggingface.co/hollowstrawberry/stable-diffusion-guide/resolve/main/models/sd-v1-5-pruned-noema-fp16.safetensors"

if "AnyLora" in training_model:
  training_modelLogName = "AnyLoRA_noVae_fp16-pruned"
elif "Anime" in training_model:
  training_modelLogName = "Animefull-final-pruned-fp16"
elif "Stable Diffusion" in training_model:
  training_modelLogName = "sd-v1-5-pruned-noema-fp16"
else:
  training_modelLogName = "other"

# These carry information from past executions
if "model_url" in globals():
  old_model_url = model_url
else:
  old_model_url = None
if "dependencies_installed" not in globals():
  dependencies_installed = False
if "model_file" not in globals():
  model_file = None

# These may be set by other cells, some are legacy
if "custom_dataset" not in globals():
  custom_dataset = None
if "override_dataset_config_file" not in globals():
  override_dataset_config_file = None
if "override_config_file" not in globals():
  override_config_file = None
if "optimizer" not in globals():
  optimizer = "AdamW8bit"
if "optimizer_args" not in globals():
  optimizer_args = None
if "continue_from_lora" not in globals():
  continue_from_lora = ""
if "weighted_captions" not in globals():
  weighted_captions = False
if "adjust_tags" not in globals():
  adjust_tags = False
if "keep_tokens_weight" not in globals():
  keep_tokens_weight = 1.0

COLAB = True # low ram
XFORMERS = True
COMMIT = "9a67e0df390033a89f17e70df5131393692c2a55"
BETTER_EPOCH_NAMES = True
LOAD_TRUNCATED_IMAGES = True

#@markdown ### ▶️ Processing
#@markdown Resolution of 512 is standard for Stable Diffusion 1.5. Higher resolution training is much slower but can lead to better details. <p>
#@markdown Images will be automatically scaled while training to produce the best results, so you don't need to crop or resize anything yourself.
resolution = 512 #@param {type:"slider", min:512, max:1024, step:128}
#@markdown This option will train your images both normally and flipped, for no extra cost, to learn more from them. Turn it on specially if you have less than 20 images. <p>
#@markdown **Turn it off if you care about asymmetrical elements in your Lora**.
flip_aug = False #@param {type:"boolean"}
caption_extension = ".txt"
shuffle_tags = True
shuffle_caption = shuffle_tags
activation_tags = "1"
keep_tokens = int(activation_tags)

#@markdown ### ▶️ Steps - Your images will repeat this number of times during training.

#@markdown 💡 Tip from [Hollowstrawberry](https://github.com/hollowstrawberry/kohya-colab): I recommend that your images multiplied by their repeats is between 200 and 400.

#@markdown 💡 Tip from [Citron Legacy](https://civitai.com/user/CitronLegacy/models): Use 10 Repeats (or less if you have too many steps and training takes longer than 2 hours.)


num_repeats = 10 #@param {type:"number"}
#@markdown Choose how long you want to train for. A good starting point is around 10 epochs or around 2000 steps.<p>
#@markdown One epoch is a number of steps equal to: your number of images multiplied by their repeats, divided by batch size. <p>

#@markdown 💡 Tip from [Citron Legacy](https://civitai.com/user/CitronLegacy/models): Always use 10 Epochs until you know what you are doing.

preferred_unit = "Epochs" #@param ["Epochs", "Steps"]
how_many = 10 #@param {type:"number"}
max_train_epochs = how_many if preferred_unit == "Epochs" else None
max_train_steps = how_many if preferred_unit == "Steps" else None
#@markdown Saving more epochs will let you compare your Lora's progress better.
save_every_n_epochs = 5 #@param {type:"number"}
keep_only_last_n_epochs = 10 #@param {type:"number"}
if not save_every_n_epochs:
  save_every_n_epochs = max_train_epochs
if not keep_only_last_n_epochs:
  keep_only_last_n_epochs = max_train_epochs
#@markdown Increasing the batch size makes training faster, but may make learning worse. Recommended 2 or 3.
train_batch_size = 2 #@param {type:"slider", min:1, max:8, step:1}

unet_lr = 5e-4
text_encoder_lr = 1e-4
lr_scheduler = "cosine_with_restarts"
lr_scheduler_number = 3
lr_scheduler_num_cycles = lr_scheduler_number if lr_scheduler == "cosine_with_restarts" else 0
lr_scheduler_power = lr_scheduler_number if lr_scheduler == "polynomial" else 0
lr_warmup_ratio = 0.05
lr_warmup_steps = 0
min_snr_gamma = True
min_snr_gamma_value = 5.0 if min_snr_gamma else None

#@markdown Dim - More dim means larger Lora, it can hold more information but more isn't always better.
lora_type = "LoRA"
dim_to_use = "Dim=32; Alpha=16" #@param ["Dim=16; Alpha=8", "Dim=32; Alpha=16", "Dim=64; Alpha=32"]
dim_to_use_text = ""


network_dim = 32
network_alpha = 16

if (dim_to_use == "Dim=16; Alpha=8"):
  network_dim = 16
  network_alpha = 8
  dim_to_use_text = "Dim 16 and Alpha 8"
elif (dim_to_use == "Dim=32; Alpha=16"):
  network_dim = 32
  network_alpha = 16
  dim_to_use_text = "Dim 32 and Alpha 16"
elif (dim_to_use == "Dim=64; Alpha=32"):
  network_dim = 64
  network_alpha = 32
  dim_to_use_text = "Dim 64 and Alpha 32"

conv_dim = 8
conv_alpha = 4

network_module = "networks.lora"
network_args = None
if lora_type.lower() == "locon":
  network_args = [f"conv_dim={conv_dim}", f"conv_alpha={conv_alpha}"]



# 👩‍💻 Cool code goes here

if optimizer.lower() == "prodigy" or "dadapt" in optimizer.lower():
  if override_values_for_dadapt_and_prodigy:
    unet_lr = 0.5
    text_encoder_lr = 0.5
    lr_scheduler = "constant_with_warmup"
    lr_warmup_ratio = 0.05
    network_alpha = network_dim

  if not optimizer_args:
    optimizer_args = ["decouple=True","weight_decay=0.01","betas=[0.9,0.999]"]
    if optimizer == "Prodigy":
      optimizer_args.extend(["d_coef=2","use_bias_correction=True"])
      if lr_warmup_ratio > 0:
        optimizer_args.append("safeguard_warmup=True")
      else:
        optimizer_args.append("safeguard_warmup=False")

root_dir = "/content"
deps_dir = os.path.join(root_dir, "deps")
repo_dir = os.path.join(root_dir, "kohya-trainer")




config_file = os.path.join(config_folder, "training_config.toml")
dataset_config_file = os.path.join(config_folder, "dataset_config.toml")
accelerate_config_file = os.path.join(repo_dir, "accelerate_config/config.yaml")


def validate_dataset():
  global lr_warmup_steps, lr_warmup_ratio, caption_extension, keep_tokens, keep_tokens_weight, weighted_captions, adjust_tags
  supported_types = (".png", ".jpg", ".jpeg", ".webp", ".bmp")

  print("\n💿 Checking dataset...")
  if not project_name.strip() or any(c in project_name for c in " .()\"'\\/"):
    print("💥 Error: Please choose a valid project name.")
    return

  if custom_dataset:
    try:
      datconf = toml.loads(custom_dataset)
      datasets = [d for d in datconf["datasets"][0]["subsets"]]
    except:
      print(f"💥 Error: Your custom dataset is invalid or contains an error! Please check the original template.")
      return
    reg = [d.get("image_dir") for d in datasets if d.get("is_reg", False)]
    datasets_dict = {d["image_dir"]: d["num_repeats"] for d in datasets}
    folders = datasets_dict.keys()
    files = [f for folder in folders for f in os.listdir(folder)]
    images_repeats = {folder: (len([f for f in os.listdir(folder) if f.lower().endswith(supported_types)]), datasets_dict[folder]) for folder in folders}
  else:
    reg = []
    folders = [images_folder]
    files = os.listdir(images_folder)
    images_repeats = {images_folder: (len([f for f in files if f.lower().endswith(supported_types)]), num_repeats)}

  for folder in folders:
    if not os.path.exists(folder):
      print(f"💥 Error: The folder {folder.replace('/content/drive/', '')} doesn't exist.")
      return
  for folder, (img, rep) in images_repeats.items():
    if not img:
      print(f"💥 Error: Your {folder.replace('/content/drive/', '')} folder is empty.")
      return
  for f in files:
    if not f.lower().endswith(".txt") and not f.lower().endswith(supported_types):
      print(f"💥 Error: Invalid file in dataset: \"{f}\". Aborting.")
      return

  if not [txt for txt in files if txt.lower().endswith(".txt")]:
    caption_extension = ""
  if continue_from_lora and not (continue_from_lora.endswith(".safetensors") and os.path.exists(continue_from_lora)):
    print(f"💥 Error: Invalid path to existing Lora. Example: /content/drive/MyDrive/Loras/example.safetensors")
    return

  pre_steps_per_epoch = sum(img*rep for (img, rep) in images_repeats.values())
  steps_per_epoch = pre_steps_per_epoch/train_batch_size
  global total_steps
  total_steps = max_train_steps or int(max_train_epochs*steps_per_epoch)
  estimated_epochs = int(total_steps/steps_per_epoch)
  lr_warmup_steps = int(total_steps*lr_warmup_ratio)

  for folder, (img, rep) in images_repeats.items():
    print("📁"+folder.replace("/content/drive/", "") + (" (Regularization)" if folder in reg else ""))
    print(f"📈 Found {img} images with {rep} repeats, equaling {img*rep} steps.")
  print(f"📉 Divide {pre_steps_per_epoch} steps by {train_batch_size} batch size to get {steps_per_epoch} steps per epoch.")
  if max_train_epochs:
    print(f"🔮 There will be {max_train_epochs} epochs, for around {total_steps} total training steps.")
  else:
    print(f"🔮 There will be {total_steps} steps, divided into {estimated_epochs} epochs and then some.")

  if total_steps > 20000:
    print("💥 Error: Your total steps are too high. You probably made a mistake. Aborting...")
    return

  if adjust_tags:
    print(f"\n📎 Weighted tags: {'ON' if weighted_captions else 'OFF'}")
    if weighted_captions:
      print(f"📎 Will use {keep_tokens_weight} weight on {keep_tokens} activation tag(s)")
    print("📎 Adjusting tags...")
    adjust_weighted_tags(folders, keep_tokens, keep_tokens_weight, weighted_captions)

  return True

def adjust_weighted_tags(folders, keep_tokens: int, keep_tokens_weight: float, weighted_captions: bool):
  weighted_tag = re.compile(r"\((.+?):[.\d]+\)(,|$)")
  for folder in folders:
    for txt in [f for f in os.listdir(folder) if f.lower().endswith(".txt")]:
      with open(os.path.join(folder, txt), 'r') as f:
        content = f.read()
      # reset previous changes
      content = content.replace('\\', '')
      content = weighted_tag.sub(r'\1\2', content)
      if weighted_captions:
        # re-apply changes
        content = content.replace(r'(', r'\(').replace(r')', r'\)').replace(r':', r'\:')
        if keep_tokens_weight > 1:
          tags = [s.strip() for s in content.split(",")]
          for i in range(min(keep_tokens, len(tags))):
            tags[i] = f'({tags[i]}:{keep_tokens_weight})'
          content = ", ".join(tags)
      with open(os.path.join(folder, txt), 'w') as f:
        f.write(content)

def create_config():
  global dataset_config_file, config_file, model_file

  if override_config_file:
    config_file = override_config_file
    print(f"\n⭕ Using custom config file {config_file}")
  else:
    config_dict = {
      "additional_network_arguments": {
        "unet_lr": unet_lr,
        "text_encoder_lr": text_encoder_lr,
        "network_dim": network_dim,
        "network_alpha": network_alpha,
        "network_module": network_module,
        "network_args": network_args,
        "network_train_unet_only": True if text_encoder_lr == 0 else None,
        "network_weights": continue_from_lora if continue_from_lora else None
      },
      "optimizer_arguments": {
        "learning_rate": unet_lr,
        "lr_scheduler": lr_scheduler,
        "lr_scheduler_num_cycles": lr_scheduler_num_cycles if lr_scheduler == "cosine_with_restarts" else None,
        "lr_scheduler_power": lr_scheduler_power if lr_scheduler == "polynomial" else None,
        "lr_warmup_steps": lr_warmup_steps if lr_scheduler != "constant" else None,
        "optimizer_type": optimizer,
        "optimizer_args": optimizer_args if optimizer_args else None,
      },
      "training_arguments": {
        "max_train_steps": max_train_steps,
        "max_train_epochs": max_train_epochs,
        "save_every_n_epochs": save_every_n_epochs,
        "save_last_n_epochs": keep_only_last_n_epochs,
        "train_batch_size": train_batch_size,
        "noise_offset": None,
        "clip_skip": 2,
        "min_snr_gamma": min_snr_gamma_value,
        "weighted_captions": weighted_captions,
        "seed": 42,
        "max_token_length": 225,
        "xformers": XFORMERS,
        "lowram": COLAB,
        "max_data_loader_n_workers": 8,
        "persistent_data_loader_workers": True,
        "save_precision": "fp16",
        "mixed_precision": "fp16",
        "output_dir": output_folder,
        "logging_dir": log_folder,
        "output_name": project_name,
        "log_prefix": project_name,
      },
      "model_arguments": {
        "pretrained_model_name_or_path": model_file,
        "v2": custom_model_is_based_on_sd2,
        "v_parameterization": True if custom_model_is_based_on_sd2 else None,
      },
      "saving_arguments": {
        "save_model_as": "safetensors",
      },
      "dreambooth_arguments": {
        "prior_loss_weight": 1.0,
      },
      "dataset_arguments": {
        "cache_latents": True,
      },
    }

    for key in config_dict:
      if isinstance(config_dict[key], dict):
        config_dict[key] = {k: v for k, v in config_dict[key].items() if v is not None}

    with open(config_file, "w") as f:
      f.write(toml.dumps(config_dict))
    print(f"\n📄 Config saved to {config_file}")

  if override_dataset_config_file:
    dataset_config_file = override_dataset_config_file
    print(f"⭕ Using custom dataset config file {dataset_config_file}")
  else:
    dataset_config_dict = {
      "general": {
        "resolution": resolution,
        "shuffle_caption": shuffle_caption,
        "keep_tokens": keep_tokens,
        "flip_aug": flip_aug,
        "caption_extension": caption_extension,
        "enable_bucket": True,
        "bucket_reso_steps": 64,
        "bucket_no_upscale": False,
        "min_bucket_reso": 320 if resolution > 640 else 256,
        "max_bucket_reso": 1280 if resolution > 640 else 1024,
      },
      "datasets": toml.loads(custom_dataset)["datasets"] if custom_dataset else [
        {
          "subsets": [
            {
              "num_repeats": num_repeats,
              "image_dir": images_folder,
              "class_tokens": None if caption_extension else project_name
            }
          ]
        }
      ]
    }

    for key in dataset_config_dict:
      if isinstance(dataset_config_dict[key], dict):
        dataset_config_dict[key] = {k: v for k, v in dataset_config_dict[key].items() if v is not None}

    with open(dataset_config_file, "w") as f:
      f.write(toml.dumps(dataset_config_dict))
    print(f"📄 Dataset config saved to {dataset_config_file}")

def download_model():
  global old_model_url, model_url, model_file
  real_model_url = model_url.strip()

  if real_model_url.lower().endswith((".ckpt", ".safetensors")):
    model_file = f"/content{real_model_url[real_model_url.rfind('/'):]}"
  else:
    model_file = "/content/downloaded_model.safetensors"
    if os.path.exists(model_file):
      !rm "{model_file}"

  if m := re.search(r"(?:https?://)?(?:www\.)?huggingface\.co/[^/]+/[^/]+/blob", model_url):
    real_model_url = real_model_url.replace("blob", "resolve")
  elif m := re.search(r"(?:https?://)?(?:www\\.)?civitai\.com/models/([0-9]+)(/[A-Za-z0-9-_]+)?", model_url):
    if m.group(2):
      model_file = f"/content{m.group(2)}.safetensors"
    if m := re.search(r"modelVersionId=([0-9]+)", model_url):
      real_model_url = f"https://civitai.com/api/download/models/{m.group(1)}"
    else:
      raise ValueError("optional_custom_training_model_url contains a civitai link, but the link doesn't include a modelVersionId. You can also right click the download button to copy the direct download link.")

  !aria2c "{real_model_url}" --console-log-level=warn -c -s 16 -x 16 -k 10M -d / -o "{model_file}"

  if model_file.lower().endswith(".safetensors"):
    from safetensors.torch import load_file as load_safetensors
    try:
      test = load_safetensors(model_file)
      del test
    except Exception as e:
      #if "HeaderTooLarge" in str(e):
      new_model_file = os.path.splitext(model_file)[0]+".ckpt"
      !mv "{model_file}" "{new_model_file}"
      model_file = new_model_file
      print(f"Renamed model to {os.path.splitext(model_file)[0]}.ckpt")

  if model_file.lower().endswith(".ckpt"):
    from torch import load as load_ckpt
    try:
      test = load_ckpt(model_file)
      del test
    except Exception as e:
      return False

  return True

def main():
  global dependencies_installed

  if COLAB and not os.path.exists('/content/drive'):
    from google.colab import drive
    print("📂 Connecting to Google Drive...")
    drive.mount('/content/drive')

  for dir in (main_dir, deps_dir, repo_dir, log_folder, images_folder, output_folder, config_folder):
    os.makedirs(dir, exist_ok=True)

  if not validate_dataset():
    return

  if not dependencies_installed:
    print("\n🏭 Installing dependencies...\n")
    t0 = time()
    install_dependencies()
    t1 = time()
    dependencies_installed = True
    print(f"\n✅ Installation finished in {int(t1-t0)} seconds.")
  else:
    print("\n✅ Dependencies already installed.")

  if old_model_url != model_url or not model_file or not os.path.exists(model_file):
    print("\n🔄 Downloading model...")
    if not download_model():
      print("\n💥 Error: The model you selected is invalid or corrupted, or couldn't be downloaded. You can use a civitai or huggingface link, or any direct download link.")
      return
    print()
  else:
    print("\n🔄 Model already downloaded.\n")

  create_config()

  print("\n⭐ Starting trainer...\n")
  os.chdir(repo_dir)

  !accelerate launch --config_file={accelerate_config_file} --num_cpu_threads_per_process=1 train_network.py --dataset_config={dataset_config_file} --config_file={config_file}

  if not get_ipython().__dict__['user_ns']['_exit_code']:
    display(Markdown("[Download your Lora from Google Drive](https://drive.google.com/drive/my-drive)\n"
                     "There will be several files, you should try the latest version (the file with the largest number next to it)"))

start_time_make_lora = time.perf_counter()
main()

end_time_make_lora = time.perf_counter()
time_total_make_lora = end_time_make_lora - start_time_make_lora
time_total_timedelta = timedelta(seconds=time_total_make_lora)

# Extract hours, minutes, and seconds from the timedelta
hours, remainder = divmod(time_total_timedelta.seconds, 3600)
minutes, seconds = divmod(remainder, 60)

# Format the time
creation_time = f"{hours:02d}:{minutes:02d}:{seconds:02d}"

display(Markdown(f"### ✅ [{project_name}] is Done! Lora Creation Process took: {creation_time}"))

#Create Log file
def create_log_text(project_name, log_file_path, image_count, model_name, flip_aug, num_repeats, unit, num_epochs_or_steps, batch_size, total_steps, resolution, network_dim, creation_time):
    log_data = [
        f"Trained on: {image_count} images",
        f"Training Model: {model_name}",
        f"flip_aug: {flip_aug}",
        f"Num of Repeats: {num_repeats}",
        f"Unit is Epochs or Steps: {unit}",
        f"Number of Epochs or Steps: {num_epochs_or_steps}",
        f"Training Batch Size: {batch_size}",
        f"Total Steps: {total_steps}",
        f"Resolution: {resolution}",
        f"Network Dim: {network_dim}",
        f"Lora Creation Process took: {creation_time}",
    ]

    # Create a unique log filename based on the current timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_filename = f"{project_name}_{timestamp}.txt"
    log_path_and_name = os.path.join(log_file_path, log_filename)
    # Write the data to the text file
    with open(log_path_and_name, "w") as text_file:
        text_file.write("\n".join(log_data))

    return log_path_and_name


image_count = count_images_in_folder(images_folder)
unit = preferred_unit
num_epochs_or_steps = how_many


log_filename = create_log_text(project_name, log_folder, image_count, training_modelLogName, flip_aug, num_repeats, unit, num_epochs_or_steps, train_batch_size, total_steps, resolution, dim_to_use_text, creation_time)
print(f"Log file created: {log_filename}")




# Utilities

In [None]:
#@markdown ### Run this Cell to Disconnect from the Runtime
#@markdown This is useful if you have a long running process and you want to disconnect once its done. It helps you not waste your free GPU time limit.
from google.colab import runtime
runtime.unassign()

In [None]:
#@title Duplicate a folder
#@markdown Use this if you want to make multiple projects with the same training data. This is hardcoded to get datasets from Google Drive
main_dir      = os.path.join("/content", "drive/MyDrive/lora_training")

local_var_working_dir = os.path.join(main_dir, "datasets")
folder_to_duplicate = "" #@param {type:"string"}
duplicate_folder_name = "" #@param {type:"string"}

%ls
%cp -av {local_var_working_dir}/{folder_to_duplicate} {local_var_working_dir}/{duplicate_folder_name}


In [None]:
#@title Check if the folder exists
import os
#@markdown Use this to make sure your dataset is where it should be.
project_name = "" #@param {type:"string"}
project_name = project_name.strip()
root_dir = "/content"
main_dir      = os.path.join(root_dir, "drive/MyDrive/lora_training")
images_folder = os.path.join(main_dir, "datasets", project_name)
if(not os.path.exists(images_folder)):
  print("Error Folder does not exist")
else:
  print(f"Number of images in folder is {colabUtilities.countNumberOfImagesInFolder(images_folder)}")

