In [1]:
!pip install peft

Collecting peft
  Obtaining dependency information for peft from https://files.pythonhosted.org/packages/14/0b/8402305043884c76a9d98e5e924c3f2211c75b02acd5b742e6c45d70506d/peft-0.6.2-py3-none-any.whl.metadata
  Downloading peft-0.6.2-py3-none-any.whl.metadata (23 kB)
Downloading peft-0.6.2-py3-none-any.whl (174 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m174.7/174.7 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: peft
Successfully installed peft-0.6.2


In [2]:
from torch.utils.data import Dataset
from PIL import Image
import os
from sklearn.metrics import accuracy_score
import torch
from transformers import BlipForQuestionAnswering, AdamW, AutoProcessor
from torch.utils.data import DataLoader
from peft import get_peft_model, LoraConfig
from tqdm import tqdm
import wandb
from torchvision import transforms

# Login to wandb
wandb.login(key="11045189c6a87e054bc175e57214d6d03c4d47b3")

# Initialize a wandb run
wandb.init(project="MediVQA", entity="dineshbond1453")

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mdineshbond1453[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
from torch.utils.data import Dataset
from PIL import Image
import os

class VQAMedDataset(Dataset):
    def __init__(self, qa_pairs_path, image_dir, transform=None):
        """
        Args:
            qa_pairs_path (str): Path to the file containing QA pairs.
            image_dir (str): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on the image.
        """
        with open(qa_pairs_path, 'r', encoding="utf-8") as f:
            lines = f.readlines()
            self.data = [line.strip().split('|') for line in lines]

        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_id, question, answer = self.data[idx]
        image_path = os.path.join(self.image_dir, f"{image_id}.jpg")
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, question, answer

In [4]:
# the path to the training images
train_image_dir = "/kaggle/input/no-val-med/NO_VAL_DATA/Training"
qa_pairs_path = "/kaggle/input/no-val-med/NO_VAL_DATA/TRAIN.txt"

# Instantiate the dataset (without image transformations for now)
sample_dataset = VQAMedDataset(qa_pairs_path, train_image_dir)

# Check a sample from the dataset
sample_dataset[0]

(<PIL.Image.Image image mode=RGB size=1024x659>,
 'what kind of image is this?',
 'cta - ct angiography')

In [5]:
sample_dataset[3]

(<PIL.Image.Image image mode=RGB size=432x709>,
 'is this a noncontrast mri?',
 'no')

In [6]:
# Custom collate function
def vqa_collate(batch):
    """
    Custom collate function for our VQA dataset.
    
    Args:
        batch (list): List of samples fetched from the VQAMedDataset.
    
    Returns:
        tuple: Contains batched images, questions, and answers.
    """
    # Unzip the batch data
    images, questions, answers = zip(*batch)
    
    # Stack images
    images = torch.stack(images, 0)
    
    return images, questions, answers

# Transformations for the images
transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
])

In [7]:
# Create the dataset and data loader
train_dataset = VQAMedDataset(qa_pairs_path, train_image_dir, transform=transform)
data_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=vqa_collate)

In [8]:
# Initialize model and optimizer
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")

Downloading config.json:   0%|          | 0.00/4.56k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.54G [00:00<?, ?B/s]

Downloading (…)rocessor_config.json:   0%|          | 0.00/445 [00:00<?, ?B/s]

Downloading tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

Downloading vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

In [9]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

print_trainable_parameters(model)

trainable params: 384672572 || all params: 384672572 || trainable%: 100.0


In [10]:
# Define LoRA configuration for the BLIP VQA model
config = LoraConfig(
    r=16,  # Rank of LoRA, adjust as needed
    lora_alpha=32,  # Scaling factor, adjust as needed
    lora_dropout=0.1,  # Dropout for LoRA layers, adjust as needed
    bias="none",  # Bias configuration for LoRA layers
    target_modules=["query", "value"]  # Target modules in the Transformer to apply LoRA
 )

# Acquire the LoRA-adapted model
peft_model = get_peft_model(model, config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
peft_model.to(device)
print_trainable_parameters(peft_model)

trainable params: 2359296 || all params: 387031868 || trainable%: 0.6095870120958619


In [11]:
# Initialize processor and tokenizer
optimizer = AdamW(peft_model.parameters(), lr=5e-5)
tokenizer = processor.tokenizer



In [12]:
best_loss = float('inf')  # Initialize the best loss to a very high value
best_model_path = "best_model.pth"  # Path where the best model will be saved

for epoch in range(5):
    progress_bar = tqdm(data_loader, desc=f"Epoch {epoch + 1}/{5}")
    total_loss = 0
    num_batches = 0

    for batch in progress_bar:
        images, questions, answers = batch

        # Tokenize questions and answers
        inputs = tokenizer(questions, return_tensors="pt", padding=True, truncation=True)
        targets = tokenizer(answers, return_tensors="pt", padding=True, truncation=True)

        # Move to device
        images = images.to(device)
        inputs = {key: val.to(device) for key, val in inputs.items()}
        targets = targets["input_ids"].to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = peft_model(pixel_values=images, input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], labels=targets)
        loss = outputs.loss

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Update total loss and batch count
        total_loss += loss.item()
        num_batches += 1

        # Update progress bar
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})

    # Calculate average loss for the epoch
    epoch_loss = total_loss / num_batches
    wandb.log({'epoch': epoch, 'loss': epoch_loss})

    # Save the best model
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(peft_model.state_dict(), best_model_path)
        print(f"Epoch {epoch+1}: New best model saved with loss {best_loss:.4f}")

Epoch 1/5:   0%|          | 0/1411 [00:00<?, ?it/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
Epoch 1/5: 100%|██████████| 1411/1411 [15:49<00:00,  1.49it/s, loss=3.7894]


Epoch 1: New best model saved with loss 7.6012


Epoch 2/5: 100%|██████████| 1411/1411 [15:35<00:00,  1.51it/s, loss=2.2171]


Epoch 2: New best model saved with loss 7.1296


Epoch 3/5: 100%|██████████| 1411/1411 [15:30<00:00,  1.52it/s, loss=3.1936]


Epoch 3: New best model saved with loss 7.0291


Epoch 4/5: 100%|██████████| 1411/1411 [15:32<00:00,  1.51it/s, loss=5.4841]


Epoch 4: New best model saved with loss 6.9709


Epoch 5/5: 100%|██████████| 1411/1411 [15:30<00:00,  1.52it/s, loss=5.0804]


Epoch 5: New best model saved with loss 6.9274


In [13]:
# # Save the model locally
# save_directory = "/kaggle/working/"
# peft_model.save_pretrained(save_directory)
# tokenizer.save_pretrained(save_directory)

In [35]:
# Preprocess the Image
def preprocess_image(image_path):
    image = Image.open(image_path).convert("RGB")
    
    # Define the transformations: resize the image, convert to tensor, and normalize
    transform = transforms.Compose([
        transforms.Resize((384, 384)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
    ])
    
    image_tensor = transform(image).unsqueeze(0)  # Adding batch dimension
    return image_tensor

In [37]:
# Time to test
image_path = "/kaggle/input/combined-all-data/Data/MED/Test_Images/synpic54082.jpg"
image_tensor = preprocess_image(image_path)

# Tokenize the Question
question = "Which modality is displayed?"
inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True)
inputs = {key: val.to(device) for key, val in inputs.items()}

In [38]:
# Model Inference
with torch.no_grad():
    image_tensor = image_tensor.to(device)
    generated_ids = peft_model.generate(input_ids=inputs["input_ids"].to(device), 
                                   attention_mask=inputs["attention_mask"].to(device), 
                                   pixel_values=image_tensor)



In [39]:
# Decode the Answer
predicted_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

print(question)
print(predicted_answer)

Which modality is displayed?
ct


In [42]:
test_dataset = VQAMedDataset("/kaggle/input/test-text/vqa_rad_test_converted.txt", "/kaggle/input/combined-all-data/Data/RAD/Images", transform=transform)
test_data_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

test_dataset[0]

(tensor([[[-1.6609, -1.6609, -1.6463,  ..., -1.5879, -1.5733, -1.5879],
          [-1.6609, -1.6463, -1.6463,  ..., -1.5879, -1.5733, -1.5879],
          [-1.6463, -1.6463, -1.6317,  ..., -1.5879, -1.5733, -1.5879],
          ...,
          [-1.7777, -1.7777, -1.7777,  ..., -1.7777, -1.7777, -1.7777],
          [-1.7777, -1.7777, -1.7777,  ..., -1.7777, -1.7777, -1.7777],
          [-1.7777, -1.7777, -1.7777,  ..., -1.7777, -1.7777, -1.7777]],
 
         [[-1.6170, -1.6170, -1.6020,  ..., -1.5420, -1.5270, -1.5420],
          [-1.6170, -1.6020, -1.6020,  ..., -1.5420, -1.5270, -1.5420],
          [-1.6020, -1.6020, -1.5870,  ..., -1.5420, -1.5270, -1.5420],
          ...,
          [-1.7371, -1.7371, -1.7371,  ..., -1.7371, -1.7371, -1.7371],
          [-1.7371, -1.7371, -1.7371,  ..., -1.7371, -1.7371, -1.7371],
          [-1.7371, -1.7371, -1.7371,  ..., -1.7371, -1.7371, -1.7371]],
 
         [[-1.3522, -1.3522, -1.3380,  ..., -1.2811, -1.2669, -1.2811],
          [-1.3522, -1.3380,

In [43]:
def predict_answers(model, data_loader, device):
    model.eval()
    predictions = []
    with torch.no_grad():
        for images, questions, _ in data_loader:
            inputs = tokenizer(questions, return_tensors="pt", padding=True, truncation=True)
            images = images.to(device)
            inputs = {key: val.to(device) for key, val in inputs.items()}
            
            outputs = model.generate(pixel_values=images, input_ids=inputs["input_ids"])
            decoded_predictions = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
            predictions.extend(decoded_predictions)
    return predictions

predicted_answers = predict_answers(peft_model, test_data_loader, device)

In [44]:
def classify_answer(answer):
    return 'OPEN' if answer.lower() not in ['yes', 'no'] else 'CLOSED'

predicted_answer_types = [classify_answer(ans) for ans in predicted_answers]

In [45]:
ground_truth = []
with open("/kaggle/input/combined-all-data/Data/RAD/rad_test.txt", 'r', encoding="utf-8") as file:
    for line in file:
        _, _, _, answer_type = line.strip().split('|')
        ground_truth.append(answer_type)

In [46]:
open_accuracy = accuracy_score(
    [gt for gt, pred in zip(ground_truth, predicted_answer_types) if gt == 'OPEN'],
    [pred for gt, pred in zip(ground_truth, predicted_answer_types) if gt == 'OPEN']
)
closed_accuracy = accuracy_score(
    [gt for gt, pred in zip(ground_truth, predicted_answer_types) if gt == 'CLOSED'],
    [pred for gt, pred in zip(ground_truth, predicted_answer_types) if gt == 'CLOSED']
)
overall_accuracy = accuracy_score(ground_truth, predicted_answer_types)

print(f"Open Accuracy: {open_accuracy}")
print(f"Closed Accuracy: {closed_accuracy}")
print(f"Overall Accuracy: {overall_accuracy}")

Open Accuracy: 0.9876543209876543
Closed Accuracy: 0.8611111111111112
Overall Accuracy: 0.9066666666666666


In [47]:
# To login in HF
!huggingface-cli login --token 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

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [48]:
# Push the model in HF
model.push_to_hub("Final-BLIP-LORA")
tokenizer.push_to_hub("Final-BLIP-LORA")
processor.push_to_hub("Final-BLIP-LORA")

model.safetensors:   0%|          | 0.00/1.54G [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/dineshcr7/Final-BLIP-LORA/commit/29c02628a2b8b3e2740a450f6fd6dd7f87a352b6', commit_message='Upload processor', commit_description='', oid='29c02628a2b8b3e2740a450f6fd6dd7f87a352b6', pr_url=None, pr_revision=None, pr_num=None)