In [None]:
'''
Get api key: secret(left side bar top 4) -> Gemini API keys (unavailable in HK, use vpn) as env: GOOGLE_API_KEY
'''
!curl https://ipinfo.io/

In [None]:
%%capture
%%bash
#install packages

#gemini
pip install llama-index-multi-modal-llms-gemini
pip install llama-index-vector-stores-qdrant
pip install llama-index-embeddings-gemini
pip install llama-index-llms-gemini

pip install llama-index 'google-generativeai>=0.3.0' matplotlib qdrant_client

#hugging face
pip install -U datasets datasets[vision]

#image processing
pip install Pillow #==11.2.1

#colbert: tpu/gpu
pip install gitpython
pip install faiss-cpu

pip install pandas gdown shortuuid

In [None]:
%%capture

'''
setup colbert / plaidrepro
'''

!rm -rf ./ColBERT ./plaidrepro
!git -C ColBERT/ pull || git clone https://github.com/stanford-futuredata/ColBERT.git

#some error using plaidrepro
#!git -C ColBERT/ pull || git clone https://github.com/seanmacavaney/plaidrepro.git

#!mv ./plaidrepro ./ColBERT
import sys; sys.path.insert(0, 'ColBERT/')

try: # When on google Colab, let's install all dependencies with pip.
    import google.colab
    !pip install -U pip
    !pip install -e ColBERT/['faiss-gpu','torch']
except Exception:
  import sys; sys.path.insert(0, 'ColBERT/')
  try:
    from colbert import Indexer, Searcher
  except Exception:
    print("If you're running outside Colab, please make sure you install ColBERT in conda following the instructions in our README. You can also install (as above) with pip but it may install slower or less stable faiss or torch dependencies. Conda is recommended.")
    assert False

!pip install bitarray datasets gitpython ninja scipy spacy tqdm transformers ujson flask python-dotenv

In [None]:
'''
setup api keys
'''

import os
from google.colab import userdata
GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')
os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY


from llama_index.llms.gemini import Gemini
from llama_index.core.llms import ChatMessage

gemini = Gemini(model_name="models/gemma-3-27b-it")
response = gemini.chat(messages=[
    ChatMessage(role="user", content="Hello! What's your name?")
])
print(response.message.content)

In [None]:
%%capture
'''
load benchmark data
'''

from datasets import load_dataset
mrag_bench = load_dataset("uclanlp/MRAG-Bench", split="test")

In [None]:
'''
Wrapper for Gemini, with rate limit

step-1, stage-1: LVM to describe images
Refernce: https://docs.llamaindex.ai/en/stable/examples/multi_modal/gemini/
'''



from PIL import Image
import PIL
from llama_index.llms.gemini import Gemini
from llama_index.core.llms import ChatMessage, ImageBlock
from google.genai.types import HttpOptions

class LVM():
  def __init__(self, model_name, rate_limit=10):
    self._llm = Gemini(model_name=model_name)
    self._rate_limit = rate_limit   # Stores the minimum time (in seconds) to wait between API calls

  # ------------------ Private Methods ------------------

  def _inference(self, messages):
      import time
      res=None
      fail_count=0
      while True:
        try:
          time.sleep(self._rate_limit + 10*fail_count)
          return self._llm.chat(messages=messages).message.content
        except Exception as e:
          fail_count += 1
          print(f'fail[{fail_count}]: {e}')

  def _attatch_multiple_images(self, images:list[ImageBlock], message: ChatMessage):
    for img in images:
      message.blocks.append(img)
    return message

  def _attatch(self, images:list[ImageBlock], messages: list[ChatMessage]):
    for img, msg in zip(images, messages):
      msg.blocks.append(img)
    return messages

  def _dataset_image_to_imageblock(self, images):
    from io import BytesIO
    from llama_index.core.llms import ImageBlock

    imgs = []
    for img in images:
      buf = BytesIO()
      img.save(buf, format="PNG")
      imgs.append(ImageBlock(image=buf.getvalue(), image_mimetype="image/png"))

    return imgs

  # ------------------ Public Methods ------------------

  def images_to_text(self, images: list[PIL.Image],
      prompt="What is inside the pictures? Give short keywords description in 15 words, separated by comma. Response directly, no need to show you understand this prompt."
  ):
    from tqdm import tqdm

    images=self._dataset_image_to_imageblock(images)

    messages = [ChatMessage(prompt) for _ in range(len(images))]
    messages = self._attatch(images,messages)

    return [
      self._inference(messages=[msg])
      for msg in tqdm(messages)
    ]

  def multi_images_chat(self, images: list[PIL.Image], prompt: str):
    images=self._dataset_image_to_imageblock(images)

    message = ChatMessage(prompt)
    message = self._attatch_multiple_images(images, message)
    return self._inference(messages=[message])


#-------------------------------[example usage]-----------------------------
#-------------------------------<images_to_text>---------------------------
# For model selection reference: https://ai.google.dev/gemini-api/docs/rate-limits#free-tier
#lvm = LVM(model_name="models/gemini-1.5-flash")
lvm = LVM(model_name="models/gemini-2.0-flash-lite", rate_limit=5)

images = mrag_bench[-2:]['image']
messages = lvm.images_to_text(images=images)

display(messages)
print(f'length: {len(messages)}')

#-------------------------------[example usage]-----------------------------
#-------------------------------<multi_images_chat>---------------------------
images = mrag_bench[-2:]['image']
for img in images:
  display(img)
message = lvm.multi_images_chat(images=images,
  prompt="What is inside each photo? What are the connections between all photos? Answer in short."
)

print(message)

In [None]:
'''
for storing/loading (image_filename, description) pairs in json format
'''
class ImageDescriptionStorage():
  def __init__(self, path='./data.json', data=[]):
    self._path = path
    self._data = data

  def load(self):
    import json

    try:
      with open(self._path, 'r') as f:
        self._data = json.load(f)
    except FileNotFoundError:
      self._data = [] # Initialize with empty list if file not found
    return self._data

  def append(self, dict_obj):
    self._data.append(dict_obj)

  def save(self):
    import json

    with open(self._path, 'w') as f:
      json.dump(self._data, f, indent=4)

  def get_data(self):
    return self._data

#------------------------[Example usage]---------------------------
test_data=[
    {
        'filename': 'cat.jpg',
        'description': 'A cat is dancing.'
    },
    {
        'filename': 'dog.jpg',
        'description': 'A dog is swimming.'
    }
]

image_description_storage = ImageDescriptionStorage(path='./test_images_description.json')
for d in test_data:
  image_description_storage.append(d)
image_description_storage.save()

retrieve=image_description_storage.load()
display(retrieve)

In [None]:
%%bash
#Download the pre-trained ColBERTv2 checkpoint

rm -rf colbertv
mkdir -p colbertv
cd colbertv
#wget -qq https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/colbertv2.0.tar.gz -O colbertv2.0.tar.gz
wget -qq -O- https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/colbertv2.0.tar.gz | tar xvz


In [None]:
'''
step-1b: RAG system: Plaid
'''

class RAG():
  '''
    root_dir: for Colbertv working dir
    data_path: for storing input data array, which is missing from Colbertv
  '''
  def __init__(self, root_dir='./experiments', storage_path='./rag_data.pt'):
    import os

    self._root_dir = root_dir
    self._storage_path = storage_path

    if os.path.exists(self._storage_path):
        print(f"Loading data from {self._storage_path}")
        self._load_storage()
    else:
        print(f"Data file {self._storage_path} not found. Starting with empty data.")
        self._storage = {
            #image descriptions
            'data': [],
            #filenames
            'annotation': []
        }

  def _save_storage(self):
    import json
    with open(self._storage_path, 'w') as f:
      json.dump(self._storage, f, indent=4)

  def _load_storage(self):
    import json
    with open(self._storage_path, 'r') as f:
      self._storage = json.load(f)

  '''
    Warning:
      data list should be at least 100 to make it work, otherwise it stucks
  '''
  def index(self, data: list[str], annotation: list[str], kmeans_niters=4,
    checkpoint_dir="./colbertv/colbertv2.0"
  ):
    from colbert.infra import Run, RunConfig, ColBERTConfig
    from colbert.data import Queries, Collection
    from colbert import Indexer, Searcher

    with Run().context(RunConfig(nranks=1, experiment="msmarco")):
        config = ColBERTConfig(
            nbits=2,
            root=self._root_dir,
            kmeans_niters=kmeans_niters,
            avoid_fork_if_possible=True
        )
        indexer = Indexer(checkpoint=checkpoint_dir, config=config)
        indexer.index(name="msmarco.nbits.2", collection=data, overwrite=True)

    self._storage['data'] = data
    self._storage['annotation'] = annotation
    self._save_storage()

    print('overwritten storage file!')

  def query(self, query: str, top_k=5):
    from colbert.infra import Run, RunConfig, ColBERTConfig
    from colbert.data import Queries, Collection
    from colbert import Indexer, Searcher

    with Run().context(RunConfig(nranks=1, experiment="msmarco")):
      config = ColBERTConfig(
          root=self._root_dir,
      )
      searcher = Searcher(index="msmarco.nbits.2", config=config)
      ranking = searcher.search(query, k=top_k)

    return {
      'index': ranking[0],
      'data': [ self._storage['data'][passage_id] for passage_id in ranking[0] ],
      'annotation': [ self._storage['annotation'][passage_id] for passage_id in ranking[0] ]
    }

#---------------------------[Example usage]----------------------------------

test_data=[
  f'Person_{i} has {i+3} cats!\n'
    for i in range(100)
]
test_annotation=[
  f'image_{i}.png'
    for i in range(len(test_data))
]

print('--------------------------use new rag------------------------------------')

!rm -f ./test_rag_data.json

rag = RAG(
  storage_path='./test_rag_data.json'
)

query='Who has 11 cats?'

rag.index(data=test_data, annotation=test_annotation)
answer = rag.query(query=query)
print(f'query: {query}')
print(f'top-k answer: {answer}')
print(f'best answer: {answer["data"][0]} @{answer["annotation"][0]}')

print('--------------------------use old rag------------------------------------')
rag = RAG(
  storage_path='./test_rag_data.json'
)

query='Who has 3 cats?'

answer = rag.query(query=query)
print(f'query: {query}')
print(f'top-k answer: {answer}')
print(f'best answer: {answer["data"][0]} @{answer["annotation"][0]}')

In [None]:
########################[BELOW TODO]################################

In [None]:
%%capture
%%bash
#download the dataset to feed into database of RAG system
rm -rf ./image_corpus ./mrag_bench_image_corpus

gdown 1atwkNXH3aEtCLuqimZoB1Mifj5CwL3CL
unzip mrag_bench_image_corpus

In [None]:
%%bash

#get small samples from image_corpus during developing/debug
#may want to use all image when deploy

rm -rf test_images_database
mkdir test_images_database
rm -f test_images_description.json test_rag_data.json

(cd image_corpus && ls | grep -v input | shuf --random-source=<(yes 'i like cat') | head -n 10 | xargs -I{} cp '{}' ../test_images_database/)
ls test_images_database

In [None]:
'''
Merge Step-1:

Input: folders of images
Output: None
Side Effect:
  1. RAG stored the indexed descriptsion.
  2. rag_storage_path has the filename, description pairs

Steps:
1. append_to_description_file: use LVM to make descitpion and output to .json file
2. rag_indexing: use RAG to index the .json file
'''

class Pipeline_GenerateDescriptions():
  '''
  description_storage_path: .json file for output description
  rag_storage_path: path for RAG database
  progress_path: for resume the pipeline
  '''
  def __init__(self, description_storage_path, rag_storage_path,
    model_name="models/gemma-3-27b-it",
    rate_limit=5
  ):
    self._lvm = LVM(model_name=model_name, rate_limit=rate_limit)
    self._image_description_storage = ImageDescriptionStorage(path=description_storage_path)
    self._rag = RAG(storage_path=rag_storage_path)

    self._image_description_storage.load()

  def append_to_description_file(self, image_dir: str, batch_size=100):
    from PIL import Image
    import os
    from tqdm import tqdm

    all_filenames = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]
    all_filenames.sort() # Sorting ensures consistent ordering across runs

    # --- Resume Logic ---
    # Get the set of filenames already processed
    processed_filenames = {d['filename'] for d in self._image_description_storage.get_data()}

    # Filter out already processed filenames
    filenames_to_process = [f for f in all_filenames if f not in processed_filenames]
    # --- End Resume Logic ---

    print(f"Found {len(all_filenames)} images in total.")
    print(f"Skipping {len(processed_filenames)} already processed images.")
    print(f"Processing {len(filenames_to_process)} new images.")

    num_to_process = len(filenames_to_process)
    num_processed_this_run = 0

    while num_processed_this_run < num_to_process:
        start_index = num_processed_this_run
        end_index = min(num_processed_this_run + batch_size, num_to_process)
        batch_filenames = filenames_to_process[start_index:end_index]

        images = []
        for filename in tqdm(batch_filenames, desc=f"Loading images for batch {num_processed_this_run // batch_size + 1}"):
            img_path = os.path.join(image_dir, filename)
            try:
                img = Image.open(img_path)
                images.append(img)
            except Exception as e:
                print(f"Error opening image {img_path}: {e}")

        # Check if any images were successfully loaded in the batch
        if not images:
            print(f"No images successfully loaded in batch {num_processed_this_run // batch_size + 1}. Skipping VLM call.")
            # Important: Still need to advance the counter to avoid infinite loop if all images in batch fail
            num_processed_this_run += len(batch_filenames)
            continue # Skip to the next batch

        print(f"Processing batch {num_processed_this_run // batch_size + 1} with {len(images)} images.")
        descriptions = self._lvm.images_to_text(images)

        for filename, description in zip(batch_filenames, descriptions):
            self._image_description_storage.append({
                'filename': filename,
                'description': description
            })
        self._image_description_storage.save()

        num_processed_this_run += len(batch_filenames)


  def rag_indexing(self):
    data = self._image_description_storage.load()
    self._rag.index(
      data=[d['description'] for d in data],
      annotation=[d['filename'] for d in data]
    )


#---------------------------[Example usage]----------------------------------

pipeline = Pipeline_GenerateDescriptions(
    description_storage_path='./test_images_description.json',
    rag_storage_path='./test_rag_data.json'
)

pipeline.append_to_description_file('./test_images_database')
#get images_description.json file completed

pipeline.rag_indexing()
#rag is read for use

In [None]:
'''
Input: 1 image
Output: n-image similar to this image
'''

import PIL

class Pipeline_RetrieveImages():
  def __init__(self, rag_storage_path, image_dir):
    self._lvm = LVM(model_name="models/gemma-3-27b-it", rate_limit=5)
    self._rag = RAG(storage_path=rag_storage_path)
    self._image_dir = image_dir

  def _rag_retrieve(self, image: PIL.Image, n=5):
    description = self._lvm.images_to_text([image])[0]
    answer = self._rag.query(query=description, top_k=n)
    return {
      'input_description': description,
      'filenames': answer['annotation'],
      'descriptions': answer['data']
    }

  def image_retrieve(self, image: PIL.Image, n=5):
    import os
    from PIL import Image

    rag_retrieved = self._rag_retrieve(image, n=n)
    images = [
      Image.open(os.path.join(self._image_dir, fname))
        for fname in rag_retrieved['filenames']
    ]
    return {
      'input_description': rag_retrieved['input_description'],
      'images': images,
      'descriptions': rag_retrieved['descriptions']
    }

#-------------------------------[Example usage]--------------------------------
pipeline = Pipeline_RetrieveImages(
  rag_storage_path='./test_rag_data.json',
  image_dir='./image_corpus'
)
rag_retrieved = pipeline.image_retrieve(mrag_bench[0]['image'], n=2)

print('input image is')
display(mrag_bench[0]['image'])
print(rag_retrieved['input_description'])

print('retrived image is')
for img, desc in zip(rag_retrieved['images'], rag_retrieved['descriptions']):
  display(img)
  print(desc)


In [None]:
'''
Given a question, an image, do

1. get relevant images from RAG
2. feed all images (input & from RAG) to LVM
'''
import PIL

class Pipeline_GenerateAnswers():
  def __init__(self, rag_storage_path, image_dir):
    self._lvm = LVM(model_name="models/gemma-3-27b-it", rate_limit=5)
    self._pipeline_retrieve_image = Pipeline_RetrieveImages(
      rag_storage_path=rag_storage_path,
      image_dir=image_dir
    )

  def _get_answer(self, prompt: str, image: PIL.Image, n_retrieve_img=5):
    rag_retrieved = self._pipeline_retrieve_image.image_retrieve(
        image,
        n=n_retrieve_img
    )
    return {
      'rag_retrived': rag_retrieved,
      'prompt': prompt,
      'chat_result': self._lvm.multi_images_chat(
          prompt=prompt,
          images=[image]+rag_retrieved['images']
      )
    }

  def get_answer_from_dataset(self, data_point, n_retrieve_img=5,
    format_prompt='Strictly give your answer in format `<A/B/C/D> <Reason>`'
  ):
    question = data_point['question']
    image = data_point['image']
    choice_strg = f"{format_prompt}\nChoices are:\nA:{data_point['A']}\nB:{data_point['B']}\nC:{data_point['C']}\nD:{data_point['D']}"

    prompt = f"{question}\n{choice_strg}"

    return self._get_answer(prompt=prompt, image=image, n_retrieve_img=n_retrieve_img)

#-------------------------------[Example usage]--------------------------------

pipeline_answer = Pipeline_GenerateAnswers(
  rag_storage_path='./test_rag_data.json',
  image_dir='./image_corpus'
)

d = mrag_bench[0]
display(d)

lvm_ans = pipeline_answer.get_answer_from_dataset(data_point=d)
display(lvm_ans)

In [None]:
#-----------------------------------------[Main Code]---------------------------------

In [None]:
'''
use the offically provided images as database images

can be resumed from interrupte:
  put images_description.json at ./
  start the cell

rag_data.json is required to get back the image description when doing query
'''

pipeline_description = Pipeline_GenerateDescriptions(
    description_storage_path='./images_description.json',
    rag_storage_path='./rag_data.json',
    rate_limit=5
)

#generate description by calling LVM on each image files in this directory
pipeline_description.append_to_description_file('./image_corpus')

pipeline_description.rag_indexing()
#rag is read for use

In [None]:
'''
process the output from LVM to input format of evalting package.
'''

import shortuuid
import json

from tqdm import tqdm

pipeline_answer = Pipeline_GenerateAnswers(
  rag_storage_path='./rag_data.json',
  image_dir='./image_corpus'
)

answers=[]

#display(mrag_bench)

i=0
for d in tqdm(mrag_bench):
  i+=1
  if i >= 20:
    break

  lvm_ans = pipeline_answer.get_answer_from_dataset(data_point=d)

  answers.append({
    "qs_id": d['id'],
    "prompt": lvm_ans['prompt'],
    "output": lvm_ans['chat_result'],
    "gt_answer": d['answer'],
    "shortuuid": shortuuid.uuid(),
    "model_id": 'gemma-3-27b-it',
    "gt_choice": d['answer_choice'],
    "scenario": d['scenario'],
    "aspect": d['aspect'],
  })

  #display(d['image'])
  #for retrived_img, desc in zip(lvm_ans['rag_retrived']['images'],lvm_ans['rag_retrived']['descriptions']):
  #  display(retrived_img)
  #  print(desc)


with open('answer.json', 'w') as ans_file:
  json.dump(answers, ans_file, indent=4)
  ans_file.flush()


In [None]:
'''
Purge the answer output from LVM
Otherwise, evalting package encouter error, or required to call openai api

i.e.
 1. Convert output string to Pure A/B/C/D
 2. Use random choice if no answer(A/B/C/D) is found.
'''
def extract_choice_from_output(original_file, extracted_file):
  import json
  from tqdm import tqdm
  import random

  # Read data from answer.json
  with open(original_file, 'r') as f:
      answers = json.load(f)

  # Process each entry
  for entry in tqdm(answers):
      if 'output' in entry and entry['output']: # Check if 'output' exists and is not empty
          original_output = entry['output'].upper()
          found_char = None
          for char in original_output:
              if 'A' <= char <= 'Z': # Check if the character is an uppercase letter
                  found_char = char
                  break # Stop searching once the first uppercase letter is found

          if found_char:
              entry['output'] = found_char
          else:
              entry['output'] = random.choice(['A', 'B', 'C', 'D']) + ' <random>'

  # Save the modified data to a new file (or overwrite the original if you prefer)
  with open(extracted_file, 'w') as f:
      json.dump(answers, f, indent=4)

extract_choice_from_output(
  original_file='answer.json',
  extracted_file='answer_purged.json'
)


In [None]:
!git clone https://github.com/mragbench/MRAG-Bench.git

In [None]:
#use evalting package
!python MRAG-Bench/eval/score.py -i "answer_purged.json"