# 准备工作

In [1]:
#@检查GPU信息，电脑没显卡可以不用看
!nvidia-smi -L

'nvidia-smi' 不是内部或外部命令，也不是可运行的程序
或批处理文件。


In [2]:
import os, subprocess

In [2]:
# 安装库，已经有库的可以不用运行
def setup():
    install_cmds = [
        ['pip', 'install', 'gradio'],    # 安装 gradio 库
        ['pip', 'install', 'open_clip_torch'],    # 安装 open_clip_torch 库
        ['pip', 'install', 'clip-interrogator'],    # 安装 clip-interrogator 库
    ]
    # 遍历所有安装命令，并执行
    for cmd in install_cmds:
        print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))

# 调用函数以安装所需库
setup()

Defaulting to user installation because normal site-packages is not writeable

Defaulting to user installation because normal site-packages is not writeable

Defaulting to user installation because normal site-packages is not writeable



In [3]:
# 定义需要用到的 CLIP 和 caption 模型
caption_model_name = 'blip-large' # @param ["blip-base", "blip-large", "git-large-coco"]
clip_model_name = 'ViT-L-14/openai' # @param ["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k"]

In [4]:
# 导入 Gradio 和 clip-interrogator 库
import gradio as gr   # GUI库
from clip_interrogator import Config, Interrogator

In [5]:
# 配置 Interrogator
config = Config()
config.clip_model_name = clip_model_name
config.caption_model_name = caption_model_name
ci = Interrogator(config)

Loading caption model blip-large...
Loading CLIP model ViT-L-14/openai...
Loaded CLIP model and data in 9.28 seconds.


In [6]:
# 对输入的图像进行分析
def image_analysis(image):
    image = image.convert('RGB')
    # 将图像转换为特征向量
    image_features = ci.image_to_features(image)
    
    # 获取与图像相似的前五个介质、艺术家、风格、流行趋势和口味
    top_mediums = ci.mediums.rank(image_features, 5)
    top_artists = ci.artists.rank(image_features, 5)
    top_movements = ci.movements.rank(image_features, 5)
    top_trendings = ci.trendings.rank(image_features, 5)
    top_flavors = ci.flavors.rank(image_features, 5)

    # 获取每个介质、艺术家、风格、流行趋势和口味的排名
    medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))}
    artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))}
    movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}
    trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}
    flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}
    
    return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks

In [7]:
# 将图像转换为提示文本
def image_to_prompt(image, mode):
    # 配置 Interrogator 的参数
    ci.config.chunk_size = 2048 if ci.config.clip_model_name == "ViT-L-14/openai" else 1024
    ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == "ViT-L-14/openai" else 1024
    image = image.convert('RGB')
    if mode == 'best':
        return ci.interrogate(image)
    elif mode == 'classic':
        return ci.interrogate_classic(image)
    elif mode == 'fast':
        return ci.interrogate_fast(image)
    elif mode == 'negative':
        return ci.interrogate_negative(image)

In [8]:
# 避免没显卡的电脑报错
import warnings
warnings.filterwarnings("ignore", message="User provided device_type of 'cuda', but CUDA is not available. Disabling")

# 生成图片解析GUI

In [9]:
# 定义生成Prompt的函数
def prompt_tab():
    # 创建一个Column块
    with gr.Column():
        # 创建一个Row块
        with gr.Row():
            # 创建一个用于展示图像的Image控件
            image = gr.Image(type='pil', label="Image")
            # 创建一个Column块
            with gr.Column():
                # 创建一个用于选择Prompt生成模式的Radio控件
                mode = gr.Radio(['best', 'fast', 'classic', 'negative'], label='Mode', value='best')
        # 创建一个用于输入Prompt的Textbox控件
        prompt = gr.Textbox(label="Prompt")
    # 创建一个用于触发Prompt生成的Button控件
    button = gr.Button("Generate prompt")
    # 绑定Button的点击事件，指定触发的函数为image_to_prompt，
    # 并指定输入参数为image和mode，输出参数为prompt
    button.click(image_to_prompt, inputs=[image, mode], outputs=prompt)

# 定义图像分析的函数    
def analyze_tab():
    # 创建一个Column块
    with gr.Column():
        # 创建一个用于展示图像的Image控件
        with gr.Row():
            image = gr.Image(type='pil', label="Image")
        # 创建多个用于显示图像分析结果的Label控件
        with gr.Row():
            medium = gr.Label(label="Medium", num_top_classes=5)
            artist = gr.Label(label="Artist", num_top_classes=5)        
            movement = gr.Label(label="Movement", num_top_classes=5)
            trending = gr.Label(label="Trending", num_top_classes=5)
            flavor = gr.Label(label="Flavor", num_top_classes=5)
    # 创建一个用于触发图像分析的Button控件
    button = gr.Button("Analyze")
    # 绑定Button的点击事件，指定触发的函数为image_analysis，
    # 并指定输入参数为image，输出参数为medium、artist、movement、trending和flavor
    button.click(image_analysis, inputs=image, outputs=[medium, artist, movement, trending, flavor])

# 创建一个包含两个Tab的GUI界面    
with gr.Blocks() as ui:
    # 创建一个Tab，用于Prompt生成
    with gr.Tab("Prompt"):
        prompt_tab()
    # 创建一个Tab，用于图像分析
    with gr.Tab("Analyze"):
        analyze_tab()

# 启动GUI界面，show_api和debug参数分别控制是否显示API信息和是否开启调试模式
ui.launch(show_api=False, debug=False)

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




100%|█| 6/6 [00:00<00:00, 12.03it/s]
100%|█| 50/50 [00:01<00:00, 45.70it/
