# Load data from flickr8k

In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("adityajn105/flickr8k")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/adityajn105/flickr8k?dataset_version_number=1...


100%|██████████| 1.04G/1.04G [00:47<00:00, 23.4MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1


In [2]:
import os


print("list:", os.listdir(path))

images_dir = os.path.join(path, "Images")
print("Images:", os.listdir(images_dir)[:10])

captions_file = os.path.join(path, "captions.txt")
with open(captions_file, 'r', encoding='utf-8') as f:
    lines = f.readlines()
    print("captions.txt top 3:", lines[:3])

list: ['captions.txt', 'Images']
Images: ['3333017828_b930b9d41b.jpg', '3248408149_41a8dd90d3.jpg', '255741044_1102982213.jpg', '464527562_a18f095225.jpg', '315436114_6d386b8c36.jpg', '3581818450_546c89ca38.jpg', '2904997007_23d4b94101.jpg', '2675397335_1dcdbd12f5.jpg', '3534824784_7133119316.jpg', '2766726291_b83eb5d315.jpg']
captions.txt top 3: ['image,caption\n', '1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set of stairs in an entry way .\n', '1000268201_693b08cb0e.jpg,A girl going into a wooden building .\n']


# change the data to the way can be use in CLIP

In [3]:
import torch
from torch.utils.data import Dataset
from PIL import Image

class Flickr8kCLIPDataset(Dataset):
    def __init__(self, images_dir, captions_file, transform=None, tokenizer=None):
        self.images_dir = images_dir
        self.transform = transform
        self.tokenizer = tokenizer

        self.samples = []  # save (img_path, caption)

        with open(captions_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()

        # drop the fisrt line : "image,caption"
        for line in lines[1:]:
            line = line.strip()
            if not line:
                continue
            parts = line.split(',', 1)
            if len(parts) == 2:
                filename = parts[0].strip()
                caption = parts[1].strip()
                img_path = os.path.join(self.images_dir, filename)
                self.samples.append((img_path, caption))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, caption = self.samples[idx]

        # 1) load image
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # 2) tokenize text
        if self.tokenizer is not None:

            text_enc = self.tokenizer(
                caption,
                padding="max_length",
                truncation=True,
                max_length=77,  # CLIP setting
                return_tensors="pt"
            )
            # squeeze first dimentio => [batch=1, seq_len] -> [seq_len]
            text_enc = {k: v.squeeze(0) for k, v in text_enc.items()}
        else:
            # or keep the original
            text_enc = caption

        return image, text_enc

In [None]:
import torchvision.transforms as T
from torch.utils.data import DataLoader
from transformers import BertTokenizer
#use BertTokenizer so it work well with Bert pretain model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

transform = T.Compose([
    T.Resize((224,224)),
    T.ToTensor(),
    T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                std=[0.26862954, 0.26130258, 0.27577711]),
])

dataset = Flickr8kCLIPDataset(images_dir, captions_file, transform=transform, tokenizer=tokenizer)
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

# check the shape
images, text_enc = next(iter(dataloader))
print("images.shape:", images.shape)  # [4, 3, 224, 224]
print("text_enc['input_ids'].shape:", text_enc["input_ids"].shape)   # [4, 77]
print("text_enc['attention_mask'].shape:", text_enc["attention_mask"].shape)

images.shape: torch.Size([100, 3, 224, 224])
text_enc['input_ids'].shape: torch.Size([100, 77])
text_enc['attention_mask'].shape: torch.Size([100, 77])


# CLIP structure

adapter

In [None]:
import torch.nn as nn
import torch.nn.functional as F
#set a simple Adapter, simple mlp and res net
class Adapter(nn.Module):
    def __init__(self, hidden_dim, adapter_dim=64):
        super().__init__()
        self.down = nn.Linear(hidden_dim, adapter_dim)
        self.up = nn.Linear(adapter_dim, hidden_dim)# give it same input shape as out put shape so it work well with pretrain model
        self.act = nn.ReLU()

    def forward(self, x):
        z = self.down(x)
        z = self.act(z)
        z = self.up(z)
        #res net
        return x + z

In [None]:
import types
#mix adaptor and bert
def add_adapters_to_bert(bert_model, adapter_dim=64):
    # bert_model => BertModel
    for layer in bert_model.encoder.layer:
        hidden_dim = layer.output.dense.out_features# get hidden_dim for adapter
        adapter = Adapter(hidden_dim, adapter_dim)
        layer.adapter = adapter

        old_forward = layer.forward

        def new_forward(self, hidden_states, *args, **kwargs):
            outputs = old_forward(hidden_states, *args, **kwargs)
            # we just want hidden_states here
            last_hidden = outputs[0]
            last_hidden = self.adapter(last_hidden)# add adapter after last_hidden, and get the result of thing go through the adapter

            outputs = (last_hidden,) + outputs[1:]# relace only the last_hidden with last_hidden + adapter
            return outputs

        layer.forward = types.MethodType(new_forward, layer)
    return bert_model

In [None]:
def add_adapters_to_vit(vit_model, adapter_dim=64):

    for layer in vit_model.encoder.layer:

        hidden_dim = layer.output.dense.out_features
        adapter = Adapter(hidden_dim, adapter_dim)
        layer.adapter = adapter

        old_forward = layer.forward

        def new_forward(self, hidden_states, *args, **kwargs):# same as bert

            outputs = old_forward(hidden_states, *args, **kwargs)
            layer_output = outputs[0]  # hidden_states
            # go though adapter
            layer_output = self.adapter(layer_output)
            # put back tuple
            outputs = (layer_output,) + outputs[1:]
            return outputs

        layer.forward = types.MethodType(new_forward, layer)

    return vit_model

In [None]:
from transformers import AutoModel

#  ViT pretrain
image_model_name = "google/vit-base-patch16-224-in21k"
vit_model = AutoModel.from_pretrained(image_model_name)  # ViTModel

#  BERT pretrain
text_model_name = "bert-base-uncased"
bert_model = AutoModel.from_pretrained(text_model_name)  # BertModel

# freeze all pretrain
for param in vit_model.parameters():
    param.requires_grad = False
for param in bert_model.parameters():
    param.requires_grad = False

# add Adapter
vit_model = add_adapters_to_vit(vit_model, adapter_dim=64)
bert_model = add_adapters_to_bert(bert_model, adapter_dim=64)

In [None]:
#get the cls token from vit and bert
class FrozenViTWithAdapter(nn.Module):
    def __init__(self, vit_model):
        super().__init__()
        self.vit = vit_model
    def forward(self, images):

        outputs = self.vit(images)

        cls_emb = outputs.last_hidden_state[:, 0, :]#take the cls
        return cls_emb

class FrozenBertWithAdapter(nn.Module):
    def __init__(self, bert_model):
        super().__init__()
        self.bert = bert_model
    def forward(self, input_ids, attention_mask=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)

        cls_emb = outputs.last_hidden_state[:, 0, :]#take the cls
        return cls_emb

CLIP

In [None]:
class MyCLIP(nn.Module):
    def __init__(self, vit_model, bert_model, embed_dim=256):
        super().__init__()
        self.image_encoder = FrozenViTWithAdapter(vit_model)
        self.text_encoder = FrozenBertWithAdapter(bert_model)

        # linear project for each model
        self.img_proj = nn.Linear(768, embed_dim)
        self.txt_proj = nn.Linear(768, embed_dim)

        # logit_scale
        self.logit_scale = nn.Parameter(torch.ones([]) * 1.0)

    def forward(self, images, input_ids, attention_mask=None):
        img_emb = self.image_encoder(images)           # [B, 768]
        txt_emb = self.text_encoder(input_ids, attention_mask)  # [B, 768]

        # linear project
        img_emb = self.img_proj(img_emb)  # [B, embed_dim]
        txt_emb = self.txt_proj(txt_emb)  # [B, embed_dim]

        # L2 normalize
        img_emb = img_emb / img_emb.norm(dim=-1, keepdim=True)
        txt_emb = txt_emb / txt_emb.norm(dim=-1, keepdim=True)

        return img_emb, txt_emb, self.logit_scale

Contrastive learning

In [None]:
def clip_contrastive_loss(img_emb, txt_emb, logit_scale):
    """
    img_emb: [B, embed_dim]
    txt_emb: [B, embed_dim]
    """
    batch_size = img_emb.size(0)
    # marix: [B, B]
    sim_matrix = img_emb @ txt_emb.t()  # 如果已经 L2 norm，那就是 cosine
    # 缩放
    sim_matrix = logit_scale.exp() * sim_matrix

    labels = torch.arange(batch_size, device=img_emb.device)

    import torch.nn.functional as F
    loss_img = F.cross_entropy(sim_matrix, labels)
    loss_txt = F.cross_entropy(sim_matrix.t(), labels)
    loss = (loss_img + loss_txt) / 2.0
    return loss


# training

In [None]:
def train_one_epoch(model, dataloader, optimizer, device="cuda"):
    model.train()
    total_loss = 0
    for batch_idx, (images, text_enc) in enumerate(dataloader):
        images = images.to(device)
        input_ids = text_enc["input_ids"].to(device)
        attention_mask = text_enc["attention_mask"].to(device)

        img_emb, txt_emb, logit_scale = model(images, input_ids, attention_mask)
        loss = clip_contrastive_loss(img_emb, txt_emb, logit_scale)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = MyCLIP(vit_model, bert_model, embed_dim=256).to(device)

# just train  adapter + projection + temprature
trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(trainable_params, lr=1e-4)

print("trainable para:", sum(p.numel() for p in trainable_params))

for epoch in range(10):
    avg_loss = train_one_epoch(model, dataloader, optimizer, device)
    print(f"Epoch {epoch} - loss = {avg_loss:.4f}")


可训练参数量: 2772993
Epoch 0 - loss = 3.4731
Epoch 1 - loss = 3.2319
Epoch 2 - loss = 3.0959
Epoch 3 - loss = 2.9791
Epoch 4 - loss = 2.8731
Epoch 5 - loss = 2.7773
Epoch 6 - loss = 2.6791
Epoch 7 - loss = 2.5935
Epoch 8 - loss = 2.5053
Epoch 9 - loss = 2.4209
