In [None]:
import os
import requests
from PIL import Image
import torch
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorWithPadding
from peft import LoraConfig, get_peft_model
from datasets import Dataset
import pandas as pd
import matplotlib.pyplot as plt
from kaggle.api.kaggle_api_extended import KaggleApi

# Set device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Download Flickr8k dataset from Kaggle
def download_flickr8k():
    print("Authenticating with Kaggle API...")
    api = KaggleApi()
    api.authenticate()
    dataset_path = 'flickr8k_download'
    if not os.path.exists(dataset_path):
        os.makedirs(dataset_path)
        print("Downloading Flickr8k dataset...")
        api.dataset_download_files('adityajn105/flickr8k', path=dataset_path, unzip=True)
    else:
        print("Flickr8k dataset already downloaded.")
    return dataset_path

dataset_path = download_flickr8k()
images_dir = os.path.join(dataset_path, 'Images')
captions_file = os.path.join(dataset_path, 'captions.txt')

# Load captions
print("Loading captions...")
captions = pd.read_csv(captions_file, delimiter=',', header=0, names=['image', 'caption'])
print("Captions loaded successfully.")

# Load pre-trained model, feature extractor, and tokenizer
print("Loading pre-trained model, feature extractor, and tokenizer...")
model_name = "nlpconnect/vit-gpt2-image-captioning"
model = VisionEncoderDecoderModel.from_pretrained(model_name).to(device)
feature_extractor = ViTImageProcessor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Model, feature extractor, and tokenizer loaded.")

# Prepare the dataset
def preprocess_function(examples):
    images = [Image.open(os.path.join(images_dir, img)).convert("RGB") for img in examples["image"]]
    pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
    labels = tokenizer(examples["caption"], padding="max_length", truncation=True, return_tensors="pt").input_ids
    return {"pixel_values": [x for x in pixel_values], "labels": [x for x in labels]}

print("Creating dataset from captions...")
dataset = Dataset.from_pandas(captions)
print("Dataset created. Preprocessing dataset...")
dataset = dataset.map(preprocess_function, batched=True, remove_columns=["image", "caption"])
print("Dataset preprocessed.")

# Split dataset into train and validation sets
print("Splitting dataset into train and validation sets...")
train_test_split = dataset.train_test_split(test_size=0.1)
train_dataset = train_test_split['train']
val_dataset = train_test_split['test']
print("Dataset split complete.")

# Define the LoRA configuration based on inspected layers
print("Setting up LoRA configuration...")
config = LoraConfig(
    r=8,             # rank of the LoRA matrix
    lora_alpha=32,   # scaling factor for the LoRA update
    target_modules=[
        "encoder.encoder.layer.0.attention.attention.query", 
        "encoder.encoder.layer.0.attention.attention.key", 
        "encoder.encoder.layer.0.attention.attention.value", 
        "encoder.encoder.layer.0.attention.output.dense",
        "decoder.transformer.h.0.crossattention.c_attn", 
        "decoder.transformer.h.0.crossattention.q_attn", 
        "decoder.transformer.h.0.crossattention.c_proj"
    ],  # correct layer names
    lora_dropout=0.1,
)
print("LoRA configuration set.")

# Create the PEFT model
print("Creating the PEFT model...")
model = get_peft_model(model, config)
model.to(device)
print("PEFT model created and moved to device.")

# Define training arguments
print("Setting up training arguments...")
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=2,  # Reduced batch size
    gradient_accumulation_steps=4,  # Gradient accumulation
    num_train_epochs=3,
    logging_dir="./logs",
    logging_steps=10,
    save_total_limit=1,
    evaluation_strategy="steps",
    eval_steps=100,
    save_steps=100,
    learning_rate=5e-5,
    weight_decay=0.01,
    predict_with_generate=True,
    fp16=True,  # Enable mixed precision training
)
print("Training arguments set.")

# Create a custom data collator
class DataCollatorForImageCaptioning:
    def __init__(self, feature_extractor, tokenizer):
        self.feature_extractor = feature_extractor
        self.tokenizer = tokenizer

    def __call__(self, batch):
        pixel_values = torch.stack([torch.tensor(item["pixel_values"]) for item in batch])
        labels = torch.stack([torch.tensor(item["labels"]) for item in batch])
        labels[labels == self.tokenizer.pad_token_id] = -100
        return {"pixel_values": pixel_values, "labels": labels}

# Initialize the Trainer
print("Initializing the Trainer...")
data_collator = DataCollatorForImageCaptioning(feature_extractor, tokenizer)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
)
print("Trainer initialized.")

# Fine-tune the model
print("Starting fine-tuning...")
torch.cuda.empty_cache()  # Free up unused memory
trainer.train()
print("Fine-tuning complete.")

# Save the fine-tuned model and tokenizer
print("Saving the fine-tuned model and tokenizer...")
model.save_pretrained("./fine-tuned-model")
tokenizer.save_pretrained("./fine-tuned-model")
print("Fine-tuned model and tokenizer saved.")

# Load and use the fine-tuned model for prediction
print("Loading the fine-tuned model for prediction...")
model = VisionEncoderDecoderModel.from_pretrained("./fine-tuned-model").to(device)
tokenizer = AutoTokenizer.from_pretrained("./fine-tuned-model")
print("Fine-tuned model loaded.")

# Example image URL for prediction
url = "https://images.freeimages.com/images/large-previews/429/plane-1449679.jpg"
print("Fetching example image for prediction...")
response = requests.get(url, stream=True)
image = Image.open(response.raw).convert("RGB")
image.show()

# Preprocess the image for prediction
print("Preprocessing the image for prediction...")
inputs = feature_extractor(images=image, return_tensors="pt").to(device)

# Generate caption
print("Generating caption...")
with torch.no_grad():
    pixel_values = inputs.pixel_values
    generated_ids = model.generate(pixel_values)
    generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

# Print the generated caption
print(f"Generated Caption: {generated_text}")

# Display the image with the caption
plt.imshow(image)
plt.title(generated_text)
plt.axis("off")
plt.show()
print("Prediction complete.")


In [None]:
# Load the original and fine-tuned models for comparison
print("Loading the original and fine-tuned models for comparison...")
original_model = VisionEncoderDecoderModel.from_pretrained(model_name).to(device)
fine_tuned_model = VisionEncoderDecoderModel.from_pretrained("./fine-tuned-model").to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Models loaded.")

# List of image URLs for prediction
image_urls = [
    "https://media.istockphoto.com/id/1500816306/photo/adult-black-male-admiring-the-streets-of-london-on-a-sunny-day-while-holding-a-smartphone-in.webp?s=2048x2048&w=is&k=20&c=0aHww9bA2AdPEUyPkbTiZBFM-ZNeSUY_oZ2nGWAorXI=",
    "https://images.freeimages.com/images/large-previews/71f/my-new-bicycle-1431529.jpg",
    "https://images.freeimages.com/images/large-previews/f02/computer-room-1242684.jpg",
    "https://images.freeimages.com/images/large-previews/4ca/tree-1552037.jpg",
    "https://images.freeimages.com/images/large-previews/15b/lap-cat-1243719.jpg",
    "https://images.freeimages.com/images/large-previews/647/snowy-mountain-1378865.jpg",
    "https://images.freeimages.com/images/large-previews/792/captiol-building-1228390.jpg",
    "https://images.freeimages.com/images/large-previews/429/plane-1449679.jpg"
]

# Iterate through the list of image URLs and make predictions with both models
for url in image_urls:
    print(f"Fetching image from URL: {url}")
    response = requests.get(url, stream=True)
    image = Image.open(response.raw).convert("RGB")

    # Preprocess the image for prediction
    print("Preprocessing the image for prediction...")
    inputs = feature_extractor(images=image, return_tensors="pt").to(device)

    # Generate caption with original model
    print("Generating caption with original model...")
    with torch.no_grad():
        pixel_values = inputs.pixel_values
        original_generated_ids = original_model.generate(pixel_values)
        original_generated_text = tokenizer.batch_decode(original_generated_ids, skip_special_tokens=True)[0]

    # Generate caption with fine-tuned model
    print("Generating caption with fine-tuned model...")
    with torch.no_grad():
        pixel_values = inputs.pixel_values
        fine_tuned_generated_ids = fine_tuned_model.generate(pixel_values)
        fine_tuned_generated_text = tokenizer.batch_decode(fine_tuned_generated_ids, skip_special_tokens=True)[0]

    # Print the generated captions
    print(f"Original Model Caption: {original_generated_text}")
    print(f"Fine-tuned Model Caption: {fine_tuned_generated_text}")

    # Display the image with both captions
    plt.imshow(image)
    plt.title(f"Original: {original_generated_text}\nFine-tuned: {fine_tuned_generated_text}")
    plt.axis("off")
    plt.show()
    print("Prediction complete.")

In [None]:
# Inspect the model layers to identify target modules for LoRA
print("Inspecting model layers to identify target modules for LoRA...")
for name, module in model.named_modules():
    print(name)