In [None]:
import pathlib
import time
from PIL import Image, ImageTk
import re
import os
import json
from PIL import Image, ImageTk
from IPython.display import display
import datetime
import math
from decimal import Decimal


# 获取触发词
def get_trigger(train_instance_images_dir):
    for root, dirs, files in os.walk(train_instance_images_dir):
        if len(dirs)!=0:
            for dir_ in dirs:
                pattern = r"^\d"
                match = re.search(pattern, dir_)
                if match is not None:
                    if dir_.find(" ") != -1:
                        trigger=dir_.split(" ")[0]
                        trigger=dir_.split("_")[1]
                    else:
                        trigger=dir_.split("_")[1]
                    return trigger
    return ""

# 获取图片张数
def get_train_image_num(train_instance_images_dir):
    trigger=get_trigger(train_instance_images_dir)
    for root, dirs, files in os.walk(train_instance_images_dir):
        if len(dirs)!=0:
            for dir_ in dirs:
                if dir_.find(trigger) != -1:
                    image_num = len(
                        [
                            f
                            for f, lower_f in (
                                (file, file.lower()) for file in os.listdir(f'{train_instance_images_dir}/{dir_}')
                            )
                            if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
                        ]
                    )
                    return image_num
    return 0

# 获取重复次数
def get_repeats_nums(train_instance_images_dir):
    trigger=get_trigger(train_instance_images_dir)
    for root, dirs, files in os.walk(train_instance_images_dir):
        if len(dirs) != 0:
            for dir_ in dirs:
                if dir_.find(trigger) != -1:
                    pattern = r'\d+'
                    match = re.findall(pattern, dir_)
                    return int(match[0])
    return 0

# 获取对应路径下的所有目录，返回list
def get_dirs(path):
    for root, dirs, files in os.walk(path):
        if len(dirs) != 0:
            return dirs
    return {}

# 获取预热的总步数
def get_warmup_steps(lr_warmup, train_batch_size, train_instance_images_dir):
    image_num = get_train_image_num(train_instance_images_dir)
    repeats_num = get_repeats_nums(train_instance_images_dir)
    print(f'image_num = {image_num}')
    print(f'repeats_num = {repeats_num}')

    if lr_warmup > 0:
        repeats = int(image_num) * int(repeats_num)
        max_train_steps = int(
            math.ceil(float(repeats) / int(train_batch_size) * int(epoch))
        )
        lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
        return lr_warmup_steps
    return 0


# 切换模型
def cut_config(sd_conf_path, args):
    if args["config_data"] is None:
        return
    conf = None
    with open(sd_conf_path, "r") as f:
        try:
            conf=json.load(f)
        except json.JSONDecodeError as e:
            print(f"JSON Decode Error:{e.msg},lineno:{e.lineno},colno:{e.colno}")
            return

    for field in args["config_data"]:
        if field in conf.keys():
            conf[field]=args["config_data"][field]
    with open(sd_conf_path, "w") as f:
        json.dump(conf, f, indent=4)
    f.close()
    return

# 初始化预览结果文件
def writ_html(log_html, temp, style):
    with open(log_html, "w+") as f:
        content=f.read()
        if content is None or content == "":
            with open(temp, "r") as tf:
                t_content=tf.read()
                t_content=t_content.replace("%title%", style)
                t_content=t_content.replace("%style%", style)
                f.write(t_content)
                print("init html log file success!")
        f.close()
    return


# 推理代码初始化
def write_train_codes(sd_path):
    new_code = '''
import os
import sys
import time
import importlib
import signal
import re
from typing import Dict, List, Any
from packaging import version

import logging

logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())

from modules import errors
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call

import torch

if ".dev" in torch.__version__ or "+git" in torch.__version__:
    torch.__long_version__ = torch.__version__
    torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)

from modules import shared, devices, ui_tempdir
shared.parser.add_argument("--output_dir", type=str, default=None, help="output_dir")
import modules.codeformer_model as codeformer
import modules.face_restoration
import modules.gfpgan_model as gfpgan
import modules.img2img

import modules.lowvram
import modules.paths
import modules.scripts
import modules.sd_hijack
import modules.sd_models
import modules.sd_vae
import modules.txt2img
import modules.script_callbacks
import modules.textual_inversion.textual_inversion
import modules.progress

import modules.ui
from modules import modelloader
from modules.shared import cmd_opts, opts
import modules.hypernetworks.hypernetwork

from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
import base64
import io
from fastapi import HTTPException
from io import BytesIO
import piexif
import piexif.helper
from PIL import PngImagePlugin, Image

from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks


def check_versions():
    if shared.cmd_opts.skip_version_check:
        return


def initialize():
    check_versions()
    extensions.list_extensions()
    localization.list_localizations(cmd_opts.localizations_dir)

    if cmd_opts.ui_debug_mode:
        shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
        modules.scripts.load_scripts()
        return

    modelloader.cleanup_models()
    modules.sd_models.setup_model()
    codeformer.setup_model(cmd_opts.codeformer_models_path)
    gfpgan.setup_model(cmd_opts.gfpgan_models_path)
    modelloader.list_builtin_upscalers()
    modules.scripts.load_scripts()
    modelloader.load_upscalers()

    modules.sd_vae.refresh_vae_list()
    modules.textual_inversion.textual_inversion.list_textual_inversion_templates()

    try:
        modules.sd_models.load_model()
    except Exception as e:
        errors.display(e, "loading stable diffusion model")
        print("", file=sys.stderr)
        print("Stable diffusion model failed to load, exiting", file=sys.stderr)
        exit(1)

    shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title

    shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
    shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
    shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
    shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
    modules.script_callbacks.before_ui_callback()

    def sigint_handler(sig, frame):
        print(f'Interrupted with signal {sig} in {frame}')
        os._exit(0)

    signal.signal(signal.SIGINT, sigint_handler)


class Handler():
    def __init__(self, path=""):
        initialize()
        self.shared = shared

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        print("-----inputs:", data["inputs"])
        print("-----output_dir:", data["output_dir"])
        args = {
            "do_not_save_samples": True,
            "do_not_save_grid": True,
            "outpath_samples": "./output",
            "prompt": "lora:koreanDollLikeness_v15:0.66, best quality, ultra high res, (photorealistic:1.4), 1girl, beige sweater, black choker, smile, laughing, bare shoulders, solo focus, ((full body), (brown hair:1), looking at viewer",
            "negative_prompt": "paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, (ugly:1.331), (duplicate:1.331), (morbid:1.21), (mutilated:1.21), (tranny:1.331), mutated hands, (poorly drawn hands:1.331), blurry, 3hands,4fingers,3arms, bad anatomy, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts,poorly drawn face,mutation,deformed",
            "sampler_name": "DPM++ SDE Karras",
            "steps": 20,  # 25
            "cfg_scale": 8,
            "width": 512,
            "height": 768,
            "seed": -1,
            "n_iter":1,
        }
        if data["inputs"]:
            for field in args:
                if field in data["inputs"].keys():
                    args[field] = data["inputs"][field]

        print("prompt:", args)
        p = StableDiffusionProcessingTxt2Img(sd_model=self.shared.sd_model, **args)
        processed = process_images(p)
        ret={}
        output_dir = data['output_dir']
        idx=0
        for img in processed.images:
            path=save_image(img,output_dir)
            ret[str(idx)] = path
            idx=idx+1
        return {
            "img_path": ret
        }


import io
from modules.shared import cmd_opts, opts
from PIL import PngImagePlugin, Image
import piexif
import piexif.helper
from fastapi import HTTPException
import base64

def save_image(image, output_dir):
    with io.BytesIO() as output_bytes:

        if opts.samples_format.lower() == 'png':
            use_metadata = False
            metadata = PngImagePlugin.PngInfo()
            for key, value in image.info.items():
                if isinstance(key, str) and isinstance(value, str):
                    metadata.add_text(key, value)
                    use_metadata = True
            image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None),
                       quality=opts.jpeg_quality)

        elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
            parameters = image.info.get('parameters', None)
            exif_bytes = piexif.dump({
                "Exif": {
                    piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode")}
            })
            if opts.samples_format.lower() in ("jpg", "jpeg"):
                image.save(output_bytes, format="JPEG", exif=exif_bytes, quality=opts.jpeg_quality)
            else:
                image.save(output_bytes, format="WEBP", exif=exif_bytes, quality=opts.jpeg_quality)

        else:
            raise HTTPException(status_code=500, detail="Invalid image format")

        bytes_data = output_bytes.getvalue()

    img_name=time.time()
    img_path=f"{output_dir}/{img_name}.png"
    with open(img_path, 'wb') as f:
        f.write(bytes_data)
    return img_path


    '''

#     inference_args_str=str(inference_args)
    input_str= '''
if __name__ == '__main__':
    import json
    cmd_opts = shared.parser.parse_args()
    options_cli = cmd_opts.options_cli
    # 配置文件
    input_file = f"{options_cli}/input.json"
    out_put_json = f"{options_cli}/output.json"
    # 输出路径
    output_dir = f"{options_cli}/output"
    h_func = Handler()

    input_ = None
    with open(input_file, 'r') as f:
        input_ = json.load(f)
    inp = {
        'inputs': input_,
        'output_dir': output_dir
    }
    res = h_func(inp)
    with open(out_put_json, 'w') as f:
        json.dump(res, f)
    print("txt2imageSuccess:")
    print(res)
    '''
#     .format(inference_args_str)

    new_code = new_code + input_str
    with open(sd_path, "w") as f:
        f.write(new_code)
        print("step: infernec code write success!")
    return

# 两个虚拟环境的桥梁
def write_train_bash(bash_file_path, sd_path, output_dir):
    new_code = f'''
    #!/bin/bash
    SD_WORKER_DIR={sd_path}
    source "$SD_WORKER_DIR/venv/bin/activate"
    echo "python /data/aigc/stable-diffusion-webui/handler.py {output_dir}"
    python /data/aigc/stable-diffusion-webui/handler.py --options-cli={output_dir} --disable-nan-check --disable-safe-unpickle --enable-insecure-extension-access --no-half --xformers --deepdanbooru
    '''
    with open(bash_file_path, "w") as f:
        f.write(new_code)
        print("step: infernec bash code write success!")

# 更新结果预览环境中html代码引用的静态文件版本号
def writef(md_file_name, data, log_html):
    # *.md
    header_str='''| 时间 | 版本 | trigger | 执行时长(s) | lr_scheduler | lr_warmup | epoch | learning_rate | text_encoder_lr | unet_lr | train_batch_size | network_dim | network_alpha | img1 | img2 | img3 | img4 |
| :----: | :----: | :----: | :----:| :----: | :----:| :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: |'''
    if os.path.exists(md_file_name) == False:
        with open(md_file_name, "a") as f:
            f.write(f"{header_str}\n{data}\n")
            f.close()
    else:
        with open(md_file_name, "a") as f:
            f.write(f"{data}\n")
            f.close()

    with open(log_html, "r") as r_hf:
        ctt=r_hf.read()
        r_hf.close()
        with open(log_html, "w") as w_hf:
            v_str=str(time.time())
            ctt=ctt.replace("%version%", v_str)
            new_st = re.sub(r'(?<=v=)\d+\.\d+', v_str, ctt)
            w_hf.write(new_st)
            w_hf.close()



##############################
#           目录规划          #
##############################
##############################
#        初始化相关文件        #
##############################
# 1.先规划kohya的训练目录，主要是人脸照片的目录
# 2.训练好的Lora 模型名称
# 3.规划sd工作目录
# 4.规划风格目录，存放风格的prompt参数文件，以及出图后的output目录


# 风格模板最终的推理目录，包含风格模板的相关的prompt信息
temp_dir='/data/aigc/inferenc_template/template'
# 风格模板名
style='snow_v1.1'

### 文件包含
# 1. {style}.md   训练+推理后的日志文件
# 2. {style}.html HTML预览文件
# 3. {style}.json prompt推理文件
# {style}.html 文件通过模板文件 “._template.html“ 生成
log_html_template=f"{temp_dir}/._template.html"
log_html=f"{temp_dir}/{style}.html"
log_name=f"{temp_dir}/{style}.md"


writ_html(log_html, log_html_template, style)

# kohya 工作目录
kohya_ss_dir="/data/aigc/kohya_ss"

# sd 工作目录
sd_worker_dir="/data/aigc/stable-diffusion-webui"
# sd 配置文件
sd_conf_path=f"{sd_worker_dir}/config.json"
# sd 注入文件
sd_handler_path=f"{sd_worker_dir}/handler.py"
# base 模型
train_base_model=f"{sd_worker_dir}/models/Stable-diffusion/v1-5-pruned.ckpt"

# kohya 入口文件
train_network=f"{kohya_ss_dir}/train_network.py"
# kohya->sd 的中间文件
bash_file_path=f"{sd_worker_dir}/sd_inferenc.sh"

# review_url
review_url="http://34.231.135.235/template"

# Lora模型名称
# >>>>>>>规定人物的时候需要调整这里
lora_name=""



######################
# Lora 模型训练参数
######################
### text_encoder 的学习率，常取定值5e-5，一般将它调成 unet_lr 的 10-15分之一，
### text_encoder 调整到 unet_lr 1/8 附近也是常见的做法。调低该参数有助于更多学习文本编码器（对 tag 更敏感）。
learning_rate=0.0001

#版本文件命名参数之一
text_encoder_lr = 5e-5
#版本文件命名参数之二
unet_lr = 0.0005
# constant,constant_with_warmup,cosine,cosine_with_restarts,linear,polynomial,adafactor
lr_scheduler="cosine"

# 预热比例（百分比：0~100）
lr_warmup=0
train_batch_size=4
epoch=15

network_dim=64
network_alpha=32


######################
# Lora 模型推理权重
######################
lora_weight="0.7"
output_image_size = 4
style_prompt=f"{temp_dir}/{style}.json"

print("style_prompt:",style_prompt)
# print(f"train_output:{train_output}")

train_instance_set_dir = f"{kohya_ss_dir}/train/men"
train_instances = get_dirs(train_instance_set_dir)


# import matplotlib.pyplot as plt
# import numpy as np
# images = [np.random.rand(100,100) for i in range(4)]
# fig = plt.figure(figsize=(10, 10), dpi=100)
# for i in range(len(images)):
#     ax = fig.add_subplot(1, len(images), i+1)
#     ax.imshow(images[i], cmap='gray')
#     ax.axis('off')
# plt.show()



# train_instances={"bairennan018"}
for its in train_instances:
    lora_name = its
    if its.find("nv") != -1:
        sex="woman"
    else:
        sex="man"
#     print("its:", its)
    # instance 人物照片
    train_instance_images=f"{train_instance_set_dir}/{lora_name}/images/instance"
    train_output=f"{sd_worker_dir}/models/Lora/pipelined/{lora_name}"
    pathlib.Path(train_output).mkdir(parents=True, exist_ok=True)
    #获取预热处理步数
    warmup_steps = get_warmup_steps(lr_warmup, train_batch_size, train_instance_images)

    ######################
    # pipeline......
    ######################

    #########
    # 背景：network_dim=32,network_alpha=16,text_encoder_lr0.000191,unet_lr=0.00251 这个参数训练出来的参数在进行推理的时候报错：
    # 错误：modules.devices.NansException: A tensor with all NaNs was produced in Unet. This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the "Upcast cross attention layer to float32" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this. Use --disable-nan-check commandline argument to disable this check.
    #
    #
    #########
    # ,"polynomial","adafactor"

    #################################################################:start
    start_time=time.time()
    now_date=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    print(f"current time:{now_date}\n")

    ######################
    # step.1 训练好后的Lora模型名称
    ######################

    ###version
    lora_version=f"t{text_encoder_lr}_u{unet_lr}_{lr_scheduler}"
    print(f"###############################lora_version:{lora_version}")

    if lora_version != "" and lora_version != None:
        lora_name_version=f"{lora_name}_{lora_version}"
    else:
        lora_name_version=lora_name

    # 推理后的输出目录
    output_dir=f"{temp_dir}/{style}/{lora_name_version}"
    if os.path.exists(output_dir):
        print(f'###############################skip:{lora_name_version}')
        continue

    ##图片地址封装：
    url_path_prefix=f"{review_url}/{style}/{lora_name_version}/output"
    output_img_dir=f"{output_dir}/output"
    pathlib.Path(output_img_dir).mkdir(parents=True, exist_ok=True)

    ##推理图片的输入输出文件
    prompt_conf=f"{output_dir}/input.json"
    out_put_json = f"{output_dir}/output.json"


    # # # # # # # # # # # # # # # #
    # step.2 训练模型
    # # # # # # # # # # # # # # # #

    train_script=f'''
    accelerate launch \
    --num_cpu_threads_per_process=2 "{train_network}" \
    --enable_bucket \
    --pretrained_model_name_or_path="{train_base_model}" \
    --train_data_dir="{train_instance_images}" \
    --output_dir="{train_output}" \
    --output_name="{lora_name_version}" \
    --resolution=512,512 \
    --bucket_reso_steps=64 \
    --train_batch_size="{train_batch_size}" \
    --learning_rate="{learning_rate}" \
    --text_encoder_lr={text_encoder_lr} \
    --unet_lr={unet_lr} \
    --lr_scheduler="{lr_scheduler}" \
    --max_train_epochs={epoch} \
    --lr_scheduler_num_cycles={epoch} \
    --network_dim={network_dim} \
    --network_alpha={network_alpha} \
    --save_model_as=safetensors \
    --network_module=networks.lora \
    --bucket_no_upscale \
    --cache_latents \
    --xformers \
    --optimizer_type="AdamW8bit" \
    --noise_offset=0.1 \
    --caption_extension=".txt" \
    --keep_tokens="1" \
    --shuffle_caption \
    --mixed_precision="fp16" \
    --save_precision="fp16" \
    --seed="1234" \
    --max_data_loader_n_workers="0" '''
    if warmup_steps>0:
        train_script+=f' --lr_warmup_steps={warmup_steps} '
    print("train_script:", train_script)

    # --lr_warmup_steps={} \
    # --text_encoder_lr=9e-05 \
    # --unet_lr=0.0004 \
    # --prior_loss_weight=0.3 \
    # --save_every_n_epochs=1 \
    #--logging_dir="/data/aigc/kohya_ss/train/lm_tw_nan020/logs/v10.0" \
    # --lr_scheduler="cosine" \

    trigger=get_trigger(train_instance_images)
    #         for root, dirs, files in os.walk(train_instance_images):
    #             if len(dirs) !=0:
    #                 for dir_ in dirs:
    #                     pattern = r"^\d"
    #                     match = re.search(pattern, dir_)
    #                     if match is not None:
    #                         trigger=get_trigger(dir_)

    if trigger!="":
#         print("trigger:", train_instance_images)
        trigger_images={}
        for root, dirs, files in os.walk(train_instance_images):
            if len(dirs) != 0:
                for dir_ in dirs:
                    tmp_dir=f"{train_instance_images}/{dir_}"
                    for r, d, fs in os.walk(tmp_dir):
                        for f in fs:
                            pth=os.path.join(r, f)
                            print("pth:", pth)
                            if pth.find(".png") != -1 or pth.find(".jpg") != -1 or pth.find(".jpeg") != -1:
                                ig=Image.open(pth)
                                display(ig)
        ! $train_script
    else:
        print("error: trigger is empty !")

    now_date=time.strftime("%m-%d %H:%M:%S", time.localtime())
    print(f"current time:{now_date}\n")



    # # # # # # # # # # # # # # # #
    # step.3 训练模型
    # # # # # # # # # # # # # # # #
    print("step:\n-------------------------------\n\nadjust model inferenc params……")
    # style_template="/data/aigc/inferenc_template/template"
    # temp_dir


    # sd 切换风格地址
    # /data/aigc/stable-diffusion-webui

    with open(style_prompt, "r") as f:
        style_data=json.load(f)
        style_data=style_data[sex]
        model_name=style_data['_model_name_']
        style_data['prompt']=style_data['prompt'] + f",<lora:{lora_name_version}:{lora_weight}>"
        style_data['prompt']=style_data['prompt'].replace("%s", trigger)
        style_data['n_iter']=output_image_size

    print("styledata:", style_data)


    if model_name=="" or model_name is None:
        print("error: adjust model error, model_name empty")
    else:
        params = {
            "config_data": {
                #切换模型名称
                "sd_model_checkpoint": model_name,
                "sd_checkpoint_hash":'',
                "sd_vae":''
            }
        }

    cut_config(sd_conf_path, params)

    print("adjust params success!")

    now_date=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    print(f"\ncurrent time:{now_date}\n")



    # # # # # # # # # # # # # # # #
    # step.4 推理
    # # # # # # # # # # # # # # # #

    with open(prompt_conf, 'w') as f:
        json.dump(style_data, f, indent=4)

    write_train_codes(sd_handler_path)
    write_train_bash(bash_file_path, sd_worker_dir, output_dir)

    ! chmod +x $bash_file_path

    print("execute infernec:.....\n\n")
    ! $bash_file_path


    for root, dirs, files in os.walk(output_img_dir):
        for file in files:
            path=os.path.join(root, file)
            if path.find(".png") != -1:
                ig=Image.open(path)
                display(ig)

    #         now_date=time.strftime("%m-%d %H:%M:%S", time.localtime())
    #         print(f"current time:{now_date}\n")
    now_utc_offset = datetime.datetime.utcnow() + datetime.timedelta(hours=8)
    time_str = now_utc_offset.strftime('%m-%d %H:%M:%S')

    end_time=time.time()
    excution_time=end_time-start_time

    print(f"excution_time: {excution_time}")

    excution_time_int=int(excution_time)
    markdown_str=f"| {time_str} | {lora_version} | {trigger} | {excution_time_int} | {lr_scheduler} | {lr_warmup} | {epoch} | {learning_rate} | {text_encoder_lr} | {unet_lr} | {train_batch_size} | {network_dim} | {network_alpha} | "
    imgs_str=""
    idx=1
    for root, dirs, files in os.walk(output_img_dir):
        for file in files:
            path=os.path.join(root, file)
            if path.find(".png") != -1:
                if idx<=4:
                    output_images = url_path_prefix + path.replace(output_img_dir, "")
                    imgs_str=imgs_str+f" ![图片描述]({output_images}) |"
                idx=idx+1

    print("imgs_str:", imgs_str)
    markdown_str=markdown_str+imgs_str
    print(markdown_str)

    writef(log_name, markdown_str, log_html)
    #################################################################