In [10]:
import sys
sys.path.append("/home/mila/l/le.zhang/scratch/colxlip/src/")
import torch
from colxlip.factory import create_model_and_transforms
import open_clip

In [11]:
model, preprocess_train, preprocess_val = create_model_and_transforms(model_name="ViT-B-16-colxlip", pretrained="/home/mila/l/le.zhang/scratch/colxlip/src/logs/2025_04_08-01_14_39-model_ViT-B-16-colxlip-lr_5e-06-b_196-j_8-p_amp/checkpoints/epoch_1.pt")

In [12]:
from datasets import load_dataset
winoground = load_dataset("facebook/winoground")
tokenizer = open_clip.get_tokenizer('ViT-B-16')

In [13]:
def compute_colbert_similarity(token_image_features, token_text_features):
    """
    Compute token-level similarity. Given relative information between image and text tokens,
    we only compute similarity from text tokens to image tokens, without considering the reverse.
    This is based on the assumption that the image tokens are more informative than the text tokens, 
    and we assume each text token is associated with image tokens while not vice versa.
    
    Args:
        token_image_features: Token-level features from images [batch_size_img, n_img_tokens, embed_dim]
        token_text_features: Token-level features from text [batch_size_txt, n_txt_tokens, embed_dim]
        
    Returns:
        Token-level similarity matrix [batch_size_txt, batch_size_img], similar to global similarity, each entry with value in [-1, 1]
    """
    sim_matrix = torch.einsum('mnd,kqd->mknq', token_text_features, token_image_features)
    max_sim_per_txt_token = torch.max(sim_matrix, dim=3)[0]  # [batch_size_txt, batch_size_img, n_txt_tokens]
    
    # Create a mask for non-zero values
    mask = (max_sim_per_txt_token != 0).float()
    # Sum of non-zero values
    sum_sim = torch.sum(max_sim_per_txt_token, dim=2)
    # Count of non-zero values (adding small epsilon to avoid division by zero)
    count = torch.sum(mask, dim=2) + 1e-8
    # Average of non-zero values
    logits_per_text_token = sum_sim / count  # [batch_size_txt, batch_size_img]
  
    return logits_per_text_token

In [None]:
from matplotlib import pyplot as plt

winoground_test = winoground['test']
example_idx = 155

ax1 = plt.subplot(1, 3, 1)
ax1.title.set_text('image_0')
plt.imshow(winoground_test[example_idx]["image_0"].convert("RGB"))

ax2 = plt.subplot(1, 3, 2)
ax2.title.set_text('image_1')
plt.imshow(winoground_test[example_idx]["image_1"].convert("RGB"))

plt.show()

print("caption_0:", winoground_test[example_idx]["caption_0"])
print("caption_1:", winoground_test[example_idx]["caption_1"])

# Define the preprocess function from your model
preprocess = preprocess_val

# Note that some images in winoground are RGBA and some are RGB. Need to convert all to RGB with .convert('RGB')
# Process images and text
image1 = preprocess(winoground_test[example_idx]["image_0"].convert("RGB")).unsqueeze(0)
image2 = preprocess(winoground_test[example_idx]["image_1"].convert("RGB")).unsqueeze(0)
images = torch.cat([image1, image2], dim=0)

text = tokenizer([winoground_test[example_idx]["caption_0"], winoground_test[example_idx]["caption_1"]])

# Get model outputs
with torch.no_grad():
    image_features, token_image_features = model.encode_image(images)
    text_features, token_text_features = model.encode_text(text)
    
    # Normalize features
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    token_image_features = token_image_features / token_image_features.norm(dim=-1, keepdim=True)
    token_text_features = token_text_features / token_text_features.norm(dim=-1, keepdim=True)
    
    # Calculate similarity scores
    logits_per_image = image_features @ text_features.T
    logits_per_text_tokens = compute_colbert_similarity(token_image_features, token_text_features)
    logits_per_image_tokens = logits_per_text_tokens.T
    
    # Print scores
    print("\nCLIP image-text match scores:")
    print(f"image_0, caption_0: {logits_per_image[0][0].item():.4f}")
    print(f"image_0, caption_1: {logits_per_image[0][1].item():.4f}")
    print(f"image_1, caption_0: {logits_per_image[1][0].item():.4f}")
    print(f"image_1, caption_1: {logits_per_image[1][1].item():.4f}")

    print("\nColBERT image-text match scores:")
    print(f"image_0, caption_0: {logits_per_image_tokens[0][0].item():.4f}")
    print(f"image_0, caption_1: {logits_per_image_tokens[0][1].item():.4f}")
    print(f"image_1, caption_0: {logits_per_image_tokens[1][0].item():.4f}")
    print(f"image_1, caption_1: {logits_per_image_tokens[1][1].item():.4f}")

In [14]:
from tqdm import tqdm
winoground_clip_scores = []
winoground_colbert_scores = []
winoground_combined_scores = []

for example in tqdm(winoground['test']):
    # Process images and text
    image1 = preprocess(example["image_0"].convert("RGB")).unsqueeze(0)
    image2 = preprocess(example["image_1"].convert("RGB")).unsqueeze(0)
    images = torch.cat([image1, image2], dim=0)
    
    text = tokenizer([example["caption_0"], example["caption_1"]])
    
    # Get model outputs
    with torch.no_grad():
        image_features, token_image_features = model.encode_image(images)
        text_features, token_text_features = model.encode_text(text)
        
        # Normalize features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        
        token_image_features = token_image_features / token_image_features.norm(dim=-1, keepdim=True)
        token_text_features = token_text_features / token_text_features.norm(dim=-1, keepdim=True)
        
        # Calculate similarity scores
        logits_per_image = image_features @ text_features.T
        logits_per_text_tokens = compute_colbert_similarity(token_image_features, token_text_features)
        logits_per_image_tokens = logits_per_text_tokens.T
        
        # Extract the four scores for CLIP
        clip_score_c0_i0 = logits_per_image[0][0].item()
        clip_score_c1_i0 = logits_per_image[0][1].item()
        clip_score_c0_i1 = logits_per_image[1][0].item()
        clip_score_c1_i1 = logits_per_image[1][1].item()
        
        # Extract the four scores for ColBERT
        colbert_score_c0_i0 = logits_per_image_tokens[0][0].item()
        colbert_score_c1_i0 = logits_per_image_tokens[0][1].item()
        colbert_score_c0_i1 = logits_per_image_tokens[1][0].item()
        colbert_score_c1_i1 = logits_per_image_tokens[1][1].item()
        
        combined_score_c0_i0 = 0.5 * clip_score_c0_i0 + 0.5 * colbert_score_c0_i0
        combined_score_c1_i0 = 0.5 * clip_score_c1_i0 + 0.5 * colbert_score_c1_i0
        combined_score_c0_i1 = 0.5 * clip_score_c0_i1 + 0.5 * colbert_score_c0_i1
        combined_score_c1_i1 = 0.5 * clip_score_c1_i1 + 0.5 * colbert_score_c1_i1
    
    # Store CLIP scores
    winoground_clip_scores.append({
        "id": example["id"], 
        "c0_i0": clip_score_c0_i0, 
        "c0_i1": clip_score_c0_i1, 
        "c1_i0": clip_score_c1_i0, 
        "c1_i1": clip_score_c1_i1
    })
    
    # Store ColBERT scores
    winoground_colbert_scores.append({
        "id": example["id"], 
        "c0_i0": colbert_score_c0_i0, 
        "c0_i1": colbert_score_c0_i1, 
        "c1_i0": colbert_score_c1_i0, 
        "c1_i1": colbert_score_c1_i1
    })
    
    # Store combined scores
    winoground_combined_scores.append({
        "id": example["id"], 
        "c0_i0": combined_score_c0_i0, 
        "c0_i1": combined_score_c0_i1, 
        "c1_i0": combined_score_c1_i0, 
        "c1_i1": combined_score_c1_i1
    })


100%|██████████| 400/400 [02:21<00:00,  2.83it/s]


In [24]:
# 按照0.3:0.7的权重计算combined score
winoground_combined_scores = []
for i,j in zip(winoground_clip_scores, winoground_colbert_scores):
    winoground_combined_scores.append({
        "id": i["id"], 
        "c0_i0": 0.7 * i["c0_i0"] + 0.3 * j["c0_i0"], 
        "c0_i1": 0.7 * i["c0_i1"] + 0.3 * j["c0_i1"], 
        "c1_i0": 0.7 * i["c1_i0"] + 0.3 * j["c1_i0"], 
        "c1_i1": 0.7 * i["c1_i1"] + 0.3 * j["c1_i1"]
    })


In [25]:
def text_correct(result):
    return result["c0_i0"] > result["c1_i0"] and result["c1_i1"] > result["c0_i1"]

def image_correct(result):
    return result["c0_i0"] > result["c0_i1"] and result["c1_i1"] > result["c1_i0"]

def group_correct(result):
    return image_correct(result) and text_correct(result)

text_correct_count = 0
image_correct_count = 0
group_correct_count = 0

colbert_text_correct_count = 0
colbert_image_correct_count = 0
colbert_group_correct_count = 0

combined_text_correct_count = 0
combined_image_correct_count = 0
combined_group_correct_count = 0

for result in winoground_clip_scores:
  text_correct_count += 1 if text_correct(result) else 0
  image_correct_count += 1 if image_correct(result) else 0
  group_correct_count += 1 if group_correct(result) else 0

denominator = len(winoground_clip_scores)
print("text score:", text_correct_count/denominator)
print("image score:", image_correct_count/denominator)
print("group score:", group_correct_count/denominator)

for result in winoground_colbert_scores:
  colbert_text_correct_count += 1 if text_correct(result) else 0
  colbert_image_correct_count += 1 if image_correct(result) else 0
  colbert_group_correct_count += 1 if group_correct(result) else 0

print("colbert text score:", colbert_text_correct_count/denominator)
print("colbert image score:", colbert_image_correct_count/denominator)
print("colbert group score:", colbert_group_correct_count/denominator)

for result in winoground_combined_scores:
  combined_text_correct_count += 1 if text_correct(result) else 0
  combined_image_correct_count += 1 if image_correct(result) else 0
  combined_group_correct_count += 1 if group_correct(result) else 0

print("combined text score:", combined_text_correct_count/denominator)
print("combined image score:", combined_image_correct_count/denominator)
print("combined group score:", combined_group_correct_count/denominator)






text score: 0.2775
image score: 0.0925
group score: 0.07
colbert text score: 0.1975
colbert image score: 0.125
colbert group score: 0.07
combined text score: 0.28
combined image score: 0.0925
combined group score: 0.06


# Baseline


In [5]:
import open_clip
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained='laion400m_e32')
from datasets import load_dataset
import torch
winoground = load_dataset("facebook/winoground")

In [9]:
from tqdm import tqdm
winoground_clip_scores = []
winoground_colbert_scores = []
winoground_combined_scores = []

for example in tqdm(winoground['test']):
    # Process images and text
    image1 = preprocess(example["image_0"].convert("RGB")).unsqueeze(0)
    image2 = preprocess(example["image_1"].convert("RGB")).unsqueeze(0)
    images = torch.cat([image1, image2], dim=0)
    
    text = tokenizer([example["caption_0"], example["caption_1"]])
    
    # Get model outputs
    with torch.no_grad():
        image_features = model.encode_image(images)
        text_features = model.encode_text(text)
        
        # Normalize features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logits_per_image = image_features @ text_features.T
        # Extract the four scores for CLIP
        clip_score_c0_i0 = logits_per_image[0][0].item()
        clip_score_c1_i0 = logits_per_image[0][1].item()
        clip_score_c0_i1 = logits_per_image[1][0].item()
        clip_score_c1_i1 = logits_per_image[1][1].item()
        
   
    # Store CLIP scores
    winoground_clip_scores.append({
        "id": example["id"], 
        "c0_i0": clip_score_c0_i0, 
        "c0_i1": clip_score_c0_i1, 
        "c1_i0": clip_score_c1_i0, 
        "c1_i1": clip_score_c1_i1
    })


100%|██████████| 400/400 [02:19<00:00,  2.87it/s]


In [None]:
def text_correct(result):
    return result["c0_i0"] > result["c1_i0"] and result["c1_i1"] > result["c0_i1"]

def image_correct(result):
    return result["c0_i0"] > result["c0_i1"] and result["c1_i1"] > result["c1_i0"]

def group_correct(result):
    return image_correct(result) and text_correct(result)

text_correct_count = 0
image_correct_count = 0
group_correct_count = 0

for result in winoground_clip_scores:
  text_correct_count += 1 if text_correct(result) else 0
  image_correct_count += 1 if image_correct(result) else 0
  group_correct_count += 1 if group_correct(result) else 0

denominator = len(winoground_clip_scores)
print("text score:", text_correct_count/denominator)
print("image score:", image_correct_count/denominator)
print("group score:", group_correct_count/denominator)