<a href="https://colab.research.google.com/github/nakamura196/000_tools/blob/main/guie_sample.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GUIE（Google Universal Image Embedding）の学習済みモデルを使用して類似画像検索を行うサンプル

本プログラムの実行にあたり、以下のモデルを使用しています。

https://www.kaggle.com/code/w3579628328/2nd-place-solution

## ライブラリのインストール

In [None]:
!pip install -U torchvision==0.12.0
!pip install open_clip_torch

In [None]:
import os
import requests
import torch
import torchvision
from torchvision import transforms
import numpy as np
from PIL import Image
from glob import glob
from tqdm import tqdm
from google.colab import userdata
import json
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt

In [None]:
print(torch.__version__)
print(torchvision.__version__)

## Kaggleからモデルをダウンロードする

In [None]:
if not os.path.exists("/content/model"):

  KAGGLE_USERNAME = userdata.get('KAGGLE_USERNAME')
  KAGGLE_KEY = userdata.get('KAGGLE_KEY')

  kaggle_json = {
      "username": KAGGLE_USERNAME,
      "key": KAGGLE_KEY
  }

  with open("/content/kaggle.json", "w") as f:
      json.dump(kaggle_json, f)

  !mkdir -p ~/.kaggle
  !mv /content/kaggle.json ~/.kaggle/
  !chmod 600 ~/.kaggle/kaggle.json  # セキュリティのためにファイルの権限を設定

  !kaggle kernels output w3579628328/2nd-place-solution -p "/content/model"

## 画像データの準備

In [None]:
rows = [
    {
        "id": "1",
        "url": "https://dl.ndl.go.jp/api/iiif/1170851/R0000062/1004,2122,1389,1889/640,/0/default.jpg"
    },
    {
        "id": "2",
        "url": "https://dl.ndl.go.jp/api/iiif/1170851/R0000062/1037,466,2067,1555/640,/0/default.jpg"
    },
    {
        "id": "3",
        "url": "https://dl.ndl.go.jp/api/iiif/2551476/R0000002/410,897,5574,3807/640,/0/default.jpg"
    }
]

In [None]:
for row in rows:
  path = f"/content/input/images/{row['id']}.jpg"

  if not os.path.exists(path):

    os.makedirs(os.path.dirname(path), exist_ok=True)

    response = requests.get(row['url'])
    with open(path, 'wb') as f:
      f.write(response.content)

  img = Image.open(path)

  # 画像を表示
  plt.figure()
  plt.title(f"Image ID: {row['id']}")
  plt.imshow(img)
  plt.axis('off')
  plt.show()

## ベクトルの作成

In [None]:
# モデルの読み込み
model_path = "/content/model/saved_model.pt"
embedding_fn = torch.jit.load(model_path).to('cuda').eval()

In [None]:
def convert_and_save(path, opath):
    if os.path.exists(opath):
        return

    os.makedirs(os.path.dirname(opath), exist_ok=True)

    # 画像の読み込み
    org_image = Image.open(path).convert("RGB")

    # resize to 224 x 224
    resize = transforms.Resize((224, 224))
    org_image = resize(org_image)

    # 画像のテンソルへの変換
    convert_to_tensor = transforms.Compose([transforms.PILToTensor()])
    input_tensor = convert_to_tensor(org_image)
    input_batch = input_tensor.unsqueeze(0)

    # GPUへの移動
    input_batch_gpu = input_batch.to('cuda')

    # 推論実行
    with torch.no_grad():
        embedding = embedding_fn(input_batch_gpu).cpu().data.numpy()



    # 結果の保存
    np.save(opath, embedding)

In [None]:
files = glob("/content/input/images/*.jpg")

for file in tqdm(files):
  filename = os.path.basename(file).replace(".jpg", "")
  output_path = "/content/output/embeddings/" + filename + ".npy"
  convert_and_save(file, output_path)

## 類似画像検索

In [None]:
# Load all embeddings into memory
embeddings = []
filenames = []

for file in glob("/content/output/embeddings/*.npy"):
    embedding = np.load(file)
    embeddings.append(embedding)
    filenames.append(os.path.basename(file))

embeddings = np.vstack(embeddings)  # Stack all embeddings into a single array

def find_similar_images(query_embedding, embeddings, filenames, top_k=5):
    similarities = cosine_similarity(query_embedding, embeddings)
    similarities = similarities.flatten()

    # Sort by similarity score in descending order
    indices = np.argsort(-similarities)

    # Get the top K similar images (excluding the query image itself)
    similar_filenames = [filenames[idx] for idx in indices[:top_k] if similarities[idx] < 1.0]
    similar_scores = [similarities[idx] for idx in indices[:top_k] if similarities[idx] < 1.0]

    return similar_filenames, similar_scores

def display_images(query_image_path, similar_images, scores):
    # Load the query image
    query_image = Image.open(query_image_path)

    # Display the query image
    plt.figure(figsize=(15, 5))
    plt.subplot(1, len(similar_images) + 1, 1)
    plt.imshow(query_image)
    plt.title("Query Image")
    plt.axis('off')

    # Display similar images
    for i, (image, score) in enumerate(zip(similar_images, scores)):
        image_path = f"input/images/{image.replace('.npy', '.jpg')}"
        similar_image = Image.open(image_path)
        plt.subplot(1, len(similar_images) + 1, i + 2)
        plt.imshow(similar_image)
        plt.title(f"Score: {score:.4f}")
        plt.axis('off')

    plt.show()

# Iterate over each image to find and display similar images
for query_file in filenames:
    query_image_path = f"input/images/{query_file.replace('.npy', '.jpg')}"
    query_embedding = np.load(f"output/embeddings/{query_file}")

    # Reshape query embedding to match dimensions
    query_embedding = query_embedding.reshape(1, -1)

    similar_images, scores = find_similar_images(query_embedding, embeddings, filenames)

    print(f"Top similar images for {query_file}:")
    for image, score in zip(similar_images, scores):
        print(f"Image: {image}, Similarity Score: {score}")

    # Display the query and similar images
    display_images(query_image_path, similar_images, scores)