In [1]:
import glob
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModel
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

In [6]:
def dino_embeddings(image_paths):
    dino_processor = AutoProcessor.from_pretrained('facebook/dino-vits16')
    dino_model = AutoModel.from_pretrained('facebook/dino-vits16', add_pooling_layer=False)
    pil_images = [Image.open(image) for image in image_paths]

    with torch.no_grad():
        inputs = dino_processor(images=pil_images, return_tensors="pt")
        outputs = dino_model(**inputs)
        last_hidden_states = outputs.last_hidden_state.mean(dim=1)
        return last_hidden_states

def compute_similarity(data_a, data_b, pairwise=False):
    similarity = cosine_similarity(data_a, data_b)
    if not pairwise:
        # 対角成分を除いた上三角行列を取得するインデックス
        rows, cols = np.triu_indices_from(similarity, k=1)
        # 対角成分を除いた上三角行列の要素をベクトルとして取得
        upper_tri_vector_excluding_diagonal = similarity[rows, cols]
        # print(upper_tri_vector_excluding_diagonal.shape)
        return np.mean(upper_tri_vector_excluding_diagonal)
    else:
        # 対角成分のみ計算
        diagonal = similarity.diagonal()
        return np.mean(diagonal)
    

def compute_dino_similarities(target_directory):
    real_files = sorted(glob.glob(f"small_coco/*.png"))
    real_embeddings = dino_embeddings(real_files) # (N, D)

    fake_files = sorted(glob.glob(f"generated/{target_directory}/*.png"))
    fake_embeddings = dino_embeddings(fake_files) # (N, D)

    dist_real_real = compute_similarity(real_embeddings, real_embeddings)
    dist_real_fake = compute_similarity(real_embeddings, fake_embeddings, pairwise=True)
    dist_fake_fake = compute_similarity(fake_embeddings, fake_embeddings)

    print(target_directory)
    print("real_real : ", dist_real_real)
    print("real_fake : ", dist_real_fake)
    print("fake_fake : ", dist_fake_fake)

In [7]:
compute_dino_similarities("sd15")
compute_dino_similarities("sd21")
compute_dino_similarities("sdxl")



sd15
real_real :  0.8542539
real_fake :  0.75566673
fake_fake :  0.7344196




sd21
real_real :  0.8542539
real_fake :  0.7556185
fake_fake :  0.7396738




sdxl
real_real :  0.8542539
real_fake :  0.77147603
fake_fake :  0.7619995
