In [79]:
import torch
from tqdm import tqdm
import types
import clip
from datasets import load_dataset

In [8]:
#Load Winoground dataset
auth_token = "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"
winoground = load_dataset("facebook/winoground", use_auth_token=auth_token)["test"]

Found cached dataset winoground (/home/samuelyu/.cache/huggingface/datasets/facebook___winoground/default/0.0.0/ce486f3e39fab90997d6f3c58c4b0103eb9c37011049ef775a465f0ab2e78d7d)
100%|██████████| 1/1 [00:00<00:00, 1078.50it/s]


In [26]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)

In [47]:
#Calculate the image-caption score for all examples in winoground
def run_winoground(winoground, clip_model, preprocess):
  winoground_clip_scores = []
  for example in tqdm(winoground):
    images = torch.stack([preprocess(example["image_0"]), preprocess(example["image_1"])], dim=0).to(device)
    captions = clip.tokenize([example["caption_0"], example["caption_1"]]).to(device)
    _, scores = clip_model(images, captions)
    winoground_clip_scores.append({"id" : example["id"], "c0_i0": scores[0,0].item(), "c0_i1": scores[0,1].item(), "c1_i0": scores[1,0].item(), "c1_i1": scores[1,1].item()})
  return winoground_clip_scores

#Functions to calculate Text, Image and Group Scores
def text_correct(result):
    return result["c0_i0"] > result["c1_i0"] and result["c1_i1"] > result["c0_i1"]

def image_correct(result):
    return result["c0_i0"] > result["c0_i1"] and result["c1_i1"] > result["c1_i0"]

def group_correct(result):
    return image_correct(result) and text_correct(result)

In [49]:
#Calculate the Text, Image and Group Scores from Winoground's image-caption score
text_correct_count = 0
image_correct_count = 0
group_correct_count = 0
winoground_clip_scores = run_winoground(winoground, clip_model, preprocess)
for result in winoground_clip_scores:
  text_correct_count += text_correct(result)
  image_correct_count += image_correct(result)
  group_correct_count += group_correct(result)

denominator = len(winoground_clip_scores)
print("text score:", text_correct_count/denominator)
print("image score:", image_correct_count/denominator)
print("group score:", group_correct_count/denominator)

100%|██████████| 400/400 [00:24<00:00, 16.42it/s]

text score: 0.305
image score: 0.105
group score: 0.08





### Now lets adjust the weightings of the positional embeddings to see how they affect the model's performance.

In [52]:
original_embedding = clip_model.positional_embedding.clone()

for alpha in [0.2, 0.5, 2.0, 5.0, 10.0]:
    clip_model.positional_embedding = torch.nn.Parameter(original_embedding * alpha).to(device)
    winoground_clip_scores = run_winoground(winoground, clip_model, preprocess)
    text_correct_count = 0
    image_correct_count = 0
    group_correct_count = 0
    for result in winoground_clip_scores:
        text_correct_count += text_correct(result)
        image_correct_count += image_correct(result)
        group_correct_count += group_correct(result)

    denominator = len(winoground_clip_scores)
    print("alpha:", alpha)
    print("text score:", text_correct_count/denominator)
    print("image score:", image_correct_count/denominator)
    print("group score:", group_correct_count/denominator)

100%|██████████| 400/400 [00:24<00:00, 16.31it/s]


alpha: 0.2
text score: 0.1825
image score: 0.0675
group score: 0.0375


100%|██████████| 400/400 [00:23<00:00, 16.73it/s]


alpha: 0.5
text score: 0.215
image score: 0.0625
group score: 0.0375


100%|██████████| 400/400 [00:24<00:00, 16.54it/s]


alpha: 2.0
text score: 0.255
image score: 0.1025
group score: 0.0675


100%|██████████| 400/400 [00:24<00:00, 16.49it/s]


alpha: 5.0
text score: 0.135
image score: 0.1075
group score: 0.045


100%|██████████| 400/400 [00:23<00:00, 16.99it/s]

alpha: 10.0
text score: 0.1525
image score: 0.085
group score: 0.04





It is probably better to re-normalize the input after adding the adjusted positional embeddings

In [82]:
def new_encode(self, text):
    x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]

    # original_x = x + self.positional_embedding.type(self.dtype)
    # x = (x + self.alpha*self.positional_embedding.type(self.dtype))
    # x = x * (original_x.norm(dim=-1, keepdim=True) / x.norm(dim=-1, keepdim=True)) # renormalize
    x = (x + self.alpha*self.positional_embedding.type(self.dtype)) / (1 + self.alpha) * 2
    x = x.permute(1, 0, 2)  # NLD -> LND
    x = self.transformer(x)
    x = x.permute(1, 0, 2)  # LND -> NLD
    x = self.ln_final(x).type(self.dtype)

    # x.shape = [batch_size, n_ctx, transformer.width]
    # take features from the eot embedding (eot_token is the highest number in each sequence)
    x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

    return x

clip_model.encode_text = types.MethodType(new_encode, clip_model)
clip_model.positional_embedding = torch.nn.Parameter(original_embedding).to(device)

In [83]:
for alpha in [0.0, 0.2, 0.5, 2.0, 5.0, 10.0]:
    clip_model.alpha = alpha
    winoground_clip_scores = run_winoground(winoground, clip_model, preprocess)
    text_correct_count = 0
    image_correct_count = 0
    group_correct_count = 0
    for result in winoground_clip_scores:
        text_correct_count += text_correct(result)
        image_correct_count += image_correct(result)
        group_correct_count += group_correct(result)

    denominator = len(winoground_clip_scores)
    print("alpha:", alpha)
    print("text score:", text_correct_count/denominator)
    print("image score:", image_correct_count/denominator)
    print("group score:", group_correct_count/denominator)

 12%|█▏        | 46/400 [00:02<00:22, 15.88it/s]

Weird that the first and last words have higher norm for the position embedding

In [69]:
original_embedding.norm(dim=1)

tensor([0.3445, 0.1132, 0.1129, 0.1110, 0.1105, 0.1101, 0.1106, 0.1110, 0.1113,
        0.1113, 0.1117, 0.1126, 0.1117, 0.1133, 0.1139, 0.1139, 0.1152, 0.1150,
        0.1153, 0.1166, 0.1171, 0.1179, 0.1204, 0.1213, 0.1222, 0.1241, 0.1261,
        0.1276, 0.1283, 0.1303, 0.1350, 0.1379, 0.1392, 0.1398, 0.1406, 0.1408,
        0.1430, 0.1466, 0.1480, 0.1482, 0.1496, 0.1521, 0.1554, 0.1554, 0.1572,
        0.1592, 0.1598, 0.1614, 0.1642, 0.1661, 0.1669, 0.1669, 0.1676, 0.1687,
        0.1707, 0.1713, 0.1724, 0.1726, 0.1740, 0.1767, 0.1778, 0.1795, 0.1807,
        0.1813, 0.1835, 0.1871, 0.1900, 0.1920, 0.1953, 0.1955, 0.1966, 0.1978,
        0.2018, 0.2046, 0.2086, 0.2196, 0.3829], device='cuda:0',
       grad_fn=<CopyBackwards>)