## CLIP(Contrastive Language-Image Pre-Training)

![CLIP_architecture](assets/clip_image1.png)

In [1]:
from clip import Flickr8kValidationSet,VisualEncoder,TextEncoder,Flickr8kDataset,image_transform
from tqdm import tqdm    
from datasets import load_dataset
from torchvision import transforms
import torch
import torch.nn.functional as F

In [2]:
from util import get_device
device =  get_device()
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224
    transforms.ToTensor(),          # Convert to [0,1] tensor, shape: [C, H, W]
])
flickr8k = load_dataset("jxie/flickr8k", data_dir="data")
text_encoder = TextEncoder().to(device)
visual_encoder = VisualEncoder().to(device)
temperature = 0.02 
num_epochs = 2
from torch.utils.data import DataLoader
train_data = flickr8k["train"]
train_dataset = Flickr8kDataset(train_data, transform=image_transform)
batch_size = 64

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
optimizer = torch.optim.AdamW(
list(visual_encoder.parameters()) + list(text_encoder.parameters()),
lr=1e-4,
weight_decay=1e-5
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
for _ in range(num_epochs):
    pbar = tqdm(train_loader)
    for batch in pbar:
        image = batch[0].to(device)
        text = batch[1]
        visual_embed = visual_encoder.forward(image)
        text_embed = text_encoder.forward(text)
        visual_embed.shape,text_embed.shape
        norm_visual_embed = F.normalize(visual_embed)
        norm_text_embed = F.normalize(text_embed)
        logits = norm_visual_embed @ norm_text_embed.T
        logits = logits / temperature
        batch_size = logits.size(0)
        labels = torch.arange(batch_size, device=logits.device)
        loss_i2t = F.cross_entropy(logits, labels)        # 图像作为 query，文本作为 target
        loss_t2i = F.cross_entropy(logits.T, labels)      # 文本作为 query，图像作为 target
        loss = (loss_i2t + loss_t2i) / 2
        loss.backward()
        pbar.set_description(f"loss:{loss.item()}")
        optimizer.step()
        optimizer.zero_grad()

loss:2.782959461212158:  40%|████      | 151/375 [01:28<02:11,  1.71it/s] 


KeyboardInterrupt: 