# Experiments with word order pair swap debiasing for Winoground

In [None]:
from datasets import load_dataset
from transformers import BlipProcessor, BlipForConditionalGeneration, BlipForImageTextRetrieval, AutoTokenizer, AutoModelForMaskedLM, CLIPModel, CLIPProcessor
import torch
from tqdm import tqdm
from main.experiment import *
from main.utils import *

In [None]:
device = "cuda"

In [None]:
winoground = load_dataset("facebook/winoground", use_auth_token=True)["test"]

blip_clm_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_clm_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)
blip_clm_model.eval()

blip_itm_processor = BlipProcessor.from_pretrained("Salesforce/blip-itm-large-coco")
blip_itm_model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-large-coco").to(device)
blip_itm_model.eval()
print()

In [None]:
images0 = winoground["image_0"]
images1 = winoground["image_1"]
captions0 = winoground["caption_0"]
captions1 = winoground["caption_1"]

In [None]:
image_data = list(zip(images0, images1, images0, images1))
caption_data = list(zip(captions0, captions0, captions1, captions1))

image_data_ragged = RaggedList(image_data)
caption_data_ragged = RaggedList(caption_data)

image_data_flat = image_data_ragged.flatten()
caption_data_flat = caption_data_ragged.flatten()

In [None]:
mlm_tokenizer = AutoTokenizer.from_pretrained('roberta-base')
mlm = AutoModelForMaskedLM.from_pretrained('roberta-base').to(device)

In [None]:
roberta_causal_score = CausalLLMTextScorer(mlm, mlm_tokenizer)
pair_text_gen = PairSwapsTextGenerator(text_scorer=roberta_causal_score, best_k=16)

In [None]:
with torch.inference_mode():
    alt_caps = pair_text_gen.generate(caption_data_flat)

In [None]:
# it_scorer = BLIPImageTextScorer(blip_itm_model, blip_itm_processor, clm_ignore_sep=True, score_type=BLIPScoreType.ITM)
it_scorer = BLIPImageTextScorer(blip_itm_model, blip_itm_processor, score_type=BLIPScoreType.CONTRASTIVE)

In [None]:
with torch.inference_mode():
    orig_scores = it_scorer.score(image_data_flat, caption_data_flat)

In [None]:
alt_caps_ragged = RaggedList(alt_caps)
alt_caps_flat = alt_caps_ragged.flatten()
image_data_flat_flat = alt_caps_ragged.flatten_broadcast(image_data_flat)

In [None]:
with torch.inference_mode():
    new_scores = it_scorer.score(image_data_flat_flat, alt_caps_flat)

In [None]:
unflat_orig_scores = caption_data_ragged.unflatten(orig_scores)

In [None]:
original_scores = torch.tensor(unflat_orig_scores)

In [None]:
def top_k_mean(vec, k=10):
    vec = torch.tensor(vec)
    return torch.topk(vec, k).values.mean()

new_scores_unflat = alt_caps_ragged.unflatten(new_scores)

In [None]:
for i in range(1, 10):
    avg_new_scores = [torch.tensor(new_scores_row[:i]).mean() for new_scores_row in new_scores_unflat]
    avg_new_scores = caption_data_ragged.unflatten(avg_new_scores)
    bias_scores = torch.tensor(avg_new_scores)
    new_scores = original_scores - bias_scores
    print(i, torch.sum((new_scores[..., 0] > new_scores[..., 1]) & (new_scores[..., 3] > new_scores[..., 2])) / len(new_scores))