In [None]:
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import os
import numpy as np
import difflib

# Load the processor
processor = AutoProcessor.from_pretrained(
    'allenai/Molmo-7B-D-0924',
    trust_remote_code=True,
    torch_dtype='auto',
    device_map='auto'
)

# Load the model
model = AutoModelForCausalLM.from_pretrained(
    'allenai/Molmo-7B-D-0924',
    trust_remote_code=True,
    torch_dtype='auto',
    device_map='auto'
)

# Specify the input and output folders
input_folder = 'IAMa_cropped'  # Folder containing input images
output_folder = 'transcriptions_IAM2_molmo'  # Folder to save final transcriptions

# Create the output folder if it doesn't exist
os.makedirs(output_folder, exist_ok=True)

# Define system prompt and user prompt
USER_PROMPT = "Please transcribe as accurately as possible the handwritten portions of the provided image."

SYSTEM_PROMPT = """You are an AI assistant specialized in transcribing handwritten text from images. Please follow these guidelines:
1. Examine the image carefully and identify all handwritten text.
2. Transcribe ONLY the handwritten text. Ignore any printed or machine-generated text in the image.
3. Maintain the original structure of the handwritten text, including line breaks and paragraphs.
4. Do not attempt to correct spelling or grammar in the handwritten text. Transcribe it exactly as written.
5. Do not describe the image or its contents.
6. Do not introduce or contextualize the transcription.
Please begin your response directly with the transcribed text. Remember, your goal is to provide an accurate transcription of ONLY the handwritten portions of the text, preserving its original form as much as possible."""

# Define refinement prompts
refinement_prompts = [
    "Review the original image and your previous transcription. Focus on correcting any spelling errors, punctuation mistakes, or missed words. Ensure the transcription accurately reflects the handwritten text.",
    "Examine the structure of the transcription. Are paragraphs and line breaks correctly represented? Adjust the layout to match the original handwritten text more closely.",
    "Make a final pass over the transcription, comparing it closely with the original image. Make any last corrections or improvements to ensure the highest possible accuracy. Do not add any introduction or contextualization you might have added to the transcribed text. Start directly with the transcription."
]

def preprocess_image(image):
    """
    Ensure image is in RGB format and correct shape.
    """
    if image.mode != 'RGB':
        image = image.convert('RGB')
    img_array = np.array(image)
    if img_array.ndim == 2:  # Grayscale image
        img_array = np.stack((img_array,) * 3, axis=-1)
    elif img_array.shape[2] == 1:  # Single channel
        img_array = np.repeat(img_array, 3, axis=-1)
    return Image.fromarray(img_array)

def print_differences(original, modified):
    """
    Print differences between the original and modified transcriptions.
    """
    diff = difflib.ndiff(original.splitlines(), modified.splitlines())
    return '\n'.join(diff)

# Process each image in the input folder
for filename in os.listdir(input_folder):
    if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')):
        image_path = os.path.join(input_folder, filename)

        # Create a folder for the current image's transcriptions
        image_output_folder = os.path.join(output_folder, os.path.splitext(filename)[0])
        os.makedirs(image_output_folder, exist_ok=True)

        # Open and preprocess the image
        image = Image.open(image_path)
        image = preprocess_image(image)

        # Combine system prompt and user prompt
        full_prompt = f"{SYSTEM_PROMPT}\n\n{USER_PROMPT}"

        # Process the image and text
        inputs = processor.process(
            images=[image],
            text=full_prompt
        )

        # Move inputs to the model's device and add a batch dimension
        inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}

        # Generate the initial transcription
        output = model.generate_from_batch(
            inputs,
            GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
            tokenizer=processor.tokenizer
        )

        # Decode the generated tokens to text
        generated_tokens = output[0, inputs['input_ids'].size(1):]
        generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)

        # Save the initial transcription in the folder for this image
        initial_output_path = os.path.join(image_output_folder, 'refinement_step_0.txt')
        with open(initial_output_path, 'w', encoding='utf-8') as f:
            f.write(generated_text)

        print(f"Initial transcription for {filename}:\n{generated_text}\n")

        # Perform refinement steps
        refined_text = generated_text
        previous_refined_text = refined_text

        for step, refinement_prompt in enumerate(refinement_prompts, start=1):
            # Create refinement prompt
            refinement_full_prompt = f"{SYSTEM_PROMPT}\n\n{refinement_prompt}\n\nTranscription:\n{refined_text}\n\nOriginal Image: {filename}"

            # Process the image and text for refinement
            refinement_inputs = processor.process(
                images=[image],
                text=refinement_full_prompt
            )

            # Move refinement inputs to model's device
            refinement_inputs = {k: v.to(model.device).unsqueeze(0) for k, v in refinement_inputs.items()}

            # Generate the refined transcription
            refined_output = model.generate_from_batch(
                refinement_inputs,
                GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
                tokenizer=processor.tokenizer
            )

            # Decode refined tokens into text
            refined_generated_tokens = refined_output[0, refinement_inputs['input_ids'].size(1):]
            refined_text = processor.tokenizer.decode(refined_generated_tokens, skip_special_tokens=True)

            # Save the refined transcription in the folder for this image
            refinement_output_path = os.path.join(image_output_folder, f'refinement_step_{step}.txt')
            with open(refinement_output_path, 'w', encoding='utf-8') as f:
                f.write(refined_text)

            # Display the refined transcription and differences
            differences = print_differences(previous_refined_text, refined_text)
            print(f"Refined transcription for {filename} (Step {step}):\n{refined_text}\n")
            print(f"Differences from previous transcription (Step {step - 1}):\n{differences}\n")

            # Update for the next iteration
            previous_refined_text = refined_text

        print(f"All refinements for {filename} saved in {image_output_folder}")

print("All images processed with refinements.")
