In [None]:
!pip install datasets
!pip install transformers
!pip install tensorflow==2.15
!pip install evaluate
!pip install rouge-score
!pip install accelerate
!pip install transformers[torch]

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import datasets
from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor,AutoTokenizer
os.environ["WANDB_DISABLED"] = "true"

In [None]:
import nltk
try:
    nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
    nltk.download("punkt", quiet=True)

In [None]:
image_encoder_model = "google/vit-base-patch16-224-in21k"
text_decode_model = "gpt2"

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    image_encoder_model, text_decode_model)

In [None]:
feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder_model)
tokenizer = AutoTokenizer.from_pretrained(text_decode_model)

In [None]:
tokenizer.pad_token = tokenizer.eos_token
model.config.eos_token_id = tokenizer.eos_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

In [None]:
output_dir = "/content"
model.save_pretrained(output_dir)
feature_extractor.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

In [None]:
from datasets import load_from_disk

import zipfile

zip_file_path = 'link to processed_dataset.zip'
extracted_folder_path = 'make a new dir'

with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(extracted_folder_path)

processed_dataset = load_from_disk(extracted_folder_path)
processed_dataset

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

os.mkdir('new dir path')
training_dir = "checkpoints dir"

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="epoch",
    save_strategy = "epoch",
    num_train_epochs = 1,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    output_dir=training_dir,
)

In [None]:
import evaluate
metric = evaluate.load("rouge")

In [None]:
import numpy as np

ignore_pad_token_for_loss = True


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    if ignore_pad_token_for_loss:
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds,
                                                     decoded_labels)

    result = metric.compute(predictions=decoded_preds,
                            references=decoded_labels,
                            use_stemmer=True)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    return result

In [None]:
from transformers import default_data_collator

trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=processed_dataset['train'],
    eval_dataset=processed_dataset['validation'],
    data_collator=default_data_collator
)

In [None]:
trainer.train(resume_from_checkpoint = True)

In [None]:
trainer.save_model(training_dir)

In [None]:
tokenizer.save_pretrained(training_dir)

In [None]:
#to save checkpoint in drive

import shutil
shutil.move("dir", "dir")

In [None]:
#to load model from checkpoint for inference

from transformers import AutoConfig
import torch
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig


config = VisionEncoderDecoderConfig.from_pretrained('latest checkpoint')
model = VisionEncoderDecoderModel.from_pretrained('latest checkpoint', config=config)



In [None]:
from transformers import pipeline
image_captioner = pipeline("image-to-text", model=model, tokenizer=tokenizer, feature_extractor = feature_extractor)

In [None]:
import os
import matplotlib.pyplot as plt

def display_image_and_caption(image_path):
    image = plt.imread(image_path)
    plt.imshow(image)
    plt.axis('off')
    plt.show()

    caption = image_captioner(image_path)
    print("Caption:", caption)

content_dir = "test images dir"
files = os.listdir(content_dir)

image_files = [f for f in files]

for image_file in image_files:
    image_path = os.path.join(content_dir, image_file)
    display_image_and_caption(image_path)
