In [62]:
import torch
import numpy as np
from datasets import load_dataset
from tqdm.auto import tqdm
from transformers import AutoProcessor, CLIPModel, AutoTokenizer
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns

In [80]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")

In [81]:
imagenette = load_dataset(
    'frgfm/imagenette',
    '160px',
 split='validation',
 revision="4d512db"
)

In [82]:
labels = imagenette.features["label"].names
print(f"Class labels in dataset: {labels}")

Class labels in dataset: ['tench', 'English springer', 'cassette player', 'chain saw', 'church', 'French horn', 'garbage truck', 'gas pump', 'golf ball', 'parachute']


In [83]:
selected_images = []
selected_labels = []

for example in tqdm(imagenette):
    label = example["label"]
    selected_images.append(example["image"])
    selected_labels.append(label)

  0%|          | 0/3925 [00:00<?, ?it/s]

In [84]:
text_inputs = processor([f"a photo of a {c}" for c in labels], return_tensors="pt", padding=True).to(device)

In [85]:
with torch.no_grad():
    label_emb = model.get_text_features(input_ids=text_inputs['input_ids'], attention_mask=text_inputs['attention_mask'])
    label_emb = label_emb.cpu().numpy()

#label_emb = label_emb / np.linalg.norm(label_emb, axis=1)[:, np.newaxis]

In [None]:
all_similarities = []
preds = []
batch_size = 50

for i in tqdm(range(0, len(selected_images), batch_size)):
    i_end = min(i + batch_size, len(selected_images))
    images = processor(
        images=selected_images[i:i_end],
        return_tensors='pt'
    )['pixel_values'].to(device)

    with torch.no_grad():
        img_emb = model.get_image_features(images)
        img_emb = img_emb.cpu().numpy()
        
    #img_emb = img_emb / np.linalg.norm(img_emb, axis=1)[:, np.newaxis]
    
    # Вычисление косинусного сходства
    similarities = np.dot(img_emb, label_emb.T)
    preds.extend(np.argmax(similarities, axis=1))
    all_similarities.extend(similarities)

  0%|          | 0/79 [00:00<?, ?it/s]

In [88]:
all_similarities = np.array(all_similarities)

In [89]:
print("\nСредние значения косинусного сходства для каждого класса:")
for i, label in enumerate(labels):
    mean_similarity = np.mean(all_similarities[:, i])
    print(f"{label}: {mean_similarity:.4f}")


Средние значения косинусного сходства для каждого класса:
tench: 34.3483
English springer: 30.5799
cassette player: 35.5207
chain saw: 39.4871
church: 41.3656
French horn: 36.2596
garbage truck: 41.1660
gas pump: 33.6668
golf ball: 36.2639
parachute: 41.5595


In [110]:
accuracy = accuracy_score(selected_labels, preds)
print(f"Zero-shot classification accuracy on Imagenette: {accuracy * 100:.2f}%")

Zero-shot classification accuracy on Imagenette: 99.29%


In [126]:
from deep_translator import GoogleTranslator

for elem in labels:
    translated = GoogleTranslator(source='en', target='ru').translate(elem) 
    print(translated)# output -> Weiter so, du bist großartig

линь
английский спрингер
кассетный плеер
цепная пила
церковь
валторна
мусоровоз
газовый насос
мяч для гольфа
парашют
