<a href="https://colab.research.google.com/github/j7tfj7f8k1f/DL-Fine-Tuning/blob/main/clip_inference_gradio.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 前置

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install torch torchvision transformers gradio
!transformers-cli env



# CLIP Gradio使用 - 文搜圖/ 圖搜文

In [None]:
import torch
from PIL import Image
from transformers import ChineseCLIPProcessor, ChineseCLIPModel
import gradio as gr
import os
import base64
from io import BytesIO

# 設置設備
device = "cuda" if torch.cuda.is_available() else "cpu"

# 載入模型和處理器
model = ChineseCLIPModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
model.load_state_dict(torch.load("/content/drive/MyDrive/best_clip_model.pth", map_location=device))
model = model.to(device)
model.eval()

processor = ChineseCLIPProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")

# 圖鑑資料夾路徑
image_dir = "/content/drive/MyDrive"
output_image_path = "/content/drive/MyDrive"  # 定義輸出圖片的路徑

# 加載中文詞彙表
with open('/content/drive/MyDrive/word_list.txt', 'r', encoding='utf-8') as f:
    vocab = [line.strip() for line in f.readlines()]


# 找跟圖最相似的商品名
def find_similar_words(image,top_k=3):
    try:
        # 確保圖片格式為 RGB
        if image.mode != "RGB":
            image = image.convert("RGB")

        batch_size = 16  # 每次處理16個詞彙
        similarities = []

        # 釋放未使用的顯存
        if device == "cuda":
            torch.cuda.empty_cache()

        # 推理並進行相似度計算
        with torch.no_grad():
            for i in range(0, len(vocab), batch_size):
                batch_vocab = vocab[i:i + batch_size]
                inputs = processor(
                    text=batch_vocab,
                    images=image,
                    return_tensors="pt",
                    padding=True
                )
                # 確保輸入移動到 GPU
                inputs = {k: v.to(device) for k, v in inputs.items()}

                # 確保模型和張量在 GPU 上進行推理
                outputs = model(**inputs)
                image_embeds = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
                text_embeds = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
                similarity = torch.matmul(image_embeds, text_embeds.T).squeeze(0).to(device)
                similarities.append(similarity)

        # 合併所有相似度
        similarity = torch.cat(similarities, dim=0)

        # 找到相似度最高的詞彙
        top_k_indices = torch.topk(similarity, top_k).indices.tolist()
        top_k_words = [vocab[idx] for idx in top_k_indices]
        result = "\n".join([f"Top {i+1}: {word}" for i, word in enumerate(top_k_words)])
        return result

    except Exception as e:
        return f"推理過程中發生錯誤: {str(e)}"

# 搜索圖片函數
def find_image_for_word(word):
    try:
        # 釋放未使用的顯存
        if device == "cuda":
            torch.cuda.empty_cache()

        # 確保文字在 GPU 上進行推理
        with torch.no_grad():
            text_inputs = processor(
                text=[word],
                return_tensors="pt",
                padding=True
            ).to(device)

            text_embeds = model.get_text_features(**text_inputs)
            text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)

        highest_similarity = -1
        best_match_image_path = None

        # 遍歷資料夾內所有圖片
        for root, _, files in os.walk(image_dir):
            for file in files:
                if file.endswith(('.png', '.jpg', '.jpeg')):  # 支援常見圖片格式
                    image_path = os.path.join(root, file)
                    image = Image.open(image_path).convert("RGB")

                    # 將圖片移動到 GPU 並取得圖片特徵
                    with torch.no_grad():
                        image_inputs = processor(
                            images=image,
                            return_tensors="pt"
                        ).to(device)

                        image_embeds = model.get_image_features(**image_inputs)
                        image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)

                    # 計算相似度
                    similarity = torch.matmul(text_embeds, image_embeds.T).item()

                    # 更新相似度最高的圖片
                    if similarity > highest_similarity:
                        highest_similarity = similarity
                        best_match_image_path = image_path

        # 返回最相似的圖片
        if best_match_image_path:
            # 將圖片另存為 PNG 或 JPEG 格式
            best_image = Image.open(best_match_image_path).convert("RGB")
            best_image.save(output_image_path, format="PNG")  # 指定保存為 PNG

            # 將圖片轉換為 Base64 格式
            buffered = BytesIO()
            best_image.save(buffered, format="PNG")
            img_base64 = base64.b64encode(buffered.getvalue()).decode()

            return output_image_path, f"匹配圖片：{os.path.basename(best_match_image_path)}，相似度：{highest_similarity:.2f}", img_base64
        else:
            return None, "找不到匹配的圖片", None

    except Exception as e:
        return None, f"搜尋過程中發生錯誤: {str(e)}", None

# # 設置 Gradio 介面
# iface = gr.Interface(
#     fn=find_image_for_word,
#     inputs=[gr.Textbox(label="輸入詞彙")],
#     outputs=[gr.Image(type="filepath"), gr.Textbox(), gr.Textbox(label="Base64 編碼")],
#     title="CLIP 中文模型 - 詞彙與圖像匹配",
#     description="輸入一個詞彙，模型會從圖鑑中找到與詞彙最相關的圖片"
# )

# # 啟動共享連結
# iface.launch(share=True, debug=True)

def launch_interface():
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                function_choice = gr.Dropdown(["find_similar_words", "find_image_for_word"], label="選擇功能")
                input_image = gr.Image(type="pil", label="上傳圖片")
                input_word = gr.Textbox(label="輸入詞彙", visible=False)
                top_k = gr.Slider(1, 10, value=3, label="Top K", visible=False)
                run_button = gr.Button("執行")

            with gr.Column():
                output_text = gr.Textbox(label="結果")
                output_image = gr.Image(type="filepath", label="匹配圖片", visible=False)
                output_base64 = gr.Textbox(label="Base64 編碼", visible=False)

        # 根據下拉選單選擇顯示/隱藏輸入和輸出組件
        def update_inputs_outputs(choice):
            if choice == "find_similar_words":
                return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
            else:
                return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)

        function_choice.change(
            update_inputs_outputs,
            inputs=[function_choice],
            outputs=[input_image, input_word, top_k, output_text, output_image],
        )

        # 執行按鈕的點擊事件
        def run_function(choice, image=None, word=None, top_k=3):
            if choice == "find_similar_words":
                return find_similar_words(image, top_k) , None , None
            else:
                image_path, result_text, img_base64 = find_image_for_word(word)
                return result_text, image_path, img_base64

        run_button.click(
            run_function,
            inputs=[function_choice, input_image, input_word, top_k],
            outputs=[output_text, output_image, output_base64],
        )

    demo.launch(share=True, debug=True)

launch_interface()


  model.load_state_dict(torch.load("/content/drive/MyDrive/chiikawa/model/best_clip_model.pth", map_location=device))


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://a0c594662477a008f4.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


  return self.preprocess(images, **kwargs)
  return self.preprocess(images, **kwargs)
  return self.preprocess(images, **kwargs)
  return self.preprocess(images, **kwargs)
  return self.preprocess(images, **kwargs)
