In [None]:
from datasets import load_dataset, Dataset

t_train_dataset = load_dataset("arampacha/rsicd", split="train")
t_valid_dataset = load_dataset("arampacha/rsicd", split="valid")

new_images = []
new_captions = []

# Iterate through the original dataset and duplicate rows for each caption
for row in t_train_dataset:
    image = row['image']
    captions = row['captions']
    
    for caption in captions:
        new_images.append(image)
        new_captions.append(caption)

# Create a new dataset with the modified data
train_dataset = Dataset.from_dict({'image': new_images, 'captions': new_captions})

new_images = []
new_captions = []
for row in t_valid_dataset:
    image = row['image']
    captions = row['captions']
    
    for caption in captions:
        new_images.append(image)
        new_captions.append(caption)

# Create a new dataset with the modified data
valid_dataset = Dataset.from_dict({'image': new_images, 'captions': new_captions})

In [None]:
# print(valid_dataset)

In [None]:
import os
import datasets
import torch
from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor,AutoTokenizer
os.environ["WANDB_DISABLED"] = "true"

In [None]:
import nltk
try:
    nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
    nltk.download("punkt", quiet=True)

In [None]:
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor

model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

In [None]:
# image feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
# text tokenizer
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

In [None]:
class ImageCapationingDataset(torch.utils.data.Dataset):
    def __init__(self, ds, max_target_length):
        self.ds = ds
        self.max_target_length = max_target_length

    def __getitem__(self, idx):
        model_inputs = {}
        image = self.ds[idx]["image"]
        image_encoded = feature_extractor(images=image, return_tensors="np").pixel_values[0]
        labels = tokenizer(self.ds[idx]["captions"], 
                  padding="max_length", truncation=True,
                  max_length=self.max_target_length).input_ids
        # This contains image path column
        model_inputs['labels'] = labels
        model_inputs['pixel_values'] = image_encoded

        return model_inputs

    def __len__(self):
        return len(self.ds)

In [None]:
train_ds = ImageCapationingDataset(train_dataset, 64)
eval_ds = ImageCapationingDataset(valid_dataset, 64)

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    output_dir="./image-captioning-output",
    save_total_limit=5,  # You can adjust this based on your requirements
    learning_rate=5e-5,  # Adjust as needed
    lr_scheduler_type="reduce_lr_on_plateau",  # Specify the scheduler type
    save_strategy="epoch",  # You can adjust this based on your requirements
    num_train_epochs=5,
)

In [None]:
!pip install -q evaluate rouge_score

In [None]:
import evaluate
metric = evaluate.load("rouge")

In [None]:
import numpy as np

ignore_pad_token_for_loss = True


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    if ignore_pad_token_for_loss:
        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds,
                                                     decoded_labels)

    result = metric.compute(predictions=decoded_preds,
                            references=decoded_labels,
                            use_stemmer=True)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    return result

In [None]:
from transformers import AdamW, get_scheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Define the AdamW optimizer
optimizer = AdamW(model.parameters(), lr=training_args.learning_rate)

# Define the ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.1,
    patience=3,
    verbose=True
)

In [None]:
from transformers import default_data_collator
# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=default_data_collator,
    optimizers=(optimizer, scheduler),
)

In [None]:
trainer.train()

In [None]:
from matplotlib import pyplot as plt
from nltk.translate.bleu_score import sentence_bleu

res = []
max_length = 64
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
for i, example in enumerate(t_valid_dataset):
    reference = [ caption.split() for caption in example["captions"] ] 
    image = example["image"]
    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to("cuda")

    output_ids = model.generate(pixel_values, **gen_kwargs)

    preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    preds = [pred.strip() for pred in preds][0]
    res.append(sentence_bleu(reference, preds.split()))

print(sum(res) / len(res))