<a href="https://colab.research.google.com/github/indrad123/imagecaptioning/blob/main/fin_blip_image_caption.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets
!pip install pandas
!pip install scikit-learn
!pip install torch
!pip install torch
!pip install transformers
!pip install tqdm
!pip install pillow
!pip install torchvision
!pip install h5py
!pip install datasets pandas scikit-learn torch transformers tqdm pillow torchvision h5py


In [None]:
from datasets import load_dataset
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoProcessor, BlipForConditionalGeneration
from tqdm import tqdm
import os
import io
from PIL import Image
from torchvision import transforms
import random
import h5py
from collections import defaultdict



In [None]:
# Load dataset
dataset = load_dataset("indrad123/flickr30k-transformed-captions-indonesia")

# Prepare the dataset
data = pd.DataFrame({
    "image": dataset["test"]["image"],
    "caption": dataset["test"]["original_alt_text_id"]
})

# Split the dataset
train_data, val_data = train_test_split(data, test_size=0.05, random_state=42)

# Define captions field
class Vocab:
    pass

captions = Vocab()
captions.itos = ['<pad>', '<start>', '<end>', '<unk>']  # Initial tokens

# Build vocabulary
all_captions = train_data['caption'].tolist()
all_tokens = [[w.lower() for w in c.split()] for c in all_captions]
all_tokens = [w for sublist in all_tokens for w in sublist]
captions.itos.extend(list(set(all_tokens)))
captions.stoi = defaultdict(lambda: captions.itos.index('<unk>'), {s: i for i, s in enumerate(captions.itos)})

# Custom dataset class
class CaptioningData(Dataset):
    def __init__(self, df, vocab, transform=None):
        self.df = df.reset_index(drop=True)
        self.vocab = vocab
        self.transform = transform if transform else transforms.Compose([
            transforms.Resize(224),
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

    def __getitem__(self, index):
        row = self.df.iloc[index].squeeze()
        image = Image.open(io.BytesIO(row.image['bytes'])).convert('RGB')
        image = self.transform(image)
        caption = row.caption.lower().split()
        target = [self.vocab.stoi['<start>']] + [self.vocab.stoi[token] for token in caption] + [self.vocab.stoi['<end>']]
        target = torch.Tensor(target).long()
        return image, target

    def __len__(self):
        return len(self.df)

    def collate_fn(self, data):
        data.sort(key=lambda x: len(x[1]), reverse=True)
        images, targets = zip(*data)
        images = torch.stack(images, 0)
        lengths = [len(tar) for tar in targets]
        padded_targets = torch.zeros(len(targets), max(lengths)).long()
        for i, tar in enumerate(targets):
            end = lengths[i]
            padded_targets[i, :end] = tar[:end]
        return images.to(device), padded_targets.to(device), torch.tensor(lengths).long().to(device)



In [None]:
# Initialize the processor and model
processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

# Create the dataset and dataloader
train_dataset = CaptioningData(train_data, captions)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2, collate_fn=train_dataset.collate_fn)

val_dataset = CaptioningData(val_data, captions)
val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=2, collate_fn=val_dataset.collate_fn)

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# Move model to the appropriate device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Checkpoint directory
checkpoint_dir = "/content/drive/MyDrive/BlipModel/model"
os.makedirs(checkpoint_dir, exist_ok=True)

# Load from checkpoint if exists
start_epoch = 0
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pth")
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming training from epoch {start_epoch}")

# Training loop with checkpoint saving
model.train()
for epoch in range(start_epoch, 50):
    print("Epoch:", epoch)
    for idx, batch in enumerate(train_dataloader):
        input_ids = batch.pop("input_ids").to(device)
        pixel_values = batch.pop("pixel_values").to(device)

        outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=input_ids)
        loss = outputs.loss

        print("Loss:", loss.item())

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    # Save checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, checkpoint_path)

# Save the final model
final_model_path = os.path.join(checkpoint_dir, "final_model")
model.save_pretrained(final_model_path)
processor.save_pretrained(final_model_path)

# Example to load and use the model
loaded_model = BlipForConditionalGeneration.from_pretrained(final_model_path).to(device)
loaded_processor = AutoProcessor.from_pretrained(final_model_path)

# Generate captions for examples
example = dataset["test"][0]
image = Image.open(io.BytesIO(example["image"]["bytes"])).convert('RGB')
inputs = loaded_processor(images=image, return_tensors="pt").to(device)
pixel_values = inputs.pixel_values

generated_ids = loaded_model.generate(pixel_values=pixel_values, max_length=50)
generated_caption = loaded_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_caption)

# Visualization of generated captions
fig = plt.figure(figsize=(18, 14))

for i, example in enumerate(dataset["test"][:6]):
    image = Image.open(io.BytesIO(example["image"]["bytes"])).convert('RGB')
    inputs = loaded_processor(images=image, return_tensors="pt").to(device)
    pixel_values = inputs.pixel_values

    generated_ids = loaded_model.generate(pixel_values=pixel_values, max_length=50)
    generated_caption = loaded_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    fig.add_subplot(2, 3, i+1)
    plt.imshow(image)
    plt.axis("off")
    plt.title(f"Generated caption: {generated_caption}")

plt.show()
