In [None]:
import torch
import clip
import os
import glob
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# image_folder_path = r'D:\BaiduNetdiskDownload\nai3\output'
image_folder_path = r'D:\BaiduNetdiskDownload\nai3\source'

class CustomImageDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB') 
        if self.transform:
            image = self.transform(image)
        return image, image_path



transform = transforms.Compose([
    transforms.Resize(224), 
    transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
])


image_paths = glob.glob(os.path.join(image_folder_path, '*.webp'))
dataset = CustomImageDataset(image_paths=image_paths, transform=transform)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=0)

model, _ = clip.load('ViT-L/14', device="cpu")
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

image_features_dict = {}

with torch.no_grad():
    for inputs, paths in tqdm(dataloader):
        inputs = inputs.to(device)
        outputs = model.encode_image(inputs)
        for i, feature in enumerate(outputs.cpu()):
            image_features_dict[os.path.basename(paths[i]).split('.')[0]] = feature
            # image_features_dict[os.path.basename(paths[i]).split('_')[0]] = feature





import pickle
with open(r'D:\BaiduNetdiskDownload\nai3\real_features.pickle', 'wb') as handle:
    pickle.dump(image_features_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)











In [ ]:
import pickle
import torch
import os
from PIL import Image
import matplotlib.pyplot as plt

cosine_similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

with open(r'D:\BaiduNetdiskDownload\nai3\gen_features.pickle', 'rb') as handle:
    gen_dict = pickle.load(handle)
    

with open(r'D:\BaiduNetdiskDownload\nai3\real_features.pickle', 'rb') as handle:
    real_dict = pickle.load(handle)
    

image_names = list(gen_dict.keys())

sim_scores = {}
for img_name in image_names:
    if img_name in real_dict:
        real_feature = real_dict[img_name]
        generated_feature = gen_dict[img_name]
        sim_scores[img_name] = cosine_similarity(generated_feature.unsqueeze(0), real_feature.unsqueeze(0))

sorted_keys = sorted(sim_scores, key=lambda k: sim_scores[k], reverse=True)



standard_width = 512

top_100_keys = sorted_keys[-2010:-2000]

fakeImagePathMap = {}
for iii in os.listdir(r"D:\BaiduNetdiskDownload\nai3\output"):
    fakeImagePathMap[os.path.basename(iii).split('_')[0]] = os.path.basename(iii)


fig, axs = plt.subplots(10, 1, figsize=(standard_width*2, 10)) 

for i, img_name in enumerate(top_100_keys):
    real_images_path = os.path.join(r"D:\BaiduNetdiskDownload\nai3\source", img_name+".webp")
    gen_images_path = os.path.join(r"D:\BaiduNetdiskDownload\nai3\output", fakeImagePathMap[img_name])
    real_img = Image.open(real_images_path)
    gen_img = Image.open(gen_images_path)
    real_height = int((standard_width / real_img.width) * real_img.height)
    gen_height = int((standard_width / gen_img.width) * gen_img.height)
    real_img_resized = real_img.resize((standard_width, real_height), Image.ANTIALIAS)
    gen_img_resized = gen_img.resize((standard_width, gen_height), Image.ANTIALIAS)
    concatenated_image = Image.new('RGB', (2 * standard_width, max(real_height, gen_height)))
    concatenated_image.paste(real_img_resized, (0, 0))
    concatenated_image.paste(gen_img_resized, (standard_width, 0))    
    axs[i].imshow(concatenated_image)
    axs[i].axis('off')


pdf_filename = 'concatenated_images.pdf'
fig.savefig(pdf_filename, bbox_inches='tight', dpi=2000) 




