# Работа с метриками

In [None]:
!pip install pyiqa > /dev/null

In [None]:
import io
import glob
import pandas as pd
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt

from transformers import CLIPImageProcessor, CLIPModel, CLIPTokenizer

import pyiqa

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms 

In [None]:
# https://www.frontiersin.org/articles/10.3389/frai.2022.976235
aes_metric = pyiqa.create_metric("laion_aes")

In [None]:
def cosine_sim(x, y, threshold):
    return nn.functional.cosine_similarity(x, y) > threshold

def aesthetic_score(x, threshold):
    return aes_metric(x) > threshold

In [None]:
files = glob.glob("/kaggle/input/pickapic-v1-validation/*.parquet")

In [None]:
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
class PickAPicDataset:
    def __init__(self, parquet_file):
        self.data = pd.read_parquet(parquet_file)
        self.images_1 = self.data["jpg_0"]
        self.images_2 = self.data["jpg_1"]
        self.captions = self.data["caption"]
    
    def __len__(self):
        return len(self.captions)
    
    def __getitem__(self, idx):
        img1 = Image.open(io.BytesIO(self.images_1.iloc[idx]))
        img2 = Image.open(io.BytesIO(self.images_2.iloc[idx]))
        caption = self.captions.iloc[idx]
        
        tensor_img1 = clip_image_processor(img1, return_tensors="pt")
        tensor_img2 = clip_image_processor(img2, return_tensors="pt")
        tensor_caption = clip_tokenizer(caption, padding=True, return_tensors="pt", truncation=True, max_length=77)
        
        return {"img1": tensor_img1, "img2": tensor_img2, "text": tensor_caption, "pillow_img1": img1, "pillow_img2": img2, "caption": caption}

In [None]:
good_ids = []
bad_ids = []

# Визуализация дитасета

In [None]:
threshold_img_img = 0.5
threshold_img_text = 0.31
threshold_aes = 5

for file in files:
    dataset = PickAPicDataset(file)
    for idx, batch in tqdm(enumerate(dataset)):
        img1_features = clip_model.get_image_features(**batch["img1"])
        img2_features = clip_model.get_image_features(**batch["img2"])
        text_features = clip_model.get_text_features(**batch["text"])
        
        if cosine_sim(img1_features, text_features, threshold_img_text) and aesthetic_score(batch["img1"]["pixel_values"], threshold_aes):
            good_ids.append(idx)
        else:
            bad_ids.append(idx)
        
        if len(good_ids) >= 10 and len(bad_ids) >= 10:
            break
    
    f, axarr = plt.subplots(2, 10, figsize=(100, 10))
    for i in range(10):
        axarr[0, i].set_xlabel(dataset[good_ids[i]]["caption"])
        axarr[0, i].imshow(dataset[good_ids[i]]["pillow_img1"])
    for i in range(10):
        axarr[1, i].imshow(dataset[bad_ids[i]]["pillow_img1"])
    plt.show()
    break