---
title: "Image Classification using Gemini"
date: "10/08/2024"
date-modified: last-modified
description-meta: "Learn how to use In-Context Learning (ICL) to classify images using Gemini"
toc: true
toc-depth: 3
lightbox: true
fig-cap-location: margin
categories:
  - mllms
  - in-context-learning
  - gemini
author:
  - name: Dylan Castillo
    url: https://dylancastillo.co
    affiliation: Iwana Labs
    affiliation-url: https://iwanalabs.com
    citation: true
    comments:
    utterances:
    repo: dylanjcastillo/blog_comments
    theme: dark-blue
    issue-term: pathname
draft: true
---

One overlooked aspect of Multimodal Large Language Models (MLLMs) is the ability to use In-Context Learning (ICL) to classify images. This is a technique that allows the model to learn from a small number of ground truth images provided at inference time and make predictions on previously unseen images.

I recently worked on a project for a client that made use of this approach, so I thought it'd be fun to write a short tutorial about it.

## Prerequisites

To follow this tutorial you'll need to:

- Generate a key in [Google AI Studio](https://aistudio.google.com/app/apikey)
- Download [Danish Fungi 2024 – Mini dataset](https://github.com/BohemianVRA/DanishFungiDataset)
- Create a virtual environment and install the requirements:

```bash
python -m venv venv
source venv/bin/activate
pip install pandas numpy scikit-learn nest-asyncio google-generativeai
```


## Why Gemini Flash 1.5?

I'm going to use Gemini Flash 1.5 for three reasons:

- It's a lot cheaper than [GPT-4o](https://platform.openai.com/pricing) and [Sonnet 3.5](https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs). For an image of 512x512 pixels, Gemini Flash 1.5 is 66x cheaper than Gemini Pro 1.5, 32x cheaper than GPT-4o, and 52x cheaper than Sonnet 3.5[^longnote].
- It lets you use up to 3,000 images per request. By trial and error, I found that GPT-4o seems to have a hard limit at 250 images per request and Sonnet 3.5's documentation mentions a limit of 20 images per request.
- It works well enough for this task. If you really want to squeeze the last bit of performance out of your model, you can use GPT-4o, Sonnet 3.5, or Gemini Pro 1.5.

Regardless, of the model you choose, this tutorial will be a good starting point for you to classify images using ICL.

[^longnote]: Estimated costs as of August 10, 2024:

    | Model | Cost per 512x512 image |
    |-------|------------------------|
    | Gemini Flash 1.5 | $0.00002 |
    | Gemini Pro 1.5 | $0.000064 |
    | GPT-4o | $0.000638 |
    | Sonnet 3.5 | $0.001047 |

## Set up

In [1]:
#| output: false
#| echo: false
%load_ext dotenv
%dotenv
%load_ext autoreload
%autoreload 2

`nest_asyncio` makes it possible to run async code in Jupyter notebooks. You can enable it in your notebook by running the following cell:

In [2]:
import nest_asyncio

nest_asyncio.apply()

## Setting up the system

In [3]:
import asyncio
import base64
import json
import os
import warnings
from datetime import datetime

import google.generativeai as genai
import numpy as np
import pandas as pd
from IPython.display import Image, display, HTML
from sklearn.metrics import accuracy_score, f1_score

warnings.filterwarnings("ignore")

np.random.seed(42)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
genai.configure(api_key=os.environ["GEMINI_API_KEY"])

## Setting up dataframe

#### Read dataframes and filter out non matching items

In [5]:
data_dir = "../data/"
metadata_dir = data_dir + "DF20M-metadata/"
images_dir = data_dir + "DF20M/"

df_train = pd.read_csv(metadata_dir + "DanishFungi2024-Mini-train-subset.csv")
df_test = pd.read_csv(metadata_dir + "DanishFungi2024-Mini-pubtest.csv")

In [6]:
# | output: false
# | echo: false
# Check that the test and train dataframes have the same items
assert (
    df_test["class_id"].isin(df_train["class_id"]).all()
), "Test and train dataframes do not have the same items"
assert (
    df_train["class_id"].isin(df_test["class_id"]).all()
), "Train and test dataframes do not have the same items"

df_test["class_id"].nunique(), df_train["class_id"].nunique()

(182, 182)

#### Create output dataset

In [7]:
# Encode images
df_train["full_image_path"] = df_train["image_path"].apply(lambda x: f"{images_dir}{x.split('/')[-1]}")
df_test["full_image_path"] = df_test["image_path"].apply(lambda x: f"{images_dir}{x.split('/')[-1]}")

## Gemini functions

### Prompts

In [8]:
CLASSIFIER_SYSTEM_PROMPT = """You are an expert mycologist specialized in classifying mushrooms from images. Your task is to match input images of mushrooms and assign a label based on the most similar context image (a picture of a mushroom). 

Provided with a list of context images and a list of input images, compare each input image to all context images and determine the best matching context image for each input image and assign a label and confidence score. 

Provide your output as a JSON object in the following format:

{
    "number_of_classified_images": <integer>,
    "output": [
        {
            "id": <input_image_id>,
            "class_id": <class_id>, 
            "confidence": <integer, between 0 and 10, the higher the more confident>
        }, 
        ...
    ]
}

## Instructions 

1. Carefully examine each input image.
2. Compare it to all context images.
3. Determine the most similar context image and assign that label to the input image.
4. Assign a confidence score between 0 and 10, where 10 indicates the highest confidence in the match.

## Guidelines

- ALWAYS produce valid JSON.
- Generate ONLY a single prediction per input image. DO NOT produce duplicated predictions for any input image.
- Make a prediction for ALL input images. If there's no matching label, classify it with the most similar label.
- Each input image must be assigned a SINGLE label based on the context images. 
- You MUST only predict the context image labels provided. 

## Example

Given these context images:

```
- <images from class 1>, class_id:class_1
- <images from class 2>, class_id:class_2
- <images from class 3>, class_id:class_3
```

These input images:
```
- <image to classify 1>, input_image_id:0
- <image to classify 2>, input_image_id:1
- <image to classify 3>, input_image_id:2
- <image to classify 4>, input_image_id:3
```

And this total number of input images: 5

This is an example of a valid output:
```
{
  "total_input_images": 5,
  "output": [
      {
        "id": "0",
        "confidence": 10,
        "class_id": "class_2"
      },
      {
        "id": "1",
        "confidence": 9,
        "class_id": "class_3"
      },
      {
        "id": "2",
        "confidence": 4,
        "class_id": "class_3"
      },
      {
        "id": "3",
        "confidence": 2,
        "class_id": "class_1"
      },
      {
        "id": "4",
        "confidence": 10,
        "class_id": "class_2"
      }
  ]
}
```
""".strip()

### Models

In [9]:
generation_config = {
  "temperature": 1,
  "max_output_tokens": 8192,
  "response_mime_type": "application/json",
}
classification_model = genai.GenerativeModel(
    "gemini-1.5-flash-exp-0827", 
    system_instruction=CLASSIFIER_SYSTEM_PROMPT, 
    generation_config=generation_config
)

#### Generate predictions

In [10]:
df_test_subset = df_test[df_test.class_id.isin([2, 50, 100])]
df_train_subset = df_train[df_train.class_id.isin([2, 50, 100])]

In [11]:

def create_context_images_message(df):
    messages = []
    grouped = df.groupby('class_id')
    for class_id, group in grouped:
        for _, row in group.iterrows():
            messages.append(Image(row["full_image_path"]))
        messages.append(f"class_id: class_{class_id}")
    return messages

def create_input_images_message(df):
    messages = []
    for i, image_path in enumerate(df.full_image_path):
        image_message = [
            Image(image_path),
            f"input_image_id: {i}",
        ]
        messages.extend(image_message)
    return messages
    
train_images_message = create_context_images_message(df_train_subset)
test_images_message = create_input_images_message(df_test_subset)

In [12]:
def display_images_by_class(df, class_id, max_images=6):
    images_of_class = df[df['class_id'] == class_id]


    html_content = '<div style="display: flex; flex-wrap: wrap;">'
    for i, (_, row) in enumerate(images_of_class.iterrows()):
        if i >= max_images:
            break
        html_content += f'''
        <div style="margin: 10px;">
            <img src="{row['full_image_path']}" width="300">
            <p>Image path: {row['full_image_path']}</p>
        </div>
        '''
    html_content += '</div>'

    display(HTML(html_content))

class_id = 4
display_images_by_class(df_train_subset, class_id)

In [13]:
class_id = 4
display_images_by_class(df_test_subset, class_id)

In [14]:
contents = train_images_message + test_images_message

In [15]:
response = await classification_model.generate_content_async(
    contents=contents
)
response_json = json.loads(response.text)

In [16]:
response_json

{'number_of_classified_images': 40,
 'output': [{'id': '0', 'class_id': 'class_50', 'confidence': 9},
  {'id': '1', 'class_id': 'class_50', 'confidence': 8},
  {'id': '2', 'class_id': 'class_2', 'confidence': 7},
  {'id': '3', 'class_id': 'class_50', 'confidence': 9},
  {'id': '4', 'class_id': 'class_100', 'confidence': 10},
  {'id': '5', 'class_id': 'class_2', 'confidence': 9},
  {'id': '6', 'class_id': 'class_2', 'confidence': 8},
  {'id': '7', 'class_id': 'class_50', 'confidence': 9},
  {'id': '8', 'class_id': 'class_2', 'confidence': 7},
  {'id': '9', 'class_id': 'class_50', 'confidence': 9},
  {'id': '10', 'class_id': 'class_100', 'confidence': 9},
  {'id': '11', 'class_id': 'class_50', 'confidence': 8},
  {'id': '12', 'class_id': 'class_50', 'confidence': 8},
  {'id': '13', 'class_id': 'class_50', 'confidence': 8},
  {'id': '14', 'class_id': 'class_2', 'confidence': 7},
  {'id': '15', 'class_id': 'class_50', 'confidence': 8},
  {'id': '16', 'class_id': 'class_2', 'confidence': 9}

In [17]:
# Extract predictions from the response
predictions = [int(item['class_id'].split('_')[-1]) for item in response_json['output']]

# Calculate accuracy
accuracy = accuracy_score(df_test_subset.class_id, predictions)
print(f"Accuracy: {accuracy:.4f}")

# Calculate F1-score
f1 = f1_score(df_test_subset.class_id, predictions, average='weighted')
print(f"F1-score: {f1:.4f}")

Accuracy: 0.9750
F1-score: 0.9747
