In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("adityajn105/flickr8k")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/adityajn105/flickr8k?dataset_version_number=1...


100%|██████████| 1.04G/1.04G [00:36<00:00, 30.9MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1


In [2]:

import os

imgs_path = os.path.join(path, "Images")
captions_path = os.path.join(path, "captions.txt")

In [3]:
os.listdir(imgs_path)[:5]

['2130986011_47cb05c8c9.jpg',
 '3182161610_4d349b257f.jpg',
 '263854883_0f320c1562.jpg',
 '2689163361_4939875be5.jpg',
 '306318683_5f1f875191.jpg']

In [19]:
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import torch.nn as nn


In [5]:

class Flickr8kDataset(Dataset):
    def __init__(self, images_dir, captions_file, vocab=None, transform=None, max_length=None):
        """
        images_dir: directory with Flickr8k images
        captions_file: path to file containing image filenames and captions
        vocab: optional pre-built vocabulary (a dict mapping token to index)
        transform: image transformations (resize, normalize, etc.)
        max_length: max caption length (including <bos> and <eos>) for padding
        """
        self.images_dir = images_dir
        self.transform = transform if transform is not None else transforms.Compose([
            transforms.Resize((224, 224)),                     # resize image
            transforms.ToTensor(),                             # convert to tensor [C x H x W]
            transforms.Normalize(mean=[0.485, 0.456, 0.406],   # normalize to ImageNet mean/std
                                 std=[0.229, 0.224, 0.225])
        ])

        self.captions = []
        self.image_files = []
        self.vocab = {"<pad>": 0, "<bos>": 1, "<eos>": 2, "<unk>": 3}

        with open(captions_file, 'r') as file:
            next(file)
            for line in file:
                cleaned_line = line.strip() # .strip() removes leading/trailing whitespace
                line_list = cleaned_line.split('.jpg,')
                img_name, caption = line_list[0]+".jpg", line_list[1]

                tokens = ["<bos>"] + caption.lower().split() + ["<eos>"]
                for token in tokens:
                    if token not in self.vocab:
                        self.vocab[token] = len(self.vocab)

                self.captions.append(tokens)
                self.image_files.append(img_name)

        self.rev_vocab = {idx: tok for tok, idx in self.vocab.items()}
        self.max_length = max(len(tokens) for tokens in self.captions)
        assert(len(self.captions) == len(self.image_files))


    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        img_path = os.path.join(self.images_dir, self.image_files[idx])
        image = Image.open(img_path).convert("RGB")
        image_tensor = self.transform(image)

        # Convert caption tokens to indices and pad
        token_list = self.captions[idx]
        cap_indices = [self.vocab.get(tok, self.vocab.get("<unk>")) for tok in token_list]
        caption_tensor = torch.tensor(cap_indices, dtype=torch.long)
        return image_tensor, caption_tensor, token_list


In [6]:
data = Flickr8kDataset(imgs_path, captions_path)
max_length = data.max_length

40

In [None]:
import math
import timm

class ImageCaptioningModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=384, decoder_dim=384, num_decoder_layers=6, n_heads=8, ff_dim=1024, dropout=0.1, max_length):
        super(ImageCaptioningModel, self).__init__()

        # small vision vit
        self.encoder = timm.create_model('vit_small_patch16_224', pretrained=True)
        # remove the head
        self.encoder.head = nn.Identity()
        for param in self.encoder.parameters():
            param.requires_grad = False  # freeze
        # when processes this: torch.randn(1, 3, 224, 224) gives torch.Size([1, 197, 384])


        self.embed_dim = embed_dim
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = nn.Embedding(max_length, embed_dim)


        decoder_layer = nn.TransformerDecoderLayer(d_model=decoder_dim, nhead=n_heads, dim_feedforward=ff_dim, dropout=dropout, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
        self.output_proj = nn.Linear(decoder_dim, vocab_size)


    def forward(self, images, captions):
        """
        images: batch of images as tensors [batch, 3, 224, 224]
        captions: batch of caption token indices [batch, seq_len]
        Returns: logits over vocabulary for each time step [batch, seq_len, vocab_size]
        """
        batch_size, seq_len = captions.size(0), captions.size(1)
        # Encode images with ViT
        image_features = self.encoder(images)
        # image_features shape: [batch, num_image_tokens, embed_dim]
        # If the encoder returns a tuple (some ViT models do), get the last_hidden_state
        if isinstance(image_features, tuple):
            image_features = image_features[0]
        # Prepare caption embeddings (shifted right for input with <bos>)
        # During training, we feed the entire caption including <bos> token. The target is shifted by one (<bos> predicts first word, etc.).
        # We assume captions already include <bos> at index 0.
        caption_embeddings = self.token_embedding(captions)  # [batch, seq_len, embed_dim]
        # Add positional encoding to caption embeddings
        # Create position indices for seq_len (0,1,...seq_len-1) and broadcast to batch
        positions = torch.arange(0, seq_len, device=captions.device).unsqueeze(0)  # shape [1, seq_len]
        pos_embeds = self.pos_embedding(positions)  # [1, seq_len, embed_dim]
        caption_embeddings = caption_embeddings + pos_embeds  # [batch, seq_len, embed_dim]
        # Generate causal mask for decoder (prevent it from looking ahead beyond current position)
        # shape: [seq_len, seq_len]
        device = captions.device
        tgt_mask = torch.triu(torch.ones(seq_len, seq_len, device=device) * float('-inf'), diagonal=1)
        # Generate key padding mask for captions (to ignore <pad> tokens in self-attention and cross-attention)
        # This mask is boolean of shape [batch, seq_len], True for positions to mask (pads)
        pad_token_id = 0  # we set <pad>=0 in vocab
        tgt_key_padding_mask = (captions == pad_token_id)  # bool tensor
        # We also generate key padding mask for image features if needed (not strictly necessary if all images produce same number of tokens)
        # ViT outputs a fixed-length sequence for all images (197 tokens for 224x224 with patch16), so we may not need an image padding mask.
        # If using variable image sizes or an adaptive encoder, we would mask non-existing tokens.
        # For completeness, we'll create one assuming all image feature sequences are full length (no padding needed):
        memory_key_padding_mask = None  # not required for ViT with fixed-length output
        # Decode: the Transformer decoder will attend to its own past (with mask) and the image features (memory)
        decoder_out = self.decoder(tgt=caption_embeddings, memory=image_features, tgt_mask=tgt_mask,
                                   tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
        # decoder_out: [batch, seq_len, decoder_dim]
        outputs = self.output_proj(decoder_out)  # [batch, seq_len, vocab_size]
        return outputs


In [8]:
import timm

a = timm.create_model('vit_small_patch16_224', pretrained=True)
a

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/88.2M [00:00<?, ?B/s]

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [20]:
a.head = nn.Identity()

In [21]:
a(torch.randn(1, 3, 224, 224)).shape

torch.Size([1, 197, 384])