In [None]:
import torch
from PIL import Image
from transformers import AutoModelForCausalLM
import os
from tqdm import tqdm

# Fixed paths and prompt
INPUT_FOLDER = "aachen_validation_set"
OUTPUT_FOLDER = "transcriptions_IAM_ovis"
TRANSCRIPTION_PROMPT = "You are an AI assistant specialized in transcribing handwritten text from images. Your task is to focus solely on the handwritten portions of the provided image and transcribe them accurately. Please follow these guidelines: 1. Examine the image carefully and identify all handwritten text. 2. Transcribe ONLY the handwritten text. Ignore any printed or machine-generated text in the image. 3. Maintain the original structure of the handwritten text, including line breaks and paragraphs. 4. Do not attempt to correct spelling or grammar in the handwritten text. Transcribe it exactly as written. Please begin your response directly with the transcribed text. Remember, your goal is to provide an accurate transcription of ONLY the handwritten portions of the text, preserving its original form as much as possible."

# Load model
model = AutoModelForCausalLM.from_pretrained("AIDC-AI/Ovis1.6-Gemma2-9B",
                                             torch_dtype=torch.bfloat16,
                                             multimodal_max_length=8192,
                                             trust_remote_code=True).cuda()
text_tokenizer = model.get_text_tokenizer()
visual_tokenizer = model.get_visual_tokenizer()

# Create output folder if it doesn't exist
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

# Process each image in the input folder
for filename in tqdm(os.listdir(INPUT_FOLDER)):
    if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
        image_path = os.path.join(INPUT_FOLDER, filename)
        image = Image.open(image_path)
        
        query = f'<image>\n{TRANSCRIPTION_PROMPT}'

        # Format conversation
        prompt, input_ids, pixel_values = model.preprocess_inputs(query, [image])
        attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
        input_ids = input_ids.unsqueeze(0).to(device=model.device)
        attention_mask = attention_mask.unsqueeze(0).to(device=model.device)
        pixel_values = [pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)]

        # Generate output
        with torch.inference_mode():
            gen_kwargs = dict(
                max_new_tokens=1024,
                do_sample=False,
                top_p=None,
                top_k=None,
                temperature=None,
                repetition_penalty=None,
                eos_token_id=model.generation_config.eos_token_id,
                pad_token_id=text_tokenizer.pad_token_id,
                use_cache=True
            )
            output_ids = model.generate(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **gen_kwargs)[0]
            output = text_tokenizer.decode(output_ids, skip_special_tokens=True)
        
        # Save output to a text file
        output_filename = os.path.splitext(filename)[0] + '.txt'
        output_path = os.path.join(OUTPUT_FOLDER, output_filename)
        with open(output_path, 'w', encoding='utf-8') as f:
            f.write(output)

print(f"Processing complete. Output files have been saved to {OUTPUT_FOLDER}")