# Colab的默认环境跟这里的默认环境有差异，需要安装额外的依赖

In [None]:
import subprocess

libraries_to_install = [
    "black[jupyter]",
    "wget",
    "transformers",
    "bitsandbytes",
    "jax",
    "xformers",
    "triton==2.0.0.dev20221120",
]

tools_to_install = [
    "p7zip-full",
    "iputils-ping",
    "git-lfs"
]

installed_libraries = subprocess.run(["pip", "freeze"], capture_output=True).stdout.decode().split("\n")

# 尝试安装所有未安装的库
for library in libraries_to_install:
    if library == "black[jupyter]" and any("black" in installed_library for installed_library in installed_libraries):
        print(f"【 {library} 】已经安装，跳过安装")
        continue
    elif not any(library in installed_library for installed_library in installed_libraries):
        !pip install $library
    else:
        print(f"【 {library} 】已经安装，跳过安装")
        
# 尝试安装所有未安装的工具
update_needed = False
for tool in tools_to_install:
    exit_code = subprocess.run(["dpkg", "-s", tool], capture_output=True).returncode
    if exit_code != 0:
        update_needed = True
        break
    else:
        print(f"【 {tool} 】已经安装，跳过安装")
if update_needed:
    !apt-get update
    for tool in tools_to_install:
        exit_code = subprocess.run(["dpkg", "-s", tool], capture_output=True).returncode
        if exit_code != 0:
            !apt-get install -y $tool
    


# 配置accelerate

In [None]:
!chmod +x ./accelerate.sh
!./accelerate.sh

# 创建Temp文件夹并下载fast-DreamBooth.ipynb到Temp目录下

In [None]:
import wget
import os, sys

sys.path.append("../")  # 因为func与ipynb位于同一个目录下，所以要往上一层路径索引
from func.env import setProxyCLI, proxyWget

cb = setProxyCLI()
proxy = cb["proxy"]
region = cb["region"]
proxyURL = cb["proxyURL"]

# 在当前运行的ipynb文件所在的目录下创建临时文件夹temp
def create_temp_folder(temp_folder):
    if not os.path.exists(temp_folder):
        os.makedirs(temp_folder)


# 获取当前运行的 ipynb 文件所在的目录
cwd = os.getcwd()

# 在该目录下创建临时文件夹 temp
temp_folder = os.path.join(cwd, "temp")

# 如果temp文件夹不存在，则创建文件夹
create_temp_folder(temp_folder)

# 覆盖下载 fast-DreamBooth.ipynb 文件到 temp 文件夹中
file_name = "fast-DreamBooth.ipynb"
fastDreamBoothPath = os.path.join(temp_folder, file_name)
if os.path.exists(fastDreamBoothPath):
    !rm -rf $fastDreamBoothPath

url = "https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/main/fast-DreamBooth.ipynb"
out = os.path.join(temp_folder, file_name)

import subprocess

# Execute the command with a timeout of 5 seconds
exit_code = subprocess.call(["timeout", "--preserve-status", "5", "wget", "-O", out, url])

# Check the exit status
if exit_code == 124:
    print("Command timed out, trying again")
    # Execute the command again
    !$proxy && wget -O $out $url
else:
    print("Command completed successfully")

# fastDreamBoothPath = proxyWget(
#     url=url,
#     out=out,
#     proxyURL=proxyURL)

# print(fastDreamBoothPath)

# 创建Content文件夹

In [None]:
import os

def check_environment(output):
    environments = {
        "AutoDL": {
            "content_path": "/root/autodl-tmp/content",
            "env_name": "AutoDL"
        },
        "OpenBayes": {
            "content_path": "/openbayes/home/content",
            "env_name": "OpenBayes"
        }
    }
    # 检查输出是否包含字符串"autodl"或"openbayes"
    # for i in output:
    if "AutoDL" in output:
        return environments["AutoDL"]
    elif "OpenBayes" in output:
        return environments["OpenBayes"]
    # 如果输出中没有包含上述字符串，则返回空字典
    return {}


def detect_environment():
    # 初始化 content_path 和 env_name 变量
    content_path = None
    env_name = None

    # 将命令行存储在列表中
    commands = [
        "cd /openbayes/home && chmod +x /etc/welcome && /etc/welcome",
        "chmod +x /etc/autodl-motd && /etc/autodl-motd"
    ]

    # 遍历命令行列表，执行命令并存储输出
    for command in commands:
        try:
            import subprocess
            # 使用 subprocess 模块执行命令
            r = subprocess.run(command, shell=True, stdout=subprocess.PIPE)
            output = r.stdout
            # 将输出转换为字符串
            output = output.decode()
            # print(output)
            result = check_environment(output)
            content_path = result["content_path"]
            env_name = result["env_name"]
            # 如果检测到环境，则退出循环
            if content_path and env_name:
                break
        except Exception as e:
            # print("无法执行命令：", e)
            continue

    # 打印结果
    if content_path and env_name:
        print("当前运行环境：", env_name)
        print("内容路径：", content_path)
    else:
        print("未检测到当前运行环境")
        
    return {
        "content_path":content_path,
        "env_name":env_name
    }


content_path=detect_environment()["content_path"]
# 在根目录下创建Content文件夹
content_folder = content_path
if not os.path.exists(content_folder):
    os.makedirs(content_folder)

# 定义一个校验和生成函数，通过xxhash生成对应文件夹校验和（依据是目录结构和文件名），用于之后查找对应的文件夹

In [None]:
import os
import hashlib
import time

def generate_checksum(folder_path, print_elapsed_time=False):
    # 开始计时
    start_time = time.time()
    # 创建 sha256 哈希对象
    checksum = hashlib.sha256()
    # 遍历文件夹
    with os.scandir(folder_path) as entries:
        # 对文件和子文件夹按照字母顺序排序
        sorted_entries = sorted(entries, key=lambda entry: entry.name)
        for entry in sorted_entries:
            # 若为子文件夹，则递归调用 generate_checksum()
            if entry.is_dir():
                # 递归调用 generate_checksum()
                checksum.update(generate_checksum(entry.path).encode())
            # 若为文件，则更新文件名哈希值
            else:
                checksum.update(entry.name.encode())
    # 结束计时
    elapsed_time = time.time() - start_time
    # 输出所用时间
    if print_elapsed_time:
        print(f'生成校验码耗时: {elapsed_time:.2f} 秒')
    # 返回十六进制哈希值
    return checksum.hexdigest()


# # 生产环境下无需运行以下代码，因为已经获取到 stable-diffusion-v2-768 的 code 是 7d3e9649526266e0e69b9bad33e509975ba261fb5b6c8ce3b7c26495c4ac4dcb
# # 设置要生成校验码的文件夹路径
# targetFolder='/output/content/stable-diffusion-v2-768'

# # 调用 generate_checksum 函数生成校验码
# code = generate_checksum(
#     folder_path=targetFolder,
#     print_elapsed_time=True
# )

# # 输出校验码
# print(code)

# 定义find_folder函数，以便后续通过校验和找回对应的文件夹

In [None]:
import time

def find_folder(root_folder, checksum):
    start_time = time.time()
    # 先计算根目录的校验和
    root_checksum = generate_checksum(root_folder)
    if root_checksum == checksum:
        elapsed_time = time.time() - start_time
        print(f'\033[32m在目录 {root_folder} 下，找到校验和为【{checksum}】的目录，耗时: {elapsed_time:.2f} 秒，文件路径为：{root_folder}\033[0m')
        return root_folder
    # 遍历目录树
    for root, dirs, files in os.walk(root_folder):
        for folder in dirs:
            folder_path = os.path.join(root, folder)
            folder_checksum = generate_checksum(folder_path)
            if folder_checksum == checksum:
                elapsed_time = time.time() - start_time
                print(f'\033[32m在目录 {root_folder} 下，找到校验和为【{checksum}】的目录，耗时: {elapsed_time:.2f} 秒，文件路径为：{folder_path}\033[0m')
                return folder_path
    elapsed_time = time.time() - start_time
    print(f'\033[31m在目录 {root_folder} 下，未找到校验和为【{checksum}】的目录，耗时: {elapsed_time:.2f} 秒\033[0m')
    return None

# # 生产环境下无需运行以下代码，
# find_folder('/input1', 'c7ee605be833a2d9')

# 自动搜寻并挂载数据集和模型，或对数据集进行解压，或创建符号链接（Symbol Link 软连接）

In [None]:
import os

def create_symlink(source_path, symlink_path, symlink_name):
    # 将软链接路径拼接成完整的文件路径
    symlink_file = f"{symlink_path}/{symlink_name}"

    # 如果软链接文件存在，则先删除它
    if os.path.islink(symlink_file):
        os.remove(symlink_file)

    # 创建软链接
    os.symlink(source_path, symlink_file)

    # 验证软链接是否成功创建
    if os.path.islink(symlink_file):
        print(f"{symlink_file} 是一个符号链接。")
        print(f"\033[1m\033[92m{symlink_file}\033[0m ➜ \033[34m{os.readlink(symlink_file)}\033[0m")
    else:
        print(f"{symlink_file} 不是一个符号链接。")

        
# # 在output目录下创建一个tf_dir指向dreambooth打印的logs软连接  
# create_symlink(
#     source_path="/output/content/models/FuXingHao768/logs/", 
#     symlink_path='/output', 
#     symlink_name='tf_dir'
# )

env_name=detect_environment()["env_name"]

if env_name=="OpenBayes":
    # 遍历input0~4，找到带有依赖压缩包的数据集（或文件夹），并进行复制操作
    def check_dependencies_and_make_symbollink():
        # 遍历 input0、input1、input2、input3、input4 五个目录
        for i in range(5):
            dependencies_dir=find_folder(f'/input{i}','6c089def2938fcab36321d0fb27ec016a2b6fc5074a1c89e91b74ce973f307f2')
            if dependencies_dir:
                print(f'已找到带有依赖压缩包的数据集（或文件夹），目录所在位置为：{dependencies_dir}')
                break
        if dependencies_dir:
        # !cp -r $dependencies_dir/* /openbayes/home/content/
            for file in os.listdir(dependencies_dir):
                create_symlink(
                    source_path=os.path.join(dependencies_dir,file), # openbayes 特有的目录，是外部数据集挂载到当前镜像的目录，编号从0~4
                    symlink_path='/output/content', # fast-dreambooth.ipynb的 model download cell一般下载模型的路径
                    symlink_name=file # 模型名称
                )
    
    check_dependencies_and_make_symbollink()
    
    def check_model_and_make_symbollink():
        stable_diffusion_v2_dir=find_folder('/output/','7d3e9649526266e0e69b9bad33e509975ba261fb5b6c8ce3b7c26495c4ac4dcb') #该校验和是通过前两个 cell 中的 generate_checksum 函数生成的
        # print(stable_diffusion_v2_dir)
        if stable_diffusion_v2_dir:
            if stable_diffusion_v2_dir=='/output/content/stable-diffusion-v2-768':
                print('stable-diffusion-v2的diffuser形态的模型已经在目标位置，无需创建软连接')
            else:
                # 在Content目录下创建一个指向挂载模型的软连接
                create_symlink(
                    source_path=stable_diffusion_v2_dir, # openbayes 特有的目录，是外部数据集挂载到当前镜像的目录，编号从0~4
                    symlink_path='/output/content', # fast-dreambooth.ipynb的 model download cell一般下载模型的路径
                    symlink_name='stable-diffusion-v2-768' # 模型名称
                )
        else:
            print('找不到stable-diffusion-v2的diffuser形态的模型，尝试通过运行convertodiffv2.ipynb来获取模型')
            %run ../tools/model_convert_tools/convertodiffv2.ipynb
            check_model_and_make_symbollink()
    
    check_model_and_make_symbollink()
    
    # # 免安装（挂载数据集）DreamBooth所需依赖（能够省掉fast-DreamBooth中的第一个环节）
    # !rm -rf /usr/local/lib/python3.8/dist-packages
    # create_symlink(
    #     source_path='/input1/local/lib/python3.8/dist-packages', 
    #     symlink_path='/usr/local/lib/python3.8',
    #     symlink_name='dist-packages' 
    # )

    # 因为默认的Sessions路径太深（原作者是为了Colab而设计的），所以要在根目录创建一个链接到Sessions的快捷方式
    create_symlink(
        source_path='/output/content/gdrive/MyDrive/Fast-Dreambooth/Sessions', 
        symlink_path='/openbayes',
        symlink_name='Sessions' 
    )
    
elif env_name=="AutoDL":
    # 在/root目录下创建一个指向/autodl-tmp/content的软连接
    create_symlink(
        source_path='/root/autodl-tmp/content', # openbayes 特有的目录，是外部数据集挂载到当前镜像的目录，编号从0~4
        symlink_path='/root', # fast-dreambooth.ipynb的 model download cell一般下载模型的路径
        symlink_name='content' # 模型名称
    )
    # 因为默认的Sessions路径太深（原作者是为了Colab而设计的），所以要在根目录创建一个链接到Sessions的快捷方式
    create_symlink(
        source_path='/root/autodl-tmp/content/gdrive/MyDrive/Fast-Dreambooth/Sessions', 
        symlink_path='/root',
        symlink_name='Sessions' 
    )
    
    


# 安装Diffusers

In [None]:
import os
import subprocess

# 定义get_thisrepo_root函数，获取当前repo所在的父目录
def get_thisrepo_root():
    # 调用命令行工具git rev-parse，获取git项目的根目录
    git_root = subprocess.run(
        ["git", "rev-parse", "--show-toplevel"],
        capture_output=True,
        text=True,
    ).stdout.strip()

    # 获取根目录的父目录
    parent_dir = os.path.dirname(git_root)
    
    return parent_dir

installed_libraries = subprocess.run(["pip", "freeze"], capture_output=True).stdout.decode().split("\n")
# 检查 diffusers 库是否已经安装
diffusers_installed = False
for library in installed_libraries:
    if library.startswith("diffusers"):
        diffusers_installed = True
        break

# 如果尚未安装 diffusers 库，则下载并安装它
if not diffusers_installed:
    print('检测到Diffusers还没安装，正在安装...')
    # 获取根目录的父目录
    parent_dir = get_thisrepo_root()
    content_folder=os.path.join(parent_dir,'./content')

    # 检测 diffusers 目录是否存在
    if not os.path.exists(os.path.join(content_folder, 'diffusers')):
        # 如果 diffusers 目录不存在，则进行 clone
        !cd $content_folder &&\
        git clone --branch updt https://github.com/TheLastBen/diffusers
    else:
        # 如果 diffusers 目录已经存在，则不进行 clone，直接输出提示信息
        print('[1;32mdiffusers 目录已经存在，不进行 clone 操作。\033[0m')    
    
    !pip uninstall -y diffusers
    !pip install /openbayes/home/content/diffusers
    print('[1;32m已成功安装Diffusers!\033[0m')
else:
    print('[1;32m已安装Diffusers，跳过下载和安装!\033[0m')

# 找到Diffusers项目中的dreambooth.py进行打印信息方面的修改

In [None]:
# 定义get_thisrepo_root函数，获取当前repo所在的父目录
def get_thisrepo_root():
    # 调用命令行工具git rev-parse，获取git项目的根目录
    git_root = subprocess.run(
        ["git", "rev-parse", "--show-toplevel"],
        capture_output=True,
        text=True,
    ).stdout.strip()

    # 获取根目录的父目录
    parent_dir = os.path.dirname(git_root)
    
    return parent_dir

# 定义 get_file_path 函数，遍历目录，找到与 anchor_file 位于同一目录下的 target_file
def get_file_path(root_dir, target_file, anchor_file):
    """
    遍历目录，返回target_file文件所在的目录路径，如果目录中有anchor_file文件
    
    Parameters
    ----------
    root_dir : str
        要遍历的目录
    target_file : str
        要查找的文件名
    anchor_file : str
        用于比较的文件名
        
    Returns
    -------
    str
        target_file文件所在的目录路径
    """
    # 遍历root_dir目录中的所有文件和子目录
    for root, dirs, files in os.walk(root_dir, topdown=True):
        # 如果当前目录中有target_file文件
        if target_file in files:
            # 获取anchor_file文件的路径
            anchor_file_path = os.path.join(root, anchor_file)
            # 获取target_file文件的路径
            target_file_path = os.path.join(root, target_file)
            # 如果anchor_file文件存在
            if os.path.exists(anchor_file_path):
                # 如果anchor_file和target_file文件在同一目录中
                if os.path.dirname(anchor_file_path) == os.path.dirname(target_file_path):
                    # 返回target_file文件的路径
                    return target_file_path

parent_dir = get_thisrepo_root()
content_folder=os.path.join(parent_dir,'./content')
diffusers_dir=os.path.realpath(os.path.join(content_folder, 'diffusers'))

train_dreambooth_dir=get_train_dreambooth_py_dir(
    diffusers_dir,
    target_file='train_dreambooth.py',
    anchor_file='train_dreambooth_kg.py'
)

print(train_dreambooth_dir)


# 清理Notebook中的输出，对Colab版中的路径进行替换，以适应openbayes或autodl的路径

In [None]:
import json
import os
import re


def modify_file_name(file_path):
    # 获取文件名和文件扩展名
    file_name, file_ext = os.path.splitext(file_path)
    # 在文件名后加上 -modified
    modified_file_path = file_name + "-modified" + file_ext
    return modified_file_path


def replace_root_path(notebook: str, old_root: str, new_root: str) -> str:
    # 加载notebook
    with open(notebook, "r") as f:
        nb = json.load(f)

    env_name=detect_environment()["env_name"]
    
    if env_name=="OpenBayes":
        # 删除指定的单元格
        delList = [
            # '@markdown # Dependencies',
            '@markdown #Instance Images', #Instance Images这个cell也不需要了，因为本身已经交代了Session
            '@markdown #Concept Images',
            'Only if you have trouble connecting to the local server.',
            "@markdown - [Create a write access token](https://huggingface.co/settings/tokens) ",
            "@markdown - Upload showcase images of your trained model",
            "Display the list of sessions from your gdrive and choose which ones to remove.",
            "Test The Trained Model",
            "Upload The Trained Model to Hugging Face",
            "@markdown  - Important! Choose the correct version and resolution of the model\n"
        ]
    elif env_name=="AutoDL":
        delList = [
            '@markdown #Concept Images',
            'Only if you have trouble connecting to the local server.',
            "@markdown - [Create a write access token](https://huggingface.co/settings/tokens) ",
            "@markdown - Upload showcase images of your trained model",
            "Display the list of sessions from your gdrive and choose which ones to remove.",
            "Test The Trained Model",
            "Upload The Trained Model to Hugging Face",
            "@markdown  - Important! Choose the correct version and resolution of the model\n"
        ]        
    
    # 由于 Python 中的 for 循环是基于迭代器的，所以如果在循环过程中对列表进行修改（如在这里的删除操作），可能会导致结果不如预期。
    for cell in nb['cells'][:]: # 在 Python 中，[:] 可以用来复制一个列表
        delete_cell = False
        for line in cell['source']:
            if any(keyword in line for keyword in delList):
                delete_cell = True
                break
        if delete_cell:
            nb['cells'].remove(cell)

    for cell in nb["cells"]:
        if cell["cell_type"] == "code":
            # Initialize the flag variable
            import_added = False
            # 隐藏代码
            cell["metadata"] = {"collapsed": True}
            # 清除代码单元格的执行次数
            cell["execution_count"] = None
            # 清除代码单元格的输出
            cell["outputs"] = []
            
            for i, line in enumerate(cell["source"]):
                # 添加一个条件判断，只有当该行不是网址路径时才进行替换
                if not line.startswith(("http", "https")):
                    # 替换旧的根路径为新的根路径
                    cell["source"][i] = line.replace(old_root, new_root)
                # 把引入google.colab的库的代码行标注为注释
                if "google.colab" in line:
                    cell["source"][i] = "# " + line
                # 把使用google drive盘挂载的代码行标注为注释
                if "drive.mount" in line:
                    cell["source"][i] = "# " + line
                # 如果找到 "!wget" 字符串，就将其替换为 "!proxy && wget"
                if "!wget" in line:
                    cell["source"][i] = line.replace("!wget", "!$proxy && wget")
                # 定义 proxyWget 函数（请确保已经定义了 proxyWget 函数）
                
                import textwrap

                if "wget.download" in line:
                    # 使用正则表达式查找所有使用了 proxyWget 函数的行
                    match = re.search(r"wget\.download\(.+\)", line)
                    if match:
                        # 获取 wget.download 函数调用的参数
                        params = match.group().strip("wget.download()")
                        # 计算新的一行代码的缩进
                        indent = " " * line.count(" ")
                        # 将 wget.download 函数调用替换为带有新参数的函数调用
                        new_line = (
                            f"{indent}proxyWget({params}, os.getcwd(), '{proxyURL}')\n"
                        )
                        cell["source"][i] = new_line

            for i, line in enumerate(cell["source"]):
                # 替换@param{type: 'xxx'}中的单引号为双引号
                if "@param" in line:
                    cell["source"][i] = line.replace("'", '"')

            for i, line in enumerate(cell["source"]):
                # Check if the line contains "wget" and add the import statement if necessary
                if "wget" in line and not import_added:
                    import_statements = [
                        "import sys",
                        "from IPython.utils import capture",
                        "sys.path.append('../../')",
                        "from func.env import setProxyCLI,proxyWget",
                        "cb=setProxyCLI()",
                        "proxy=cb['proxy']",
                        "region=cb['region']",
                        "proxyURL=cb['proxyURL']",
                        "",
                    ]
                    import_statements_with_newline = [
                        s + "\n" for s in import_statements
                    ]
                    # print(import_statements_with_newline)
                    cell["source"] = import_statements_with_newline + cell["source"]
                    # Set the flag to indicate that the import has been added
                    import_added = True
                    
            for i, line in enumerate(cell["source"]):    
                # 把pip安装的静默去掉
                if "pip install -q" in line:
                    cell["source"][i] = line.replace("pip install -q", "pip install")
                # 把wget的静默去掉
                if "wget -q" in line:
                    cell["source"][i] = line.replace("wget -q", "wget")

                # # 原始字符串
                # original = "!cp -r /openbayes/home/content/usr/local/lib/python3.8/dist-packages /usr/local/lib/python3.8/"
                # if original in line:
                #     # 替换python路径
                #     cell["source"][i] = re.sub(r'/usr/local/lib/python3.8/', '/usr/local/lib/python/', original)
                
    # 遍历所有单元格，去掉科学计数法表示的数字的双引号，并且保留每个数字两边的中括号
    
    # 定义正则表达式
    pattern = r'"([\d.]+e[+-]\d+)"'  
    for cell in nb['cells']:
        # 如果单元格是代码单元格
        if cell['cell_type'] == 'code':
            # 遍历所有行
            for i, line in enumerate(cell['source']):
                # 在整个字符串中匹配正则表达式
                match = re.search(pattern, line)
                # 如果匹配成功
                if match:
                    # 将文本中的双引号去掉
                    new_line = re.sub(pattern, r'\1', line)
                    
                    # print(line)  # 输出原来的行
                    # print(new_line)  # 输出修改后的行
                    
                    # 将文本替换回原来的位置
                    cell['source'][i] = new_line

    # 保存修改后的notebook
    newNotebook = modify_file_name(notebook)
    with open(newNotebook, "w") as f:
        json.dump(nb, f)

    return newNotebook

newNotebookPath = replace_root_path(
    notebook=fastDreamBoothPath, old_root="/content", new_root=content_folder
)

# #对newNotebookPath进行格式化处理
# !black $newNotebookPath

print(newNotebookPath)

# 根据特征片段进行整个Cell的替换

In [None]:
import json
import copy

def modify_ipynb_file(path_to_modified, path_to_clipboard):
    """
    修改被修改对象文件，使用代码剪贴板文件中具有相同特征字符串的Cell来替换。
    """
    # 读取被修改对象和代码剪贴板文件
    with open(path_to_modified, 'r') as f:
        modified = json.load(f)
    with open(path_to_clipboard, 'r') as f:
        clipboard = json.load(f)

    # 特征字符串列表，用于搜索被修改对象文件中的Cell和代码剪贴板文件中的Cell
    replaceList = [
            '@markdown # Dependencies',
        ]        
    
    # 遍历被修改对象文件中的所有Cell
    for cell_modified in modified['cells'][:]: # 在 Python 中，[:] 可以用来复制一个列表
        # 标记是否需要替换当前Cell
        replace_cell = False
        # 遍历当前Cell的所有行
        for line_modified in cell_modified['source']:
            # 如果当前行包含特征字符串，则标记需要替换
            if any(keyword in line_modified for keyword in replaceList):
                replace_cell = True
                break
        # 如果标记需要替换
        if replace_cell:
            # 遍历代码剪贴板文件中的所有Cell
            for cell_clipboard in clipboard['cells'][:]:
                # 遍历当前Cell的所有行
                for line_clipboard in cell_clipboard['source']:
                    # 如果当前行包含特征字符串
                    for i, keyword in enumerate(replaceList):
                        if keyword in line_clipboard:
                            # 使用深拷贝（deep copy）来替换Cell
                            cell_modified['source'] = copy.deepcopy(cell_clipboard['source'])
                            print(f'代码单元格包含特征字符串“{keyword}”的内容已被代码剪贴板中包含相同特征字符串的代码单元格所替换。')
                            break
    # 将修改后的被修改对象文件写回磁盘
    with open(path_to_modified, 'w') as f:
        json.dump(modified, f)

# 示例：修改被修改对象文件，使用代码剪贴板文件中的内容

modify_ipynb_file(
    path_to_modified=newNotebookPath,
    path_to_clipboard='./clip/clipForOpenbayes.ipynb'
)

# 对指定的变量名进行字符串和参数的替换

In [None]:
import json
import os

def modify_file_name(file_path):
    # 获取文件名和文件扩展名
    file_name, file_ext = os.path.splitext(file_path)
    # 在文件名后加上 -modified
    modified_file_path = file_name + "-customData" + file_ext
    return modified_file_path


# 读取修改后的notebook
def update_ipynb_vars(inputFilePath, outputFilePath, var_map):
    with open(inputFilePath, "r") as f:
        notebook_data = json.load(f)

    # 遍历所有单元格
    for cell in notebook_data["cells"]:
        # 对代码单元格进行操作
        if cell["cell_type"] == "code":
            
            cell["metadata"]={
                "collapsed": True,
                "jupyter": {
                    "outputs_hidden": True,
                    "source_hidden": True
                },
                "tags": []
            }
            
            for i in range(len(cell["source"])):
                for var, val in var_map.items():
                    if isinstance(val, str):
                        val = f'"{val}"'
                    # 有 #@ 说明这一行是提供给用户自行填充的参数，所以才进行参数替换
                    identifier = "#@"
                    thisLine = cell["source"][i]
                    if identifier in thisLine and (
                        thisLine.startswith(f"{var} =")
                        or thisLine.startswith(f"{var}=")
                    ):
                        annotation = thisLine.split(identifier)[1]
                        cell["source"][i] = f"{var}={val} {identifier+annotation}"

    # 保存修改后的notebook_data
    with open(outputFilePath, "w") as f:
        json.dump(notebook_data, f)

    # 此函数来验证文件是否已经修改成功
    # 读取修改后的文件
    with open(outputFilePath, "r") as f:
        modified_data = f.read()
    if notebook_data == modified_data:
        print("修改失败")
        return
    else:
        print("修改成功")
        return

    return outputFilePath


# 更新变量
replacements = {
    "Model_Version": "V2.1-768px",
    "Session_Name": "FuXingHao768",
    "Crop_size": 768,
    "Remove_existing_instance_images": False,
    # "IMAGES_FOLDER_OPTIONAL": "/openbayes/home/content/gdrive/MyDrive/Fast-Dreambooth/Sessions/FuXingHao768/instance_images",
    "Crop_images": False,
    "Resolution": 768,
}

# 更新变量
updatedDataNotebookPath = modify_file_name(newNotebookPath)
update_ipynb_vars(newNotebookPath, updatedDataNotebookPath, replacements)

# 检查最终版本的ipynb里面的赋值情况

In [None]:
import json
import os


def print_with_color(string, keywords):
    # 创建一个空字符串用于存储结果
    result = ""

    # 遍历字符串中的每一个字眼
    for word in string.split(" "):
        if os.path.isabs(word.strip('"')):
            result += f"\033[37;42m{word}"
        elif word.isnumeric():
            result += f"\033[0;36m{word}"
        else:
            # 如果这个字眼是关键词，就将它的颜色更改为相应的颜色
            if word in keywords:
                result += f"{keywords[word]}{word}"
            # 否则，使用默认颜色
            else:
                result += f"\033[1;33m{word}"
        result += " \033[0m"  # 收尾都要加一个' \033[0m'
    return result


keywords = {
    "=": "\033[1;34m",  # &&：正常颜色
    "False": "\033[1;31m",  # git：红色
    "True": "\033[0;32m",  # user：绿色
    "config": "\033[1;33m",  # config：黄色
    "--global": "\033[1;34m",  # --global：蓝色
    '""': "\033[1;35m",  # http.proxy：紫色
    "user.email": "\033[1;36m",  # user.email：深蓝色
    "user.name": "\033[1;37m",  # user.name：深绿色
    "http.proxy": "\033[1;38m",  # http.proxy：深紫色
    "proxyURL": "\033[1;39m",  # proxyURL：深红色
    "cd": "\033[1;31m",  # cd：红色
    "dir_path": "\033[1;41m",  # dir_path：深黑色
}

# colored_string = print_with_color("git config --global user.email user@example.com", keywords)
# print(colored_string)


def detect_constant_assignments(inputFilePath: str):
    # 定义一个 set，用于记录已经出现过的字符串
    seen_strings = set()
    empty_strings = set()
    # 打开文件并读取数据
    with open(inputFilePath, "r") as f:
        ipynb_data = json.load(f)

    # 遍历所有的代码单元
    for cell in ipynb_data["cells"]:
        if cell["cell_type"] == "code":
            # 遍历当前代码单元中的所有行
            for line in cell["source"]:
                # 检测是否是赋值行，并获取变量名和值
                parts = line.split("=")
                if len(parts) == 2 and "@param" in line:
                    var_name = parts[0].strip()
                    value = parts[1].strip()
                    before, after = value.split("#@para")
                    # 如果 parts[0] 还没有出现过，就输出
                    if parts[0] not in seen_strings:
                        seen_strings.add(parts[0])
                        print(
                            f"\033[1;37m{var_name}"
                            + f'{ print_with_color(f" = {before}",keywords) }'
                            + "\033[1;30m"
                            + f"#@para{after}"
                        )


detect_constant_assignments(updatedDataNotebookPath)

# 可以用线上的这个工具来进行Escape Code Color的合成：https://ansi.gabebanks.net/
def refsColor():
    print("\033[1;33m \t\t\tprint颜色参考 \033[0m \n");
    print("\033[1;33m \t 字体颜色:\n");
    print("\033[0m none:\n");
    print("\033[0;30m back:\n");
    print("\033[1;30m dark_back:\n");
    print("\033[0;34m blue:\n");
    print("\033[1;34m light_blue:\n");
    print("\033[0;32m green:\n");
    print("\033[1;32m light_green:\n");
    print("\033[0;36m cyan:\n");
    print("\033[1;36m light_scan:\n");
    print("\033[0;31m red:\n");
    print("\033[1;31m light_read:\n");
    print("\033[0;35m purple:\n");
    print("\033[1;35m light_purple:\n");
    print("\033[0;33m brown:\n");
    print("\033[1;33m yellow:\n");
    print("\033[0;37m light_yellow:\n");
    print("\033[1;37m white:\n");
    print("\033[1;33m \t 背景颜色:\n");
    print("\033[0m none:\n");
    print("\033[0;40m back:\n");
    print("\033[0;44m blue:\n");
    print("\033[0;42m green:\n");
    print("\033[0;46m cyan:\n");
    print("\033[0;41m red:\n");
    print("\033[0;45m purple:\n");
    print("\033[0;43m brown:\n");
    print("\033[0;47m light_yellow:\033[0m\n");
    print("\033[1;33m \t 背景字体颜色:\n");
    print("\033[47;31m hello world\033[?25l");
    print("\033[42;50m hello world\033[?25l");

# refsColor()