# SD Trainer Kaggle
#### Created by [licyk](https://github.com/licyk)

Jupyter Notebook 仓库：[licyk/sd-webui-all-in-one](https://github.com/licyk/sd-webui-all-in-one)

一个在 [Kaggle](https://www.kaggle.com) 部署 [SD Trainer](https://github.com/Akegarasu/lora-scripts) 的 Jupyter Notebook。

使用时请按顺序运行笔记单元。

## 提示：
1. 可以将训练数据上传至`kaggle/input`文件夹，运行安装时将会把训练数据放置`/kaggle/data`文件夹中。
2. 训练需要的模型将下载至`/kaggle/lora-scripts/sd-models`文件夹中。
3. 推荐将训练输出的模型路径改为`kaggle/working`文件夹，方便下载。

### 参数配置

In [None]:
INIT_CONFIG = 1

# 消息格式输出
def echo(msg):
    print(f":: {msg}")


# 将/kaggle/input中的文件复制到/kaggle/data
def cp_data():
    import os
    if not os.path.exists("/kaggle/data"):
        !mkdir -p /kaggle/data
    data_list = os.listdir("/kaggle/input")
    for i in data_list:
        file_path = os.path.join("/kaggle/input", i)
        !cp -rf {file_path} /kaggle/data


# ARIA2
class ARIA2:
    WORKSPACE = ""
    WORKFOLDER = ""


    def __init__(self, workspace, workfolder) -> None:
        self.WORKSPACE = workspace
        self.WORKFOLDER = workfolder


    # 下载器
    def aria2(self, url, path, filename):
        import os
        if not os.path.exists(path + "/" + filename):
            echo(f"开始下载 {filename} ，路径: {path}/{filename}")
            !aria2c --console-log-level=error -c -x 16 -s 16 "{url}" -d "{path}" -o "{filename}"
            if os.path.exists(path + "/" + filename) and not os.path.exists(path + "/" + filename + ".aria2"):
                echo(f"{filename} 下载完成")
            else:
                echo(f"{filename} 下载中断")
        else:
            if os.path.exists(path + "/" + filename + ".aria2"):
                echo(f"开始下载 {filename} ，路径: {path}/{filename}")
                !aria2c --console-log-level=error -c -x 16 -s 16 "{url}" -d "{path}" -o "{filename}"
                if os.path.exists(path + "/" + filename) and not os.path.exists(path + "/" + filename + ".aria2"):
                    echo(f"{filename} 下载完成")
                else:
                    echo(f"{filename} 下载中断")
            else:
                echo(f"{filename} 文件已存在，路径: {path}/{filename}")


    # 大模型下载
    def get_sd_model(self, url, filename):
        pass


    # vae模型下载
    def get_vae_model(self, url, filename):
        pass


# GIT
class GIT:
    WORKSPACE = ""
    WORKFOLDER = ""


    def __init__(self, workspace, workfolder) -> None:
        self.WORKSPACE = workspace
        self.WORKFOLDER = workfolder


    # 检测要克隆的项目是否存在于指定路径
    def exists(self, addr=None, path=None, name=None):
        import os
        if addr is not None:
            if path is None and name is None:
                path = os.getcwd() + "/" + addr.split("/").pop().split(".git", 1)[0]
            elif path is None and name is not None:
                path = os.getcwd() + "/" + name
            elif path is not None and name is None:
                path = os.path.normpath(path) + "/" + addr.split("/").pop().split(".git", 1)[0]

        if os.path.exists(path):
            return True
        else:
            return False


    # 克隆项目
    def clone(self, addr, path=None, name=None):
        import os
        repo = addr.split("/").pop().split(".git", 1)[0]
        if not self.exists(addr, path, name):
            echo(f"开始下载 {repo}")
            if path is None and name is None:
                path = os.getcwd()
                name = repo
            elif path is not None and name is None:
                name = repo
            elif path is None and name is not None:
                path = os.getcwd()
            !git clone {addr} "{path}/{name}" --recurse-submodules
        else:
            echo(f"{repo} 已存在")



# TUNNEL
class TUNNEL:
    LOCALHOST_RUN = "localhost.run"
    REMOTE_MOE = "remote.moe"
    WORKSPACE = ""
    WORKFOLDER = ""
    PORT = ""


    def __init__(self, workspace, workfolder, port) -> None:
        self.WORKSPACE = workspace
        self.WORKFOLDER = workfolder
        self.PORT = port


    # ngrok内网穿透
    def ngrok(self, ngrok_token: str):
        from pyngrok import conf, ngrok
        conf.get_default().auth_token = ngrok_token
        conf.get_default().monitor_thread = False
        port = self.PORT
        ssh_tunnels = ngrok.get_tunnels(conf.get_default())
        if len(ssh_tunnels) == 0:
            ssh_tunnel = ngrok.connect(port, bind_tls=True)
            return ssh_tunnel.public_url
        else:
            return ssh_tunnels[0].public_url


    # cloudflare内网穿透
    def cloudflare(self):
        from pycloudflared import try_cloudflare
        port = self.PORT
        urls = try_cloudflare(port).tunnel
        return urls


    from typing import Union
    from pathlib import Path

    # 生成ssh密钥
    def gen_key(self, path: Union[str, Path]) -> None:
        import subprocess
        import shlex
        from pathlib import Path
        path = Path(path)
        arg_string = f'ssh-keygen -t rsa -b 4096 -N "" -q -f {path.as_posix()}'
        args = shlex.split(arg_string)
        subprocess.run(args, check=True)
        path.chmod(0o600)


    # ssh内网穿透
    def ssh_tunnel(self, host: str) -> None:
        import subprocess
        import atexit
        import shlex
        import re
        import os
        from pathlib import Path
        from tempfile import TemporaryDirectory

        ssh_name = "id_rsa"
        ssh_path = Path(self.WORKSPACE) / ssh_name
        port = self.PORT

        tmp = None
        if not ssh_path.exists():
            try:
                self.gen_key(ssh_path)
            # write permission error or etc
            except subprocess.CalledProcessError:
                tmp = TemporaryDirectory()
                ssh_path = Path(tmp.name) / ssh_name
                self.gen_key(ssh_path)

        arg_string = f"ssh -R 80:127.0.0.1:{port} -o StrictHostKeyChecking=no -i {ssh_path.as_posix()} {host}"
        args = shlex.split(arg_string)

        tunnel = subprocess.Popen(
            args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8"
        )

        atexit.register(tunnel.terminate)
        if tmp is not None:
            atexit.register(tmp.cleanup)

        tunnel_url = ""
        LOCALHOST_RUN = self.LOCALHOST_RUN
        lines = 27 if host == LOCALHOST_RUN else 5
        localhostrun_pattern = re.compile(r"(?P<url>https?://\S+\.lhr\.life)")
        remotemoe_pattern = re.compile(r"(?P<url>https?://\S+\.remote\.moe)")
        pattern = localhostrun_pattern if host == LOCALHOST_RUN else remotemoe_pattern

        for _ in range(lines):
            line = tunnel.stdout.readline()
            if line.startswith("Warning"):
                print(line, end="")

            url_match = pattern.search(line)
            if url_match:
                tunnel_url = url_match.group("url")
                if lines == 27:
                    os.environ['LOCALHOST_RUN'] = tunnel_url
                    return tunnel_url
                else:
                    os.environ['REMOTE_MOE'] = tunnel_url
                    return tunnel_url
                # break
        else:
            echo(f"启动 {host} 内网穿透失败")


    # localhost.run穿透
    def localhost_run(self):
        urls = self.ssh_tunnel(self.LOCALHOST_RUN)
        return urls


    # remote.moe内网穿透
    def remote_moe(self):
        urls = self.ssh_tunnel(self.REMOTE_MOE)
        return urls


    # gradio内网穿透
    def gradio(self):
        import subprocess
        import shlex
        import atexit
        import re
        port = self.PORT
        cmd = f"gradio-tunneling --port {port}"
        cmd = shlex.split(cmd)
        tunnel = subprocess.Popen(
            cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8"
        )

        atexit.register(tunnel.terminate)

        tunnel_url = ""
        lines = 5
        gradio_pattern = re.compile(r"(?P<url>https?://\S+\.gradio\.live)")
        pattern = gradio_pattern

        for _ in range(lines):
            line = tunnel.stdout.readline()
            if line.startswith("Warning"):
                print(line, end="")
            url_match = pattern.search(line)
            if url_match:
                tunnel_url = url_match.group("url")
                return tunnel_url
        else:
            echo(f"启动 Gradio 内网穿透失败")


    # 启动内网穿透
    def start(self, ngrok=False, ngrok_token=None, cloudflare=False, remote_moe=False, localhost_run=False, gradio=False):
        if cloudflare is True or ngrok is True or ngrok_token is not None or remote_moe is True or localhost_run is True or gradio is True:
            echo("启动内网穿透")

        if cloudflare is True:
            cloudflare_url = self.cloudflare()
        else:
            cloudflare_url = None

        if ngrok is True and ngrok_token is not None:
            ngrok_url = self.ngrok(ngrok_token)
        else:
            ngrok_url = None

        if remote_moe is True:
            remote_moe_url = self.remote_moe()
        else:
            remote_moe_url = None

        if localhost_run is True:
            localhost_run_url = self.localhost_run()
        else:
            localhost_run_url = None

        if gradio is True:
            gradio_url = self.gradio()
        else:
            gradio_url = None

        echo("下方为访问地址")
        print("==================================================================================")
        echo(f"CloudFlare: {cloudflare_url}")
        echo(f"Ngrok: {ngrok_url}")
        echo(f"remote.moe: {remote_moe_url}")
        echo(f"localhost_run: {localhost_run_url}")
        echo(f"Gradio: {gradio_url}")
        print("==================================================================================")



# ENV
class ENV:
    WORKSPACE = ""
    WORKFOLDER = ""


    def __init__(self, workspace, workfolder) -> None:
        self.WORKSPACE = workspace
        self.WORKFOLDER = workfolder


    # 准备ipynb笔记自身功能的依赖
    def prepare_env_depend(self, use_mirror=True):
        if use_mirror is True:
            pip_mirror = "--index-url https://mirrors.cloud.tencent.com/pypi/simple --find-links https://mirror.sjtu.edu.cn/pytorch-wheels/cu121/torch_stable.html"
        else:
            pip_mirror = "--index-url https://pypi.python.org/simple --find-links https://download.pytorch.org/whl/cu121/torch_stable.html"

        echo("安装自身组件依赖")
        !pip install pyngrok pycloudflared gradio-tunneling {pip_mirror}
        !apt update
        !apt install aria2 ssh google-perftools -y


    # 安装pytorch和xformers
    def prepare_torch(self, torch_ver, xformers_ver, use_mirror=False):
        if use_mirror is True:
            pip_mirror = "--index-url https://mirrors.cloud.tencent.com/pypi/simple --find-links https://mirror.sjtu.edu.cn/pytorch-wheels/cu121/torch_stable.html"
        else:
            pip_mirror = "--index-url https://pypi.python.org/simple --find-links https://download.pytorch.org/whl/cu121/torch_stable.html"
        
        if torch_ver != "":
            echo("安装 PyTorch")
            !pip install {torch_ver} {pip_mirror}
        if xformers_ver != "":
            echo("安装 xFormers")
            !pip install {xformers_ver} {pip_mirror}
    

    # 安装requirements.txt依赖
    def install_requirements(self, path, use_mirror=False):
        import os
        if use_mirror is True:
            pip_mirror = "--index-url https://mirrors.cloud.tencent.com/pypi/simple --find-links https://mirror.sjtu.edu.cn/pytorch-wheels/cu121/torch_stable.html"
        else:
            pip_mirror = "--index-url https://pypi.python.org/simple --find-links https://download.pytorch.org/whl/cu121/torch_stable.html"
        if os.path.exists(path):
            echo("安装依赖")
            !pip install -r "{path}" {pip_mirror}
        else:
            echo("依赖文件路径为空")


    # python软件包安装
    # 可使用的操作:
    # 安装: install -> install
    # 仅安装: install_single -> install --no-deps
    # 强制重装: force_install -> install --force-reinstall
    # 仅强制重装: force_install_single -> install --force-reinstall --no-deps
    # 更新: update -> install --upgrade
    # 卸载: uninstall -y
    def py_pkg_manager(self, pkg, type=None, use_mirror=False):
        if use_mirror is True:
            pip_mirror = "--index-url https://mirrors.cloud.tencent.com/pypi/simple --find-links https://mirror.sjtu.edu.cn/pytorch-wheels/cu121/torch_stable.html"
        else:
            pip_mirror = "--index-url https://pypi.python.org/simple --find-links https://download.pytorch.org/whl/cu121/torch_stable.html"

        if type == "install":
            func = "install"
            args = ""
        elif type == "install_single":
            func = "install"
            args = "--no-deps"
        elif type == "force_install":
            func = "install"
            args = "--force-reinstall"
        elif type == "force_install_single":
            func = "install"
            args = "install --force-reinstall --no-deps"
        elif type == "update":
            func = "install"
            args = "--upgrade"
        elif type == "uninstall":
            func = "uninstall"
            args = "-y"
            pip_mirror = ""
        else:
            echo(f"未知操作: {type}")
            return
        echo(f"执行操作: pip {func} {pkg} {args} {pip_mirror}")
        !pip {func} {pkg} {args} {pip_mirror}


    # 配置内存优化
    def tcmalloc(self):
        echo("配置内存优化")
        import os
        os.environ["LD_PRELOAD"] = "/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4"



# MANAGER
class MANAGER:
    WORKSPACE = ""
    WORKFOLDER = ""


    def __init__(self, workspace, workfolder) -> None:
        self.WORKSPACE = workspace
        self.WORKFOLDER = workfolder


    # 清理ipynb笔记的输出
    def clear_up(self, opt):
        from IPython.display import clear_output
        clear_output(wait=opt)


    # 检查gpu是否可用
    def check_gpu(self):
        echo("检测 GPU 是否可用")
        import tensorflow as tf
        echo(f"TensorFlow 版本: {tf.__version__}")
        if tf.test.gpu_device_name():
            echo("GPU 可用")
        else:
            echo("GPU 不可用")
            raise Exception("\n没有使用GPU，请在代码执行程序-更改运行时类型-设置为GPU！\n如果不能使用GPU，建议更换账号！")


    def select_list(self, data, name):
        # https://stackoverflow.com/questions/57219796/ipywidgets-dynamic-creation-of-checkboxes-and-selection-of-data
        # https://gist.github.com/MattJBritton/9dc26109acb4dfe17820cf72d82f1e6f
        import ipywidgets as widgets
        names = [] # 可选择的列表
        checkbox_objects = [] # 按钮对象
        for key in data:
            value = key[1]
            key = key[0].split("/").pop()
            if value == 1:
                select = True
            else:
                select = False
            checkbox_objects.append(widgets.Checkbox(value=select, description=key, )) # 扩展按钮列表
            names.append(key)

        arg_dict = {names[i]: checkbox for i, checkbox in enumerate(checkbox_objects)}

        ui = widgets.VBox(children=checkbox_objects) # 创建widget

        selected_data = []
        select_value = [] # 存储每个模型选择情况
        url_list = [] # 地址列表
        def select_data(**kwargs): # 每次点击按钮时都会执行
            selected_data.clear()
            select_value.clear()
            for key in kwargs:
                if kwargs[key] is True:
                    selected_data.append(key)
                    select_value.append(True)
                else:
                    select_value.append(False)

            list = ""
            for i in selected_data: # 已选择的模型列表(模型名称)
                list = f"{list}\n- {i}"
            print(f"已选择列表: {list}")
            j = 0
            url_list.clear()
            for i in select_value: # 返回的地址列表
                if i is True:
                    url_list.append(data[j][0])
                j += 1
        
        out = widgets.interactive_output(select_data, arg_dict)
        ui.children = [*ui.children, out]
        ui = widgets.Accordion(children=[ui,], titles=(name,))
        #display(ui, out)
        display(ui)
        return url_list



# SD_TRAINER
class SD_TRAINER(ARIA2, GIT, TUNNEL, MANAGER, ENV):
    WORKSPACE = ""
    WORKFOLDER = ""

    tun = TUNNEL(WORKSPACE, WORKFOLDER, 28000)

    def __init__(self, workspace, workfolder) -> None:
        self.WORKSPACE = workspace
        self.WORKFOLDER = workfolder


    def get_sd_model(self, url, filename = None):
        path = self.WORKSPACE + "/" + self.WORKFOLDER + "/sd-models"
        filename = url.split("/").pop() if filename is None else filename
        super().aria2(url, path, filename)


    def get_vae_model(self, url, filename = None):
        path = self.WORKSPACE + "/" + self.WORKFOLDER + "/sd-models"
        filename = url.split("/").pop() if filename is None else filename
        super().aria2(url, path, filename)


    def get_sd_model_from_list(self, list):
        for i in list:
            if i != "":
                self.get_sd_model(i, i.split("/").pop())


    def get_vae_model_from_list(self, list):
        for i in list:
            if i != "":
                self.get_vae_model(i, i.split("/").pop())


    def install_kohya_requirements(self):
        import os
        os.chdir(WORKSPACE + "/" + WORKFOLDER + "/sd-scripts")
        self.install_requirements(self.WORKSPACE + "/" + self.WORKFOLDER + "/sd-scripts/requirements.txt")
        os.chdir(WORKSPACE + "/" + WORKFOLDER)


    def fix_lang(self):
        # ???
        import os
        os.environ["LANG"] = "zh_CN.UTF-8"


    def fix_py_package(self, package, use_mirror):
        !rm -rf /opt/conda/lib/python3.10/site-packages/{package}
        !rm -rf /opt/conda/lib/python3.10/site-packages/{package}-*
        self.py_pkg_manager(package, "uninstall", use_mirror)
        self.py_pkg_manager(package, "install", use_mirror)


    def install(self, torch_ver, xformers_ver, sd, vae, use_mirror):
        import os
        self.check_gpu()
        self.prepare_env_depend(use_mirror)
        self.clone("https://github.com/Akegarasu/lora-scripts", self.WORKSPACE)
        os.chdir(f"{self.WORKSPACE}/{self.WORKFOLDER}")
        self.prepare_torch(torch_ver, xformers_ver)
        req_file = self.WORKSPACE + "/" + self.WORKFOLDER + "/requirements.txt"
        self.fix_py_package("aiohttp", use_mirror)
        self.install_kohya_requirements()
        self.install_requirements(req_file, use_mirror)
        self.py_pkg_manager("protobuf==3.20.0", "install", use_mirror)
        self.py_pkg_manager("numpy==1.26.4", "force_install", use_mirror)
        self.fix_lang()
        self.tcmalloc()
        self.get_sd_model_from_list(sd)
        self.get_vae_model_from_list(vae)

#############################################################

echo("初始化功能完成")
try:
    echo("尝试安装 ipywidgets 组件")
    !pip install ipywidgets -qq
    from IPython.display import clear_output
    clear_output(wait=False)
    INIT_CONFIG = 1
except:
    raise Exception("未初始化功能")

import ipywidgets as widgets


WORKSPACE = "/kaggle"
WORKFOLDER = "lora-scripts"
USE_NGROK = False
NGROK_TOKEN = ""
USE_CLOUDFLARE = False
USE_REMOTE_MOE = True
USE_LOCALHOST_RUN = True
USE_GRADIO_SHARE = False
TORCH_VER = ""
XFORMERS_VER = ""
sd_model = [
    ["https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/v1-5-pruned-emaonly.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/animefull-final-pruned.safetensors", 1],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/Counterfeit-V3.0_fp16.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/cetusMix_Whalefall2.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/cuteyukimixAdorable_neochapter3.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/ekmix-pastel-fp16-no-ema.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/ex2K_sse2.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/kohakuV5_rev2.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/meinamix_meinaV11.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/oukaStar_10.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/pastelMixStylizedAnime_pastelMixPrunedFP16.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/rabbit_v6.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/sweetSugarSyndrome_rev15.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/AnythingV5Ink_ink.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/bartstyledbBlueArchiveArtStyleFineTunedModel_v10.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/meinapastel_v6Pastel.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/qteamixQ_omegaFp16.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sd_1.5/tmndMix_tmndMixSPRAINBOW.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/sd_xl_base_1.0_0.9vae.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/animagine-xl-3.0.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/AnythingXL_xl.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/abyssorangeXLElse_v10.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/animaPencilXL_v200.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/animagine-xl-3.1.safetensors", 1],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/heartOfAppleXL_v20.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/baxlBartstylexlBlueArchiveFlatCelluloid_xlv1.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/kohaku-xl-delta-rev1.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/kohakuXLEpsilon_rev1.safetensors", 1],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/nekorayxl_v06W3.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/CounterfeitXL-V1.0.safetensors", 0],
    ["https://huggingface.co/licyk/sd-model/resolve/main/sdxl_1.0/ponyDiffusionV6XL_v6StartWithThisOne.safetensors", 0]
]
vae = [
    ["https://huggingface.co/licyk/sd-vae/resolve/main/sd_1.5/vae-ft-ema-560000-ema-pruned.safetensors", 0],
    ["https://huggingface.co/licyk/sd-vae/resolve/main/sd_1.5/vae-ft-mse-840000-ema-pruned.safetensors", 1],
    ["https://huggingface.co/licyk/sd-vae/resolve/main/sdxl_1.0/sdxl_fp16_fix_vae.safetensors", 1]
]

manager = MANAGER(WORKSPACE, WORKFOLDER)

torch_ver_state = widgets.Textarea(value="torch==2.3.0+cu121 torchvision==0.18.0+cu121 torchaudio==2.3.0+cu121", placeholder="请填写 PyTorch 版本", description="PyTorch 版本: ", disabled=False)
xformers_ver_state = widgets.Textarea(value="xformers==0.0.26.post1", placeholder="请填写 xFormers 版本", description="xFormers 版本: ", disabled=False)
use_ngrok_state = widgets.Checkbox(value=False, description="使用 Ngrok 内网穿透", disabled=False)
ngrok_token_state = widgets.Textarea(value="", placeholder="请填写 Ngrok Token（可在 Ngrok 官网获取）", description="Ngrok Token: ", disabled=False)
use_cloudflare_state = widgets.Checkbox(value=False, description="使用 CloudFlare 内网穿透", disabled=False)
use_remote_moe_state = widgets.Checkbox(value=True, description="使用 remote.moe 内网穿透", disabled=False)
use_localhost_run_state = widgets.Checkbox(value=True, description="使用 localhost.run 内网穿透", disabled=False)
use_gradio_share_state = widgets.Checkbox(value=False, description="使用 Gradio 内网穿透", disabled=False)
# 自定义模型下载
model_url = widgets.Textarea(value="", placeholder="请填写模型下载链接", description="模型链接: ", disabled=False)  #@param {type:"string"}
model_name = widgets.Textarea(value="", placeholder="请填写模型名称，包括后缀名，例：kohaku-xl.safetensors", description="模型名称: ", disabled=False)  #@param {type:"string"}
model_type = widgets.Dropdown(options=[("Stable Diffusion 模型（大模型）", "sd"), ("VAE 模型", "vae")], value="sd", description='模型种类: ')

display(torch_ver_state, xformers_ver_state, use_ngrok_state, ngrok_token_state, use_cloudflare_state, use_remote_moe_state, use_localhost_run_state, use_gradio_share_state, model_name, model_url, model_type)


sd_model_list = manager.select_list(sd_model,"Stable Diffusion 模型")
vae_list = manager.select_list(vae, "VAE 模型")

### 安装

In [None]:
try:
    i = INIT_CONFIG
    INIT_CONFIG_1 = 1
except:
    raise Exception("未运行\"参数配置\"单元")

TORCH_VER = torch_ver_state.value
XFORMERS_VER = xformers_ver_state.value
USE_NGROK = use_ngrok_state.value
NGROK_TOKEN = ngrok_token_state.value
USE_CLOUDFLARE = use_cloudflare_state.value
USE_REMOTE_MOE = use_remote_moe_state.value
USE_LOCALHOST_RUN = use_localhost_run_state.value
USE_GRADIO_SHARE = use_gradio_share_state.value
USE_MIRROR = False

sd_trainer = SD_TRAINER(WORKSPACE, WORKFOLDER)

import os
os.chdir(WORKSPACE)

echo(f"开始安装 SD Trainer")
sd_trainer.install(TORCH_VER, XFORMERS_VER, sd_model_list, vae_list, USE_MIRROR)
if model_url.value != "" and model_name.value != "":
    if model_type.value == "sd":
        sd_trainer.get_sd_model(model_url.value, model_name.value)
    elif model_type.value == "vae":
        sd_trainer.get_vae_model(model_url.value, model_name.value)
cp_data()
sd_trainer.clear_up(False)
echo(f"SD Trainer 安装完成")

### 启动

In [None]:
try:
    i = INIT_CONFIG_1
except:
    raise Exception("未运行\"安装\"单元")

import os
os.chdir(WORKSPACE + "/" + WORKFOLDER)
echo("启动 SD Trainer 中")
sd_trainer.tun.start(ngrok=USE_NGROK, ngrok_token=NGROK_TOKEN, cloudflare=USE_CLOUDFLARE, remote_moe=USE_REMOTE_MOE, localhost_run=USE_LOCALHOST_RUN, gradio=USE_GRADIO_SHARE)
!python "{WORKSPACE}"/lora-scripts/gui.py
sd_trainer.clear_up(False)
echo("SD Trainer 已关闭")