In [3]:
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from blip_itm import blip_itm
from datasets import load_dataset

# adapted from https://colab.research.google.com/github/salesforce/BLIP/blob/main/demo.ipynb

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def preprocess(image, image_size, device):
    # just take the preprocessing method from their demo for now
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    ])
    return transform(image).unsqueeze(0).to(device)

def score(output):
    return torch.nn.functional.softmax(output, dim=1)[:, 1].item()

image_size = 384 # default for BLIP
model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'

model = blip_itm(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device=device)

auth_token = "" # use your own
winoground = load_dataset("facebook/winoground", use_auth_token=auth_token)["test"]

f = open("BLIP_scores.txt", "w")
f.write("ID\xa0tag\xa0secondary_tag\xa0num_main_preds\xa0collapsed_tag\xa0C0I0\xa0C0I1\xa0C1I0\xa0C1I1\n")

Downloading: 100%|██████████| 232k/232k [00:00<00:00, 3.02MB/s]
Downloading: 100%|██████████| 28.0/28.0 [00:00<00:00, 45.1kB/s]
Downloading: 100%|██████████| 570/570 [00:00<00:00, 916kB/s]
100%|██████████| 1.78G/1.78G [00:52<00:00, 36.0MB/s]


load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth


Using the latest cached version of the module from /home/samuelyu/.cache/huggingface/modules/datasets_modules/datasets/facebook--winoground/ce486f3e39fab90997d6f3c58c4b0103eb9c37011049ef775a465f0ab2e78d7d (last modified on Sat Oct  8 23:11:20 2022) since it couldn't be found locally at facebook/winoground., or remotely on the Hugging Face Hub.
Found cached dataset winoground (/home/samuelyu/.cache/huggingface/datasets/facebook___winoground/default/0.0.0/ce486f3e39fab90997d6f3c58c4b0103eb9c37011049ef775a465f0ab2e78d7d)
100%|██████████| 1/1 [00:00<00:00, 948.29it/s]


70

In [4]:


for example in winoground:
  i0 = preprocess(example["image_0"].convert("RGB"), image_size, device)
  i1 = preprocess(example["image_1"].convert("RGB"), image_size, device)
  c0 = example["caption_0"]
  c1 = example["caption_1"]
  output_c0_i0 = model(i0, c0, match_head="itm")
  output_c1_i0 = model(i0, c1, match_head="itm")
  output_c0_i1 = model(i1, c0, match_head="itm")
  output_c1_i1 = model(i1, c1, match_head="itm")
  print(output_c0_i0.shape)
  break
  score_c0_i0 = score(output_c0_i0)
  score_c1_i0 = score(output_c1_i0)
  score_c0_i1 = score(output_c0_i1)
  score_c1_i1 = score(output_c1_i1)
  row = [str(example["id"]), example["tag"], example["secondary_tag"], str(example["num_main_preds"]),
         example["collapsed_tag"], str(score_c0_i0), str(score_c0_i1), str(score_c1_i0),
         str(score_c1_i1)]
  f.write("\xa0".join(row) + "\n")

torch.Size([1, 2])
