---
title: "Classify images with Gemini Flash 1.5"
date: "10/08/2024"
date-modified: last-modified
description-meta: "Learn how to use In-Context Learning (ICL) to classify images using Gemini Flash 1.5"
toc: true
toc-depth: 3
lightbox: true
fig-cap-location: margin
categories:
  - mllms
  - in-context-learning
  - gemini
author:
  - name: Dylan Castillo
    url: https://dylancastillo.co
    affiliation: Iwana Labs
    affiliation-url: https://iwanalabs.com
    citation: true
    comments:
    utterances:
    repo: dylanjcastillo/blog_comments
    theme: dark-blue
    issue-term: pathname
draft: true
---

One overlooked aspect of Multimodal Large Language Models (MLLMs) is the ability to use In-Context Learning (ICL) to classify images. This is a technique that allows the model to learn from a small number of ground truth images provided at inference time and make predictions on previously unseen images.

This approach has been demonstrated to work quite well for image classification tasks in the literature (see [here](https://arxiv.org/abs/2405.09798) and [here](https://arxiv.org/abs/2403.07407)), and I've also had success with it in the past. While you're unlikely to achieve state-of-the-art results with it, it can often give you pretty good results with very little effort and data.

I recently worked on a project for a client that made use of this approach, so I thought it'd be fun to write a short tutorial about it.

## Prerequisites

To follow this tutorial you'll need to:

- Generate a key in [Google AI Studio](https://aistudio.google.com/app/apikey)
- Download [EuroSAT](https://github.com/phelber/EuroSAT)
- Create a virtual environment and install the requirements:

```bash
python -m venv venv
source venv/bin/activate
pip install pandas numpy scikit-learn nest-asyncio google-generativeai pillow
```


## Why Gemini Flash 1.5?

You can use any MLLM for this task, but I like Gemini Flash 1.5 because:

1. It's cheaper than [Gemini Pro 1.5](https://ai.google.dev/pricing), [GPT-4o](https://platform.openai.com/pricing), and [Sonnet 3.5](https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs). For an image of 512x512 pixels, Gemini Flash 1.5 is 66x cheaper than Gemini Pro 1.5, 32x cheaper than GPT-4o, and 52x cheaper than Sonnet 3.5[^longnote].
2. It lets you use up to 3,000 images per request. By trial and error, I found that GPT-4o seems to have a hard limit at 250 images per request and Sonnet 3.5's documentation mentions a limit of 20 images per request.
3. It works well enough for this task. If you really want to squeeze the last bit of performance out of your model, you can use a bigger model, but for the purposes of this tutorial, Gemini Flash 1.5 will do just fine.

Regardless of the model you choose, this tutorial will be a good starting point for you to classify images using ICL.

[^longnote]: Estimated costs as of August 10, 2024:

    | Model | Cost per 512x512 image |
    |-------|------------------------|
    | Gemini Flash 1.5 | $0.00002 |
    | Gemini Pro 1.5 | $0.000064 |
    | GPT-4o | $0.000638 |
    | Sonnet 3.5 | $0.001047 |

## Set up

In [None]:
#| output: false
#| echo: false
%load_ext dotenv
%dotenv
%load_ext autoreload
%autoreload 2

`nest_asyncio` makes it possible to run async code in Jupyter notebooks. You can enable it in your notebook by running the following cell:

In [None]:
import nest_asyncio

nest_asyncio.apply()

## Setting up the system

In [None]:
import asyncio
import base64
import json
import os
import warnings
from datetime import datetime

import google.generativeai as genai
import numpy as np
import pandas as pd
from IPython.display import display, HTML
from sklearn.metrics import accuracy_score, f1_score
from PIL import Image as PILImage
from pathlib import Path

warnings.filterwarnings("ignore")

np.random.seed(42)

In [None]:
genai.configure(api_key=os.environ["GEMINI_API_KEY"])

## Setting up dataframe

In [None]:
def create_datasets(data_dir, n_train=5, n_test=10):
    train_data = []
    test_data = []
    
    for class_id, class_name in enumerate(sorted(os.listdir(data_dir))):
        class_dir = Path(data_dir) / class_name
        if not class_dir.is_dir():
            continue
        
        image_files = sorted(class_dir.glob('*.jpg'))
        
        # Train dataset: first 10 images
        for img_path in image_files[:n_train]:
            class_letter = chr(64 + class_id)  # A, B, C, ...
            generated_class_name = f"class_{class_letter}"
            train_data.append({
                'image_path': str(img_path),
                'class_id': generated_class_name,
                'class_name': class_name
            })
        
        # Test dataset: next 20 images
        for img_path in image_files[n_train:n_train+n_test]:
            class_letter = chr(64 + class_id)  # A, B, C, ...
            generated_class_name = f"class_{class_letter}"
            test_data.append({
                'image_path': str(img_path),
                'class_id': generated_class_name,
                'class_name': class_name
            })
    
    df_train = pd.DataFrame(train_data)
    df_test = pd.DataFrame(test_data).sample(frac=1).reset_index(drop=True)
    
    return df_train, df_test

#### Read dataframes and filter out non matching items

In [None]:
data_dir = "../data/EuroSAT_RGB/"

df_train, df_test = create_datasets(data_dir)

## Gemini

### Prompts

In [None]:
CLASSIFIER_SYSTEM_PROMPT = """You are an satellite imagery classification expert. Your task is to match input images from a satellite and assign a label based on the most similar context image. 

Provided with a list of context images and a list of input images, compare each input image to all context images for each class and determine the best matching context image for each input image and assign a label and confidence score. 

Provide your output as a JSON object in the following format:

{
    "number_of_labeled_images": <integer>,
    "output": [
        {
            "image_id": <image id, integer, starts at 0>,
            "confidence": <number between 0 and 10, the higher the more confident, integer>,
            "correct_label": <label of the most similar context image, string>
        }, 
        ...
    ]
}

## Instructions 

1. Carefully examine each input image.
2. Compare it to all context images for each class.
3. Determine the most similar context image for each input image and assign that label to the input image.
4. Assign a confidence score between 0 and 10, where 10 indicates the highest confidence in the match.

## Guidelines

- ALWAYS produce valid JSON.
- Generate ONLY a single prediction per input image. DO NOT produce duplicated predictions for any input image.
- Make a prediction for ALL input images. If there's no matching label, classify it with the most similar label.
- Each input image must be assigned a SINGLE label based on the context images. 
- You MUST only predict the context image labels provided.
- The `number_of_labeled_images` MUST be the same as the number of input images.

## Example

This is an example of a valid output:
```
{
  "number_of_labeled_images": 5,
  "output": [
      {
        "image_id": 0,
        "confidence": 10,
        "correct_label": "class_B"
      },
      {
        "image_id": 1,
        "confidence": 9,
        "correct_label": "class_C"
      },
      {
        "image_id": 2,
        "confidence": 4,
        "correct_label": "class_A"
      },
      {
        "image_id": 3,
        "confidence": 2,
        "correct_label": "class_B"
      },
      {
        "image_id": 4,
        "confidence": 10,
        "correct_label": "class_C"
      }
  ]
}
```
""".strip()

### Models

In [None]:
generation_config = {
  "temperature": 1,
  "max_output_tokens": 8192,
  "response_mime_type": "application/json",
}
classification_model = genai.GenerativeModel(
    "gemini-1.5-pro-exp-0827", 
    system_instruction=CLASSIFIER_SYSTEM_PROMPT, 
    generation_config=generation_config
)

#### Generate predictions

In [None]:
from io import BytesIO

def image_to_base64(img_path):
    img = PILImage.open(img_path)
    buffered = BytesIO()
    img.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode('utf-8')

def create_context_images_message(df):
    messages = ["Context images:"]
    grouped = df.groupby('class_name')
    for class_name, group in grouped:
        for _, row in group.iterrows():
            base64_img = image_to_base64(row["image_path"])
            messages.append(base64_img)
        messages.append(f"correct_label: {class_name}")
    return messages

def create_input_images_message(df):
    messages = ["Input images:"]
    for i, image_path in enumerate(df.image_path):
        base64_img = image_to_base64(image_path)
        image_message = [
            base64_img,
            f"input_image_id: {i}",
        ]
        messages.extend(image_message)
    messages.append(f"Please correctly classify all {df.shape[0]} images.")
    return messages
    
train_images_message = create_context_images_message(df_train)
test_images_message = create_input_images_message(df_test)

In [None]:
df_test.head()

In [None]:
def display_images_by_class(df, class_id, max_images=5):
    images_of_class = df[df['class_name'] == class_id]
    html_content = '<div style="display: flex; flex-wrap: wrap;">'
    for i, (_, row) in enumerate(images_of_class.iterrows()):
        if i >= max_images:
            break
        html_content += f'''
        <div style="margin: 10px;">
            <img src="{row['image_path']}" width="300">
            <p>Image path: {row['image_path']}</p>
        </div>
        '''
    html_content += '</div>'

    display(HTML(html_content))

class_id = "River" 
display_images_by_class(df_test, class_id)

In [None]:
test_images_message[:3]

In [None]:
contents = train_images_message + test_images_message[:3]
response = await classification_model.generate_content_async(
    contents=contents
)
response_json = json.loads(response.text)

In [None]:
response_json

In [None]:
predictions = [item['correct_label'] for item in response_json['output']]

accuracy = accuracy_score(df_test.class_name, predictions)
print(f"Accuracy: {accuracy:.4f}")

f1 = f1_score(df_test.class_name, predictions, average='weighted')
print(f"F1-score: {f1:.4f}")