In [None]:
import os
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
import datasets
import torchvision
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from transformers import CLIPProcessor, CLIPModel, AutoProcessor, AutoModel

import sys
sys.path.append('../')
import src

## Filter MS COCO

In [None]:
dataset = datasets.load_dataset(
    "clip-benchmark/wds_mscoco_captions",
    split="test",
    streaming=True
)

### CLIP

In [None]:
# MODEL_NAME = "openai/clip-vit-base-patch32"
MODEL_NAME = "openai/clip-vit-base-patch16"
model = CLIPModel.from_pretrained(MODEL_NAME)
model.to(device)
processor = CLIPProcessor.from_pretrained(MODEL_NAME)

In [None]:
ret = {'best_text_id': [], 'logit': []}
for d in dataset:
    inputs = processor(
        images=d['jpg'], 
        text=d['txt'].split("\n"), 
        return_tensors="pt", 
        padding=True
    )
    with torch.no_grad():
        outputs = model(**inputs.to(device))
    logits = outputs.logits_per_text
    ret['best_text_id'].append(logits.argmax().item())
    ret['logit'].append(logits.max().item())

In [None]:
df = pd.DataFrame(ret)

In [None]:
df.to_csv(f'../results/{MODEL_NAME}/mscoco_predictions.csv')

### siglip-2

In [None]:
MODEL_NAME = "google/siglip2-base-patch32-256"
model = AutoModel.from_pretrained(MODEL_NAME)
model.to(device)
processor = AutoProcessor.from_pretrained(MODEL_NAME)

In [None]:
ret = {'best_text_id': [], 'logit': []}
for d in dataset:
    try:
        inputs = processor(
            images=d['jpg'], 
            text=d['txt'].split("\n"), 
            return_tensors="pt", 
            padding="max_length",
            max_length=64
        )
        with torch.no_grad():
            outputs = model(**inputs.to(device))
        logits = outputs.logits_per_text
        ret['best_text_id'].append(logits.argmax().item())
        ret['logit'].append(logits.max().item())
    except:
        ret['best_text_id'].append(0)
        ret['logit'].append(0)
        print("black-white image")

In [None]:
df = pd.DataFrame(ret)

In [None]:
df.to_csv(f'../results/{MODEL_NAME}/mscoco_predictions.csv')

## Pointing Game with ImageNet-1k

In [None]:
dataset_imagenet = datasets.load_dataset('imagenet-1k', split="validation", trust_remote_code=True)

In [None]:
dataset_imagenet

https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/

In [None]:
labels = {
    'goldfish': 1,
    'cat': 282,
    'husky': 248,

    'banana': 954,
    'pizza': 963,

    'plane': 404,
    'tractor': 866,

    'ball': 805,
    
    'church': 497,

    'ipod': 605
}

In [None]:
for label, id in labels.items():
    print(id, label)
    idx_label = np.where(np.array(dataset_imagenet['label']) == id)[0]
    dataset_label = dataset_imagenet.select(idx_label)
    print(len(idx_label))
    path = f'../data/imagenet_pointing_game/{label}/'
    if not os.path.exists(path):
        os.makedirs(path)
    for i, item in enumerate(dataset_label):
        item['image'].save(f'../data/imagenet_pointing_game/{label}/{i}.jpg')

In [None]:
games = [
    ['goldfish', 'husky', 'pizza', 'tractor'],
    ['cat', 'goldfish', 'plane', 'pizza'],
    ['banana', 'cat', 'tractor', 'ball'],
    ['husky', 'banana', 'plane', 'church'],
    ['pizza', 'ipod', 'goldfish', 'banana'],
    ['ipod', 'cat', 'husky', 'plane'],
    ['tractor', 'ball', 'banana', 'ipod'],
    ['plane', 'church', 'ball', 'goldfish'],
    ['church', 'pizza', 'ipod', 'cat'],
    ['ball', 'husky', 'banana', 'tractor'],
]

In [None]:
resizer = torchvision.transforms.Resize((224, 224))

In [None]:
for game in games:
    cl = "_".join(game)
    path = f'../data/imagenet_pointing_game/{cl}/'
    if not os.path.exists(path):
        os.makedirs(path)
    for i in range(50):
        images = []
        for label in game:
            img = Image.open(f'../data/imagenet_pointing_game/{label}/{i}.jpg')
            images.append(resizer(img))
        img1 = src.utils.append_images([images[0], images[1]], direction='horizontal')
        img2 = src.utils.append_images([images[2], images[3]], direction='horizontal')
        final = src.utils.append_images([img1, img2], direction='vertical')
        final.save(f'../data/imagenet_pointing_game/{cl}/{i}.jpg')