# Install Dependencies

In [None]:
!pip install transformers
!pip install datasets

# Load Winoground (enter a huggingface user token or it will fail)

Easily get a token by logging into https://huggingface.co/, clicking on your user profile -> settings -> access tokens -> new token

In [None]:
from datasets import load_dataset
auth_token = ""  # Replace with an auth token, which you can get from your huggingface account: Profile -> Settings -> Access Tokens -> New Token
winoground = load_dataset("facebook/winoground", use_auth_token=auth_token)["test"]

# Load FLAVA

In [None]:
from transformers import FlavaProcessor, FlavaForPreTraining
flava_model = FlavaForPreTraining.from_pretrained("facebook/flava-full").eval().to("cuda")
flava_processor = FlavaProcessor.from_pretrained("facebook/flava-full")

# Look at an example from Winoground and get the image-caption scores from FLAVA

In [None]:
from matplotlib import pyplot as plt
import torch

ax1 = plt.subplot(1, 3, 1)
ax1.title.set_text('image_0')
plt.imshow(winoground[155]["image_0"].convert("RGB"))

ax2 = plt.subplot(1, 3, 2)
ax2.title.set_text('image_1')
plt.imshow(winoground[155]["image_1"].convert("RGB"))

plt.show()

print("caption_0:", winoground[155]["caption_0"])
print("caption_1:", winoground[155]["caption_1"])

# Note that some images in winoground are RGBA and some are RGB. Need to convert all to RGB with .convert('RGB')
inputs_c0_i0 = flava_processor(text=[winoground[155]["caption_0"]], images=[winoground[155]["image_0"].convert("RGB")], return_tensors="pt", max_length=77, padding=True, return_codebook_pixels=True, return_image_mask=True).to("cuda")
inputs_c1_i0 = flava_processor(text=[winoground[155]["caption_1"]], images=[winoground[155]["image_0"].convert("RGB")], return_tensors="pt", max_length=77, padding=True, return_codebook_pixels=True, return_image_mask=True).to("cuda")
inputs_c0_i1 = flava_processor(text=[winoground[155]["caption_0"]], images=[winoground[155]["image_1"].convert("RGB")], return_tensors="pt", max_length=77, padding=True, return_codebook_pixels=True, return_image_mask=True).to("cuda")
inputs_c1_i1 = flava_processor(text=[winoground[155]["caption_1"]], images=[winoground[155]["image_1"].convert("RGB")], return_tensors="pt", max_length=77, padding=True, return_codebook_pixels=True, return_image_mask=True).to("cuda")

inputs_c0_i0["input_ids_masked"] = inputs_c0_i0["input_ids"].detach().clone() 
inputs_c1_i0["input_ids_masked"] = inputs_c1_i0["input_ids"].detach().clone() 
inputs_c0_i1["input_ids_masked"] = inputs_c0_i1["input_ids"].detach().clone() 
inputs_c1_i1["input_ids_masked"] = inputs_c1_i1["input_ids"].detach().clone() 

inputs_c0_i0["bool_masked_pos"] = torch.zeros_like(inputs_c0_i0["bool_masked_pos"])
inputs_c1_i0["bool_masked_pos"] = torch.zeros_like(inputs_c1_i0["bool_masked_pos"])
inputs_c0_i1["bool_masked_pos"] = torch.zeros_like(inputs_c0_i1["bool_masked_pos"])
inputs_c1_i1["bool_masked_pos"] = torch.zeros_like(inputs_c1_i1["bool_masked_pos"])

outputs_c0_i0 = flava_model(**inputs_c0_i0)
outputs_c1_i0 = flava_model(**inputs_c1_i0)
outputs_c0_i1 = flava_model(**inputs_c0_i1)
outputs_c1_i1 = flava_model(**inputs_c1_i1)

flava_contrastive_scores_c0_i0 = outputs_c0_i0.contrastive_logits_per_image.item()
flava_contrastive_scores_c1_i0 = outputs_c1_i0.contrastive_logits_per_image.item()
flava_contrastive_scores_c0_i1 = outputs_c0_i1.contrastive_logits_per_image.item()
flava_contrastive_scores_c1_i1 = outputs_c1_i1.contrastive_logits_per_image.item()
print()
print("FLAVA contrastive image-text match scores:")
print("image_0, caption_0:", flava_contrastive_scores_c0_i0)
print("image_0, caption_1:", flava_contrastive_scores_c1_i0)
print("image_1, caption_0:", flava_contrastive_scores_c0_i1)
print("image_1, caption_1:", flava_contrastive_scores_c1_i1)

flava_itm_scores_c0_i0 = torch.nn.functional.softmax(outputs_c0_i0.itm_logits)[0][1].item()
flava_itm_scores_c1_i0 = torch.nn.functional.softmax(outputs_c1_i0.itm_logits)[0][1].item()
flava_itm_scores_c0_i1 = torch.nn.functional.softmax(outputs_c0_i1.itm_logits)[0][1].item()
flava_itm_scores_c1_i1 = torch.nn.functional.softmax(outputs_c1_i1.itm_logits)[0][1].item()
print()
print("FLAVA itm image-text match scores:")
print("image_0, caption_0:", flava_itm_scores_c0_i0)
print("image_0, caption_1:", flava_itm_scores_c1_i0)
print("image_1, caption_0:", flava_itm_scores_c0_i1)
print("image_1, caption_1:", flava_itm_scores_c1_i1)

# Get FLAVA image-caption scores from the whole dataset

In [None]:
from tqdm import tqdm
winoground_flava_contrastive_scores = []
winoground_flava_itm_scores = []
for example in tqdm(winoground):
  # Note that some images in winoground are RGBA and some are RGB. Need to convert all to RGB with .convert('RGB')
  inputs_c0_i0 = flava_processor(text=[example["caption_0"]], images=[example["image_0"].convert("RGB")], return_tensors="pt", padding="max_length", max_length=77, return_codebook_pixels=True, return_image_mask=True).to("cuda")
  inputs_c1_i0 = flava_processor(text=[example["caption_1"]], images=[example["image_0"].convert("RGB")], return_tensors="pt", padding="max_length", max_length=77, return_codebook_pixels=True, return_image_mask=True).to("cuda")
  inputs_c0_i1 = flava_processor(text=[example["caption_0"]], images=[example["image_1"].convert("RGB")], return_tensors="pt", padding="max_length", max_length=77, return_codebook_pixels=True, return_image_mask=True).to("cuda")
  inputs_c1_i1 = flava_processor(text=[example["caption_1"]], images=[example["image_1"].convert("RGB")], return_tensors="pt", padding="max_length", max_length=77, return_codebook_pixels=True, return_image_mask=True).to("cuda")
  
  inputs_c0_i0["input_ids_masked"] = inputs_c0_i0["input_ids"].detach().clone() 
  inputs_c1_i0["input_ids_masked"] = inputs_c1_i0["input_ids"].detach().clone() 
  inputs_c0_i1["input_ids_masked"] = inputs_c0_i1["input_ids"].detach().clone() 
  inputs_c1_i1["input_ids_masked"] = inputs_c1_i1["input_ids"].detach().clone() 

  inputs_c0_i0["bool_masked_pos"] = torch.zeros_like(inputs_c0_i0["bool_masked_pos"])
  inputs_c1_i0["bool_masked_pos"] = torch.zeros_like(inputs_c1_i0["bool_masked_pos"])
  inputs_c0_i1["bool_masked_pos"] = torch.zeros_like(inputs_c0_i1["bool_masked_pos"])
  inputs_c1_i1["bool_masked_pos"] = torch.zeros_like(inputs_c1_i1["bool_masked_pos"])

  outputs_c0_i0 = flava_model(**inputs_c0_i0)
  outputs_c1_i0 = flava_model(**inputs_c1_i0)
  outputs_c0_i1 = flava_model(**inputs_c0_i1)
  outputs_c1_i1 = flava_model(**inputs_c1_i1)

  flava_contrastive_scores_c0_i0 = outputs_c0_i0.contrastive_logits_per_image.item()
  flava_contrastive_scores_c1_i0 = outputs_c1_i0.contrastive_logits_per_image.item()
  flava_contrastive_scores_c0_i1 = outputs_c0_i1.contrastive_logits_per_image.item()
  flava_contrastive_scores_c1_i1 = outputs_c1_i1.contrastive_logits_per_image.item()
  winoground_flava_contrastive_scores.append({"id" : example["id"], "c0_i0": flava_contrastive_scores_c0_i0, "c0_i1": flava_contrastive_scores_c0_i1, "c1_i0": flava_contrastive_scores_c1_i0, "c1_i1": flava_contrastive_scores_c1_i1})

  flava_itm_scores_c0_i0 = torch.nn.functional.softmax(outputs_c0_i0.itm_logits)[0][1].item()
  flava_itm_scores_c1_i0 = torch.nn.functional.softmax(outputs_c1_i0.itm_logits)[0][1].item()
  flava_itm_scores_c0_i1 = torch.nn.functional.softmax(outputs_c0_i1.itm_logits)[0][1].item()
  flava_itm_scores_c1_i1 = torch.nn.functional.softmax(outputs_c1_i1.itm_logits)[0][1].item()
  winoground_flava_itm_scores.append({"id" : example["id"], "c0_i0": flava_itm_scores_c0_i0, "c0_i1": flava_itm_scores_c0_i1, "c1_i0": flava_itm_scores_c1_i0, "c1_i1": flava_itm_scores_c1_i1})

# Define the text, image, and group metrics, and compute the overall performance of FLAVA

In [None]:
def text_correct(result):
    return result["c0_i0"] > result["c1_i0"] and result["c1_i1"] > result["c0_i1"]

def image_correct(result):
    return result["c0_i0"] > result["c0_i1"] and result["c1_i1"] > result["c1_i0"]

def group_correct(result):
    return image_correct(result) and text_correct(result)


contrastive_text_correct_count = 0
contrastive_image_correct_count = 0
contrastive_group_correct_count = 0
for result in winoground_flava_contrastive_scores:
  contrastive_text_correct_count += 1 if text_correct(result) else 0
  contrastive_image_correct_count += 1 if image_correct(result) else 0
  contrastive_group_correct_count += 1 if group_correct(result) else 0

denominator = len(winoground_flava_contrastive_scores)
print("contrastive text score:", contrastive_text_correct_count/denominator)
print("contrastive image score:", contrastive_image_correct_count/denominator)
print("contrastive group score:", contrastive_group_correct_count/denominator)

itm_text_correct_count = 0
itm_image_correct_count = 0
itm_group_correct_count = 0
for result in winoground_flava_itm_scores:
  itm_text_correct_count += 1 if text_correct(result) else 0
  itm_image_correct_count += 1 if image_correct(result) else 0
  itm_group_correct_count += 1 if group_correct(result) else 0

denominator = len(winoground_flava_itm_scores)
print("itm text score:", itm_text_correct_count/denominator)
print("itm image score:", itm_image_correct_count/denominator)
print("itm group score:", itm_group_correct_count/denominator)
