# Image Captioning

This notebook includes the following scripts used in the image captioning logic:

1) Image classification
2) Document summarization (focus on the retrieval of key chemical concepts from a raw text)
3) Image captions generation

In [None]:
#%pip install openai tiktoken

In [None]:
import io
import os
import re
import uuid
import json
import base64
import logging
import pandas as pd
from PIL import Image
from pathlib import Path
from jinja2 import Template
from datetime import datetime
from dotenv import load_dotenv
from tiktoken.core import Encoding
from openai import AzureOpenAI
from typing import Dict, List, Optional, Tuple

In [None]:
load_dotenv("../my.env")

In [None]:
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
AZURE_OPENAI_DEPLOYMENT_NAME = os.getenv("AZURE_OPENAI_DEPLOYMENT")
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")

In [None]:
DOCUMENT_SUMMARIZATION_SYSTEM_PROMPT = "You are a chemistry research assistant specializing in chemical document summarization."
DOCUMENT_SUMMARIZATION_USER_PROMPT = """
As a chemistry research assistant, summarize the document text by extracting only chemical information with accuracy and clarity.
Focus on:
- **Chemical Entities:** List all elements, compounds, and molecules (with IUPAC and common names if relevant).
- **Reactions and Mechanisms:** Summarize key reactions, reactants, products, intermediates, and catalysts. Specify reaction types (e.g., oxidation, reduction) and mechanistic insights.
- **Functional Groups and Stereochemistry:** Mention relevant functional groups, stereochemical details, and molecular structures.
- **Experimental Conditions:** Include essential parameters like temperature, pressure, solvents, and concentrations.
Avoid summarizing non-chemical content. Ensure professional terminology, structured clarity, and accurate representation of chemical concepts.
The summary must be in English and well-organized for readability.
DOCUMENT_TEXT: {{DOCUMENT_TEXT}}
"""


IMAGE_CLASSIFICATION_SYSTEM_PROMPT = """
You are an expert image classification assistant specializing in scientific and chemical images. Your task is to analyze an input image and classify it into **one of the following categories**, returning a structured JSON response.
### **Categories:**
1 → Logos  
2 → Chemical molecules, reactions, formulas
3 → Table images  
4 → Charts (line plots, bar charts, etc., with numerical data representation)  
5 → Process diagrams, schemas (e.g., chemical lab setup)  
6 → Regular photos (e.g., field, lab, general photography)  
7 → Measurements (e.g., microscopic images)  
8 → Chemical manual writing (e.g., handwritten diagrams, chemical reactions)
9 → Hand-written signatures
10 → Other 

Always return **only a single category** that best matches the image. Prioritize **chemically and scientifically relevant** labels when applicable. Your response must be in **valid JSON format**, containing:
- `"image_class"` - The assigned category number (integer from 1-9).
- `"probability_score"` - Confidence level (float between 0 and 1).

**Do not include any explanations, extra text, or formatting beyond the required JSON output.**
"""
IMAGE_CLASSIFICATION_USER_PROMPT = """
Classify the following image into one of the predefined categories and return a **JSON object** strictly in the following format:
{
  "image_class": <category_number>,
  "probability_score": <confidence_value>
}
"""


IMAGE_CAPTIONING_SYSTEM_PROMPT = "You are a chemistry research assistant specializing in generating precise captions for chemical images."
IMAGE_CAPTIONING_USER_PROMPT = """
# IMPORTANT PRINCIPLES:
## Accuracy & Completeness:
- Use IMAGE_TITLE and DOCUMENT_CONTEXT for finding additional information to generate a presise caption.
- Always ensure chemical names, reactions, functional groups, and molecular structures are correctly described.
- Avoid ambiguous descriptions; explicitly state what is shown in the image.
- Provide molecular names using both IUPAC nomenclature and common names (if applicable).
- If oxidation states, valency, or reaction mechanisms are involved, verify correctness before captioning.

## Consistency with Image Data:
- Do not assume missing elements; caption exactly what is visible.
- If multiple interpretations are possible, specify the conditions under which each occurs.
- Use correct chemical terminology, including stereochemical descriptors where needed.

## Response Formatting Based on Image Type:
- Follow structured output rules for different categories of chemical images.
- Ensure chemical equations, reaction conditions, tables, and plots are formatted precisely using Markdown where applicable.


# IMAGE CATEGORY-SPECIFIC CAPTIONING RULES:
## Chemical Molecules and Reactions:
- Identify the molecule's IUPAC name and common name (if applicable).
- Clearly state functional groups, stereochemistry, and substituents' positions.
- Clearly state the meaning of each color if relevant, e.g. colored spheres representing atoms.
- If the image contains a reaction, list reactants, intermediates, and products, specifying the reaction type (e.g., oxidation, reduction, condensation).
- If the image contains a molecule, explain its structure in details, including all atoms and layers seen in an image; if multiple colors are used, explain the meaning of each one if relevant.
- Search for additional details (names of chemical elements) in IMAGE_TITLE and DOCUMENT_CONTEXT.
- Mention relevant catalysts, solvents, temperature, and pressure conditions.
- Ensure electron flow and mechanistic pathways (if applicable) are correctly represented.

## Microscopy & Crystallographic Images:
- Specify the chemical composition of observed structures.
- Indicate phase, crystal system, and lattice parameters where possible.
- Mention any defects, grain boundaries, or other notable structural features.
- Identify scale bars, magnification, and imaging techniques (e.g., SEM, TEM, XRD, AFM).

## Experimental Laboratory Setups:
- Provide a clear description of the setup, including essential apparatus and chemicals involved.
- Explain experimental parameters such as temperature, pressure, pH, and reagent concentrations.
- If the image contains a reaction, summarize its purpose and expected outcome.

## Tables (Tabular Data Representation):
- Extract all table contents exactly as shown, maintaining structure and formatting in Markdown:
| Column 1 | Column 2 | Column 3 |
|----------|----------|----------|
| Data 1   | Data 2   | Data 3   |
- Ensure all abbreviations are correctly expanded unless universally understood.
- In case of complex table structure, ensure that markdown columns are properly shifted and completely match the table image.

## Graphs, Charts, and Plots:
- Given a line plot, bar chart, scatter plot or any other similat chart, first of all ALWAYS accurately reconstruct it in a tabular (Markdown) format (if numerical points are limited in an image, ALWAYS extrapolate them precisely to get a table), for example:
| Time [min] (X axis) | Amount gas [ml] (Y axis) | Legend                 |
|---------------------|--------------------------|-------------------------
| 0                   | 0                        | Jaegers-L-00239 (Blue) |
| 100                 | 200                      | Jaegers-L-00239 (Blue) |
- Don't discribe visual aspects of graphs, charts, and plots. Focus on quantitative data.
- Summarize (numerically) notable trends, peak values, inflection points, and outliers.
- Include equation-based descriptions for any regression models, best-fit lines, or calculated values.

## Handwritten Notes (Formulas, Reactions, Calculations):
- Transcribe chemical equations, formulas, and reaction mechanisms with precision.
- Ensure correct subscripts, superscripts, charges, and reaction arrows.
- If handwritten content is ambiguous, provide a clarification note while maintaining the original meaning.


# COMMON ERRORS TO AVOID:
## Missing Key Observations:
- Ensure all significant features (e.g., acetate presence in crystal images) are described.

## Ambiguous Abbreviations:
- Expand uncommon abbreviations unless contextually evident.

## Incorrect Functional Group Assignments:
- Validate oxidation states, functional group transformations, and reaction mechanisms.
- Avoid confusion between aldehydes, ketones, alcohols, and carboxylic acids.

## Incomplete Molecular Descriptions:
- Always specify where substituents are attached (e.g., "X is on the 3-position of the phenyl ring").


# FINAL OUTPUT FORMAT REQUIREMENTS
- Responses must be structured in Markdown where applicable.
- Chemical equations and formulas should be formatted correctly, using LaTeX notation where necessary.
- Tables and numerical data must be accurately transcribed in Markdown table format.
- Captions must be precise, detailed, and contextually appropriate for a professional audience.
- Avoid assumptions and describe only the observable features in the image.

IMAGE_TITLE: {{IMAGE_TITLE}}
DOCUMENT_CONTEXT: {{DOCUMENT_CONTEXT}}
"""

In [None]:
def validate_image(image_path: str, max_image_size=20971520) -> bool:
    """Validate if the image is suitable for processing."""
    try:
        if not os.path.exists(image_path):
            print(f"Image file does not exist: {image_path}")
            return False

        # Check file size
        file_size = os.path.getsize(image_path)
        if file_size > max_image_size:
            print(f"Image too large ({file_size} bytes): {image_path}")
            return False

        # Verify image can be opened and is valid
        with Image.open(image_path) as img:
            img.verify()
        # Check if image has valid dimensions
        with Image.open(image_path) as img:
            width, height = img.size
            if width == 0 or height == 0:
                print(f"Invalid image dimensions: {image_path}")
                return False

            if width > 32768 or height > 32768:
                print(f"Image dimensions too large: {width}x{height}")
                return False
        return True
    except Exception as e:
        print(f"Image validation failed for {image_path}: {str(e)}")
        return False

    
def preprocess_image(image_path: str) -> Optional[Image.Image]:
    try:
        with Image.open(image_path) as img:
            if img.mode not in ('RGB', 'L'):
                img = img.convert('RGB')

            max_dimension = 2048
            if img.width > max_dimension or img.height > max_dimension:
                ratio = min(max_dimension / img.width, max_dimension / img.height)
                new_size = (int(img.width * ratio), int(img.height * ratio))
                img = img.resize(new_size, Image.Resampling.LANCZOS)

            img_byte_arr = io.BytesIO()
            img.save(img_byte_arr, format='JPEG', quality=85)
            img_byte_arr.seek(0)

            return Image.open(img_byte_arr)

    except Exception as e:
        print(f"Image preprocessing failed for {image_path}: {str(e)}")
        return None

    
def encode_image(image_path: str) -> Optional[str]:
    """Encode image as base64 with proper validation and preprocessing."""
    try:
        if not validate_image(image_path):
            return None

        processed_img = preprocess_image(image_path)
        if processed_img is None:
            return None

        img_byte_arr = io.BytesIO()
        processed_img.save(img_byte_arr, format='JPEG', quality=85)
        img_byte_arr.seek(0)
        base64_encoded = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')

        try:
            base64.b64decode(base64_encoded)
            return base64_encoded
        except Exception as e:
            print(f"Base64 validation failed for {image_path}: {str(e)}")
            return None

    except Exception as e:
        print(f"Image encoding failed for {image_path}: {str(e)}")
        return None


def get_image_classification(image_path, client, deployment_name):
    try:
        base64_image = encode_image(image_path)
        if not base64_image:
            print(f"Skipping caption generation for invalid image: {image_path}")
            return ""
        
        max_retries = 3
        for attempt in range(max_retries):
            try:
                response = client.chat.completions.create(
                    model = deployment_name,
                    response_format={ "type": "json_object" },
                    messages=[
                        {
                            "role": "system",
                            "content": IMAGE_CLASSIFICATION_SYSTEM_PROMPT
                        },
                        {
                            "role": "user",
                            "content": [
                                {"type": "text", 
                                 "text": IMAGE_CLASSIFICATION_USER_PROMPT
                                },
                                {
                                    "type": "image_url",
                                    "image_url": {
                                        "url": f"data:image/jpeg;base64,{base64_image}"
                                    }
                                }
                            ]
                        }
                    ],
                    temperature=0.0,
                    max_tokens=100
                )
                return json.loads(response.choices[0].message.content)

            except Exception as e:
                if attempt < max_retries - 1:
                    print(f"Image classification attempt {attempt + 1} failed: {str(e)}")
                    continue
                else:
                    print(f"All image classification attempts failed for {image_path}")
                    return None

    except Exception as e:
        print(f"Image classification failed: {str(e)}")
        return None
    

def get_summary(raw_text, client, deployment_name):
    try:
        template = Template(DOCUMENT_SUMMARIZATION_USER_PROMPT)
        user_prompt = template.render(DOCUMENT_TEXT=raw_text)
        
        max_retries = 3
        for attempt in range(max_retries):
            try:
                response = client.chat.completions.create(
                    model = deployment_name,
                    messages=[
                        {
                            "role": "system",
                            "content": DOCUMENT_SUMMARIZATION_SYSTEM_PROMPT
                        },
                        {
                            "role": "user",
                            "content": [
                                {"type": "text", 
                                 "text": user_prompt
                                }
                            ]
                        }
                    ],
                    temperature=0.0,
                    max_tokens=4000
                )
                return response.choices[0].message.content

            except Exception as e:
                if attempt < max_retries - 1:
                    print(f"Summary generation attempt {attempt + 1} failed: {str(e)}")
                    continue
                else:
                    print(f"All summary generation attempts failed for {image_path}")
                    return ""

    except Exception as e:
        print(f"Summary generation failed: {str(e)}")
        return ""
    
    
def generate_caption(image_path, figure_title, document_summary, vision_client, vision_deployment_name) -> str:
    """Generate image caption using GPT-4 Vision."""
    try:
        base64_image = encode_image(image_path)
        if not base64_image:
            print(f"Skipping caption generation for invalid image: {image_path}")
            return ""
        
        template = Template(IMAGE_CAPTIONING_USER_PROMPT)
        user_prompt = template.render(IMAGE_TITLE=figure_title, DOCUMENT_CONTEXT=document_summary)
        
        max_retries = 3
        for attempt in range(max_retries):
            try:
                response = vision_client.chat.completions.create(
                    model = vision_deployment_name,
                    messages=[
                        {
                            "role": "system",
                            "content": IMAGE_CAPTIONING_SYSTEM_PROMPT
                        },
                        {
                            "role": "user",
                            "content": [
                                {"type": "text", 
                                 "text": user_prompt
                                },
                                {
                                    "type": "image_url",
                                    "image_url": {
                                        "url": f"data:image/jpeg;base64,{base64_image}"
                                    }
                                }
                            ]
                        }
                    ],
                    temperature=0.0,
                    max_tokens=2000
                )
                return response.choices[0].message.content

            except Exception as e:
                if attempt < max_retries - 1:
                    print(f"Caption generation attempt {attempt + 1} failed: {str(e)}")
                    continue
                else:
                    print(f"All caption generation attempts failed for {image_path}")
                    return ""

    except Exception as e:
        print(f"Caption generation failed for {image_path}: {str(e)}")
        return ""

In [None]:
aoai_vision_client = AzureOpenAI(
    api_key=AZURE_OPENAI_API_KEY,
    azure_endpoint=AZURE_OPENAI_ENDPOINT,
    api_version=AZURE_OPENAI_API_VERSION
)

### Image classification

In [None]:
image_path = "<IMAGE PATH>"
result = get_image_classification(image_path, aoai_vision_client, AZURE_OPENAI_DEPLOYMENT_NAME)
result

### Document summarization

In [None]:
document_raw_text = "<DOCUMENT RAW TEXT>"
document_summary = get_summary(document_raw_text, aoai_vision_client, AZURE_OPENAI_DEPLOYMENT_NAME)
document_summary

### Image Captioning

In [None]:
image_path = "<IMAGE PATH>"
document_summary = "<DOCUMENT SUMMARY>"
figure_title = "<FIGURE TITLE>" # can be empty

image_caption = generate_caption(image_path, figure_title, document_summary, aoai_vision_client, AZURE_OPENAI_DEPLOYMENT_NAME)
image_caption