## Install Depedencies

In [7]:
!pip install transformers torchvision datasets pillow evaluate rouge_score pycocoevalcap bert_score --quiet

In [2]:
import os
import random
from glob import glob
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as T


from transformers import (
    CLIPVisionModel, CLIPProcessor,
    OPTForCausalLM, AutoTokenizer,
    get_linear_schedule_with_warmup
)



############### Flickr 8K Class for All Captions ####################


class Flickr8kDataset(Dataset):
    def __init__(self, images_dir, captions_file, tokenizer,
                 prefix="a photo of", max_length=40, image_size=224):
        self.images_dir = images_dir
        self.tokenizer  = tokenizer
        self.prefix     = prefix
        self.max_length = max_length

        self.samples = []  
        with open(captions_file, 'r') as f:
            for line in f:
                key, cap = line.strip().split('\t', 1)
                filename = key.split('#')[0]
                self.samples.append((filename, cap))

        # Precompute the image transform
        self.transform = T.Compose([
            T.Resize((image_size, image_size)),
            T.ToTensor(),
            T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                        std=(0.26862954, 0.26130258, 0.27577711)),
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        filename, caption = self.samples[idx]
        img_path     = os.path.join(self.images_dir, filename)
        image        = Image.open(img_path).convert('RGB')
        pixel_values = self.transform(image)  # (3, H, W)

        full = f"{self.prefix} {caption}"
        toks = self.tokenizer(
            full,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        input_ids      = toks.input_ids.squeeze(0)      # (T,)
        attention_mask = toks.attention_mask.squeeze(0) # (T,)

        labels = input_ids.clone()
        prefix_ids = self.tokenizer(
            self.prefix,
            add_special_tokens=False
        )['input_ids']
        labels[:len(prefix_ids)] = -100

        return pixel_values, input_ids, attention_mask, labels



class QFormer(nn.Module):
    def __init__(self, image_feat_dim=768, query_dim=768,
                 num_queries=32, num_heads=8, num_layers=4):
        super().__init__()
        self.num_queries = num_queries
        self.query_dim = query_dim
        self.queries = nn.Parameter(torch.randn(num_queries, query_dim))
        self.cross_attn = nn.ModuleList([
            nn.MultiheadAttention(query_dim, num_heads)
            for _ in range(num_layers)
        ])
        self.self_attn = nn.ModuleList([
            nn.MultiheadAttention(query_dim, num_heads)
            for _ in range(num_layers)
        ])
        self.proj = nn.Linear(image_feat_dim, query_dim)
        self.ffns = nn.ModuleList([
            nn.Sequential(
                nn.Linear(query_dim, query_dim * 4),
                nn.GELU(),
                nn.Linear(query_dim * 4, query_dim)
            )
            for _ in range(num_layers)
        ])
        self.norm1 = nn.ModuleList([nn.LayerNorm(query_dim) for _ in range(num_layers)])
        self.norm2 = nn.ModuleList([nn.LayerNorm(query_dim) for _ in range(num_layers)])
        self.norm3 = nn.ModuleList([nn.LayerNorm(query_dim) for _ in range(num_layers)])

    def forward(self, image_feats):
        """
        image_feats: (B, S, D_img)
        returns: (B, num_queries, query_dim)
        """
        B, S, D = image_feats.size()
        proj_feats = self.proj(image_feats)          # (B, S, query_dim)
        proj_feats = proj_feats.permute(1, 0, 2)      # (S, B, Q)
        q = self.queries.unsqueeze(1).repeat(1, B, 1) # (num_q, B, Q)
        for i in range(len(self.cross_attn)):
            q2, _ = self.cross_attn[i](
                query=q,
                key=proj_feats,
                value=proj_feats
            )
            q = self.norm1[i](q + q2)
            q2, _ = self.self_attn[i](
                query=q, key=q, value=q
            )
            q = self.norm2[i](q + q2)
            q2 = self.ffns[i](q)
            q = self.norm3[i](q + q2)
        q = q.permute(1, 0, 2)  # (B, num_q, Q)
        return q


class BLIP2Captioning(nn.Module):
    def __init__(self,
                 vision_model_name="openai/clip-vit-base-patch32",
                 llm_model_name="facebook/opt-125m",
                 num_queries=32,
                 query_dim=768):
        super().__init__()
        self.num_queries = num_queries

        self.vision = CLIPVisionModel.from_pretrained(vision_model_name)
        for p in self.vision.parameters():
            p.requires_grad = False

        img_feat_dim = self.vision.config.hidden_size
        self.qformer = QFormer(
            image_feat_dim=img_feat_dim,
            query_dim=query_dim,
            num_queries=num_queries,
            num_heads=8,
            num_layers=4
        )

        self.llm = OPTForCausalLM.from_pretrained(llm_model_name)
        for p in self.llm.parameters():
            p.requires_grad = False
        self.tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        lm_dim = self.llm.config.hidden_size
        self.vis_proj = nn.Linear(query_dim, lm_dim)

    def forward(self, pixel_values, input_ids, attention_mask, labels=None):
        v = self.vision(pixel_values=pixel_values).last_hidden_state  # (B, S, D_img)

        q = self.qformer(v)                                           # (B, Q, query_dim)

        vis_emb = self.vis_proj(q)                                    # (B, Q, lm_dim)

        tok_emb = self.llm.model.decoder.embed_tokens(input_ids)      # (B, T, lm_dim)

        inputs_embeds = torch.cat([vis_emb, tok_emb], dim=1)          # (B, Q+T, lm_dim)

        vp_mask = inputs_embeds.new_ones((inputs_embeds.size(0), vis_emb.size(1)))
        attn_mask = torch.cat([vp_mask, attention_mask], dim=1)       # (B, Q+T)

        if labels is not None:
            B, T = labels.size()
            Q = self.num_queries
            pad = torch.full((B, Q), -100, device=labels.device, dtype=labels.dtype)
            labels = torch.cat([pad, labels], dim=1)                  # (B, Q+T)

        out = self.llm(
            inputs_embeds=inputs_embeds,
            attention_mask=attn_mask,
            labels=labels,
        )
        return out







  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from torch.utils.data import Subset
device       = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
images_dir = "Dataset/Images"
captions_file= "Dataset/Flickr8k.token.txt"

# hyperparams
batch_size   = 64
lr           = 1e-4
epochs       = 100
max_len      = 40
warmup_steps = 1000


model = BLIP2Captioning().to(device)
tokenizer = model.tokenizer
ds = Flickr8kDataset(images_dir, captions_file,
                      tokenizer,
                      max_length=max_len,
                      image_size=224)
ds_small = Subset(ds, list(range(10000)))
print(len(ds))

def train():
    # paths

    loader = DataLoader(ds_small, batch_size=batch_size, shuffle=True, num_workers=2)

    optimizer = torch.optim.AdamW(model.qformer.parameters(), lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps,
        num_training_steps=epochs * len(loader)
    )

    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for pix, ids, am, labs in loader:
            pix, ids, am, labs = pix.to(device), ids.to(device), am.to(device), labs.to(device)
            out = model(pix, ids, am, labels=labs)
            loss = out.loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.qformer.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs} — avg loss: {total_loss/len(loader):.4f}")


40455


In [None]:
train() ## train the model


In [None]:
torch.save(model.state_dict(), "blip2_flickr8k_10kexamples.pth") ## save the model



## Inference


In [4]:


def load_model(checkpoint_path, device):
    model = BLIP2Captioning().to(device)
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.eval()
    return model

def preprocess_image(image_path, image_size, device):
    transform = T.Compose([
        T.Resize((image_size, image_size)),
        T.ToTensor(),
        T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                    std=(0.26862954, 0.26130258, 0.27577711)),
    ])
    img = Image.open(image_path).convert('RGB')
    pix = transform(img).unsqueeze(0).to(device)  # (1, 3, H, W)
    return pix

@torch.no_grad()
def generate_caption(model, image_path,
                     device,
                     prefix="a photo of",
                     max_len=40,
                     num_beams=3):
    pix = preprocess_image(image_path, image_size=224, device=device)

    v = model.vision(pixel_values=pix).last_hidden_state      # (1, S, D_img)

    q = model.qformer(v)                                      # (1, Q, query_dim)

    vis_emb = model.vis_proj(q)                               # (1, Q, lm_dim)

    # tok = model.tokenizer(prefix, return_tensors='pt').to(device)
    tok = model.tokenizer(
    prefix,
    return_tensors='pt',
    add_special_tokens=False     
      ).to(device)
    prefix_emb = model.llm.model.decoder.embed_tokens(tok.input_ids)
    inputs_embeds = torch.cat([vis_emb, prefix_emb], dim=1)   # (1, Q+P, lm_dim)

    vp_mask = torch.ones((1, vis_emb.size(1)), device=device)
    attn_mask = torch.cat([vp_mask, tok.attention_mask], dim=1)

    out = model.llm.generate(
        inputs_embeds=inputs_embeds,
        attention_mask=attn_mask,
        # max_length=vis_emb.size(1) + tok.input_ids.size(1) + max_len,
        max_new_tokens=max_len,             # generate up to `max_len` new tokens
        num_beams=num_beams,
        eos_token_id=model.tokenizer.eos_token_id,
        pad_token_id=model.tokenizer.pad_token_id,
        early_stopping=True,
        do_sample = True,
        top_p = 0.9,
        # top_k = 50,
        temperature = 0.7,
    )  


    q, p = vis_emb.size(1), tok.input_ids.size(1)
    gen_ids = out

    caption = model.tokenizer.decode(gen_ids[0], skip_special_tokens=True).strip()

    return caption

# Example usage:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = load_model("blip2_flickr8k_10kexamples.pth", device)

for img_path in [
    "Dataset/Images/1000268201_693b08cb0e.jpg",
    "Dataset/Images/10815824_2997e03d76.jpg",
    "Dataset/Images/3726168984_1fa2c8965b.jpg",
    "Dataset/Images/3744832122_2f4febdff6.jpg",
    "Dataset/Images/3726168984_1fa2c8965b.jpg",
   


]:
    cap = generate_caption(model, img_path, device)
    print(img_path, "→", cap)

Dataset/Images/1000268201_693b08cb0e.jpg → A little girl climbing into a wooden playhouse .
Dataset/Images/10815824_2997e03d76.jpg → A girl and her horse stand by a fire .
Dataset/Images/3726168984_1fa2c8965b.jpg → Two black dogs running in the grass .
Dataset/Images/3744832122_2f4febdff6.jpg → A boy plays a baseball game .
Dataset/Images/3726168984_1fa2c8965b.jpg → Two black dogs running in the grass .


## Evaluation Metrics

In [7]:
import os, random, torch
from collections import defaultdict
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction

import evaluate
from pycocoevalcap.spice.spice import Spice
from pycocoevalcap.cider.cider import Cider
from transformers import CLIPProcessor, CLIPModel


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
smooth = SmoothingFunction().method4

model = load_model("blip2_flickr8k_10kexamples.pth", device)
model.eval()

full_ds = Flickr8kDataset(
    images_dir=images_dir,
    captions_file=captions_file,
    tokenizer=model.tokenizer,
    prefix="a photo of",
    max_length=40,
    image_size=224
)

img2refs = defaultdict(list)
for fname, cap in full_ds.samples:  # full_ds.samples is list of (filename, caption)
    img2refs[fname].append(cap)
unique_fnames = list(img2refs.keys())

n_caption_val = 5000
# each image has 5 captions → how many *unique* images?
n_unique_val = n_caption_val // 5
val_fnames = unique_fnames[-n_unique_val:]

val_pairs = [(fn, random.choice(img2refs[fn])) for fn in val_fnames]

val_loss = 0.0
for fn, gt in val_pairs:
    img_path = os.path.join(images_dir, fn)
    pix = preprocess_image(img_path, image_size=224, device=device)
    full = full_ds.prefix + " " + gt
    toks = model.tokenizer(full,
                           padding='max_length',
                           truncation=True,
                           max_length=full_ds.max_length,
                           return_tensors='pt').to(device)
    input_ids = toks.input_ids
    am        = toks.attention_mask
    labels = input_ids.clone()
    prefix_len = len(model.tokenizer(full_ds.prefix,
                                     add_special_tokens=False)['input_ids'])
    labels[:, :prefix_len] = -100
    out = model(pix, input_ids, am, labels=labels)
    val_loss += out.loss.item()

avg_val_loss = val_loss / len(val_pairs)
print(f"Validation loss over {len(val_pairs)} images: {avg_val_loss:.4f}")

bleu_hf   = evaluate.load("bleu")
meteor    = evaluate.load("meteor")
rouge     = evaluate.load("rouge")
spice     = Spice()
cider     = Cider()
bertscore = evaluate.load("bertscore")
clip_model= CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_proc = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

preds = []
refs = []  
for fn in val_fnames:
    img_path = os.path.join(images_dir, fn)
    pred = generate_caption(
        model, img_path, device,
        prefix=full_ds.prefix,
        max_len=40,
        num_beams=3
    )
    preds.append(pred)
    refs.append(img2refs[fn])

bleu_res   = bleu_hf.compute(predictions=preds, references=refs)
meteor_res = meteor.compute(predictions=preds,
                            references=[r[0] for r in refs])
rouge_res  = rouge.compute(predictions=preds,
                           references=[r[0] for r in refs])
print(f"HF-BLEU-4: {bleu_res['bleu']:.4f}")
print(f"METEOR : {meteor_res['meteor']:.4f}")
print(f"ROUGE-L: {rouge_res['rougeL']:.4f}")

pred_toks = [p.split() for p in preds]
ref_toks  = [[r.split() for r in rs] for rs in refs]
bleu1 = corpus_bleu(ref_toks, pred_toks, weights=(1,0,0,0), smoothing_function=smooth)
bleu2 = corpus_bleu(ref_toks, pred_toks, weights=(0.5,0.5,0,0), smoothing_function=smooth)
bleu3 = corpus_bleu(ref_toks, pred_toks, weights=(0.33,0.33,0.33,0), smoothing_function=smooth)
bleu4 = corpus_bleu(ref_toks, pred_toks, weights=(0.25,0.25,0.25,0.25), smoothing_function=smooth)
print(f"BLEU-1: {bleu1:.4f}  BLEU-2: {bleu2:.4f}")
print(f"BLEU-3: {bleu3:.4f}  BLEU-4: {bleu4:.4f}")

cider_score, _ = cider.compute_score(
    {i: refs[i]    for i in range(len(refs))},
    {i: [preds[i]] for i in range(len(preds))}
)
spice_score, _ = spice.compute_score(
    {i: refs[i]    for i in range(len(refs))},
    {i: [preds[i]] for i in range(len(preds))}
)
print(f"CIDEr : {cider_score:.4f}")
print(f"SPICE : {spice_score:.4f}")

bert_res = bertscore.compute(
    predictions=preds,
    references=[r[0] for r in refs],
    model_type="bert-base-uncased",
    device=device
)
f1 = sum(bert_res["f1"]) / len(bert_res["f1"])
print(f"BERTScore-F1: {f1:.4f}")

all_sims = []
batch_size = 32
for i in range(0, len(val_fnames), batch_size):
    chunk = val_fnames[i : i + batch_size]
    imgs  = [Image.open(os.path.join(images_dir, fn)).convert("RGB") for fn in chunk]
    txts  = [preds[val_fnames.index(fn)] for fn in chunk]
    inputs = clip_proc(text=txts, images=imgs, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        iv = clip_model.get_image_features(pixel_values=inputs.pixel_values)
        tv = clip_model.get_text_features(input_ids=inputs.input_ids,
                                          attention_mask=inputs.attention_mask)
    iv = iv / iv.norm(dim=-1, keepdim=True)
    tv = tv / tv.norm(dim=-1, keepdim=True)
    sims = (iv * tv).sum(dim=-1)
    all_sims.append(sims.cpu())
clip_score = torch.cat(all_sims).mean().item()
print(f"CLIPScore: {clip_score:.4f}")

Validation loss over 1000 images: 1.2649


Downloading builder script: 100%|██████████| 5.94k/5.94k [00:00<00:00, 778kB/s]
Downloading extra modules: 4.07kB [00:00, 1.61MB/s]                   
Downloading extra modules: 100%|██████████| 3.34k/3.34k [00:00<?, ?B/s]
Downloading builder script: 100%|██████████| 7.02k/7.02k [00:00<?, ?B/s]
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\Paracha\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\Paracha\AppData\Roaming\nltk_data...
[nltk_data]   Unzipping tokenizers\punkt_tab.zip.
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\Paracha\AppData\Roaming\nltk_data...
Downloading builder script: 100%|██████████| 6.27k/6.27k [00:00<00:00, 6.29MB/s]


Downloading stanford-corenlp-3.6.0 for SPICE ...
Progress: 384.5M / 384.5M (100.0%)
Extracting stanford-corenlp-3.6.0 ...
Done.


Downloading builder script: 100%|██████████| 7.95k/7.95k [00:00<?, ?B/s]
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


HF-BLEU-4: 0.1809
METEOR : 0.3031
ROUGE-L: 0.3446
BLEU-1: 0.6378  BLEU-2: 0.4274
BLEU-3: 0.2840  BLEU-4: 0.1809


CalledProcessError: Command '['java', '-jar', '-Xmx8G', 'spice-1.0.jar', 'c:\\Users\\Paracha\\miniconda3\\Lib\\site-packages\\pycocoevalcap\\spice\\tmp\\tmp22nzhjrd', '-cache', 'c:\\Users\\Paracha\\miniconda3\\Lib\\site-packages\\pycocoevalcap\\spice\\cache', '-out', 'c:\\Users\\Paracha\\miniconda3\\Lib\\site-packages\\pycocoevalcap\\spice\\tmp\\tmp9aier1tg', '-subset', '-silent']' returned non-zero exit status 1.