In [None]:
import os, time, requests, copy, json, glob, toml, random
from subprocess import getoutput
from PIL import Image

!pip uninstall flax -y

def read_file(filename):
    with open(filename, "r") as f:
        contents = f.read()
    return contents

def write_file(filename, contents):
    with open(filename, "w") as f:
        f.write(contents)

def install_dependencies():
    s = getoutput('nvidia-smi')

    if 'T4' in s:
        !sed -i "s@cpu@cuda@" library/model_util.py

    !pip install -q -r requirements.txt

    from accelerate.utils import write_basic_config

    if not os.path.exists(accelerate_config):
        write_basic_config(save_location=accelerate_config)

def install_sd15train():
    %cd {root_dir}
    !pip install -q torch==2.1.0
    !apt install aria2 -qq
    for dir in [training_dir,config_dir,pretrained_model]:
        os.makedirs(dir, exist_ok=True)

    %cd {repo_dir}
    install_dependencies()
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["BITSANDBYTES_NOWELCOME"] = "1"
    os.environ["SAFETENSORS_FAST_GPU"] = "1"
    cuda_path = "/usr/local/cuda-11.8/targets/x86_64-linux/lib/"
    ld_library_path = os.environ.get("LD_LIBRARY_PATH", "")
    os.environ["LD_LIBRARY_PATH"] = f"{ld_library_path}:{cuda_path}"

#Caption

def clean_directory(directory):
  supported_types = [".png",".jpg",".jpeg",".webp",".bmp",".JPG",".PNG",".JPEG"]
  for item in os.listdir(directory):
      file_path = os.path.join(directory, item)
      if os.path.isfile(file_path):
          file_ext = os.path.splitext(item)[1]
          if file_ext not in supported_types:
              print(f"Deleting file {item} from {directory}")
              os.remove(file_path)
      elif os.path.isdir(file_path):
          clean_directory(file_path)

def join_arg(config):
  args = ""
  for k, v in config.items():
      if k.startswith("_"):
          args += f'"{v}" '
      elif isinstance(v, str):
          args += f'--{k}="{v}" '
      elif isinstance(v, bool) and v:
          args += f"--{k} "
      elif isinstance(v, float) and not isinstance(v, bool):
          args += f"--{k}={v} "
      elif isinstance(v, int) and not isinstance(v, bool):
          args += f"--{k}={v} "
  return args

#Download

def aria_down(link,path,name):
  !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M {link} -d  {path} -o {name}
def aria_down_over(link,path,name):
  !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M --allow-overwrite=true {link} -d  {path} -o {name}

def download_lib(model,modellist,pretrained_model):
  if 'https:' in model:
    model = model.replace('&', '\&')
    aria_down_over(model,pretrained_model,"model.safetensors")
    model_path = f"{pretrained_model}/model.safetensors"
  elif '/content/' in model:
    model_path = model
  else:
    if not any(ext in model for ext in ['.ckpt', '.gguf', '.safetensors']):
        model += '.safetensors'
    if model not in  modellist:
      model = "RealisticVision51.safetensors"
    aria_down(modellist[model],pretrained_model,model)
    model_path = f"{pretrained_model}/{model}"
  return model_path

#Dataset

def check_dir(image_dir):
  if not any([filename.endswith(".txt") for filename in os.listdir(image_dir)]):
      for filename in os.listdir(image_dir):
          if filename.endswith((".png",".jpg",".jpeg",".webp",".bmp",".JPG",".PNG",".JPEG")):
              open(os.path.join(image_dir, filename.split(".")[0] + ".txt"),"w",).close()
              
def process_tags(filename, custom_tag, append, remove_tag):
    contents = read_file(filename)
    if remove_tag:
      contents = contents.replace(custom_tag, "")
    else:
      tags = [tag.strip() for tag in contents.split(',')]
      custom_tags = [tag.strip() for tag in custom_tag.split(',')]
      for custom_tag in custom_tags:
          custom_tag = custom_tag.replace("_", " ")
          if custom_tag not in tags:
              if append:
                  tags.append(custom_tag)
              else:
                  tags.insert(0, custom_tag)
      contents = ', '.join(tags)
    write_file(filename, contents)

def process_dir(image_dir, tag, append, remove_tag):
  check_dir(image_dir)
  for filename in os.listdir(image_dir):
      file_path = os.path.join(image_dir, filename)
      if os.path.isdir(file_path) :
          process_dir(file_path, tag, append, remove_tag)
      elif filename.endswith(".txt"):
          process_tags(file_path, tag, append, remove_tag)

def add_forder_name(folder):
  for filename in os.listdir(folder):
    file_path = os.path.join(folder, filename)
    if os.path.isdir(file_path):
      folder_name = os.path.basename(file_path)
      try:
          repeats, keywork = folder_name.split('_', 1)
          repeats = int(repeats)
      except ValueError:
          keywork = folder_name
      process_dir(file_path, keywork, False, False)
      add_forder_name(file_path)

def get_num_repeats(folder):
    folder_name = os.path.basename(folder)
    try:
        repeats, keywork = folder_name.split('_', 1)
        num_repeats = int(repeats)
    except ValueError:
        num_repeats = dataset_repeats
        keywork = folder_name
    return num_repeats, keywork

def get_supported_images(folder):
    supported_extensions = (".png",".jpg",".jpeg",".webp",".bmp",".JPG",".PNG",".JPEG")
    return [file for ext in supported_extensions for file in glob.glob(f"{folder}/*{ext}")]

def get_subfolders(folder):
    subfolders = [os.path.join(folder, subfolder) for subfolder in os.listdir(folder) if os.path.isdir(os.path.join(folder, subfolder))]
    if len(subfolders) > 0:
      for subfolder in subfolders:
        subfolders += get_subfolders(subfolder)
    return subfolders

def get_subfolders_with_supported_images(folder):
    subfolders = get_subfolders(folder)
    subfolders.append(folder)
    return [subfolder for subfolder in subfolders if len(get_supported_images(subfolder)) > 0]

def get_subsets(train_data_dir,reg_data_dir=""):
    subsets = []
    train_subfolders = get_subfolders_with_supported_images(train_data_dir)
    for subfolder in train_subfolders:
        num_repeats = get_num_repeats(subfolder)[0]
        subsets.append({
            "image_dir": subfolder,
            "class_tokens": get_num_repeats(subfolder)[1],
            "num_repeats": get_num_repeats(subfolder)[0],
        })
    if reg_data_dir != "":
        reg_subfolders = get_subfolders_with_supported_images(reg_data_dir)
        for subfolder in reg_subfolders:
            subsets.append({
                "is_reg": True,
                "image_dir": subfolder,
                "class_tokens": activation_word,
                "num_repeats": 1,
            })
    return subsets

#Config

def final_config(json_path,toml_path):
    with open(json_path, 'r') as file:
      config = json.load(file)
    for key in config:
        for sub_key, value in config[key].items():
            if sub_key not in globals():
                globals()[sub_key] = value

    final_config = {}
    for key in config:
        final_config[key] = {}
        for sub_key, value in config[key].items():
            final_config[key][sub_key] = globals()[sub_key]

    for key in final_config:
        if isinstance(final_config[key], dict):
            for sub_key in final_config[key]:
                if final_config[key][sub_key] == "":
                    final_config[key][sub_key] = None
        elif final_config[key] == "":
            final_config[key] = None
    config_str = toml.dumps(final_config)
    write_file(toml_path, config_str)