In [None]:
!pip install -u timm
!pip install -u albumentations
!pip install -U torchtext
!pip install --no-cache-dir torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 torchtext==0.18.0 --index-url https://download.pytorch.org/whl/cu121

In [None]:
import torch
import torch.nn as nn
import albumentations as alb
from albumentations.pytorch import ToTensorV2
import timm
import cv2
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import torchtext; torchtext.disable_torchtext_deprecation_warning()
from collections import Counter
from torchtext.data.utils import get_tokenizer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#Hyperparameters
class ImageCaptioner(nn.Module):
  def __init__(self, context_lenght, vocabulary_size, num_blocks, model_dim, num_heads, prob):
    super().__init__()
    self.cnn_encoder = timm.create_model('efficientnet_b0', pretrained=True)
    test_image = torch.zeros(1, 3, 224, 224)

    with torch.no_grad():
      cnn_output = self.cnn_encoder(test_image)
    in_features = cnn_output.shape[1]
    self.project = nn.Linear(in_features, model_dim)

    self.word_embeddings = nn.Embedding(vocabulary_size, model_dim)
    self.position_embeddings = nn.Embedding(context_lenght, model_dim)

    block = nn.TransformerDecoderLayer(model_dim, num_heads, 2 * model_dim, dropout=prob, batch_first=True, norm_first=True)
    self.blocks = nn.TransformerDecoder(block, num_blocks)

    self.vocab_projection = nn.Linear(model_dim, vocabulary_size)


  def forward(self, image, true_labels):
    token_embedded = self.word_embeddings(true_labels)
    B, T = true_labels.shape
    positions = torch.arange(T).to(device)
    position_embedded = self.position_embeddings(positions)
    total_embeddings = token_embedded + position_embedded

    with torch.no_grad():
      encoded_image = self.project(self.cnn_encoder(image).view(B, -1))

    img_for_attention = torch.unsqueeze(encoded_image, 1)

    attention_mask = nn.Transformer.generate_square_subsequent_mask(T).to(device)
    block_output = self.blocks(total_embeddings, img_for_attention, tgt_mask=attention_mask)

    vocabulary_vector = self.vocab_projection(block_output)

    return vocabulary_vector

In [None]:
from google.colab import drive
from google.colab.patches import cv2_imshow
from torchtext.vocab import vocab

drive.mount('/content/drive')
%cd "/content/drive/MyDrive/flick8r"

caption_filename = 'captions.txt'
missing = "2258277193_586949ec62.jpg"

with open(caption_filename) as captions:
  lines = captions.readlines()

get_captions = {}
all_captions = [] #5 * len(get_captions)

for caption in lines:
  data = caption.rstrip('\n').split('.jpg,')
  img_name = data[0] + '.jpg'
  if img_name == missing:
    continue

  caption_list = get_captions.get(img_name, [])
  caption_list.append(data[1])
  get_captions[img_name] = caption_list
  all_captions.append(data[1])

df = pd.DataFrame(columns=['filename', 'caption'])
df['filename'] = get_captions.keys()
df['caption'] = df['filename'].map(lambda filename: get_captions[filename])

vocab_frequency = Counter()
word_tokenizer = get_tokenizer('basic_english')

for cap in all_captions:
  vocab_frequency.update(word_tokenizer(cap))

vocabulary = torchtext.vocab.vocab(vocab_frequency)
vocabulary.insert_token('<UNKNOWN>', 0)
vocabulary.insert_token('<PAD>', 1)
vocabulary.insert_token('<START>', 2)
vocabulary.insert_token('<END>', 3)
vocabulary.set_default_index(0)

context_lenght = 20

class ImageCaptioningDataset(Dataset):
  def __init__(self, split):
    self.df = df
    self.img_size = 224
    transformation_list = [alb.Resize(self.img_size, self.img_size)]
    if split == 'training':
      transformation_list.append(alb.HorizontalFlip())
      transformation_list.append(alb.ColorJitter())
    transformation_list.append(alb.Normalize())
    transformation_list.append(ToTensorV2())

    self.transformations = alb.Compose(transformation_list)

  def __len__(self):
    return len(self.df)

  def __getitem__(self, idx):
    image_filename, captions = self.df.iloc[idx]
    actual_image = cv2.cvtColor(cv2.imread('Images/' + image_filename), cv2.COLOR_BGR2RGB)
    cv2_imshow(actual_image)
    transformed_img = self.transformations(image=actual_image)['image']

    encoded_captions = [] # Initialize the list here
    for i, cap in enumerate(captions):
      splitted = word_tokenizer(cap)

      integers = [vocabulary[word] for word in splitted]
      integers = [2] + integers + [3]

      if len(integers) <= context_lenght:
        pads_to_add = context_lenght - len(integers)
        integers += [1] * pads_to_add
      else:
        integers = integers[:context_lenght - 1] + [3]

      encoded_captions.append(torch.tensor(integers, dtype=torch.long))

    random_idx = torch.randint(5, (1,)).item()
    return transformed_img, encoded_captions[random_idx]

training_dataset = ImageCaptioningDataset('training')
training_data = DataLoader(training_dataset, batch_size=1, shuffle=True)

In [None]:
context_lenght = 20
vocabulary_size = len(vocabulary)
num_blocks = 6
model_dim = 512
num_heads = 16
prob = 0.5

model = ImageCaptioner(context_lenght, vocabulary_size, num_blocks, model_dim, num_heads, prob).to(device)

for layer in model.cnn_encoder.parameters():
  layer.requires_grad = False

loss_function = nn.CrossEntropyLoss(ignore_index=vocabulary['<PAD>'])
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

In [None]:
num_epochs = 20
num_iterations = 0

for epoch in range(num_epochs):
  for images, captions in training_data:
    images, captions = images.to(device), captions.to(device)
    B, T = captions.shape
    model_prediction = model(images, captions)
    model_prediction = model_prediction.view(B * T, vocabulary_size)
    loss = loss_function(model_prediction, captions.view(B*T))
    optimizer.zero_grad()
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 2.0)
    optimizer.step()
    if num_iterations % 100 == 0:
      print(loss.item())
    num_iterations += 1

In [None]:
integer_to_word = vocabulary.get_itos()
num_epochs = 1
num_iterations = 0

for epoch in range(num_epochs):
  for x, y in training_data:
    x, y = x.to(device), y.to(device)
    prediction = model(x, y)
    _, indices = torch.max(prediction, dim = -1)
    first_caption = indices[0]
    sentence = []
    for id in first_caption:
      sentence.append(integer_to_word[id])
      if id == 3:
        break
    print(' '.join(sentence))
    B, T, vocabulary_size = prediction.shape
    prediction = prediction.view(B * T, vocabulary_size)
    optimizer.zero_grad()
    loss = loss_function(prediction, y.view(B*T))
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 2.0)
    num_iterations += 1