## Requirement Installation

In [None]:
!pip install transformers[torch] datasets -q

## Import Libraries

In [None]:
from datasets import load_from_disk
import requests
from PIL import Image
import torch
from transformers import BlipProcessor, BlipForQuestionAnswering,BlipImageProcessor, AutoProcessor
from transformers import BlipConfig
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load the Dataset

In [None]:
dataset = load_from_disk("/kaggle/input/medical-vqa-dataset/VQA_Medical_Dataset")
print(dataset)

## Sample Visualization

In [None]:
sample = dataset['train'][1]
PIL_image = Image.fromarray(np.array(sample['image'])).convert('RGB')
plt.imshow(sample['image'].convert('RGB'))
print("Question: {}".format(sample['question']))
print("Answer: {}".format(sample['answer']))

In [None]:
from transformers import BlipConfig

try:
    # Attempt to load the config without local files
    config = BlipConfig.from_pretrained("Salesforce/blip-vqa-base")
except Exception as e:
    print(f"Error occurred: {e}. Falling back to local files only.")
    # Fallback to loading the config with local files only
    config = BlipConfig.from_pretrained("Salesforce/blip-vqa-base", local_files_only=True)

# Continue with the rest of your code
print("Config loaded successfully:", config)

In [None]:
train_sample = dataset['train'] 
val_sample = dataset['validation']


## Build Data-loader

In [None]:
class VQADataset(torch.utils.data.Dataset):
    def __init__(self, data, segment, text_processor, image_processor):
        self.data = data
        self.questions = data['question']
        self.answers = data['answer']
        self.text_processor = text_processor
        self.image_processor = image_processor
        self.max_length = 32
        self.image_height = 128
        self.image_width = 128

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # get image + text
        answers = self.answers[idx]
        questions = self.questions[idx]
        image = self.data[idx]['image'].convert('RGB')
        text = self.questions[idx]

        image_encoding = self.image_processor(image,
                                  do_resize=True,
                                  size=(self.image_height,self.image_width),
                                  return_tensors="pt")

        encoding = self.text_processor(
                                  None,
                                  text,
                                  padding="max_length",
                                  truncation=True,
                                  max_length = self.max_length,
                                  return_tensors="pt"
                                  )
        # remove batch dimension
        for k,v in encoding.items():
            encoding[k] = v.squeeze()
        encoding["pixel_values"] = image_encoding["pixel_values"][0]
        
        # add labels
        labels = self.text_processor.tokenizer.encode(
            answers,
            max_length= self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors='pt'
        )[0]
        encoding["labels"] = labels

        return encoding

In [None]:
from transformers import BlipProcessor, BlipImageProcessor

# Load text processor with error handling
try:
    text_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
    print("Text processor loaded successfully from the Hugging Face model hub.")
except Exception as e:
    print(f"Error loading text processor: {e}. Falling back to local files only.")
    text_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base", local_files_only=True)

# Load image processor with error handling
try:
    image_processor = BlipImageProcessor.from_pretrained("Salesforce/blip-vqa-base")
    print("Image processor loaded successfully from the Hugging Face model hub.")
except Exception as e:
    print(f"Error loading image processor: {e}. Falling back to local files only.")
    image_processor = BlipImageProcessor.from_pretrained("Salesforce/blip-vqa-base", local_files_only=True)

# Print confirmation
print("Text processor and image processor loaded successfully.")

In [None]:
train_vqa_dataset = VQADataset(data=train_sample,
                     segment='train',
                     text_processor = text_processor,
                     image_processor = image_processor
                         )

val_vqa_dataset = VQADataset(data=val_sample,
                     segment='validation',
                     text_processor = text_processor,
                     image_processor = image_processor
                         )

In [None]:
train_vqa_dataset[0]

In [None]:
def collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    pixel_values = [item['pixel_values'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    labels = [item['labels'] for item in batch]
    # create new batch
    batch = {}
    batch['input_ids'] = torch.stack(input_ids)
    batch['attention_mask'] = torch.stack(attention_mask)
    batch['pixel_values'] = torch.stack(pixel_values)
    batch['labels'] = torch.stack(labels)

    return batch

train_dataloader = DataLoader(train_vqa_dataset,
                              collate_fn=collate_fn,
                              batch_size=32,
                              shuffle=False)
val_dataloader = DataLoader(val_vqa_dataset,
                            collate_fn=collate_fn,
                            batch_size=32,
                            shuffle=False)

In [None]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
    print(k, v.shape)

## Build Model

In [None]:
try:
    # Attempt to load the model without local files
    model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
    print("Model loaded successfully from the Hugging Face model hub.")
except Exception as e:
    print(f"Error occurred: {e}. Falling back to local files only.")
    # Fallback to loading the model with local files only
    model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base", local_files_only=True)

# Move the model to the specified device (GPU or CPU)
model.to(device)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
image_mean = image_processor.image_mean
image_std = image_processor.image_std

In [None]:
batch_idx = 1

unnormalized_image = (batch["pixel_values"][batch_idx].cpu().numpy() * np.array(image_std)[:, None, None]) + np.array(image_mean)[:, None, None]
unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
unnormalized_image = (unnormalized_image * 255).astype(np.uint8)

print("Question: ",text_processor.decode(batch["input_ids"][batch_idx]))
print("Answer: ",text_processor.decode(batch["labels"][batch_idx]))
plt.imshow(Image.fromarray(unnormalized_image))

## Model Training

In [None]:
from tqdm import tqdm
import torch
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import StepLR
from nltk.translate.bleu_score import sentence_bleu
import nltk

nltk.download('punkt')  # Ensure BLEU tokenizer is available

# Hyperparameters
epochs = 50
accumulation_steps = 4  # Gradients will be accumulated over 4 mini-batches
max_grad_norm = 1.0
best_bleu = 0.0  # Track best BLEU score

# Learning Rate Scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=0.1)

# Mixed Precision Training
scaler = GradScaler()

# Start training
for epoch in range(epochs):
    print(f"Epoch: {epoch}")
    model.train()
    
    total_loss = []
    for i, batch in enumerate(tqdm(train_dataloader, disable=False)):
        batch = {k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad()

        with autocast():  
            outputs = model(**batch)
            loss = outputs.loss
        
        # Gradient Accumulation
        scaler.scale(loss).backward()

        if (i + 1) % accumulation_steps == 0:
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            # Update optimizer step
            scaler.step(optimizer)
            scaler.update()

            optimizer.zero_grad()

        total_loss.append(loss.item())

    # Learning Rate Scheduling
    scheduler.step()

    # Validation with BLEU score
    model.eval()
    val_loss = 0
    total_bleu = 0
    count = 0
    
    with torch.no_grad():
        for batch in val_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            
            with autocast():
                # Compute loss separately
                outputs = model(**batch)
                val_loss += outputs.loss.item()
                
                # Generate predictions
                generated_outputs = model.generate(
                    pixel_values=batch['pixel_values'], 
                    input_ids=batch['input_ids'],
                    max_new_tokens=50  # Explicitly set max_new_tokens
                )
    
            # Decode predictions and references
            predicted_answer = text_processor.decode(generated_outputs[0], skip_special_tokens=True)
            reference_answer = text_processor.decode(batch['labels'][0], skip_special_tokens=True)
    
            # Compute BLEU score
            reference_tokens = [reference_answer.split()]
            predicted_tokens = predicted_answer.split()
            bleu_score = sentence_bleu(reference_tokens, predicted_tokens)
            total_bleu += bleu_score
            count += 1



    val_loss /= len(val_dataloader)
    avg_bleu = total_bleu / count  # Compute average BLEU score

    print(f"Epoch {epoch} - Training Loss: {sum(total_loss)}, Validation Loss: {val_loss}, BLEU Score: {avg_bleu:.4f}")

    # Save best model based on BLEU score
    if avg_bleu > best_bleu:
        best_bleu = avg_bleu
        fine_tuned_model_path = "/kaggle/working/medical_vqa_blip"

        # Save fine-tuned model
        model.save_pretrained("/kaggle/working/medical_vqa_blip")

        # Save tokenizer and image processor
        text_processor.save_pretrained("/kaggle/working/medical_vqa_blip/text_processor")
        image_processor.save_pretrained("/kaggle/working/medical_vqa_blip/image_processor")
        print(f"Fine-tuned model and processor saved to {fine_tuned_model_path} with BLEU Score: {best_bleu:.4f}")

print("Training complete.")


## Inference

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Assuming `image_mean` and `image_std` are defined as:
image_mean = [0.48145466, 0.4578275, 0.40821073]  # Standard mean values
image_std = [0.26862954, 0.26130258, 0.27577711]  # Standard deviation values

# Loop over validation dataset
for x in range(50):
    sample = val_vqa_dataset[x]

    # Decode question
    question_text = text_processor.decode(sample['input_ids'], skip_special_tokens=True)

    # Convert sample to batch format and move to GPU
    sample = {k: v.unsqueeze(0).to(device) for k, v in sample.items()}

    # Generate prediction
    with torch.no_grad():
        outputs = model.generate(pixel_values=sample['pixel_values'], input_ids=sample['input_ids'])

    # Decode predicted and actual answers
    predicted_answer = text_processor.decode(outputs[0], skip_special_tokens=True)
    actual_answer = text_processor.decode(sample['labels'][0], skip_special_tokens=True) if 'labels' in sample else "N/A"

    # Unnormalize image
    unnormalized_image = (sample["pixel_values"][0].cpu().numpy() * np.array(image_std)[:, None, None]) + np.array(image_mean)[:, None, None]
    unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
    unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)  # Convert from (C, H, W) to (H, W, C)

    # Display the image with Matplotlib
    plt.figure(figsize=(6, 6))
    plt.imshow(unnormalized_image)
    plt.axis("off")
    plt.title(f"Question: {question_text}\nTrue Answer: {actual_answer}\nPredicted: {predicted_answer}")
    plt.show()

    print("###################################################################")
