# Sugarcrepe benchmark

## Get dataset

## Run benchmark

In [None]:
import json

import jax
import jax.numpy as jnp
import numpy as np
import optax
import orbax
import wandb
from flax.training import orbax_utils
from PIL import Image
from tqdm import tqdm
from transformers import AutoTokenizer

from clip_jax import CLIPModel
from clip_jax.data import image_to_logits, shift_tokens_left
from clip_jax.tokenizer import AutoTokenizer
from clip_jax.utils import load_config

In [None]:
assert jax.local_device_count() > 1

In [None]:
tokenizer_name = "cappa_tokenizer"
model_checkpoint = "craiyon/cappa-jax/config-ydqtfo4c:latest"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

In [None]:
config = load_config(model_checkpoint)

In [None]:
model = CLIPModel(**config)
rng = jax.random.PRNGKey(0)
logical_shape = jax.eval_shape(lambda rng: model.init_weights(rng), rng)["params"]
params = jax.tree.map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), logical_shape)

In [None]:
artifact = wandb.Api().artifact(model_checkpoint)
model_path = artifact.metadata["output_dir"]
step = int(artifact.metadata["step"])

In [None]:
# restore checkpoint
ckpt = {"params": params}
restore_args = orbax_utils.restore_args_from_target(ckpt)
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
orbax_options = orbax.checkpoint.CheckpointManagerOptions()
checkpoint_manager = orbax.checkpoint.CheckpointManager(model_path, orbax_checkpointer, orbax_options)
ckpt = checkpoint_manager.restore(step, ckpt, restore_kwargs={"restore_args": restore_args, "transforms": {}})
params = ckpt["params"]
del ckpt

In [None]:
def process_text(c):
    captions = " ".join(
                c.lower()
                .replace(",", ", ")
                .replace(".", ". ")
                .replace("-", " ")
                .replace(";", ", ")
                .replace(":", ", ")
                .replace('"', ' " ')
                .replace("/", ", ")
                .replace(".", ", ")
                .replace(")", ", ")
                .replace(" (", ", ")
                .strip(", ?\n")
                .split()
            ).replace(" ,", ",")
    txt_inputs = tokenizer(
        captions,
        padding="max_length",
        truncation=True,
        max_length=config["text_config"]["max_length"],
        return_tensors="np",
    )
    labels = shift_tokens_left(txt_inputs["input_ids"], pad_token_id=tokenizer.pad_token_id)
    labels_mask = shift_tokens_left(txt_inputs["attention_mask"], pad_token_id=0)
    return {
        "input_ids": txt_inputs["input_ids"],
        "attention_mask": txt_inputs["attention_mask"],
        "labels": labels,
        "labels_mask": labels_mask,
    }


In [None]:
def load_item(item):
    # image
    img = Image.open(f"val2017/{item['filename']}")
    img = img.resize((256, 256))
    img = img.convert("RGB")
    pixel_values = image_to_logits(img)
    pixel_values = pixel_values[np.newaxis, ...]
    # text   
    pos_inputs = process_text(item["caption"])
    neg_inputs = process_text(item["negative_caption"])
    return {
        "pixel_values": pixel_values,
        "pos_inputs": pos_inputs,
        "neg_inputs": neg_inputs,
    }

In [None]:
@jax.jit
def get_scores(pixel_values, inputs, params):
    assert pixel_values.shape[0] == 1, "only support 1 image at a time"
    encoder_outputs = model.apply({"params": params}, pixel_values=pixel_values, method=model.get_image_features)[
        "vision_model_output"
    ]["last_hidden_state"]
    logits = model.apply(
        {"params": params},
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        encoder_hidden_states=encoder_outputs,
        decode=False,
        method=model.get_text_features,
    )["text_model_output"]["last_hidden_state"]
    score = -optax.softmax_cross_entropy_with_integer_labels(logits, inputs["labels"]) * inputs["labels_mask"]
    score = score.sum(axis=-1)
    return score[0]

In [None]:
results = {}
for benchmark in ["add_att", "add_obj","replace_att", "replace_obj", "replace_rel", "swap_att", "swap_obj"]:
    print(f"benchmark: {benchmark}")
    labels = json.load(open(f"labels/{benchmark}.json"))
    count = 0
    success = 0
    for item in tqdm(labels.values()):
        inputs = load_item(item)
        pos_score = get_scores(inputs["pixel_values"], inputs["pos_inputs"], params)
        neg_score = get_scores(inputs["pixel_values"], inputs["neg_inputs"], params)
        count += 1
        if pos_score > neg_score:
            success += 1
    print(f"count: {count}, success: {success}, acc: {success / count}")
    results[benchmark] = success / count

In [None]:
results