In [None]:
from datasets import load_dataset

dataset = load_dataset("arampacha/rsicd")

In [None]:
# print(train_dataset)

In [None]:
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator, notebook_launcher
from tqdm import tqdm

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        # remove batch dimension
        encodings = []
        for caption in item["captions"]:
            encoding = self.processor(images=item["image"], text=caption, padding="max_length", return_tensors="pt")
            encoding = {k:v.squeeze() for k,v in encoding.items()}
            encodings.append(encoding)
        return encodings

In [None]:
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import AdamW, get_linear_schedule_with_warmup
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [None]:
# Initialize the tokenizer, processor, and model
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

In [None]:
train_dataset = CustomDataset(dataset["train"], processor)
val_dataset = CustomDataset(dataset["valid"], processor)

In [None]:
def training_loop(mixed_precision="fp16", num_epochs=3, learning_rate=5e-5):
    # Initialize accelerator
    accelerator = Accelerator(mixed_precision=mixed_precision)
    
    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
    
    # Use DataLoader for efficient batching
    train_loader = DataLoader(train_dataset, batch_size=5, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=5, shuffle=False)
    
    # Set up the optimizer and learning rate scheduler
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)
    
    model, optimizer, train_loader, val_loader = accelerator.prepare(model, optimizer, train_loader, val_loader)
    
    model.train()
    for epoch in range(num_epochs):
        epoch_losses = []  # To store losses for each batch in the epoch

        for idx, encodings in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}", unit="batch")):
            for encoding in encodings:
                input_ids = encoding.pop("input_ids")
                pixel_values = encoding.pop("pixel_values")

                outputs = model(input_ids=input_ids,
                                pixel_values=pixel_values,
                                labels=input_ids)

                loss = outputs.loss
                epoch_losses.append(loss.item())  # Store the loss for this batch

                accelerator.backward(loss)

                optimizer.step()
                optimizer.zero_grad()

        # Calculate and print the average loss for the epoch
        average_loss = sum(epoch_losses) / len(epoch_losses)
        accelerator.print(f"Average Training Loss for Epoch {epoch + 1}: {average_loss}")

        # Validation phase
        model.eval()
        val_losses = []

        with torch.no_grad():
            for val_encodings in tqdm(val_loader, desc="Validation", unit="batch"):
                for val_encoding in val_encodings:
                    val_input_ids = val_encoding.pop("input_ids")
                    val_pixel_values = val_encoding.pop("pixel_values")

                    val_outputs = model(input_ids=val_input_ids,
                                        pixel_values=val_pixel_values,
                                        labels=val_input_ids)

                    val_loss = val_outputs.loss
                    val_losses.append(val_loss.item())

        average_val_loss = sum(val_losses) / len(val_losses)
        accelerator.print(f"Average Validation Loss for Epoch {epoch + 1}: {average_val_loss}")

        # Update learning rate based on validation loss
        scheduler.step(average_val_loss)

        model.train()

    # Save the fine-tuned model
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(
            f"new_model_epoch_{epoch + 1}",
            is_main_process=accelerator.is_main_process,
            save_function=accelerator.save,
        )

In [None]:
args = ("fp16", 5, 5e-7)
notebook_launcher(training_loop, args, num_processes=2)