# Download Data

In [1]:
import os
data_dir = 'data'

%pip install git+https://github.com/openai/CLIP.git open_clip_torch
%pip install git-lfs

annotations_folder = os.path.join(data_dir, "annotations")
if not os.path.exists(annotations_folder):
    print(f"{annotations_folder} folder does not exist, downloading...")
    !wget -P $data_dir http://images.cocodataset.org/annotations/annotations_trainval2017.zip
    !unzip $data_dir/annotations_trainval2017.zip -d $data_dir

val_folder = os.path.join(data_dir, "val2017")
if not os.path.exists(val_folder):
    print(f"{val_folder} folder does not exist, downloading...")
    !wget -P $data_dir http://images.cocodataset.org/zips/val2017.zip
    !unzip $data_dir/val2017.zip -d $data_dir

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-f50l89ua
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-f50l89ua
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25ldone
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
from transformers import AutoConfig, AutoModel, AutoProcessor
import open_clip
import sys
sys.path.append('../..')
import torch
from torchvision.datasets import CocoCaptions
from PIL import Image
import clip
from tqdm import tqdm
device = "cuda" if torch.cuda.is_available() else "cpu"

2025-02-25 17:41:43.854284: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Encoder Function

In [3]:
from transformers import AutoConfig, AutoModel
import os
import shutil
from safetensors.torch import load_file

def load_triplet_clip_encoders(
    repo_name="TripletCLIP/CC12M_TripletCLIP_ViTB12",
    text_path='',
    image_path='',
    work_dir="/content",
    device="cuda"
):
    repo_url = f"https://huggingface.co/{repo_name}"

    repo_folder_name = repo_name.split("/")[-1]  # e.g. CC12M_TripletCLIP_ViTB12
    target_path = os.path.join(work_dir, repo_folder_name)
    if os.path.exists(target_path):
        shutil.rmtree(target_path)

    !git clone $repo_url $target_path
    !cd $target_path
    !git lfs install

    text_config_path = os.path.join(target_path, text_path)
    image_config_path = os.path.join(target_path, image_path)

    text_config = AutoConfig.from_pretrained(text_config_path, local_files_only=True, trust_remote_code=True, device_map="auto")
    image_config = AutoConfig.from_pretrained(image_config_path, local_files_only=True, trust_remote_code=True, device_map="auto")

    text_encoder = AutoModel.from_config(text_config, trust_remote_code=True)
    image_encoder = AutoModel.from_config(image_config, trust_remote_code=True)

    text_encoder.load_state_dict(load_file(os.path.join(text_config_path, 'model.safetensors')), strict=False)
    image_encoder.load_state_dict(load_file(os.path.join(image_config_path, 'model.safetensors')), strict=False)

    text_encoder.to(device)
    image_encoder.to(device)

    text_encoder.eval()
    image_encoder.eval()

    return text_encoder, image_encoder

def transform(
    dataset,
    image_encoder,
    text_encoder,
    processor,
    tokenizer,
    device='cuda',
    sample_size=None
):
    from itertools import islice
    image_features = []
    text_features = []

    if sample_size is not None and sample_size <= 5000:
        dataset_iter = islice(dataset, sample_size)
    else:
        dataset_iter = dataset

    with torch.no_grad():
        for image, captions in tqdm(dataset_iter, total=sample_size):
            image_input = processor(images=image, return_tensors="pt").to(device)
            image_features.append(image_encoder(image_input))

            captions = captions[0:5]
            caption_input = tokenizer(captions).to(device)
            text_features.extend(text_encoder(caption_input).text_embeds)

        image_features = torch.stack(image_features).squeeze()
        image_features /= image_features.norm(dim=-1, keepdim=True)

        print(image_features.shape)
        text_features = torch.stack(text_features)
        text_features /= text_features.norm(dim=-1, keepdim=True)

    return image_features, text_features

# Dataset

In [4]:
coco = CocoCaptions(root=val_folder, annFile=os.path.join(annotations_folder, 'captions_val2017.json'), transform=None)

loading annotations into memory...
Done (t=0.03s)
creating index...
index created!


# Processor and Tokenizer

In [5]:
processor = AutoProcessor.from_pretrained('openai/clip-vit-base-patch32')
tokenizer = open_clip.get_tokenizer('ViT-B-32')

# Encoder

In [6]:
text_encoder_t, image_encoder_t = load_triplet_clip_encoders(repo_name='TripletCLIP/CC12M_TripletCLIP_ViTB12',text_path='text-encoder',image_path='vision-encoder', work_dir=data_dir, device=device)
text_encoder_n, image_encoder_n = load_triplet_clip_encoders(repo_name='TripletCLIP/CC12M_NegCLIP_ViTB12',text_path='text-encoder',image_path='vision-encoder', work_dir=data_dir, device=device)
# text_encoder_clip_base, image_encoder_clip_base = load_triplet_clip_encoders(repo_name='openai/clip-vit-base-patch32', work_dir='/content', device=device)
# text_encoder_clip_large, image_encoder_clip_large = load_triplet_clip_encoders(repo_name='openai/clip-vit-large-patch14', work_dir='/content', device=device)
text_encoder_l, image_encoder_l = load_triplet_clip_encoders(repo_name='TripletCLIP/CC12M_LaCLIP_ViTB12',text_path='text-encoder',image_path='vision-encoder', work_dir=data_dir, device=device)
text_encoder_npp, image_encoder_npp = load_triplet_clip_encoders(repo_name='TripletCLIP/CC12M_NegCLIPPP_ViTB12',text_path='text-encoder',image_path='vision-encoder', work_dir=data_dir, device=device)

Cloning into 'data/CC12M_TripletCLIP_ViTB12'...
remote: Enumerating objects: 12, done.[K
remote: Counting objects: 100% (9/9), done.[K
remote: Compressing objects: 100% (9/9), done.[K
remote: Total 12 (delta 1), reused 0 (delta 0), pack-reused 3 (from 1)[K
Unpacking objects: 100% (12/12), 5.43 KiB | 1.36 MiB/s, done.
Filtering content: 100% (2/2), 577.12 MiB | 16.46 MiB/s, done.
Updated Git hooks.
Git LFS initialized.
Cloning into 'data/CC12M_NegCLIP_ViTB12'...
remote: Enumerating objects: 12, done.[K
remote: Counting objects: 100% (9/9), done.[K
remote: Compressing objects: 100% (9/9), done.[K
remote: Total 12 (delta 1), reused 0 (delta 0), pack-reused 3 (from 1)[K
Unpacking objects: 100% (12/12), 5.43 KiB | 1.09 MiB/s, done.
Filtering content: 100% (2/2), 577.12 MiB | 16.35 MiB/s, done.
Updated Git hooks.
Git LFS initialized.
Cloning into 'data/CC12M_LaCLIP_ViTB12'...
remote: Enumerating objects: 12, done.[K
remote: Counting objects: 100% (9/9), done.[K
remote: Compressing 

## Feature Extraction

In [7]:
sample_size = 5000
image_features_t, text_features_t = transform(coco, image_encoder_t, text_encoder_t, processor, tokenizer, device, sample_size=sample_size)
image_features_n, text_features_n = transform(coco, image_encoder_n, text_encoder_n, processor, tokenizer, device, sample_size=sample_size)
# image_features_clip_base, text_features_clip_base = transform(coco, image_encoder_clip_base, text_encoder_clip_base, processor, tokenizer, device, sample_size=sample_size)
image_features_l, text_features_l = transform(coco, image_encoder_l, text_encoder_l, processor, tokenizer, device, sample_size=sample_size)
image_features_npp, text_features_npp = transform(coco, image_encoder_npp, text_encoder_npp, processor, tokenizer, device, sample_size=sample_size)

100%|██████████| 5000/5000 [01:24<00:00, 58.92it/s]


torch.Size([5000, 512])


100%|██████████| 5000/5000 [01:25<00:00, 58.42it/s]


torch.Size([5000, 512])


100%|██████████| 5000/5000 [01:24<00:00, 59.12it/s]


torch.Size([5000, 512])


100%|██████████| 5000/5000 [01:26<00:00, 58.10it/s]


torch.Size([5000, 512])


# Similarity

In [42]:
def rotate_and_calc_similarity(image_features, text_features):

    # ========== 1) 计算全局均值并归一化 ==========
    # image 全局均值: (D,)
    mean_image = image_features.mean(dim=0)
    # text 全局均值:  (D,)
    mean_text  = text_features.mean(dim=0)

    # 归一化 (若想直接点乘当作余弦相似度, image/text_features 本身也需归一化)
    mean_image_norm = mean_image / (mean_image.norm() + 1e-12)
    mean_text_norm  = mean_text  / (mean_text.norm()  + 1e-12)

    # ========== 2) 计算夹角 theta 并构造二维平面内的正交向量 ==========
    # cos_angle = a·b
    cos_angle = torch.dot(mean_image_norm, mean_text_norm)
    # 防止浮点误差导致 acos 输入超出 [-1,1]
    cos_angle = torch.clamp(cos_angle, -1.0, 1.0)
    theta = torch.acos(cos_angle)  # 弧度

    sin_angle = torch.sqrt(1 - cos_angle**2 + 1e-12)
    # 在 (mean_image_norm, mean_text_norm) 所张平面上，构造与 mean_image_norm 正交的单位向量 v
    v = (mean_text_norm - cos_angle * mean_image_norm) / (sin_angle + 1e-12)

    # ========== 3) 在 (mean_image_norm, v) 平面内对 image_features 做旋转 ==========
    # 投影系数
    proj_a = image_features @ mean_image_norm  # (N,)
    proj_v = image_features @ v                # (N,)

    # 2D 旋转
    # new_a = a*cosθ - v*sinθ
    # new_v = a*sinθ + v*cosθ
    rotated_proj_a = proj_a * torch.cos(theta) - proj_v * torch.sin(theta)  # (N,)
    rotated_proj_v = proj_a * torch.sin(theta) + proj_v * torch.cos(theta)  # (N,)

    # 在平面内的分量(旋转后)
    rotated_parallel = (rotated_proj_a.unsqueeze(1) * mean_image_norm.unsqueeze(0)
                      + rotated_proj_v.unsqueeze(1) * v.unsqueeze(0))

    # 原本在平面内的分量
    orig_parallel = (proj_a.unsqueeze(1) * mean_image_norm.unsqueeze(0)
                   + proj_v.unsqueeze(1) * v.unsqueeze(0))

    # 正交分量(不在平面内, 保持不变)
    orthogonal_component = image_features - orig_parallel

    # 旋转后的图像特征 (N, D)
    rotated_image_features = rotated_parallel + orthogonal_component

    # ========== 4) 计算相似度矩阵 (N, 5N) ==========
    # 如果想用余弦相似度, 这里要保证 rotated_image_features 与 text_features 已各自归一化。
    # 否则就是一般点乘。
    similarity_matrix = rotated_image_features @ text_features.T  # (N, 5N)

    # ========== 5) 通过广播为 "正例" 与 "负例" 构造布尔掩码 ==========
    # 对第 i 行(图像 i), 正例对应的列区间是 [5*i, ..., 5*i+4]
    #   => j_idx // 5 == i_idx

    N = image_features.size(0)
    i_idx = torch.arange(N, device=similarity_matrix.device).unsqueeze(1).expand(N, 5*N)
    j_idx = torch.arange(5*N, device=similarity_matrix.device).unsqueeze(0).expand(N, 5*N)

    pos_mask = (j_idx // 5 == i_idx)  # 形状 (N, 5N), True 表示匹配
    # 所有正例相似度 => (N*5,) 向量
    pos_sims = similarity_matrix[pos_mask]
    # 如果你想分别取 "每个图像 vs. 它的 5 条文本" 的平均，可以 reshape => (N,5) 再 mean
    # 这会先对每张图像5个caption做平均，再对 N 张图像平均
    mean_pos = pos_sims.view(N, 5).mean()

    # 负例相似度(不匹配)
    neg_sims = similarity_matrix[~pos_mask]  # (N*(5N-5),)
    mean_neg = neg_sims.mean()

    return mean_pos.item(), mean_neg.item()

In [59]:
sim_pos_t, sim_neg_t = rotate_and_calc_similarity(image_features_t, text_features_t)
sim_pos_n, sim_neg_n = rotate_and_calc_similarity(image_features_n, text_features_n)
sim_pos_l, sim_neg_l = rotate_and_calc_similarity(image_features_l, text_features_l)
sim_pos_npp, sim_neg_npp = rotate_and_calc_similarity(image_features_npp, text_features_npp)
print(f'TripletCLIP Similarity    - Positive: {sim_pos_t}, Negative: {sim_neg_t}, Pos-Neg: {sim_pos_t - sim_neg_t}')
print(f'NegativeCLIP Similarity   - Positive: {sim_pos_n}, Negative: {sim_neg_n}, Pos-Neg: {sim_pos_n - sim_neg_n}')
print(f'LaCLIP Similarity         - Positive: {sim_pos_l}, Negative: {sim_neg_l}, Pos-Neg: {sim_pos_l - sim_neg_l}')
print(f'NegativeCLIP++ Similarity - Positive: {sim_pos_npp}, Negative: {sim_neg_npp}, Pos-Neg: {sim_pos_npp - sim_neg_npp}')

TripletCLIP Similarity    - Positive: 0.7865129709243774, Negative: 0.6824867129325867, Pos-Neg: 0.10402625799179077
NegativeCLIP Similarity   - Positive: 0.6373680830001831, Negative: 0.47875773906707764, Pos-Neg: 0.15861034393310547
LaCLIP Similarity         - Positive: 0.5077755451202393, Negative: 0.3667873442173004, Pos-Neg: 0.14098820090293884
NegativeCLIP++ Similarity - Positive: 0.6485074162483215, Negative: 0.49330487847328186, Pos-Neg: 0.15520253777503967


# Retrieval

In [36]:
import torch

def evaluate_retrieval(image_features, text_features, num_captions=5):
    """
    评估 image-to-text (I2T) 和 text-to-image (T2I) 的检索性能指标。
    
    Args:
        image_features (torch.Tensor): shape 为 (N, D) 的图片特征张量，其中 N 为图片数量。
        text_features (torch.Tensor): shape 为 (N * num_captions, D) 的文本特征张量，每张图片对应 num_captions 个 caption。
        num_captions (int): 每张图片对应的 caption 数量（默认 5）。
    
    Returns:
        dict: 包含以下键值对的字典：
            {
                'I2T_top1': float,  # 图片检索文本 Top1 准确率
                'I2T_top5': float,  # 图片检索文本 Top5 准确率
                'I2T_top10': float, # 图片检索文本 Top10 准确率
                'T2I_top1': float,  # 文本检索图片 Top1 准确率
                'T2I_top5': float,  # 文本检索图片 Top5 准确率
                'T2I_top10': float, # 文本检索图片 Top10 准确率
            }
    """
    # 确保文本数量与图片数量及每张图片的 caption 数匹配
    num_images = image_features.size(0)
    assert text_features.size(0) == num_images * num_captions, "文本特征数量与图片数量不匹配！"

    # 若未归一化，则先归一化特征（如果已归一化可省略）
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    
    # 计算相似度矩阵：shape 为 (N, N * num_captions)
    similarity = image_features @ text_features.t()
    
    # 评估 image-to-text (I2T)
    I2T_top1, I2T_top5, I2T_top10 = 0, 0, 0
    for i in range(num_images):
        sim_i = similarity[i]  # 第 i 张图片与所有文本之间的相似度
        # 获取从大到小排序后的文本索引
        sorted_indices = torch.argsort(sim_i, descending=True)
        # 第 i 张图片的 ground truth 文本索引范围
        gt_indices = list(range(i * num_captions, i * num_captions + num_captions))
        # Top1 检索：如果 ground truth 中任一索引出现在前 1 个，则认为正确
        if any(idx in sorted_indices[:1] for idx in gt_indices):
            I2T_top1 += 1
        # Top5 检索
        if any(idx in sorted_indices[:5] for idx in gt_indices):
            I2T_top5 += 1
        # Top10 检索
        if any(idx in sorted_indices[:10] for idx in gt_indices):
            I2T_top10 += 1

    I2T_top1_score = I2T_top1 / num_images
    I2T_top5_score = I2T_top5 / num_images
    I2T_top10_score = I2T_top10 / num_images

    # 评估 text-to-image (T2I)
    # 这里可以利用相似度矩阵的转置，shape 为 (N * num_captions, N)
    similarity_t = similarity.t()
    T2I_top1, T2I_top5, T2I_top10 = 0, 0, 0
    for j in range(text_features.size(0)):
        sim_j = similarity_t[j]  # 第 j 个文本与所有图片的相似度
        sorted_indices = torch.argsort(sim_j, descending=True)
        # 对于第 j 个文本，其对应图片索引为 j // num_captions
        gt_image = j // num_captions
        if gt_image in sorted_indices[:1]:
            T2I_top1 += 1
        if gt_image in sorted_indices[:5]:
            T2I_top5 += 1
        if gt_image in sorted_indices[:10]:
            T2I_top10 += 1

    total_texts = text_features.size(0)
    T2I_top1_score = T2I_top1 / total_texts
    T2I_top5_score = T2I_top5 / total_texts
    T2I_top10_score = T2I_top10 / total_texts

    results = {
        'I2T_top1': I2T_top1_score,
        'I2T_top5': I2T_top5_score,
        'I2T_top10': I2T_top10_score,
        'T2I_top1': T2I_top1_score,
        'T2I_top5': T2I_top5_score,
        'T2I_top10': T2I_top10_score,
    }

    return results

In [57]:
results_t = evaluate_retrieval(image_features_t, text_features_t)
results_n = evaluate_retrieval(image_features_n, text_features_n)
results_l = evaluate_retrieval(image_features_l, text_features_l)
results_npp = evaluate_retrieval(image_features_npp, text_features_npp)
print(f'Image to Text (I2T) Top1  - TripletCLIP: {results_t["I2T_top1"]}, NegCLIP: {results_n["I2T_top1"]}, LaCLIP: {results_l["I2T_top1"]}, NegCLIP++: {results_npp["I2T_top1"]}')
print(f'Image to Text (I2T) Top5  - TripletCLIP: {results_t["I2T_top5"]}, NegCLIP: {results_n["I2T_top5"]}, LaCLIP: {results_l["I2T_top5"]}, NegCLIP++: {results_npp["I2T_top5"]}')
print(f'Image to Text (I2T) Top10 - TripletCLIP: {results_t["I2T_top10"]}, NegCLIP: {results_n["I2T_top10"]}, LaCLIP: {results_l["I2T_top10"]}, NegCLIP++: {results_npp["I2T_top10"]}')
print(f'Text to Image (T2I) Top1  - TripletCLIP: {results_t["T2I_top1"]}, NegCLIP: {results_n["T2I_top1"]}, LaCLIP: {results_l["T2I_top1"]}, NegCLIP++: {results_npp["T2I_top1"]}')
print(f'Text to Image (T2I) Top5  - TripletCLIP: {results_t["T2I_top5"]}, NegCLIP: {results_n["T2I_top5"]}, LaCLIP: {results_l["T2I_top5"]}, NegCLIP++: {results_npp["T2I_top5"]}')
print(f'Text to Image (T2I) Top10 - TripletCLIP: {results_t["T2I_top10"]}, NegCLIP: {results_n["T2I_top10"]}, LaCLIP: {results_l["T2I_top10"]}, NegCLIP++: {results_npp["T2I_top10"]}')

Image to Text (I2T) Top1  - TripletCLIP: 0.1416, NegCLIP: 0.1224, LaCLIP: 0.0982, NegCLIP++: 0.1098
Image to Text (I2T) Top5  - TripletCLIP: 0.3358, NegCLIP: 0.293, LaCLIP: 0.2552, NegCLIP++: 0.2756
Image to Text (I2T) Top10 - TripletCLIP: 0.4494, NegCLIP: 0.4014, LaCLIP: 0.3526, NegCLIP++: 0.3736
Text to Image (T2I) Top1  - TripletCLIP: 0.11348, NegCLIP: 0.08208, LaCLIP: 0.07116, NegCLIP++: 0.0836
Text to Image (T2I) Top5  - TripletCLIP: 0.27908, NegCLIP: 0.22212, LaCLIP: 0.1922, NegCLIP++: 0.2232
Text to Image (T2I) Top10 - TripletCLIP: 0.38332, NegCLIP: 0.31744, LaCLIP: 0.27956, NegCLIP++: 0.3184


# Uniformity and Alignment

In [53]:
# alignment metric
# features: (N, D)
import numpy as np
def compute_alignment(text_features, image_features, alpha=2, captions_per_image=5):
    # image_features: shape [N, D]
    # text_features: shape [N * captions_per_image, D]
    num_images, D = image_features.shape
    # 重塑为 [N, captions_per_image, D]
    text_features = text_features.reshape(num_images, captions_per_image, D)
    # 扩展 image_features 为 [N, 1, D] 以便广播
    image_features_expanded = image_features[:, None, :]
    # 计算每个 caption 与对应 image 的差异
    diff = text_features - image_features_expanded  # shape: [N, captions_per_image, D]
    # L2 范数，计算每个 caption 的距离
    dist = np.linalg.norm(diff, axis=2)  # shape: [N, captions_per_image]
    # 对所有距离的 alpha 次方取均值
    return np.mean(dist ** alpha)

# uniformity metric
def compute_uniformity_approx(features, t=2.0, num_samples=100000):
    """
    Approximate uniformity via random sampling of pairs.
    
    L = log( E_{x,y}[ exp(-t ||x-y||^2) ] )
    """
    N = features.shape[0]
    # Randomly sample pairs (i, j)
    idx1 = np.random.randint(0, N, size=num_samples)
    idx2 = np.random.randint(0, N, size=num_samples)
    
    # Compute squared distances for the sampled pairs
    diff = features[idx1] - features[idx2]
    dist2 = np.sum(diff * diff, axis=-1)
    
    # Compute exp(-t * dist^2) and take the average
    values = np.exp(-t * dist2)
    avg = np.mean(values)
    
    return np.log(avg)

In [58]:
alignment_t = compute_alignment(text_features_t.cpu().numpy(), image_features_t.cpu().numpy())
alignment_n = compute_alignment(text_features_n.cpu().numpy(), image_features_n.cpu().numpy())
alignment_l = compute_alignment(text_features_l.cpu().numpy(), image_features_l.cpu().numpy())
alignment_npp = compute_alignment(text_features_npp.cpu().numpy(), image_features_npp.cpu().numpy())
uniformity_text_t = compute_uniformity_approx(text_features_t.cpu().numpy())
uniformity_image_t = compute_uniformity_approx(image_features_t.cpu().numpy())
uniformity_text_n = compute_uniformity_approx(text_features_n.cpu().numpy())
uniformity_image_n = compute_uniformity_approx(image_features_n.cpu().numpy())
uniformity_text_l = compute_uniformity_approx(text_features_l.cpu().numpy())
uniformity_image_l = compute_uniformity_approx(image_features_l.cpu().numpy())
uniformity_text_npp = compute_uniformity_approx(text_features_npp.cpu().numpy())
uniformity_image_npp = compute_uniformity_approx(image_features_npp.cpu().numpy())
print("TripletCLIP - Alignment: {}, Uniformity Text: {}, Uniformity Image: {}".format(alignment_t, uniformity_text_t, uniformity_image_t))
print("NegCLIP     - Alignment: {}, Uniformity Text: {}, Uniformity Image: {}".format(alignment_n, uniformity_text_n, uniformity_image_n))
print("LaCLIP      - Alignment: {}, Uniformity Text: {}, Uniformity Image: {}".format(alignment_l, uniformity_text_l, uniformity_image_l))
print("NegCLIP++   - Alignment: {}, Uniformity Text: {}, Uniformity Image: {}".format(alignment_npp, uniformity_text_npp, uniformity_image_npp))

TripletCLIP - Alignment: 1.3321770429611206, Uniformity Text: -1.4073137044906616, Uniformity Image: -1.0404778718948364
NegCLIP     - Alignment: 1.3323582410812378, Uniformity Text: -1.21366286277771, Uniformity Image: -2.557321071624756
LaCLIP      - Alignment: 1.3822613954544067, Uniformity Text: -1.5379735231399536, Uniformity Image: -3.0081543922424316
NegCLIP++   - Alignment: 1.3408807516098022, Uniformity Text: -1.219635248184204, Uniformity Image: -2.4809818267822266


# CKA

In [18]:
def linear_cka(X, Y):
    """
    计算线性 CKA (PyTorch 版本)。
    X, Y 都是 torch.Tensor，形状 (n, d)。
    """
    # X^T Y
    XY = X.T @ Y  # (d_X, d_Y)

    # 分子
    numerator = torch.sum(XY**2)

    # 分母
    XX = X.T @ X
    YY = Y.T @ Y
    denominator = torch.sqrt(torch.sum(XX**2) * torch.sum(YY**2))

    return numerator / denominator

def replicate_features(X: torch.Tensor, replicate_times: int = 5) -> torch.Tensor:
    """
    将形状为 (N, d) 的特征张量 X 中的每个样本复制 replicate_times 次，
    返回形状为 (N * replicate_times, d)。
    """
    # X.shape = (N, d)
    N, d = X.shape

    X_extended = X.unsqueeze(1).expand(N, replicate_times, d).reshape(-1, d)

    return X_extended

In [19]:
cka_t = linear_cka(text_features_t, replicate_features(image_features_t, 5))
cka_n = linear_cka(text_features_n, replicate_features(image_features_n, 5))
cka_l = linear_cka(text_features_l, replicate_features(image_features_l, 5))
cka_npp = linear_cka(text_features_npp, replicate_features(image_features_npp, 5))
print("TripletCLIP CKA: {}".format(cka_t))
print("NegCLIP     CKA: {}".format(cka_n))
print("LaCLIP      CKA: {}".format(cka_l))
print("NegCLIP++   CKA: {}".format(cka_npp))

TripletCLIP CKA: 0.9902653098106384
NegCLIP     CKA: 0.9574718475341797
LaCLIP      CKA: 0.9180614948272705
NegCLIP++   CKA: 0.963750958442688


In [22]:
cka_text_t_n = linear_cka(text_features_t, text_features_n)
cka_text_t_l = linear_cka(text_features_t, text_features_l)
cka_text_t_npp = linear_cka(text_features_t, text_features_npp)
cka_text_n_l = linear_cka(text_features_n, text_features_l)
cka_text_n_npp = linear_cka(text_features_n, text_features_npp)
cka_text_l_npp = linear_cka(text_features_l, text_features_npp)
cka_image_t_n = linear_cka(image_features_t, image_features_n)
cka_image_t_l = linear_cka(image_features_t, image_features_l)
cka_image_t_npp = linear_cka(image_features_t, image_features_npp)
cka_image_n_l = linear_cka(image_features_n, image_features_l)
cka_image_n_npp = linear_cka(image_features_n, image_features_npp)
cka_image_l_npp = linear_cka(image_features_l, image_features_npp)

import pandas as pd

labels = ["t", "n", "l", "npp"]

# 构造文本特征的表格数据（转换为标量）
text_data = [
    [1,                   cka_text_t_n.item(),   cka_text_t_l.item(),    cka_text_t_npp.item()],
    [cka_text_t_n.item(), 1,                     cka_text_n_l.item(),    cka_text_n_npp.item()],
    [cka_text_t_l.item(), cka_text_n_l.item(),   1,                      cka_text_l_npp.item()],
    [cka_text_t_npp.item(), cka_text_n_npp.item(), cka_text_l_npp.item(),  1]
]

# 构造图像特征的表格数据
image_data = [
    [1,                    cka_image_t_n.item(),   cka_image_t_l.item(),    cka_image_t_npp.item()],
    [cka_image_t_n.item(), 1,                      cka_image_n_l.item(),    cka_image_n_npp.item()],
    [cka_image_t_l.item(), cka_image_n_l.item(),   1,                      cka_image_l_npp.item()],
    [cka_image_t_npp.item(), cka_image_n_npp.item(), cka_image_l_npp.item(),  1]
]

# 使用 DataFrame 显示
df_text = pd.DataFrame(text_data, index=labels, columns=labels)
df_image = pd.DataFrame(image_data, index=labels, columns=labels)

print("Text Features CKA:")
print(df_text)
print("\nImage Features CKA:")
print(df_image)

Text Features CKA:
            t         n         l       npp
t    1.000000  0.994721  0.992668  0.995431
n    0.994721  1.000000  0.993695  0.996771
l    0.992668  0.993695  1.000000  0.994085
npp  0.995431  0.996771  0.994085  1.000000

Image Features CKA:
            t         n         l       npp
t    1.000000  0.966823  0.930975  0.972535
n    0.966823  1.000000  0.954775  0.978458
l    0.930975  0.954775  1.000000  0.954309
npp  0.972535  0.978458  0.954309  1.000000


In [26]:
cka_t_img_n_text = linear_cka(replicate_features(image_features_t, 5), text_features_n)
cka_t_img_l_text = linear_cka(replicate_features(image_features_t, 5), text_features_l)
cka_t_img_npp_text = linear_cka(replicate_features(image_features_t, 5), text_features_npp)
cka_n_img_t_text = linear_cka(replicate_features(image_features_n, 5), text_features_t)
cka_n_img_l_text = linear_cka(replicate_features(image_features_n, 5), text_features_l)
cka_n_img_npp_text = linear_cka(replicate_features(image_features_n, 5), text_features_npp)
cka_l_img_t_text = linear_cka(replicate_features(image_features_l, 5), text_features_t)
cka_l_img_n_text = linear_cka(replicate_features(image_features_l, 5), text_features_n)
cka_l_img_npp_text = linear_cka(replicate_features(image_features_l, 5), text_features_npp)
cka_npp_img_t_text = linear_cka(replicate_features(image_features_npp, 5), text_features_t)
cka_npp_img_n_text = linear_cka(replicate_features(image_features_npp, 5), text_features_n)
cka_npp_img_l_text = linear_cka(replicate_features(image_features_npp, 5), text_features_l)

labels = ["t", "n", "l", "npp"]
# 构造文本特征的表格数据（转换为标量）
data = [
    [cka_t.item(), cka_t_img_n_text.item(), cka_t_img_l_text.item(), cka_t_img_npp_text.item()],
    [cka_n_img_t_text.item(), cka_n.item(), cka_n_img_l_text.item(), cka_n_img_npp_text.item()],
    [cka_l_img_t_text.item(), cka_l_img_n_text.item(), cka_l.item(), cka_l_img_npp_text.item()],
    [cka_npp_img_t_text.item(), cka_npp_img_n_text.item(), cka_npp_img_l_text.item(), cka_npp.item()]
]

df = pd.DataFrame(data, index=labels, columns=labels)
print("CKA between text and image features:")
print(df)

CKA between text and image features:
            t         n         l       npp
t    0.990265  0.990506  0.985396  0.991323
n    0.958043  0.957472  0.953378  0.958353
l    0.922640  0.921082  0.918061  0.921956
npp  0.963544  0.962884  0.958622  0.963751
