In [None]:
import os
import toml
from google.colab import drive
from accelerate.utils import write_basic_config

### Гиперпараметры:

In [None]:
# Setup
project_name = "loha_v2"
model_url = "https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-nonema-pruned.safetensors"
is_based_on_sd2 = True

# Processing
resolution = 768
flip_aug = True

# Steps
num_repeats = 5
max_train_steps = 5000
save_every_n_epochs = 5
keep_only_last_n_epochs = 10
train_batch_size = 2
unet_lr = 2e-4
text_encoder_lr = 1e-6
lr_scheduler = "cosine_with_restarts"
lr_scheduler_num_cycles = 3
min_snr_gamma = True
min_snr_gamma_value = 5.0 if min_snr_gamma else None

# Structure
lora_type = "LoHa Lycoris"
network_dim = 8
network_alpha = 4
conv_dim = 8
conv_alpha = 1
conv_compression = False
network_module = "lycoris.kohya" if "Lycoris" in lora_type else "networks.lora"
network_args = [
    f"conv_dim={conv_dim}",
    f"conv_alpha={conv_alpha}",
    f"algo={'loha' if 'LoHa' in lora_type else 'lora'}",
    f"disable_conv_cp={str(not conv_compression)}",
]

### Функции для установки зависимостей

In [None]:
root_dir = "/content"
deps_dir = os.path.join(root_dir, "deps")
repo_dir = os.path.join(root_dir, "kohya-trainer")

main_dir = os.path.join(root_dir, "drive/MyDrive/Loras")
log_folder = os.path.join(main_dir, "logs")
config_folder = os.path.join(main_dir, project_name)
images_folder = os.path.join(main_dir, project_name, "dataset")
output_folder = os.path.join(main_dir, project_name, "output")

config_file = os.path.join(config_folder, "training_config.toml")
dataset_config_file = os.path.join(config_folder, "dataset_config.toml")
accelerate_config_file = os.path.join(repo_dir, "accelerate_config/config.yaml")


def clone_repo():
    os.chdir(root_dir)
    os.system(f"git clone https://github.com/kohya-ss/sd-scripts {repo_dir}")
    os.chdir(repo_dir)
    commit = "5050971ac687dca70ba0486a583d283e8ae324e2"
    os.system(f"git reset --hard {commit}")
    os.system(
        "wget https://raw.githubusercontent.com/hollowstrawberry/kohya-colab/main/requirements.txt -q -O requirements.txt"
    )


def install_dependencies():
    clone_repo()
    os.system("apt -y update -qq")
    os.system("apt -y install aria2")
    os.system("pip -q install --upgrade -r requirements.txt")

    os.system('sed -i "s@cpu@cuda@" library/model_util.py')

    os.system(
        "sed -i 's/from PIL import Image/from PIL import Image, ImageFile\nImageFile.LOAD_TRUNCATED_IMAGES=True/g' library/train_util.py"
    )

    os.system("sed -i 's/{:06d}/{:02d}/g' library/train_util.py")
    os.system(
        'sed -i \'s/model_name + "."/model_name + "-{:02d}.".format(num_train_epochs)/g\' train_network.py'
    )

    if not os.path.exists(accelerate_config_file):
        write_basic_config(save_location=accelerate_config_file)

    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["BITSANDBYTES_NOWELCOME"] = "1"
    os.environ["SAFETENSORS_FAST_GPU"] = "1"

### Функция для создания файлов конфигурации для датасета и модели

In [None]:
def create_config(model_file, dataset_config_file, config_file):
    config_dict = {
        "additional_network_arguments": {
            "unet_lr": unet_lr,
            "text_encoder_lr": text_encoder_lr,
            "network_dim": network_dim,
            "network_alpha": network_alpha,
            "network_module": network_module,
            "network_args": network_args,
            "network_train_unet_only": True if text_encoder_lr == 0 else None,
            "network_weights": None,
        },
        "optimizer_arguments": {
            "learning_rate": unet_lr,
            "lr_scheduler": lr_scheduler,
            "lr_scheduler_num_cycles": lr_scheduler_num_cycles,
            "lr_scheduler_power": None,
            "lr_warmup_steps": None,
            "optimizer_type": "AdamW8bit",
            "optimizer_args": None,
        },
        "training_arguments": {
            "max_train_steps": max_train_steps,
            "max_train_epochs": None,
            "save_every_n_epochs": save_every_n_epochs,
            "save_last_n_epochs": keep_only_last_n_epochs,
            "train_batch_size": train_batch_size,
            "noise_offset": None,
            "clip_skip": 2,
            "min_snr_gamma": min_snr_gamma_value,
            "weighted_captions": False,
            "seed": 42,
            "max_token_length": 225,
            "xformers": True,
            "lowram": True,
            "max_data_loader_n_workers": 8,
            "persistent_data_loader_workers": True,
            "save_precision": "fp16",
            "mixed_precision": "fp16",
            "output_dir": output_folder,
            "logging_dir": log_folder,
            "output_name": project_name,
            "log_prefix": project_name,
            "save_state": False,
            "save_last_n_epochs_state": None,
            "resume": None,
        },
        "model_arguments": {
            "pretrained_model_name_or_path": model_file,
            "v2": is_based_on_sd2,
            "v_parameterization": True if is_based_on_sd2 else None,
        },
        "saving_arguments": {
            "save_model_as": "safetensors",
        },
        "dreambooth_arguments": {
            "prior_loss_weight": 1.0,
        },
        "dataset_arguments": {
            "cache_latents": True,
        },
    }

    for key in config_dict:
        if isinstance(config_dict[key], dict):
            config_dict[key] = {
                k: v for k, v in config_dict[key].items() if v is not None
            }

    with open(config_file, "w") as f:
        f.write(toml.dumps(config_dict))
    print(f"\nConfig saved to {config_file}")

    dataset_config_dict = {
        "general": {
            "resolution": resolution,
            "shuffle_caption": False,
            "keep_tokens": 0,
            "flip_aug": flip_aug,
            "caption_extension": ".txt",
            "enable_bucket": True,
            "bucket_reso_steps": 64,
            "bucket_no_upscale": False,
            "min_bucket_reso": 320 if resolution > 640 else 256,
            "max_bucket_reso": 1280 if resolution > 640 else 1024,
        },
        "datasets": [
            {
                "subsets": [
                    {
                        "num_repeats": num_repeats,
                        "image_dir": images_folder,
                        "class_tokens": None,
                    }
                ]
            }
        ],
    }

    for key in dataset_config_dict:
        if isinstance(dataset_config_dict[key], dict):
            dataset_config_dict[key] = {
                k: v for k, v in dataset_config_dict[key].items() if v is not None
            }

    with open(dataset_config_file, "w") as f:
        f.write(toml.dumps(dataset_config_dict))
    print(f"Dataset config saved to {dataset_config_file}")

### Функция для скачивания базовой модели

In [1]:
def download_model(model_url):
    model_url = model_url.strip()
    model_file = f"/content{model_url[model_url.rfind('/'):]}"

    os.system(
        f'aria2c "{model_url}" --console-log-level=warn -c -s 16 -x 16 -k 10M -d / -o "{model_file}"'
    )

    return model_file

### Основная функция

In [None]:
def main():
    if not os.path.exists("/content/drive"):
        print("Connecting to Google Drive...")
        drive.mount("/content/drive")

    for dir in (main_dir, deps_dir, repo_dir, log_folder, output_folder, config_folder):
        os.makedirs(dir, exist_ok=True)

    print("\nInstalling dependencies...\n")
    install_dependencies()
    print(f"\nInstallation finished.")

    print("\nDownloading model...")
    model_file = download_model(model_url=model_url)
    print("\nDownloading finished.")

    create_config("\nCreating config files...\n")
    create_config(model_file=model_file, dataset_config_file=dataset_config_file, config_file=config_file)

    print("\nStarting training...\n")
    os.chdir(repo_dir)

    os.system(
        f"accelerate launch --config_file={accelerate_config_file} --num_cpu_threads_per_process=1 train_network.py --dataset_config={dataset_config_file} --config_file={config_file}"
    )

In [None]:
if __name__ == "__main__":
    main()