Text Generation using BLIP2

In [1]:
import torch
from PIL import Image
import pandas as pd
from lavis.models import load_model_and_preprocess

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

In [2]:
# loads BLIP-2 pre-trained model
model, vis_processors, _ = load_model_and_preprocess(name="blip2_t5", model_type="pretrain_flant5xxl", is_eval=True, device=device)

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [2]:
image_qa_df = pd.read_csv("../Dataset/test/DiagramQuestionsData.csv")
image_qa_df["blip_2_generated_answers"] = None
image_qa_df.drop(columns=["Unnamed: 0"], inplace=True)
image_qa_df.head()

Unnamed: 0,lesson_name,question_name,answer_choice_1,answer_choice_2,answer_choice_3,answer_choice_4,correct_answer,image_path,image_has_labels_to_guess,caption,blip_2_generated_answers
0,climate and its causes,Which label refers to rains?,V,T,U,E,E,../Dataset/test/abc_question_images/rain_shado...,Yes,a diagram of the water cycle,
1,climate and its causes,How does water from the clouds reach the land ...,ICE FROM THE MOUNTAIN PEAK,AS RAIN,WIND,GRASS,AS RAIN,../Dataset/test/abc_question_images/rain_shado...,Yes,a diagram of the water cycle,
2,climate and its causes,What letter represents the condensation process?,W,J,H,T,T,../Dataset/test/abc_question_images/rain_shado...,Yes,a diagram of the water cycle,
3,climate and its causes,Where can you find moist air?,W,A,H,J,J,../Dataset/test/abc_question_images/rain_shado...,Yes,a diagram of the water cycle,
4,climate and its causes,Where is condensation?,T,H,A,W,T,../Dataset/test/abc_question_images/rain_shado...,Yes,a diagram of the water cycle,


In [4]:
for idx, question in enumerate(image_qa_df["question_name"]):
    if pd.isna(image_qa_df.loc[idx, "blip_2_generated_answers"]):
        # Build answer choice string
        answer_choice_1 = image_qa_df.loc[idx, "answer_choice_1"]
        answer_choice_2 = image_qa_df.loc[idx, "answer_choice_2"]
        answer_choice_3 = image_qa_df.loc[idx, "answer_choice_3"]
        answer_choice_4 = image_qa_df.loc[idx, "answer_choice_4"]
        answer_string = f"\na. {answer_choice_1}\nb. {answer_choice_2}\nc. {answer_choice_3}\nd. {answer_choice_4}"
        
        # prepare the image
        raw_image = Image.open(image_qa_df.loc[idx, "image_path"]).convert("RGB")
        image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
        
        # Build prompt
        prompt = f"\nQuestion: {question} Choose only one option.\n{answer_string}\n\nAnswer:"
        image_qa_df.loc[idx, "blip_2_generated_answers"] = model.generate({"image": image, "prompt": prompt})[0]
        
        # Continuously store answers in a CSV file in-case of a session timeout
        image_qa_df.to_csv("blip2_answers.csv", index=False)

In [4]:
def get_predictions(filename):
    blip2_answers = pd.read_csv(filename)
    def map_predicted_to_actual(row):
        if row['blip_2_generated_answers'] in ['a.', 'a']:
            return row['answer_choice_1']
        elif row['blip_2_generated_answers'] in ['b', 'b.']:
            return row['answer_choice_2']
        elif row['blip_2_generated_answers'] in ['c', 'c.']:
            return row['answer_choice_3']
        elif row['blip_2_generated_answers'] in ['d.', '(d)', 'd']:
            return row['answer_choice_4']
        else:
            return None

    blip2_answers['mapped_predictions'] = blip2_answers.apply(map_predicted_to_actual, axis=1)
    correct_blip2_preds = blip2_answers[blip2_answers["mapped_predictions"].str.lower() == blip2_answers["correct_answer"].str.lower()]
    print(f"Number of correct predictions: {len(correct_blip2_preds)} out of a total of {len(blip2_answers)} questions.")
    print(f"Accuracy: {len(correct_blip2_preds) / len(blip2_answers) * 100}")

In [5]:
# Calculate accuracy for predictions where image descriptions are not provided
get_predictions("blip2_answers.csv")

Number of correct predictions: 1805 out of a total of 3285 questions.
Accuracy: 54.946727549467276


In [6]:
# Calculate accuracy for predictions on a subset of images where image descriptions are provided.
get_predictions("common_blip2_llava_gpt4_preds.csv")

Number of correct predictions: 212 out of a total of 403 questions.
Accuracy: 52.605459057071954
