<a href="https://colab.research.google.com/github/francisco-renteria-rios/VQA_Assistance/blob/main/VQA_Grocery_Assistant_App.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VQA Grocery Assistant

This notebook implements an app that read a labeled image, asks a question about the image, and get the predicted answer.
<br>The app uses a strong, pre-trained model (BLIP-2) to answer questions about grocery labels, but first extracts text using EasyOCR and includes that text in the model's prompt.
This significantly improves accuracy when reading small, dense text like ingredient lists.

## We need to install some Dependencies
We need `transformers` for the VQA model (BLIP-2), `accelerate` for performance, `easyocr` for the text extraction, and `datasets` to load the VizWiz data.

In [None]:
# Install necessary libraries
# EasyOCR will download language models on first use.
!pip install -q transformers datasets accelerate torch torchvision pillow tqdm easyocr
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from PIL import Image
import easyocr
from datasets import load_dataset
import random
from tqdm.auto import tqdm
import io
from google.colab import userdata

# Load the 'HF_TOKEN' secret and store it in an environment variable
import os
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')

## Setup and Model Loading
We will load the pre-trained BLIP-2 model and the EasyOCR reader.

In [None]:
#import torch
#from transformers import Blip2Processor, Blip2ForConditionalGeneration
import easyocr # Added this import

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load BLIP-2 Model and Processor
try:
    processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
    # Load model with float16 for reduced memory usage on GPU
    model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
    model.to(device)
except Exception as e:
    print(f"Error loading BLIP-2 model. Falling back to CPU/float32. Error: {e}")
    # Fallback to CPU for demonstration if necessary
    processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
    model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
    model.to(device)

# Load EasyOCR Reader (English, common for grocery labels)
ocr_reader = easyocr.Reader(['en'], gpu=True if device=='cuda' else False)
print("Models and OCR reader loaded successfully!")

## OCR + Image Pipeline
This python function implements the key logic: extract text, then build a better prompt.

**Prompt Structure:**
<br>`Question: [User Question]?
Extracted Text: [OCR Text]
Answer:`
<br>This approach helps the VQA model handle dense, small text by providing the text directly as context.

In [None]:
import numpy as np

def get_ocr_text(image):
    """Extracts all text from the PIL Image using EasyOCR and returns a single string."""
    # Convert PIL Image to numpy array, as EasyOCR sometimes prefers it or has issues with direct PIL objects
    image_np = np.array(image)
    # EasyOCR returns a list of (bbox, text, confidence)
    results = ocr_reader.readtext(image_np)
    # Concatenate all text results into one string
    extracted_text = ' '.join([text for (bbox, text, conf) in results])
    return extracted_text

def hybrid_vqa(image, question):
    """Performs VQA using the image and an OCR-augmented prompt."""
    # 1. OCR Extraction
    ocr_output = get_ocr_text(image)
    # 2. Hybrid Prompt Construction
    prompt = f"Question: {question}? Extracted Text: {ocr_output} Answer:"
    # 3. Model Inference
    # Note: Using float16 for efficiency, relies on GPU being available
    inputs = processor(image, prompt, return_tensors="pt").to(device, torch.float16)
    # Generation configuration
    generated_ids = model.generate(**inputs, max_new_tokens=50, num_beams=5, early_stopping=True)
    # Decode the answer
    answer = processor.decode(generated_ids[0], skip_special_tokens=True)
    return answer, ocr_output

Download a VizWiz VQA subset of 100 items from Hugging Face

In [None]:
# @title Python Code to Load the First 100 VizWiz Entries (Hugging Face)
!pip install -q datasets pandas
from datasets import load_dataset
import pandas as pd
import requests
import os
from PIL import Image
from io import BytesIO

# Define the dataset name and split
HF_DATASET_NAME = "lmms-lab/VizWiz-VQA" # Example dataset ID, actual IDs may vary
SPLIT_NAME = "val"
SUBSET_SIZE = 100

print(f"Loading first {SUBSET_SIZE} metadata entries from {HF_DATASET_NAME}...")

try:
    # Load only the first 100 items of the 'train' split
    # The 'datasets' library handles efficient loading of metadata/links
    full_dataset_split = load_dataset(HF_DATASET_NAME, split=f"{SPLIT_NAME}[:{SUBSET_SIZE}]")

    print(f"Successfully loaded {len(full_dataset_split)} metadata entries.")

    # Convert to pandas DataFrame
    df_subset = pd.DataFrame(full_dataset_split)

    # Save the metadata to a CSV file in Colab
    output_filename_csv = "vizwiz_metadata_subset_100.csv"
    df_subset.to_csv(output_filename_csv, index=False)
    print(f"Metadata saved to '{output_filename_csv}'.")

except Exception as e:
    print(f"\nAn error occurred: {e}")
    print("Please check the exact dataset ID on Hugging Face as variations exist.")


## Load small VizWiz Dataset Samples
To speed up testing and avoid downloading the entire VizWiz validation set, we will use the `streaming=True` feature and then take a small random sample of the first few hundred entries. This is much faster than downloading the whole split.

In [None]:
import random
from datasets import load_dataset
NUM_SAMPLES = 10
MAX_INITIAL_FETCH = 500
# Fetch up to 500 initial samples to choose 10 random ones
random.seed(42)
# for reproducibility

print(f"Downloading a small, random sample of {NUM_SAMPLES} images from VizWiz...")
# 1. Load the validation split in streaming mode
#   vizwiz_stream = load_dataset("HuggingFaceM4/VizWiz", split='validation', streaming=True)
vizwiz_stream = load_dataset("lmms-lab/VizWiz-Caps", split='val', streaming=True)
# 2. Take the first MAX_INITIAL_FETCH samples and convert to a list
initial_samples = []
for i, sample in enumerate(vizwiz_stream):
    if i >= MAX_INITIAL_FETCH:
        break
    initial_samples.append(sample)
# 3. Randomly select the desired number of samples from the fetched list
if len(initial_samples) < NUM_SAMPLES:
    # Fallback if the initial fetch was too small (unlikely for VizWiz)
    test_samples = initial_samples
    print(f"Warning: Only found {len(initial_samples)} samples.")
else:
    test_samples = random.sample(initial_samples, NUM_SAMPLES)
    print(f"Successfully loaded and sampled {len(test_samples)} examples from the streamed data.")

## Automated Testing Loop
This loop runs the `hybrid_vqa` function on the 10 selected VizWiz samples and prints the results, showing the predicted answer against the ground truth answer (the most frequent answer provided by human annotators).

In [None]:
#!pip install tqdm
from tqdm import tqdm
print("\n" + "*"*60)
print(f"STARTING VQA TEST ON {len(test_samples)} VIZWIZ SAMPLES")
print("*"*60)
# Inspect the keys
print(sample.keys())
for i, sample in tqdm(enumerate(test_samples), total=len(test_samples)):
    # Get the image (PIL format) and question
    image = sample['image'].convert('RGB')
    #question = sample['question']
    question = sample['caption']
    # The ground truth answer is the most frequent answer from the human annotators
    #ground_truth = sample['answers'][0]['answer']
    # The VizWiz-Caps dataset provides 'caption' as the primary textual annotation.
    # Since there isn't a direct 'answer' field, we'll use the caption as a reference
    # for the ground truth, which also resolves the TypeError.
    ground_truth = sample['caption']
    # Takes the first, most common answer
    try:
        # Run the Hybrid VQA Pipeline
        predicted_answer, ocr_output = hybrid_vqa(image, question)
        # Print Results
        print("\n" + "-"*50)
        print(f"TEST {i+1}:")
        print(f"  Question: {question}")
        print(f"  Ground Truth: {ground_truth}")
        print(f"  OCR Output (snippet): {ocr_output[:70]}...")
        print(f"  PREDICTED: {predicted_answer}")
        print("-"*50)
    except Exception as e:
        print(f"An error occurred during VQA for sample {i+1}: {e}")
        continue
print("\n" + "*"*60)
print("TESTING COMPLETE")
print("*"*60)

# Conclusion
###This notebook demonstrates how to build a specialized VQA system for label reading by incorporating text extraction (OCR) into the prompt construction. This hybrid approach is key to achieving high accuracy on detail-oriented tasks like reading ingredient lists or nutrition facts.
###The automated testing loop allows you to quickly evaluate the model's performance on real-world data from the VizWiz dataset.