In [None]:
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import os
import numpy as np

# Load the processor
processor = AutoProcessor.from_pretrained(
    'allenai/Molmo-7B-O-0924',
    trust_remote_code=True,
    torch_dtype='auto',
    device_map='auto'
)

# Load the model
model = AutoModelForCausalLM.from_pretrained(
    'allenai/Molmo-7B-O-0924',
    trust_remote_code=True,
    torch_dtype='auto',
    device_map='auto'
)

# Specify the input and output folders
input_folder = 'aachen_validation_set'  # Folder containing input images
output_folder = 'transcriptions_IAM_A-D_molmo'  # Folder to save transcriptions

# Create the output folder if it doesn't exist
os.makedirs(output_folder, exist_ok=True)

def preprocess_image(image):
    # Convert image to RGB if it's not already
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    # Convert to numpy array
    img_array = np.array(image)

    # Ensure the image is in the correct shape (height, width, channels)
    if img_array.ndim == 2:  # Grayscale image
        img_array = np.stack((img_array,) * 3, axis=-1)
    elif img_array.shape[2] == 1:  # Single channel
        img_array = np.repeat(img_array, 3, axis=-1)

    # Return the image as a PIL Image
    return Image.fromarray(img_array)

# Process each image in the input folder
for filename in os.listdir(input_folder):
    if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')):
        # Construct full file paths
        image_path = os.path.join(input_folder, filename)
        output_path = os.path.join(output_folder, f"{os.path.splitext(filename)[0]}_transcription.txt")
        
        # Open the image
        image = Image.open(image_path)

        # Preprocess the image
        image = preprocess_image(image)

        # Process the image and text
        inputs = processor.process(
            images=[image],
            text="You are an AI assistant specialized in transcribing handwritten text from images. Your task is to focus solely on the handwritten portions of the provided image and transcribe them accurately. Please follow these guidelines: 1. Examine the image carefully and identify all handwritten text. 2. Transcribe ONLY the handwritten text. Ignore any printed or machine-generated text in the image. 3. Maintain the original structure of the handwritten text, including line breaks and paragraphs. 4. Do not attempt to correct spelling or grammar in the handwritten text. Transcribe it exactly as written. Please begin your response directly with the transcribed text. Remember, your goal is to provide an accurate transcription of ONLY the handwritten portions of the text, preserving its original form as much as possible."
        )

        # Move inputs to the correct device and add batch dimension
        inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}  # Unsqueeze to add batch dimension

        # Generate output; maximum 200 new tokens; stop generation when <|endoftext|> is generated
        output = model.generate_from_batch(
            inputs,
            GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
            tokenizer=processor.tokenizer
        )

        # Only get generated tokens; decode them to text
        generated_tokens = output[0, inputs['input_ids'].size(1):]
        generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)

        # Save the transcription to a file
        with open(output_path, 'w', encoding='utf-8') as f:
            f.write(generated_text)
        
        print(f"Transcription for {filename} saved to {output_path}")

print("All images processed.")
