In [None]:
import os
import re
import toml
import shutil
import zipfile
from time import time

# Эти переменные сохраняют информацию из предыдущих выполнений.
# Если вы запускаете скрипт на сервере, вы можете инициализировать их значения здесь.
model_url = "YOUR_DEFAULT_MODEL_URL_HERE"  # Измените на вашу стандартную URL-ссылку модели, если есть.
dependencies_installed = False
model_file = None

# Эти переменные могут быть установлены другими частями кода.
custom_dataset = None
override_dataset_config_file = None
override_config_file = None
optimizer = "AdamW8bit"
optimizer_args = None
continue_from_lora = ""
weighted_captions = False
adjust_tags = False
keep_tokens_weight = 1.0

# Параметры настройки
project_name = input("Введите имя вашего проекта (не должно содержать пробелов): ")

print("""
1. Организовать по категории (MyDrive/lora_training/datasets/project_name)
2. Организовать по проекту (/home/popkov-mi/Loras/project_name/dataset)
""")
folder_structure_choice = int(input("Выберите структуру папок (1/2): "))

if folder_structure_choice == 1:
    folder_structure = "Organize by category (MyDrive/lora_training/datasets/project_name)"
else:
    folder_structure = "Organize by project (/home/popkov-mi/Loras/project_name/dataset)"

print("""
1. Anime (animefull-final-pruned-fp16.safetensors)
2. AnyLora (AnyLoRA_noVae_fp16-pruned.ckpt)
3. Stable Diffusion (sd-v1-5-pruned-noema-fp16.safetensors)
4. Указать свою
""")
training_model_choice = int(input("Выберите модель для обучения (1/2/3/4): "))

if training_model_choice == 4:
    model_url = input("Введите URL-ссылку на вашу модель: ")
elif training_model_choice == 2:
    model_url = "https://huggingface.co/Lykon/AnyLoRA/resolve/main/AnyLoRA_noVae_fp16-pruned.ckpt"
elif training_model_choice == 1:
    model_url = "https://huggingface.co/hollowstrawberry/stable-diffusion-guide/resolve/main/models/animefull-final-pruned-fp16.safetensors"
else:
    model_url = "https://huggingface.co/hollowstrawberry/stable-diffusion-guide/resolve/main/models/sd-v1-5-pruned-noema-fp16.safetensors"

# Параметры обработки
resolution = int(input("Введите разрешение (512, 640, 768, 896, 1024): "))
flip_aug = input("Обучать изображения в нормальном и перевернутом виде? (y/n): ").lower() == 'y'
caption_extension = input("Введите расширение для подписей (пустое значение для отсутствия подписей): ")
shuffle_tags = input("Перемешать теги аниме? (y/n): ").lower() == 'y'
activation_tags = input("Введите активационные теги (0/1/2/3): ")
keep_tokens = int(activation_tags)

# Теперь вы можете продолжить выполнение своего кода здесь...


In [None]:
# Параметры для этапа тренировки
num_repeats = 10
preferred_unit = "Epochs"
how_many = 10

max_train_epochs = how_many if preferred_unit == "Epochs" else None
max_train_steps = how_many if preferred_unit == "Steps" else None
save_every_n_epochs = 1
keep_only_last_n_epochs = 10
train_batch_size = 2

# Параметры для обучения
unet_lr = 5e-4
text_encoder_lr = 1e-4
lr_scheduler = "cosine_with_restarts"
lr_scheduler_number = 3
lr_scheduler_num_cycles = lr_scheduler_number if lr_scheduler == "cosine_with_restarts" else 0
lr_scheduler_power = lr_scheduler_number if lr_scheduler == "polynomial" else 0
lr_warmup_ratio = 0.05
lr_warmup_steps = 0
min_snr_gamma = True
min_snr_gamma_value = 5.0 if min_snr_gamma else None

# Параметры структуры LoRA
lora_type = "LoRA"
network_dim = 16
network_alpha = 8
conv_dim = 8
conv_alpha = 4

network_module = "networks.lora"
network_args = None
if lora_type.lower() == "locon":
    network_args = [f"conv_dim={conv_dim}", f"conv_alpha={conv_alpha}"]

# TODO: Здесь добавьте код для дальнейших действий (например, инициализация модели, загрузка данных, запуск тренировки и т.д.)



In [None]:
import os
import subprocess

# Настройка оптимизатора
optimizer_args = None
if optimizer.lower() == "prodigy" or "dadapt" in optimizer.lower():
    if override_values_for_dadapt_and_prodigy:
        unet_lr = 0.5
        text_encoder_lr = 0.5
        lr_scheduler = "constant_with_warmup"
        lr_warmup_ratio = 0.05
        network_alpha = network_dim

    if not optimizer_args:
        optimizer_args = ["decouple=True","weight_decay=0.01","betas=[0.9,0.999]"]
        if optimizer == "Prodigy":
            optimizer_args.extend(["d_coef=2","use_bias_correction=True"])
            if lr_warmup_ratio > 0:
                optimizer_args.append("safeguard_warmup=True")
            else:
                optimizer_args.append("safeguard_warmup=False")

root_dir = "/home/popkov-mi/Loras"
deps_dir = os.path.join(root_dir, "deps")
repo_dir = os.path.join(root_dir, "kohya-trainer")

# Переписываем структуру директорий под ваш сервер
main_dir = root_dir
log_folder = os.path.join(main_dir, "_logs")
config_folder = os.path.join(main_dir, project_name)
images_folder = os.path.join(main_dir, project_name, "dataset")
output_folder = os.path.join(main_dir, project_name, "output")

config_file = os.path.join(config_folder, "training_config.toml")
dataset_config_file = os.path.join(config_folder, "dataset_config.toml")
accelerate_config_file = os.path.join(repo_dir, "accelerate_config/config.yaml")

def run_command(command):
    """Выполнить команду в операционной системе."""
    process = subprocess.Popen(command, shell=True)
    process.communicate()

def clone_repo():
    os.chdir(root_dir)
    run_command(f"git clone https://github.com/kohya-ss/sd-scripts {repo_dir}")
    os.chdir(repo_dir)
    if COMMIT:
        run_command(f"git reset --hard {COMMIT}")
    run_command("wget https://raw.githubusercontent.com/hollowstrawberry/kohya-colab/xformers-fix/requirements.txt -q -O requirements.txt")

def install_dependencies():
    clone_repo()
    run_command("apt -y update")
    run_command("apt -y install aria2")
    run_command("pip install --upgrade -r requirements.txt")
    if XFORMERS:
        run_command("pip install xformers==0.0.22.post4")

    # Настройка kohya
    if LOAD_TRUNCATED_IMAGES:
        run_command('sed -i \'s/from PIL import Image/from PIL import Image, ImageFile\\nImageFile.LOAD_TRUNCATED_IMAGES=True/g\' library/train_util.py')
    if BETTER_EPOCH_NAMES:
        run_command('sed -i \'s/{:06d}/{:02d}/g\' library/train_util.py')
        run_command('sed -i \'s/"." + args.save_model_as)/"-{:02d}.".format(num_train_epochs) + args.save_model_as)/g\' train_network.py')

    from accelerate.utils import write_basic_config
    if not os.path.exists(accelerate_config_file):
        write_basic_config(save_location=accelerate_config_file)

    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["BITSANDBYTES_NOWELCOME"] = "1"
    os.environ["SAFETENSORS_FAST_GPU"] = "1"



In [None]:
def validate_dataset():
    global lr_warmup_steps, lr_warmup_ratio, caption_extension, keep_tokens, keep_tokens_weight, weighted_captions, adjust_tags
    supported_types = (".png", ".jpg", ".jpeg", ".webp", ".bmp")

    print("\n💿 Checking dataset...")
    if not project_name.strip() or any(c in project_name for c in " .()\"'\\/"):
        print("💥 Error: Please choose a valid project name.")
        return

    if custom_dataset:
        try:
            datconf = toml.loads(custom_dataset)
            datasets = [d for d in datconf["datasets"][0]["subsets"]]
        except:
            print(f"💥 Error: Your custom dataset is invalid or contains an error! Please check the original template.")
            return
        reg = [d.get("image_dir") for d in datasets if d.get("is_reg", False)]
        datasets_dict = {d["image_dir"]: d["num_repeats"] for d in datasets}
        folders = datasets_dict.keys()
        files = [f for folder in folders for f in os.listdir(folder)]
        images_repeats = {folder: (len([f for f in os.listdir(folder) if f.lower().endswith(supported_types)]), datasets_dict[folder]) for folder in folders}
    else:
        reg = []
        folders = [images_folder]
        files = os.listdir(images_folder)
        images_repeats = {images_folder: (len([f for f in files if f.lower().endswith(supported_types)]), num_repeats)}

    for folder in folders:
        if not os.path.exists(folder):
            print(f"💥 Error: The folder {folder} doesn't exist.")
            return
    for folder, (img, rep) in images_repeats.items():
        if not img:
            print(f"💥 Error: Your {folder} folder is empty.")
            return
    for f in files:
        if not f.lower().endswith(".txt") and not f.lower().endswith(supported_types):
            print(f"💥 Error: Invalid file in dataset: \"{f}\". Aborting.")
            return

    if not [txt for txt in files if txt.lower().endswith(".txt")]:
        caption_extension = ""
    if continue_from_lora and not (continue_from_lora.endswith(".safetensors") and os.path.exists(continue_from_lora)):
        print(f"💥 Error: Invalid path to existing Lora. Example: /content/drive/MyDrive/Loras/example.safetensors")
        return

    pre_steps_per_epoch = sum(img*rep for (img, rep) in images_repeats.values())
    steps_per_epoch = pre_steps_per_epoch/train_batch_size
    total_steps = max_train_steps or int(max_train_epochs*steps_per_epoch)
    estimated_epochs = int(total_steps/steps_per_epoch)
    lr_warmup_steps = int(total_steps*lr_warmup_ratio)

    for folder, (img, rep) in images_repeats.items():
        print("📁"+folder.replace("/content/drive/", "") + (" (Regularization)" if folder in reg else ""))
        print(f"📈 Found {img} images with {rep} repeats, equaling {img*rep} steps.")

    print(f"📉 Divide {pre_steps_per_epoch} steps by {train_batch_size} batch size to get {steps_per_epoch} steps per epoch.")
    if max_train_epochs:
        print(f"🔮 There will be {max_train_epochs} epochs, for around {total_steps} total training steps.")
    else:
        print(f"🔮 There will be {total_steps} steps, divided into {estimated_epochs} epochs and then some.")

    if total_steps > 10000:
        print("💥 Error: Your total steps are too high. You probably made a mistake. Aborting...")
        return

    if adjust_tags:
        print(f"\n📎 Weighted tags: {'ON' if weighted_captions else 'OFF'}")
        if weighted_captions:
            print(f"📎 Will use {keep_tokens_weight} weight on {keep_tokens} activation tag(s)")
        print("📎 Adjusting tags...")
        adjust_weighted_tags(folders, keep_tokens, keep_tokens_weight, weighted_captions)

    return True


def adjust_weighted_tags(folders, keep_tokens: int, keep_tokens_weight: float, weighted_captions: bool):
    weighted_tag = re.compile(r"\((.+?):[.\d]+\)(,|$)")
    for folder in folders:
        for txt in [f for f in os.listdir(folder) if f.lower().endswith(".txt")]:
            with open(os.path.join(folder, txt), 'r') as f:
                content = f.read()
                # reset previous changes
                content = content.replace('\\', '')
                content = weighted_tag.sub(r'\1\2', content)
                if weighted_captions:
                # re-apply changes
                    content = content.replace(r'(', r'\(').replace(r')', r'\)').replace(r':', r'\:')
            if keep_tokens_weight > 1:
                tags = [s.strip() for s in content.split(",")]
                for i in range(min(keep_tokens, len(tags))):
                    tags[i] = f'({tags[i]}:{keep_tokens_weight})'
                    content = ", ".join(tags)
            with open(os.path.join(folder, txt), 'w') as f:
                f.write(content)


In [None]:
import toml

def create_config():
    global dataset_config_file, config_file, model_file

    if override_config_file:
        config_file = override_config_file
        print(f"\n⭕ Using custom config file {config_file}")
    else:
        config_dict = {
            "additional_network_arguments": {
                "unet_lr": unet_lr,
                "text_encoder_lr": text_encoder_lr,
                "network_dim": network_dim,
                "network_alpha": network_alpha,
                "network_module": network_module,
                "network_args": network_args,
                "network_train_unet_only": True if text_encoder_lr == 0 else None,
                "network_weights": continue_from_lora if continue_from_lora else None
            },
            "optimizer_arguments": {
                "learning_rate": unet_lr,
                "lr_scheduler": lr_scheduler,
                "lr_scheduler_num_cycles": lr_scheduler_num_cycles if lr_scheduler == "cosine_with_restarts" else None,
                "lr_scheduler_power": lr_scheduler_power if lr_scheduler == "polynomial" else None,
                "lr_warmup_steps": lr_warmup_steps if lr_scheduler != "constant" else None,
                "optimizer_type": optimizer,
                "optimizer_args": optimizer_args if optimizer_args else None,
            },
            "training_arguments": {
                "max_train_steps": max_train_steps,
                "max_train_epochs": max_train_epochs,
                "save_every_n_epochs": save_every_n_epochs,
                "save_last_n_epochs": keep_only_last_n_epochs,
                "train_batch_size": train_batch_size,
                "clip_skip": 2,
                "min_snr_gamma": min_snr_gamma_value,
                "weighted_captions": weighted_captions,
                "seed": 42,
                "max_token_length": 225,
                "xformers": XFORMERS,
                "max_data_loader_n_workers": 8,
                "persistent_data_loader_workers": True,
                "save_precision": "fp16",
                "mixed_precision": "fp16",
                "output_dir": output_folder,
                "logging_dir": log_folder,
                "output_name": project_name,
                "log_prefix": project_name,
            },
            "model_arguments": {
                "pretrained_model_name_or_path": model_file,
                "v2": custom_model_is_based_on_sd2,
                "v_parameterization": True if custom_model_is_based_on_sd2 else None,
            },
            "saving_arguments": {
                "save_model_as": "safetensors",
            },
            "dreambooth_arguments": {
                "prior_loss_weight": 1.0,
            },
            "dataset_arguments": {
                "cache_latents": True,
            },
        }

        for key in config_dict:
            if isinstance(config_dict[key], dict):
                config_dict[key] = {k: v for k, v in config_dict[key].items() if v is not None}

        with open(config_file, "w") as f:
            f.write(toml.dumps(config_dict))
        print(f"\n📄 Config saved to {config_file}")

    if override_dataset_config_file:
        dataset_config_file = override_dataset_config_file
        print(f"⭕ Using custom dataset config file {dataset_config_file}")
    else:
        dataset_config_dict = {
            "general": {
                "resolution": resolution,
                "shuffle_caption": shuffle_caption,
                "keep_tokens": keep_tokens,
                "flip_aug": flip_aug,
                "caption_extension": caption_extension,
                "enable_bucket": True,
                "bucket_reso_steps": 64,
                "bucket_no_upscale": False,
                "min_bucket_reso": 320 if resolution > 640 else 256,
                "max_bucket_reso": 1280 if resolution > 640 else 1024,
            },
            "datasets": toml.loads(custom_dataset)["datasets"] if custom_dataset else [
                {
                    "subsets": [
                        {
                            "num_repeats": num_repeats,
                            "image_dir": images_folder,
                            "class_tokens": None if caption_extension else project_name
                        }
                    ]
                }
            ]
        }

        for key in dataset_config_dict:
            if isinstance(dataset_config_dict[key], dict):
                dataset_config_dict[key] = {k: v for k, v in dataset_config_dict[key].items() if v is not None}

        with open(dataset_config_file, "w") as f:
            f.write(toml.dumps(dataset_config_dict))
        print(f"📄 Dataset config saved to {dataset_config_file}")


In [None]:
import subprocess
import os
from time import time

def download_model():
    global old_model_url, model_url, model_file
    real_model_url = model_url.strip()

    if real_model_url.lower().endswith((".ckpt", ".safetensors")):
        model_file = f"/home/popkov-mi/Loras{real_model_url[real_model_url.rfind('/'):]}"
    else:
        model_file = "/home/popkov-mi/Loras/downloaded_model.safetensors"
        if os.path.exists(model_file):
            subprocess.run(["rm", model_file])

    if m := re.search(r"(?:https?://)?(?:www\.)?huggingface\.co/[^/]+/[^/]+/blob", model_url):
        real_model_url = real_model_url.replace("blob", "resolve")
    elif m := re.search(r"(?:https?://)?(?:www\.)?civitai\.com/models/([0-9]+)", model_url):
        real_model_url = f"https://civitai.com/api/download/models/{m.group(1)}"

    # !aria2c "{real_model_url}" --console-log-level=warn -c -s 16 -x 16 -k 10M -d / -o "{model_file}"

    if model_file.lower().endswith(".safetensors"):
        from safetensors.torch import load_file as load_safetensors
        try:
            test = load_safetensors(model_file)
            del test
        except Exception as e:
        #if "HeaderTooLarge" in str(e):
        new_model_file = os.path.splitext(model_file)[0]+".ckpt"
        # !mv "{model_file}" "{new_model_file}"
        model_file = new_model_file
        print(f"Renamed model to {os.path.splitext(model_file)[0]}.ckpt")

    if model_file.lower().endswith(".ckpt"):
        from torch import load as load_ckpt
        try:
        test = load_ckpt(model_file)
        del test
        except Exception as e:
        return False

    return True



In [None]:
def main():
    global dependencies_installed

    for dir in (main_dir, deps_dir, repo_dir, log_folder, images_folder, output_folder, config_folder):
        os.makedirs(dir, exist_ok=True)

    if not validate_dataset():
        return

    if not dependencies_installed:
        print("\n🏭 Installing dependencies...\n")
        t0 = time()
        install_dependencies()
        t1 = time()
        dependencies_installed = True
        print(f"\n✅ Installation finished in {int(t1-t0)} seconds.")
    else:
        print("\n✅ Dependencies already installed.")

    if old_model_url != model_url or not model_file or not os.path.exists(model_file):
        print("\n🔄 Downloading model...")
        if not download_model():
            print("\n💥 Error: The model you selected is invalid or corrupted, or couldn't be downloaded. You can use a civitai or huggingface link, or any direct download link.")
        return
        print()
    else:
        print("\n🔄 Model already downloaded.\n")

    create_config()

    print("\n⭐ Starting trainer...\n")
    os.chdir(repo_dir)
    subprocess.run(["accelerate", "launch", "--config_file={}".format(accelerate_config_file), "--num_cpu_threads_per_process=1", "train_network.py", "--dataset_config={}".format(dataset_config_file), "--config_file={}".format(config_file)])

    # Заменим вывод в IPython на обычный print
    print("✅ Done! Check your Lora files in /home/popkov-mi/Loras")

main()