## 搜索文件函数

In [None]:
import os

def find_file(root, target):
    # 下面的命令会将标准错误输出重定向到/dev/null，这样就会抑制错误信息。
    # 这里的目的是忽略标准错误输出中可能产生的重定向错误信息。
    target_list = !find $root -name $target 2>/dev/null
    # target_list = list(filter(lambda x: x.endswith('.whl'), target))
    print(target_list)
    if target_list:
        target_path = target_list[0]
        return target_path
    else:
        return None

def find_file_list(root_dir, target_name):
    paths = []
    for dirpath, dirnames, filenames in os.walk(root_dir, topdown=True):
        if target_name in filenames:
            path = os.path.join(dirpath, target_name)
            paths.append(path)
        if len(dirpath.split(os.sep)) - len(root_dir.split(os.sep)) >= 2:
            # 如果目录深度超过3层，则不进入子目录，跳过这一层
            del dirnames[:]
    paths.sort(key=len)
    return paths if paths else None

def find_file_in_input(target):
    # 遍历 input0、input1、input2、input3、input4 五个目录
    for i in range(5):
        file_dir = find_file(f"/input{i}",target)
        if file_dir:
            # 如果找到了 v2-1_768-nonema-pruned.ckpt 文件，则输出文件的完整路径
            # print(f'找到目标文件，详细路径为：{file_dir}')
            return file_dir
    return None
    

## 挂载器函数

In [None]:
import shutil

def create_symlink(source_path, symlink_path, symlink_name=None):
    # 如果没有指定symlink_name，则自动求解
    if not symlink_name:
        symlink_name = os.path.basename(symlink_path)
        symlink_path = os.path.dirname(symlink_path)

    # 将软链接路径拼接成完整的文件路径
    symlink_file = os.path.join(symlink_path, symlink_name)

    # 如果软链接文件存在，则先删除它
    if os.path.exists(symlink_file):
        if os.path.islink(symlink_file):
            os.unlink(symlink_file)
        else:
            if os.path.isdir(symlink_file):
                shutil.rmtree(symlink_file)
            else:
                os.remove(symlink_file)

    # 创建软链接
    os.symlink(source_path, symlink_file)

    # 验证软链接是否成功创建
    if os.path.islink(symlink_file):
        print(f"\033[1m\033[92m{symlink_file}\033[0m ➜ \033[34m{os.readlink(symlink_file)}\033[0m")
    else:
        print(f"{symlink_file} 不是一个符号链接。")

def mount(source_path,symlink_path,symlink_name):
    return create_symlink(
        source_path = source_path,
        symlink_path = symlink_path,
        symlink_name = symlink_name 
    )
    
def find_file_in_input_and_mount(target,symlink_path,symlink_name=None):
    source_path = find_file_in_input(target)
    if not symlink_name:
        symlink_name = target
    print(source_path,symlink_path,symlink_name)
    symlink_file_path = mount(
        source_path = source_path,
        symlink_path = symlink_path,
        symlink_name = symlink_name
    )
    return source_path,symlink_file_path
    

## 搜索文件夹结构函数

In [None]:
def check_file(file_path, size):
    return os.path.exists(file_path) and os.path.getsize(file_path) == size

import os

def find_target_model_dist(env_name, model_name):
    if env_name == 'AutoDL':
        pre_nas_list = [
            '/root/autodl-nas',
            '/root/autodl-fs',
            # '/root/autodl-pub',
        ]

        nas_list = []

        for n in pre_nas_list:
            if os.path.exists(n):
                nas_list.append(n)
        # print(nas_list)

        for i in nas_list:
            file_list = find_file_list(f'{i}', model_name)
            if file_list:
                return file_list[0]

    elif env_name == 'OpenBayes':
        for i in range(0, 5):  # 在input中进行搜索
            file_list = find_file_list(f'/input{i}', model_name)
            if file_list:
                return file_list[0]

    return None

def check_model_and_make_symbollink_general(model_name, models_dir, use_symlink=True, repo_dir=None):
    if repo_dir is None:
        raise ValueError("repo_dir must be provided.")
    
    if os.path.isabs(models_dir):  # 判断models_dir是否为绝对路径
        abs_models_dir = models_dir
    else:
        abs_models_dir = os.path.realpath(os.path.join(repo_dir, models_dir))
    
    if not os.path.exists(abs_models_dir):
        os.makedirs(abs_models_dir)
        print(f'创建目录 {abs_models_dir}')

    model_path = os.path.realpath(os.path.join(abs_models_dir, model_name))

    target_model_dist = find_target_model_dist(env_name, model_name)

    if target_model_dist:
        print(target_model_dist)

        if os.path.exists(model_path) and check_file(model_path, os.path.getsize(target_model_dist)):
            print(f'{model_name}的模型已经在目标位置，无需创建软连接')
        else:
            if use_symlink:
                # create symbolic link
                create_symlink(
                    source_path=target_model_dist,  # openbayes 特有的目录，是外部数据集挂载到当前镜像的目录，编号从0~4
                    symlink_path=abs_models_dir,  # fast-dreambooth.ipynb的 model download cell一般下载模型的路径
                    symlink_name=model_name  # 模型名称
                )
                print('创建符号链接完成！')
            else:
                # copy file
                if not os.path.exists(model_path) or not check_file(model_path, os.path.getsize(target_model_dist)):
                    print('正在进行复制...')
                    !cp $target_model_dist $abs_models_dir/$model_name
                    print('复制完成！')
                else:
                    print(f'{model_name}的模型已经在目标位置，无需复制')
        return os.path.dirname(target_model_dist)
    else:
        print(f'不存在{model_name}')
        return None
    
def check_model_and_make_symbollink(model_name, models_dir, use_symlink=True):
    global webUIDir
    try:
        webUIDir
        if len(webUIDir)==0:
            webUIDir = getWebUIDir()
    except:
        webUIDir = getWebUIDir()
    # print(webUIDir)
    check_model_and_make_symbollink_general(model_name, models_dir, use_symlink,webUIDir)



In [None]:
# import os

# def search_dir_structure(root_dir, structure):
#     """
#     Searches for a specific directory structure in a given root directory.
#     root_dir: str, the root directory to search in
#     structure: dict, the directory structure to search for
#     """
#     dirpaths=[]
#     for dirpath, dirnames, filenames in os.walk(root_dir):
#         match = True
#         for subdir in structure:
#             if subdir not in dirnames:
#                 match = False
#                 break
#             # Check if the subdirectory has the same structure as defined in the structure
#             subdir_path = os.path.join(dirpath, subdir)
#             match = match and search_dir_structure(subdir_path, structure[subdir])
#         if match:
#              dirpaths.append(dirpath)    
#     return dirpaths
        
# structure = {
#     "BLIP": {
#         "data": {},
#         "models": {}
#     },
#     "CodeFormer": {

#     },
#     "k-diffusion": {},
#     "stable-diffusion-stability-ai": {},
#     "taming-transformers": {}
# }

# dir_list = search_dir_structure("/", structure)

# input_repos_dir = None

# for i in dir_list:
#     if 'input' in i:
#         input_repos_dir = i
#         break

# repos_dir = os.path.realpath(os.path.join(webUIDir,'./repositories'))

# print(input_repos_dir)
# print(os.path.dirname(repos_dir))
# repos_in_sd_list = search_dir_structure(os.path.dirname(repos_dir), structure)

# repos_in_sd_dir_toTest = None
# if len(repos_in_sd_list)>0:
#     for i in dir_list:
#         if 'stable-diffusion-webui' in i:
#             repos_in_sd_dir_toTest = i
#             break
            
# if input_repos_dir:
#     if os.path.exists(repos_dir) and search_dir_structure(os.path.dirname(repos_dir), structure):
#         print(f"{repos_dir} 已经存在并拥有正确的Repos")
#     else:
#         if not os.path.exists(repos_dir):
#             !mkdir $repos_dir
#         # !cp -r $input_repos_dir/* $repos_dir
#         if os.path.exists(repos_dir):
#             !rm -rf $repos_dir
#         !ln -s $input_repos_dir $repos_dir
# else:
#     print('找不到对应的数据集')