In [1]:
import asyncio

import pandas as pd
import json
import os
import openai
import matplotlib.pyplot as plt
import base64
import io

from typing import Union, Dict, Optional, Any, List, Tuple, Sequence, Literal
from PIL import Image
from time import time

from llama_index import SimpleDirectoryReader

from llama_index.schema import ImageDocument
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
from llama_index.multi_modal_llms.base import ChatMessage

In [190]:
MAIN_DIR = os.path.dirname(os.getcwd())
DATA_DIR = os.path.join(MAIN_DIR, "db")
DATABASE_DIR = os.path.join(DATA_DIR, "database")
REFERENCE_DIR = os.path.join(DATA_DIR, "reference")
CONSUMER_DIR = os.path.join(DATA_DIR, "consumer")

with open(os.path.join(MAIN_DIR, "auth", "api_keys.json"), "r") as f:
    api_keys = json.load(f)

os.environ["OPENAI_API_KEY"] = api_keys["LEQUAN_OPENAI_KEY"]
openai.api_key = api_keys["LEQUAN_OPENAI_KEY"]

metadata_df = pd.read_csv(os.path.join(DATA_DIR, "exp_metadata.csv"))
master_metadata = json.load(open(os.path.join(MAIN_DIR, "data", "pill", "master_metadata.json")))

all_consumer_images = metadata_df["consumer_1"].tolist() + metadata_df["consumer_2"].tolist() + metadata_df["consumer_3"].tolist() \
                        + metadata_df["consumer_4"].tolist() + metadata_df["consumer_5"].tolist()

all_reference_images = metadata_df["first_reference"].tolist() + metadata_df["second_reference"].tolist()

print("Total Number of consumer images:", len(all_consumer_images))

print("Total Number of consumer images:", len(all_consumer_images))
print("Total Number of reference images:", len(all_reference_images))

# with open(os.path.join(DATA_DIR, "consumer_list.txt"), "w") as f:
#     for image in all_consumer_images:
#         f.write(image + "\n")

# with open(os.path.join(DATA_DIR, "reference_list.txt"), "w") as f:
#     for image in all_reference_images:
#         f.write(image + "\n")

with open(os.path.join(DATA_DIR, "consumer_list.txt"), "r") as f:
    consumer_images = f.readlines()

with open(os.path.join(DATA_DIR, "reference_list.txt"), "r") as f:
    reference_images = f.readlines()

consumer_images = [image.strip() for image in consumer_images]
reference_images = [image.strip() for image in reference_images]

Total Number of consumer images: 10000


In [191]:
def generate_img_url(image_path: str, resize: Optional[Union[int, Tuple[int, int], Literal["auto"]]] = None):
    encoded_image = encode_image(image_path, resize=resize)
    image_url = f"data:image/jpeg;base64,{encoded_image}"
    return image_url

def resize_with_same_aspect_ratio(height: int, width: int, max_dimension: int = 1024) -> Tuple[int, int]:
    if height < max_dimension and width < max_dimension:
        resize = (height, width)
    elif height > width:
        resize = (512, int(width / (height/512)))
    else:
        resize = (int(height / (width/512)), 512)
    return resize

def encode_image(image_path: str, resize: Optional[Union[int, Tuple[int, int], Literal["auto"]]] = None):
    if resize:
        img = Image.open(image_path)
        h, w = img.size
        if isinstance(resize, int):
            resize = resize_with_same_aspect_ratio(h, w, resize)
        elif resize == "auto":
            resize = resize_with_same_aspect_ratio(h, w, 512)

        resized_img = img.resize(resize)

        # Save the resized image to a buffer
        buffer = io.BytesIO()
        resized_img.save(buffer, format="PNG")
        buffer.seek(0)

        # Encode the resized image to base64
        return base64.b64encode(buffer.getvalue()).decode("utf-8")

    else:
        with open(image_path, "rb") as image_file:
            img_64_str = base64.b64encode(image_file.read()).decode('utf-8')
        return img_64_str

def generate_openai_vision_llamaindex_chat_message(
    prompt: str,
    role: str,
    image_documents: Optional[Sequence[ImageDocument]] = None,
    image_detail: Optional[str] = "high",
    resize: Optional[Union[int, Tuple[int, int], Literal["auto"]]] = None
) -> ChatMessage:
    # if image_documents is empty, return text only chat message
    if image_documents is None:
        return ChatMessage(role=role, content=prompt)

    # if image_documents is not empty, return text with images chat message
    completion_content = [{"type": "text", "text": prompt}]
    for image_document in image_documents:
        image_content: Dict[str, Any] = {}
        mimetype = image_document.image_mimetype or "image/jpeg"
        if image_document.image_path and image_document.image_path != "":
            base64_image = encode_image(image_document.image_path, resize=resize)
            image_content = {
                "type": "image_url",
                "image_url": {
                    "url": f"data:{mimetype};base64,{base64_image}",
                    "detail": image_detail,
                },
            }
        elif (
            "file_path" in image_document.metadata
            and image_document.metadata["file_path"] != ""
        ):
            base64_image = encode_image(image_document.metadata["file_path"], resize=resize)
            image_content = {
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/jpeg;base64,{base64_image}",
                    "detail": image_detail,
                },
            }

        completion_content.append(image_content)
    return ChatMessage(role=role, content=completion_content)

def plot_images(images: List[Union[str, ImageDocument]]):
    image_paths = []
    for image in images:
        if isinstance(image, ImageDocument):
            image_paths.append(image.metadata["file_path"])
        elif isinstance(image, str):
            image_paths.append(image)
        else:
            ValueError("Invalid type of image")

    images_shown = 0
    plt.figure(figsize=(8, 8 * len(image_paths)))
    for img_path in image_paths:
        image = Image.open(img_path)

        plt.subplot(1, len(image_paths), images_shown + 1)
        plt.imshow(image)
        plt.xticks([])
        plt.yticks([])

        images_shown += 1

In [None]:
sample_images = [
    "PillProjectDisc37/images/BK0UDTM343KHX1I-ANSI_RI-DR_U_6I.JPG",
    "PillProjectDisc110/images/_CNSLF343EJTA_F25S1SJG9-EIKPQ7.JPG",
]

sample_paths = [os.path.join(REFERENCE_DIR, image_file) for image_file in sample_images]
sample_docs = SimpleDirectoryReader(input_files = sample_paths).load_data()
sample_docs

# GPT-4V LLM

## Zero-shot prompt

In [116]:
CAPTION_EXTRACTION_PROMPT = (
    "You are given an image of a pill. Extract the relevant visual features from the given image,"
    "including shape, colors, imprint, imprint color, symbols, if any and whether the pill can be broken down into smaller parts.\n"
    "=========\n"
    "OUTPUT INSTRUCTIONS: Each line of the final output should contain description about an individual visual feature of the pill. If the visual characteristics cannot be inferred from the image, return N/A for that features\n"
    "EXAMPLE:\n"
    "Image features of the pill:\n"
    "- Shape: CAPSULE - Has the shape of a capsule when viewing the capsule from a point perpendicular to its longest side. The capsule shape is reserved for two-part capsules and banded two-part capsules.\n"
    "- Color: WHITE\n"
    "- Imprint: Lilly;3227;10;mg\n"
    "- ImprintColor: BLACK\n"
    "- ImprintType: PRINTED - Imprint is printed onto the surface of the pill\n"
    "- Score: Pill is not scored to break into smaller dosage parts\n"
    "- Other visual features: 4-leave flower"
)

In [114]:
lmm = OpenAIMultiModal(model= "gpt-4-vision-preview", temperature=0, max_new_tokens=256, image_detail="high", max_retries=3)

In [None]:
response = lmm.complete(prompt = CAPTION_EXTRACTION_PROMPT, image_documents = [sample_docs[0]])

print(response.text)
print()
print(master_metadata[sample_images[0]]["metadata"])
plot_images([sample_docs[0]])

In [None]:
response = lmm.complete(prompt = CAPTION_EXTRACTION_PROMPT, image_documents = [sample_docs[1]])

print(response.text)
print()
print(master_metadata[sample_images[1]]["metadata"])
plot_images([sample_docs[1]])

## Few-shot prompt

In [159]:
lmm = OpenAIMultiModal(model= "gpt-4-vision-preview", temperature=0, max_new_tokens=256, image_detail="high", max_retries=3)

In [166]:
SYSTEM_PROMPT = (
    "You are given an image of a pill. Extract the relevant visual features from the given image,"
    "including shape, colors, imprint, imprint color, symbols, if any and whether the pill can be broken down into smaller parts.\n"
    "=====\n"
    "OUTPUT INSTRUCTIONS: Each line of the final output should contain description about an individual visual feature of the pill. If the visual characteristics cannot be inferred from the image, return N/A for that features\n"
    )

EXAMPLES = [
     {
         "image_file": "PillProjectDisc37/images/BK0UDTM343KHX1I-ANSI_RI-DR_U_6I.JPG",
         "image_content": (
             "Image features of the pill:\n"
             "- Shape: ROUND - The pill has a circular shape when viewed from above\n"
             "- Color: WHITE\n"
             "- Imprint: N/A\n"
             "- Imprint Color: N/A\n"
             "- Imprint Type: N/A\n"
             "- Score: N/A - There is no visible score line indicating the pill can be broken into smaller parts.\n"
             "- Other visual symbols: N/A\n"
         )
     },
     {
         "image_file": "PillProjectDisc110/images/_CNSLF343EJTA_F25S1SJG9-EIKPQ7.JPG",
         "image_content": (
             "Image features of the pill:\n"
             "- Shape: CAPSULE - Has the shape of a capsule when viewing the capsule from a point perpendicular to its longest side."
             "- Color: ORANGE"
             "- Imprint: barr;936"
             "- Imprint Color: BROWN"
             "- Imprint Type: PRINTED - Imprint is printed onto the surface of the pill"
             "- Score: Pill is not scored to break into smaller dosage parts"
             "- Other visual symbols: N/A"
         )
     }
 ]

EXAMPLE_PROMPT = "=====\nEXAMPLE:\n"
QUERY_PROMPT = "=====\nReturn the visual features of this pill"
# EXAMPLES = []

In [167]:
## Few-shot prompt
def generate_llamaindex_chat_messages(
    query_file_path: str,
    system_prompt: str = SYSTEM_PROMPT,
    examples: Optional[List[Dict]] = None,
    image_detail: Literal["high", "low"] = "high",
    resize: Optional[Union[int, Tuple[int, int], Literal["auto"]]] = None
):

    message_dicts = [
        {"prompt": system_prompt, "role": "system", "image_documents": None}
        ]

    for example in examples:
        image_file, image_content = example["image_file"], example["image_content"]
        sample_paths = [os.path.join(REFERENCE_DIR, image_file)]
        example_images = SimpleDirectoryReader(input_files = sample_paths).load_data()
        example_prompt = EXAMPLE_PROMPT + image_content + "\n"
        message_dicts.append({"prompt": example_prompt, "role": "system", "image_documents": example_images})

    query_images = SimpleDirectoryReader(input_files = [query_file_path]).load_data()

    message_dicts.append({"prompt": QUERY_PROMPT, "role": "user", "image_documents": query_images})

    message_list = [
        generate_openai_vision_llamaindex_chat_message(
            prompt=message_dict["prompt"],
            role=message_dict["role"],
            image_documents=message_dict["image_documents"],
            image_detail=image_detail,
            resize=resize
            )
        for message_dict in message_dicts
    ]

    return message_list

### Sequential

In [208]:
resize = 1024
image_detail = "high"

start = time()

chat_responses = []

for image in consumer_images[:10]:
    try:
        query_path = os.path.join(CONSUMER_DIR, image)
        print(query_path)
        message_list = generate_llamaindex_chat_messages(
            query_path, system_prompt = SYSTEM_PROMPT, examples = None,
            image_detail = image_detail, resize = resize
        )
        chat_response = lmm.chat(message_list)
        chat_responses.append(chat_response)
    except Exception as e:
        print(f"Image {image} gets exception {e}")
        chat_responses.append(None)

print(time() - start)

/content/drive/MyDrive/LLM/data/pill/consumer/PillProjectDisc21/images/B5T5HI5XI8X2HSBJL-TGDHET4YG5C5F.JPG
/content/drive/MyDrive/LLM/data/pill/consumer/PillProjectDisc92/images/K6QQET!08476TI3778Y4EJ3ILMZ!CR.JPG
/content/drive/MyDrive/LLM/data/pill/consumer/PillProjectDisc35/images/BINZI-PDPZ36VT4NFBE_5CNNF6N3676.JPG
/content/drive/MyDrive/LLM/data/pill/consumer/PillProjectDisc106/images/WGK1UE6H7T-EARPNYO4LB09!T_VF9_.JPG
/content/drive/MyDrive/LLM/data/pill/consumer/PillProjectDisc2/images/-CYJI2!27!DZ9MOE83-5K-54DDSKRY.JPG
/content/drive/MyDrive/LLM/data/pill/consumer/PillProjectDisc40/images/BNH0_FQ-T4G9AL_H45VENPU7LK-!GH_.JPG
/content/drive/MyDrive/LLM/data/pill/consumer/PillProjectDisc17/images/B1ZEV_ZLK!76VL9N6!Y3D2G70NOECMX.JPG
Image PillProjectDisc17/images/B1ZEV_ZLK!76VL9N6!Y3D2G70NOECMX.JPG gets exception Error code: 400 - {'error': {'message': 'Your input image may contain content that is not allowed by our safety system.', 'type': 'invalid_request_error', 'param': None, 'c

In [196]:
for image, chat_response in zip(consumer_images[:10], chat_responses):
    if chat_response:
        print(image)
        print(chat_response.message.content)
        plot_images([os.path.join(CONSUMER_DIR, image)])
        print(master_metadata[image]["metadata"])
        print()

### Asynchronous

In [210]:
import nest_asyncio
nest_asyncio.apply()

resize = 1024
image_detail = "high"

start = time()

chat_responses = []

for image in consumer_images[:10]:
    try:
        query_path = os.path.join(CONSUMER_DIR, image)
        print(query_path)
        message_list = generate_llamaindex_chat_messages(
            query_path, system_prompt = SYSTEM_PROMPT, examples = None,
            image_detail = image_detail, resize = resize
        )
        chat_response = lmm.achat(message_list)
        chat_responses.append(chat_response)
    except Exception as e:
        print(f"Image {image} gets exception {e}")
        chat_responses.append(None)

print(time() - start)

/content/drive/MyDrive/LLM/data/pill/consumer/PillProjectDisc21/images/B5T5HI5XI8X2HSBJL-TGDHET4YG5C5F.JPG
/content/drive/MyDrive/LLM/data/pill/consumer/PillProjectDisc92/images/K6QQET!08476TI3778Y4EJ3ILMZ!CR.JPG
/content/drive/MyDrive/LLM/data/pill/consumer/PillProjectDisc35/images/BINZI-PDPZ36VT4NFBE_5CNNF6N3676.JPG
/content/drive/MyDrive/LLM/data/pill/consumer/PillProjectDisc106/images/WGK1UE6H7T-EARPNYO4LB09!T_VF9_.JPG
/content/drive/MyDrive/LLM/data/pill/consumer/PillProjectDisc2/images/-CYJI2!27!DZ9MOE83-5K-54DDSKRY.JPG
/content/drive/MyDrive/LLM/data/pill/consumer/PillProjectDisc40/images/BNH0_FQ-T4G9AL_H45VENPU7LK-!GH_.JPG
/content/drive/MyDrive/LLM/data/pill/consumer/PillProjectDisc17/images/B1ZEV_ZLK!76VL9N6!Y3D2G70NOECMX.JPG
/content/drive/MyDrive/LLM/data/pill/consumer/PillProjectDisc13/images/9_KKSUTJTJXL5UO4_-Z6F3H51RIRO4.JPG
/content/drive/MyDrive/LLM/data/pill/consumer/PillProjectDisc102/images/SRWU_F3Z088KVB4QZ4XMC_SMD5HJP0.JPG
/content/drive/MyDrive/LLM/data/pill/cons

In [211]:
start = time()
chat_responses = await asyncio.gather(*chat_responses, return_exceptions=True)

print(time() - start)

12.114030838012695


In [212]:
chat_responses

[ChatResponse(message=ChatMessage(role=<MessageRole.ASSISTANT: 'assistant'>, content='Image features of the pill:\n- Shape: CAPSULE - The pill has an elongated shape with rounded ends, typical of a capsule.\n- Color: GRAY\n- Imprint: N/A\n- Imprint Color: N/A\n- Imprint Type: N/A\n- Score: N/A - There is no visible score line indicating the pill can be broken into smaller parts.\n- Other visual symbols: N/A', additional_kwargs={}), raw={'id': 'chatcmpl-8mlkXDbM4jkdWEFRXBtfprnzoYoOH', 'choices': [Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Image features of the pill:\n- Shape: CAPSULE - The pill has an elongated shape with rounded ends, typical of a capsule.\n- Color: GRAY\n- Imprint: N/A\n- Imprint Color: N/A\n- Imprint Type: N/A\n- Score: N/A - There is no visible score line indicating the pill can be broken into smaller parts.\n- Other visual symbols: N/A', role='assistant', function_call=None, tool_calls=None))], 'created': 1706633985,