<a href="https://colab.research.google.com/github/AhmedMAbdelRashied/Medical-Visual-Question-Answering/blob/main/med_vqa_Final_Code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Medical Visual Question Answering (VQA)
is a combination of medical artificial intelligence and popular VQA challenges. Given a medical image and a clinically relevant question in natural language, the medical VQA system is expected to predict a plausible and convincing answer.

In this notebook, I will introduce a new faster and smaller multimodal architecture for (VQA) Tasks
> Without alot of amount of training



In [None]:
!pip install peft
!pip install wandb
!pip install datasets

Collecting peft
  Downloading peft-0.9.0-py3-none-any.whl (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.9/190.9 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
Collecting accelerate>=0.21.0 (from peft)
  Downloading accelerate-0.27.2-py3-none-any.whl (279 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m280.0/280.0 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: accelerate, peft
Successfully installed accelerate-0.27.2 peft-0.9.0
Collecting wandb
  Downloading wandb-0.16.3-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.42-py3-none-any.whl (195 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m195.4/195.4 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-1.40

In [None]:
import json
import PIL
import numpy as np
import pandas as pd
import random
import copy
from PIL import Image
import tqdm

In [None]:
from transformers import AutoTokenizer, GPT2ForQuestionAnswering,GPT2LMHeadModel
from transformers import PreTrainedModel, PretrainedConfig, AutoImageProcessor
from transformers import AutoModel, AutoConfig
from transformers import GPT2ForQuestionAnswering,GPT2LMHeadModel
from transformers import TrainingArguments, Trainer

In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset
from torchvision import transforms
import torch.nn.functional as F
from datasets import load_metric


In [None]:
from peft import LoraConfig, get_peft_model, TaskType,PeftType ,inject_adapter_in_model
import zipfile
import glob
import wandb

## Load `PMC-VQA` Dataset from huggingface

In [None]:
from huggingface_hub import snapshot_download
snapshot_download(repo_id="xmcmic/PMC-VQA",
                  repo_type="dataset",
                  local_dir='/content/data')

Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

README.md:   0%|          | 0.00/1.49k [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/2.45k [00:00<?, ?B/s]

test_clean.csv:   0%|          | 0.00/419k [00:00<?, ?B/s]

test.csv:   0%|          | 0.00/10.7M [00:00<?, ?B/s]

test_2.csv:   0%|          | 0.00/12.4M [00:00<?, ?B/s]

images_2.zip:   0%|          | 0.00/2.21G [00:00<?, ?B/s]

images.zip:   0%|          | 0.00/18.9G [00:00<?, ?B/s]

train.csv:   0%|          | 0.00/38.1M [00:00<?, ?B/s]

train_2.csv:   0%|          | 0.00/56.7M [00:00<?, ?B/s]

'/content/data'

## Extract images from images.zip file

In [None]:
with zipfile.ZipFile('/content/data/images.zip') as zf:
     for member in tqdm.tqdm(zf.infolist(), desc='Extracting '):
         zf.extract(member)

Extracting : 100%|██████████| 149076/149076 [03:55<00:00, 631.69it/s]


##Helper functions
functions that will be used during data data preparation and model validation

In [None]:
def edit_image_paths(path:str,image:str):
    """
    this function will merge the image names with the given path
        Parameters:
                path : the file path that contains the images
                image: the image name

        Returns:
                merged path for the image

    """
    return path+image

def print_number_of_trainable_model_parameters(model):
    """
    this function will calculate the ratio of the trainable parameters
    for llms
        Parameters:
                model : llm model

        Returns:
         trainable model parameters
         model parameters
         percentage of trainable model parameters
    """
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"""trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage
            of trainable model parameters:
            {100 * trainable_model_params / all_model_params:.2f}%"""
def edit_csv_file(csv_path,img_path='/content/images/'):
    """
    this function will
    1. read csv file from givin path
    2. delete any row that contains null answer
    3. will change the
        Parameters:
                csv_path: path fpr the csv file
                img_path: path to be meged woth the figure path in the csv file

        Returns:
         cleaned csv_file ready for processing
    """
    csv_file=pd.read_csv(csv_path) #read train.csv file
    csv_file.dropna(inplace=True)    # Drop all samples that contains nan as the answer
    csv_file.drop(csv_file[csv_file['Answer'].str.lower().str.contains('n/a')].index,
               axis=0,
               inplace=True)
    csv_file['Figure_path']=csv_file['Figure_path'].apply(
        lambda x: edit_image_paths(img_path,x)
    ) # change the Figure_path to be full path to easily read the imgae
    return csv_file

In [None]:
def prepare_predictions(predictions, labels):
    """
    This function remove all special tokens such as EOS and -100
    param:
        predictions: list of prediction labels
        labels:List of labels
    return:
        clear prediction and label that does not contain EOS and -100 tokens
    """
    pred=[]
    label=[]
    for i in range(len(predictions)):
        if labels[i]==-100:
            continue # Ignore this token
        if labels[i]==tokenizer.eos_token_id:
            break #Stop the process as the end of predictions
        else:
            """
            Append the answer tokens for label and prediction
            Note:
            predictions[i-1] because I did not add SOS token before the prompt
            """
            pred.append(predictions[i-1])
            label.append(labels[i])
    return np.array(pred),np.array(label)

def compute_bleu(predictions, labels):
    """
    This function calculate BLEU score between the prediction and label
    param:
        predictions: list of prediction labels
        labels:List of labels
    return:
        BLEU Score
    """
    pred=[]
    label=[]
    metric = load_metric('bleu')

    for i in range(len(predictions)):
        # remove the special tokens
        sample_predictions,sample_labels=prepare_predictions(predictions[i], labels[i])
        # Decode label and prediction and map it into list of token
        pred.append(tokenizer.decode(sample_predictions).split(' '))
        label.append([tokenizer.decode(sample_labels).split(' ')])
    return metric.compute(predictions=pred, references=label)['bleu']

def compute_accuracy(predictions, labels):
    """
    This function calculate Accuracy score between the prediction and label
    param:
        predictions: list of prediction labels
        labels:List of labels
    return:
        Accuracy Score
    """
    acc=[] # list for ACCs as the accuracy metric must applied on one sample (1D array)
    metric4 = load_metric('accuracy')
    for i in range(len(predictions)): # preocess each sample
        # remove the special tokens
        sample_predictions,sample_labels=prepare_predictions(predictions[i], labels[i])
        acc.append(metric4.compute(predictions=sample_predictions, references=sample_labels)['accuracy'])
    return np.mean(acc) # return the average of accuracies

def compute_metrics(p):
    """
    This function compute_metrics between the prediction and label
    this function will be passed to the training opject
    param:
        p:tuple of logits and labels
    return:
        metrics scores
    """
    predictions, labels = p
    predictions = np.argmax(predictions[0], axis=-1) # Get tokens from logits
    acc=compute_accuracy(predictions, labels)
    bleu=compute_bleu(predictions, labels)
    return {"accuracy": acc,
            'bleu_score':bleu
            }

def compute_all_metrics():
    """
    This function will work as a iterative loop to compute_metrics
    between the prediction and label
    GPU RAM can not process all the validaion datasets
    Thus:
        this function process 20 sample for each step like patch size
        then store the metrics values and return the Average
    param:
    return:
        Average metrics scores
    """
    results={
     'eval_loss': [],
     'eval_accuracy': [],
     'eval_bleu_score': []
    }
    start_index=0 # initial start index
    end_index=20  # initial end index

    while end_index < len(test_csv):
        print(f'From index :{start_index} to index :{end_index}')
        # if end_index is greater than validation datset size then process till the end
        if end_index > len(test_csv):
            test_ds=PMC_VQA_dataset_loader(data_csv=test_csv.iloc[start_index:],
                                image_processor=processor,
                                text_tokenizer=tokenizer,
                                 H=224,
                                 W=224,)
            break
        else:
            test_ds=PMC_VQA_dataset_loader(data_csv=test_csv.iloc[start_index:end_index],
                                image_processor=processor,
                                text_tokenizer=tokenizer,
                                 H=224,
                                 W=224,)
            end_index+=20
            start_index+=20

        trainer = Trainer(
            model=qa_model,
            args=args,
            #train_dataset=train_ds,
            eval_dataset=test_ds,
            compute_metrics=compute_metrics
        )
        result=trainer.evaluate()
        results['eval_loss'].append(result['eval_loss'])
        results['eval_accuracy'].append(result['eval_accuracy'])
        results['eval_bleu_score'].append(result['eval_bleu_score'])
        # if end_index % 1000 == 0:
        #     print(np.mean(results['eval_accuracy']))

    return results

## Reading CSV files
load the csv files that contains data info
like
1. image path
2. question
3. answer

In [None]:
train=edit_csv_file('/content/data/train.csv')
test_csv=edit_csv_file('/content/data/test.csv')

In [None]:
class PMC_VQA_dataset_loader(Dataset):
    """
        A class used to return sample of dataset
        ...

    Attributes
    ----------
        data_csv: pandas dataframe that contains figure path and questions and answers
        text_tokenizer: LLM tokenizer opject
        image_processor: VIT image processor opject
        mode: string that describe the mode (default Train)
        H: the hight of the image (default 512)
        W: the width of the image (default 512)
        text_type: string that control the type of generated question (default random)
        image_tokens:int the number of image vectors that will be concatenated
            during model process like forwaed ad generat (default 257)
        seq_length: int the sequence length for text prompts (default 512)

    """

    def __init__(self,data_csv,
                 text_tokenizer,
                 image_processor,
                 mode = 'Train',
                 H=512,
                 W=512,
                 text_type = 'random',
                 image_tokens= 257,
                 seq_length = 512
                 ):
        self.img_padding = [-100 for i in range(image_tokens)]
        self.image_processor=image_processor
        self.text_tokenizer=text_tokenizer
        self.data=data_csv
        self.text_type = text_type
        self.seq_length=seq_length
        normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.transform = transforms.Compose([
                transforms.Resize((H,W), interpolation=Image.BICUBIC),
                transforms.ToTensor(),
                normalize,
            ]) # note we do not nedd the transform

        self.mode = mode
    def __len__(self):
        return len(self.data)


    def random_answer(self,Question, choice_list,Answer):
        """
        this function will
        generate a question based on mode
        Parameters:
                Question: str the question
                choice_list: list of choices for the quetion
                Answer: str the answer for the question

        Returns:
         pre_text: prompt that contains the question and itschoices
         final_o : is the Pre_text with the answer
        """
        p=random.random() # generate random probability
        Combined_choice = f"""
        Choices:
        {choice_list[0]}
        {choice_list[1]}
        {choice_list[2]}
        {choice_list[3]}
        """ # build choices using template
        # bulid prompt based on mode (random)
        if self.text_type =='random':
            if p<=0.50:
                pre_text = f"""Question:
                {Question}
                The Answer is:"""
                final_o =f"""Question:
                {Question}
                The Answer is:{Answer} """
            if p>=0.50:
                pre_text = f"""Question:
                {Question}
                Choices:
                {Combined_choice}
                The Answer is:"""

                final_o =f"""Question:
                {Question}
                Choices:
                {Combined_choice}
                The Answer is:{Answer} """
            # bulid prompt based on mode (blank)
            if self.text_type =='blank':
                pre_text = f"""Question:
                {Question}
                The Answer is:"""
                final_o =f"""Question:
                {Question}
                The Answer is:{Answer} """
            # bulid prompt based on mode (choice)
            if self.text_type =='choice':
                pre_text = f"""Question:
                {Question}
                Choices:
                {Combined_choice}
                The Answer is:"""

                final_o =f"""Question:
                {Question}
                Choices:
                {Combined_choice}
                The Answer is:{Answer} """
        return pre_text,final_o

    def __getitem__(self, idx):
        'Generates one sample of data'
        # Select sample
        sample = self.data.iloc[idx]    # get the sample info
        Question  = sample['Question']  # get the question
        # get all Choices
        Choice_A = sample['Choice A']
        Choice_B = sample['Choice B']
        Choice_C = sample['Choice C']
        Choice_D = sample['Choice D']

        Figure_path	= sample['Figure_path'] # get the figure path
        Anwser = sample['Answer'] # get the answer
        choice_list = [Choice_A,Choice_B,Choice_C,Choice_D] # build list ofChoices
        img_source = Image.open(Figure_path) # read the image using PIL lib
        #img_source=Image.fromarray((img_source * 255).astype(np.uint8)).convert('RGB')
        #img=self.transform(img_source)
        image=self.image_processor(img_source, return_tensors="pt")
        #convert the imag into batches to be processed using vit model
        pre_text,final_o = self.random_answer(Question,choice_list,Anwser)
        # build the prompts

        if self.mode=='Train':
            final_o = self.text_tokenizer(final_o)
            input_ids = final_o['input_ids']
            input_ids.append(self.text_tokenizer.eos_token_id) # add end of sentence
            input_ids = np.array(input_ids)

            # padding the prompts
            if len(input_ids) < self.seq_length:
                input_ids = np.pad(input_ids, (0, self.seq_length - len(input_ids)), 'constant', constant_values=0)
            else:
                input_ids = input_ids[:self.seq_length]

            label = copy.deepcopy(input_ids) # copy input_ids as a lebel
            # change all pad tokens into -100 to be ignored while while calculate the loss and accuracy
            label[label==0] = -100
            pre_text = self.text_tokenizer(pre_text)
            # change all question tokens into -100 to be ignored while while calculate the loss and accuracy
            label[:len(pre_text['input_ids'])] = -100

            label = label.tolist()
            label = np.array(self.img_padding + label)# pre padding the prompt by image tokens
            """
                Note:
                decoder only models process pervious tokens to generate he next token
                while trainig the image tokens will be in front of the question
                thus:
                the decoder output will contains all image tokens
                I add the -100 to ignore this tokens in validation
                and loss and evaluation metrics calculations
            """
            item={
                'images':torch.squeeze(image['pixel_values']),
                'input_ids':torch.torch.from_numpy(input_ids),
                'labels': torch.torch.from_numpy(label)
            }

        else:
            pre_text = self.text_tokenizer(pre_text)
            input_ids = np.array(pre_text['input_ids'])
            if len(input_ids) < self.seq_length:

                input_ids = np.pad(input_ids, (0, self.seq_length - len(input_ids)), 'constant', constant_values=0)
                print()
            else:
                input_ids = input_ids[:self.seq_length]
            item={
                'images':torch.squeeze(image['pixel_values']),
                'input_ids':torch.torch.from_numpy(input_ids)
            }


        return item

In [None]:
class MyConfig(PretrainedConfig):
    """
    A class used to save Med-VQA configurations
    """
    _name_or_path= "ahmedabdelrashied/MedVQA",
    model_type = 'visual-question-answering'
    vit_model='facebook/dinov2-base'
    llm_model='gpt2'
    vit_ffn_dim=768
    ffn_dim=1024
    llm_ffn_dim=768
    def __init__(self,**kwargs):
        super().__init__(**kwargs)

In [None]:
class QA_model(PreTrainedModel):
    """
    A class hat contains the Med-VQA Model architecture

    """
    config_class=MyConfig
    def __init__(self,
                 config
                 ):

        super().__init__(config)
        # load VIT Model
        self.vit_model=AutoModel.from_pretrained(config.vit_model)
        # convert vit model into PEFT
        self.vit_model=self.peft_vit( self.vit_model)
        # load LLM model
        self.llm_model=GPT2LMHeadModel.from_pretrained(config.llm_model)
        # convert LLM into Peft
        self.llm_model=self.peft_llm(self.llm_model)
        self.vit_ffn_dim=config.vit_ffn_dim
        self.llm_ffn_dim=config.llm_ffn_dim
        self.ffn_dim=config.ffn_dim,
        self.relu=nn.ReLU()
        # load the LLM embedding layer
        self.text_embeddings=self.llm_model.get_input_embeddings()
        # Projection layers
        self.vit_ffn=nn.Linear(config.vit_ffn_dim,config.ffn_dim)
        self.hidden_ffn=nn.Linear(in_features=config.ffn_dim,
                           out_features=config.ffn_dim)
        self.image_to_text_ffn=nn.Linear(in_features=config.ffn_dim,
                           out_features=config.llm_ffn_dim)

        self.config=config

    def peft_llm(self,model):
        """
        This function convert Vit model into peft
        parms:
            Model: vit Model object that will be converted
        return:
            PEFT Model
        """

        llm_lora_config=llm_lora_config = LoraConfig(
            r=32, # Rank
            lora_alpha=32,
            target_modules=['c_attn',
                            'c_proj',
                            'c_fc',
                            'qa_outputs'
                            ],
            lora_dropout=0.05,
            bias="none",
            peft_type=PeftType.LORA,
            fan_in_fan_out=True,

        )
        return get_peft_model(model,
                            llm_lora_config)

    def peft_vit(self,model):
        """
        This function convert LLM model into peft
        parms:
            Model: LLM Model object that will be converted
        return:
            PEFT Model
        """
        vit_lora_config=LoraConfig(
        r=32, # Rank
        lora_alpha=32,
        target_modules=['query',
                        'key',
                        'value'
                        ],
        lora_dropout=0.05,
        bias="none",
        peft_type=PeftType.LORA,

        )
        return get_peft_model(model,
                            vit_lora_config)

    def change_to_peft_lora(self,model,lora_config):
        """
        This function convert any model into peft
        parms:
            Model: Model object that will be converted
            lora_config: lora configuration
        return:
            PEFT Model
        """
        return get_peft_model(model,
                            lora_config)

    def forward(self,
                input_ids,
                images,
                labels=None,
                ):
        """
        This function do one forward step
        parms:
            input_ids: text tokens input_ids
            images: 'pixel_values' of the image
            labels: labels IDs
        """
        ###### Vision encoder part #####################
        vit_out=self.vit_model(images)
        x=vit_out.last_hidden_state
        x=self.relu(self.vit_ffn(x))
        x=self.hidden_ffn(x)
        incoded_image_features=self.relu(self.image_to_text_ffn(x))
        incoded_image_features=torch.squeeze(incoded_image_features)

        ########### text Decoder part ##############################
        text_features=self.text_embeddings(input_ids)
        if len(incoded_image_features.shape)==2:
            incoded_image_features=torch.unsqueeze(incoded_image_features,dim=0)

        input_embedding =torch.cat((incoded_image_features,text_features),dim=1)

        output = self.llm_model(inputs_embeds = input_embedding, labels = labels)
        return output

    def generate(self,
                input_ids,
                images,
                args=None
               ):
        """
        This function generate text
        parms:
            input_ids: text tokens input_ids
            images: 'pixel_values' of the image
            args: dict contains any arguments that can be passed to LLM models generate function
        """
        with torch.no_grad():
            ###### Vision encoder part #####################
            vit_out=self.vit_model(images)
            x=vit_out.last_hidden_state
            x=self.relu(self.vit_ffn(x))
            x=self.relu(self.hidden_ffn(x))
            incoded_image_features=self.relu(self.image_to_text_ffn(x))
            if len(incoded_image_features.shape)==2:
                incoded_image_features=torch.unsqueeze(incoded_image_features,dim=0)
            text_features=self.text_embeddings(input_ids)

            input_embedding =torch.cat((incoded_image_features,text_features),dim=1)
            ########### text Decoder part ##############################
            if args== None:
                output = self.llm_model.generate(inputs_embeds = input_embedding)
            else :
                output = self.llm_model.generate(inputs_embeds = input_embedding,**args)
        return output

# Load the pretrained

1. Note this step uesd after the first epoch
2. first epoch was model initialization

In [None]:
qa_model = QA_model.from_pretrained('ahmedabdelrashied/MedVQA')


config.json:   0%|          | 0.00/200 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/881M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/548 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/69.0 [00:00<?, ?B/s]

In [None]:
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
tokenizer = AutoTokenizer.from_pretrained("gpt2")

preprocessor_config.json:   0%|          | 0.00/436 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

# create instance of dataloader for train and validation datasetd

In [None]:
train_ds=PMC_VQA_dataset_loader(data_csv=train,
                            image_processor=processor,
                            text_tokenizer=tokenizer,
                             H=224,
                             W=224)

test_ds=PMC_VQA_dataset_loader(data_csv=test_csv,
                                image_processor=processor,
                                text_tokenizer=tokenizer,
                                 H=224,
                                 W=224)

In [None]:
wandb.init(mode='disabled') # To close the the online loging



## Training

In [None]:
args = TrainingArguments(
    output_dir="MedVQA",
    evaluation_strategy="epoch",
    #eval_steps=8,
    save_strategy="no",
    num_train_epochs=1,
    seed=0,
    fp16=True,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,
    load_best_model_at_end=False,
    logging_strategy='epoch',

)


In [None]:
trainer = Trainer(
            model=qa_model,
            args=args,
            train_dataset=train_ds,
            eval_dataset=test_ds,
            compute_metrics=compute_metrics
        )

In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

# Model evaluation
in this notebook I will evaluate my code using `pmc-vqa` dataset.
I will use `BLEU` and `Accuracy` as the avaluation metrics


In [None]:
args = TrainingArguments(
    output_dir="MedVQA",
    evaluation_strategy="epoch",
    #eval_steps=8,
    save_strategy="no",
    num_train_epochs=1,
    seed=0,
    fp16=True,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=1,
    load_best_model_at_end=False,
    logging_strategy='epoch',

)


In [None]:
results=compute_all_metrics()

In [None]:
results['eval_loss']=np.mean(results['eval_loss'])
results['eval_accuracy']=np.mean(results['eval_accuracy'])
results['eval_bleu_score']=np.mean(results['eval_bleu_score'])
results