# What Do You See? Enhancing Zero-Shot Image Classification with Multimodal Large Language Models

Authors: Abdelrahman Abdelhamed, Mahmoud Afifi, Alec Go

Large language models (LLMs) has been effectively used for many computer vision tasks, including image classification. In this paper, we present a simple yet effective approach for zero-shot image classification using multimodal LLMs. By employing multimodal LLMs, we generate comprehensive textual representations from input images. These textual representations are then utilized to generate fixed-dimensional features in a cross-modal embedding space. Subsequently, these features are fused together to perform zero-shot classification using a linear classifier. Our method does not require prompt engineering for each dataset; instead, we use a single, straightforward, set of prompts across all datasets. We evaluated our method on several datasets, and our results demonstrate its remarkable effectiveness, surpassing benchmark accuracy on multiple datasets. On average over ten benchmarks, our method achieved an accuracy gain of 4.1 percentage points, with an increase of 6.8 percentage points on the ImageNet dataset, compared to prior methods. Our findings highlight the potential of multimodal LLMs to enhance computer vision tasks such as zero-shot image classification, offering a significant improvement over traditional methods.


https://arxiv.org/abs/2405.15668

In [2]:
!pip install -qU langchain langchain-aws langchain-community

[0m

In [3]:
AWS_REGION = "us-east-1"

In [196]:
import base64
import boto3
from botocore.config import Config
import io
import json
from tqdm.notebook import tqdm

# import tensorflow as tf
import numpy as np

from langchain_aws.chat_models import ChatBedrock
from langchain_aws.embeddings import BedrockEmbeddings
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser

from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseChatModel


Let's, first of all, create some abstractions that will make our implementation very similar to that of the original paper

In [210]:
def _encode_image(uri: str) -> str:
    """Get base64 string from image URI."""
    if isinstance(uri, io.BytesIO):
        uri.seek(0)
        return base64.b64encode(uri.read()).decode("utf-8")

    with open(uri, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")

class Encoder:
    embeddings: Embeddings
    dimensions: int
    
    def __init__(self, embeddings: Embeddings, dimensions: int = 1024):
        self.embeddings = embeddings
        self.dimensions = dimensions
    
    def encode_text(self, text) -> list[float]:
        return np.array(self.embeddings.embed_query(text))
    
    def encode_image(self, image) -> list[float]:
        b64_text = _encode_image(image)

        payload = {"inputImage": b64_text}
        body = json.dumps(payload)
        
        try:
            response = self.embeddings.client.invoke_model(
                body=body, modelId=self.embeddings.model_id, accept="application/json", contentType="application/json"
            )

            vector_json = json.loads(response["body"].read().decode("utf8"))

            return np.array(vector_json["embedding"])
        except Exception as e:
            raise ValueError(f"Error raised by inference endpoint: {e}")
    
    @property
    def output_feature_length(self,) -> int:
        return self.dimensions
        

class LLM:
    chat_model: BaseChatModel
    
    def __init__(self, chat_model):
        self.chat_model = chat_model
    
    def process(self, prompt: str | tuple[str, str | io.BytesIO], temperature: float = 0) -> str:
        self.chat_model.model_kwargs["temperature"] = temperature
        
        if isinstance(prompt, str):
            chain = self.chat_model | StrOutputParser()
            
            return chain.invoke(prompt)
        
        sys_prompt, image = prompt
        
        final_prompt = ChatPromptTemplate.from_messages([
            ("system", sys_prompt),
            MessagesPlaceholder(variable_name="input")
        ])
        
        chain = final_prompt | self.chat_model | StrOutputParser()
        
        return chain.invoke({"input": [
            HumanMessage(
                content=[
                    {
                        "type": "image",
                        "source": {"type": "base64", "media_type": "image/jpeg", "data": _encode_image(image)},
                    },
                ]
            ),
        ]})

In [211]:
boto_session = boto3.session.Session(region_name=AWS_REGION)

bedrock_client = boto_session.client(
    "bedrock-runtime",
    config=Config(
        connect_timeout=120,
        read_timeout=120,
        retries={
            "max_attempts": 10,
            "mode": "adaptive",
        },
    ),
)

In [212]:
claude = ChatBedrock(
    client=bedrock_client,
    model_id="anthropic.claude-3-sonnet-20240229-v1:0",
    model_kwargs={},
)

titan_embeddings = BedrockEmbeddings(client=bedrock_client, model_id="amazon.titan-embed-image-v1")

In [213]:
encoder = Encoder(titan_embeddings, 1024)
llm = LLM(claude)

# Now, on to the implementation of the paper

I strived to be as close as possible to the implementation on page 12 of the arxiv paper ("Code 1" listing). However, I replaced tensorflow with numpy here.

In [214]:
classification_p = "You are given an image and a list of class labels. Classify the image given the class labels. Answer using a single word if possible. Here are the class labels: {classes}"

description_p = "What do you see? Describe any object precisely, including its type or class."

class_ps = [
    "Describe what a {class_label} looks like in one or two sentences",
    "How can you identify a {class_label} in one or two sentences?",
    "What does a {class_label} look like? Respond with one or two sentences.",
    "Describe an image from the internet that you know of of a {class_label}. Respond with one or two sentences.",  # had to change from the original paper, o/w LLM would simply respond "Sorry, but you did not provide an image..."
    "A short caption of an image of a {class_label}"
]

In [120]:
def create_classifier(class_names, k: int = 5):
    assert k >= len(class_ps)
    assert k % len(class_ps) == 0

    weights = []
    
    for class_name in tqdm(class_names):
        class_name_feature = encoder.encode_text(class_name)
        template_feature = encoder.encode_text(f"A photo of {class_name}")
        
        llm_class_description = np.zeros(encoder.output_feature_length)
        for _ in (progress_bar := tqdm(range(k // len(class_ps)))):
            progress_bar.set_description(f"class: {class_name}")
            for class_p in class_ps:
                llm_class_feature = llm.process(class_p.format(class_label=class_name), temperature=0.99)
                print(llm_class_feature.split('\n', 1)[0])
                llm_class_description += encoder.encode_text(llm_class_feature)
        llm_class_description /= k
        
        class_feature = class_name_feature + template_feature + llm_class_description
        normalized_class_feature = class_feature / np.linalg.norm(class_feature)
        
        weights.append(np.squeeze(normalized_class_feature))
        
    model = {"weights": np.transpose(np.array(weights)), "class_names": class_names}

    return model

In [121]:
classifier = create_classifier(["plastic", "metal", "paper"])

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Plastic is a synthetic material made from a wide range of organic polymers that can be molded into various shapes and forms. It can have different colors, textures, and levels of transparency, ranging from completely clear to opaque, and can be rigid or flexible depending on its composition and intended use.
You can identify different types of plastics by looking at the resin identification code, which is a number from 1 to 7 inside a triangle symbol stamped or printed on the plastic item.
Plastic typically has a smooth, shiny surface and can come in various colors, textures, and transparencies depending on its composition and intended use.
I don't actually have access to specific images from the internet. As an AI assistant without the ability to browse the web, I can only analyze and describe images that are provided to me directly.
Here are some potential short captions for an image of plastic:


  0%|          | 0/1 [00:00<?, ?it/s]

A metal typically has a lustrous, shiny appearance that reflects light, giving it a distinctive metallic gleam. Metals can vary in color, ranging from the silvery hues of aluminum and steel to the warm yellow tones of gold or the reddish-brown of copper.
You can identify a metal by its characteristic properties such as a lustrous appearance, malleability (ability to be hammered into thin sheets), ductility (ability to be drawn into wires), and high electrical and thermal conductivity.
A metal typically has a shiny, lustrous appearance and a grayish color, though the exact shade can vary depending on the specific metal. It also has a high density and can conduct heat and electricity well.
Here is a one sentence description of an image of metal from the internet:
Here are some potential short captions for an image of metal:


  0%|          | 0/1 [00:00<?, ?it/s]

A paper is a thin, flat material made from compressed plant fibers, typically rectangular or square in shape, with a smooth surface suitable for writing or printing on.
To identify a paper in one or two sentences, you can state the title of the paper, the authors, and the publication it appeared in (journal, conference proceedings, etc.). For example: "The paper 'Title of the Paper' by Author1, Author2, and Author3 was published in the Journal of X Research in 2020."
A paper typically appears as a thin, flat sheet or rectangular material made from wood pulp or other fibrous materials, usually white or off-white in color.
Here is a one sentence description of a well-known image of a paper from the internet: The famous "This is a Wireframe" joke image depicts a crudely drawn hand holding up a piece of lined paper with the text "THIS IS A WIREFRAME" scribbled on it.
Here are some potential short captions for an image of a paper:


In [229]:
def classify(image, _classifier):
    image_feature = encoder.encode_image(image)
    image_feature /= np.linalg.norm(image_feature)
    
    initial_prediction = llm.process((classification_p.format(classes=["plastic", "metal", "paper"]), image), temperature=0.3)
    print(f"Initial prediction: {initial_prediction}")
    prediction_feature = encoder.encode_text(initial_prediction)
    prediction_feature /= np.linalg.norm(prediction_feature)
    
    image_description = llm.process((description_p, image), temperature=0)
    description_feature = encoder.encode_text(image_description)
    description_feature /= np.linalg.norm(description_feature)
    
    query_feature = image_feature + prediction_feature + description_feature
    query_feature /= np.linalg.norm(query_feature)
    # return query_feature
    likelihoods = np.matmul(query_feature, classifier["weights"])
    print(likelihoods)
    index = np.argmax(likelihoods)
    
    return classifier["class_names"][np.squeeze(index)]

In [226]:
cashews_plastic = "./1701029658.jpg"
kitkat_plastic = "./1709498901.jpg"

result = classify(kitkat_plastic, classifier)

result

Initial prediction: The image depicts a plastic bag containing a red-wrapped food item or snack. Based on the class labels provided - 'plastic', 'metal', and 'paper' - the appropriate classification for this image would be plastic, as the primary visible object is a plastic bag or packaging material.
[0.59115122 0.50143578 0.54155003]


'plastic'

## Benchmarking with our dataset

In [230]:
RED = "\033[31m"  # Red text
GREEN = "\033[32m"  # Green text
RESET = "\033[0m"  # Reset to default color

for image, expected_material, label in [
        ("1709498901.jpg", "plastic", "kitkat"),
        ("1709498394.jpg", "metal", "coke can"),
        ("1709500856.jpg", "plastic", "water bottle"),
        ("1709660526.jpg", "paper", "toddynho"),
        ("1701029658.jpg", "plastic", "cashews"),
        ("1701222618.jpg", "paper", "ben and jerry's"),
    ]:
    material_type = classify(f"./trash/{image}", classifier)
    
    match = material_type == expected_material
    
    print(
        f"{GREEN if match else RED}For {image} [{label}] (expected {expected_material}, got {material_type}){RESET}"
    )

Initial prediction: Based on the image and the provided class labels, the material depicted is plastic. The image shows a crumpled red plastic bag or wrapper hanging or suspended against a bluish background.
[0.53298613 0.44400442 0.51629422]
[32mFor 1709498901.jpg [kitkat] (expected plastic, got plastic[0m
Initial prediction: Based on the image and the given class labels, the appropriate classification for this image would be metal. The image shows a red aluminum beverage can protruding through a torn piece of paper or cardboard material, indicating the can is made of metal.
[0.51483863 0.53598519 0.54020967]
[31mFor 1709498394.jpg [coke can] (expected metal, got paper[0m
Initial prediction: Based on the image, the material that best fits the given class labels is plastic. The image shows a clear plastic bag or wrapping material containing what appears to be a blue plastic bottle or container.
[0.64935216 0.49125745 0.54324783]
[32mFor 1709500856.jpg [water bottle] (expected plas