In [None]:
from mistralai import Mistral
import base64
import os
import json
from tqdm import tqdm
from PIL import Image
import time
from IPython.display import clear_output

api_key = "xxx"
client = Mistral(api_key=api_key)
ocr_model = "mistral-ocr-latest"

In [3]:
import base64
from io import BytesIO
from PIL import Image

# def encode_image_data(image_data):
#     try:
#         # Ensure image_data is bytes
#         if isinstance(image_data, bytes):
#             # Directly encode bytes to base64
#             return base64.b64encode(image_data).decode('utf-8')
#         else:
#             # Convert image data to bytes if it's not already
#             buffered = BytesIO()
#             image_data.save(buffered, format="JPEG")
#             return base64.b64encode(buffered.getvalue()).decode('utf-8')
#     except Exception as e:
#         print(f"Error encoding image: {e}")
#         return None

def encode_image_data(image_path):
    """Encode a local image to base64 with proper MIME type."""
    try:
        file_extension = os.path.splitext(image_path)[1].lower()
        mime_type = "image/jpeg"  # Default
        
        if file_extension == ".png":
            mime_type = "image/png"
        elif file_extension in [".jpg", ".jpeg"]:
            mime_type = "image/jpeg"
        elif file_extension == ".gif":
            mime_type = "image/gif"
        elif file_extension == ".bmp":
            mime_type = "image/bmp"
        elif file_extension == ".tiff" or file_extension == ".tif":
            mime_type = "image/tiff"
            
        with open(image_path, "rb") as image_file:
            encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
            return f"data:{mime_type};base64,{encoded_string}"
    except Exception as e:
        print(f"Error encoding image {image_path}: {e}")
        return None

In [4]:
def get_image_paths(folder_path):
    """Get paths of all image files in the folder."""
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.gif']
    image_paths = []
    
    for filename in os.listdir(folder_path):
        if any(filename.lower().endswith(ext) for ext in image_extensions):
            image_paths.append(os.path.join(folder_path, filename))
    
    return image_paths

data_folder = "1a0acabc-5140-425b-9eb8-f90e1721a6c3"
image_paths = get_image_paths(data_folder)
#image_paths = image_paths[:10]
print(f"Found {len(image_paths)} images in the data folder.")

Found 4066 images in the data folder.


In [None]:
# def create_batch_file_from_local(image_paths, output_file):
#     with open(output_file, 'w') as file:
#         for index, path in enumerate(tqdm(image_paths)):
#             image_url = encode_image_data(path)
#             if image_url:
#                 entry = {
#                     "custom_id": f"{index}_{os.path.basename(path)}",
#                     "body": {
#                         "document": {
#                             "type": "image_url",
#                             "image_url": image_url
#                         },
#                         "include_image_base64": True
#                     }
#                 }
#                 file.write(json.dumps(entry) + '\n')

# batch_file = "local_batch_file.jsonl"
# create_batch_file_from_local(image_paths, batch_file)


def create_batch_files(image_paths, max_file_size_mb=512):
    batch_files = []
    batch_data = []
    batch_size = 0  # Track current batch size
    batch_index = 1
    max_file_size_bytes = max_file_size_mb * 1024 * 1024
    
    for index, path in enumerate(tqdm(image_paths)):
        image_data = encode_image_data(path)
        if not image_data:
            continue
        
        entry = {
            "custom_id": f"{index}_{os.path.basename(path)}",
            "body": {
                "document": {
                    "type": "image_url",
                    "image_url": image_data
                },
                "include_image_base64": True
            }
        }
        entry_size = len(json.dumps(entry).encode('utf-8'))
        
        if batch_size + entry_size > max_file_size_bytes:
            # Save the current batch and start a new one
            batch_filename = f"batch_file_{batch_index}.jsonl"
            with open(batch_filename, 'w') as file:
                for item in batch_data:
                    file.write(json.dumps(item) + '\n')
            batch_files.append(batch_filename)
            
            # Reset batch
            batch_index += 1
            batch_data = []
            batch_size = 0
        
        batch_data.append(entry)
        batch_size += entry_size
    
    # Save the last batch
    if batch_data:
        batch_filename = f"batch_file_{batch_index}.jsonl"
        with open(batch_filename, 'w') as file:
            for item in batch_data:
                file.write(json.dumps(item) + '\n')
        batch_files.append(batch_filename)
    
    return batch_files

batch_files = create_batch_files(image_paths)

100%|██████████| 4066/4066 [00:18<00:00, 221.48it/s]


In [22]:
# for batch_file in tqdm(batch_files):
#     batch_data = client.files.upload(
#         file={
#             "file_name": batch_file,
#             "content": open(batch_file, "rb")
#         },
#         purpose="batch"
#     )
#     print(f"Uploaded {batch_file}")

#     created_job = client.batch.jobs.create(
#         input_files=[batch_data.id],
#         model=ocr_model,
#         endpoint="/v1/ocr",
#         metadata={"job_type": "local_files_ocr2"}
#     )

#     retrieved_job = client.batch.jobs.get(job_id=created_job.id)

batch_data = client.files.upload(
    file={
        "file_name": batch_files[3],
        "content": open(batch_files[1], "rb")
    },
    purpose="batch"
)
print(f"Uploaded {batch_files[3]}")

created_job = client.batch.jobs.create(
    input_files=[batch_data.id],
    model=ocr_model,
    endpoint="/v1/ocr",
    metadata={"job_type": "local_files_ocr2"}
)

retrieved_job = client.batch.jobs.get(job_id=created_job.id)

Uploaded batch_file_4.jsonl


In [23]:
while retrieved_job.status in ["QUEUED", "RUNNING"]:
    retrieved_job = client.batch.jobs.get(job_id=created_job.id)
    
    clear_output(wait=True)  # Clear the previous output
    print(f"Status: {retrieved_job.status}")
    print(f"Total requests: {retrieved_job.total_requests}")
    print(f"Failed requests: {retrieved_job.failed_requests}")
    print(f"Successful requests: {retrieved_job.succeeded_requests}")
    print(
        f"Percent done: {round((retrieved_job.succeeded_requests + retrieved_job.failed_requests) / retrieved_job.total_requests, 4) * 100}%"
    )
    time.sleep(2)

Status: SUCCESS
Total requests: 1153
Failed requests: 0
Successful requests: 1153
Percent done: 100.0%


In [7]:
# batch_data = client.files.upload(
#     file={
#         "file_name": batch_file,
#         "content": open(batch_file, "rb")
#     },
#     purpose="batch"
# )

In [None]:

# created_job = client.batch.jobs.create(
#     input_files=[batch_data.id],
#     model=ocr_model,
#     endpoint="/v1/ocr",
#     metadata={"job_type": "local_files_ocr2"}
# )

In [None]:
# retrieved_job = client.batch.jobs.get(job_id=created_job.id)
# while retrieved_job.status in ["QUEUED", "RUNNING"]:
#     retrieved_job = client.batch.jobs.get(job_id=created_job.id)
    
#     clear_output(wait=True)  # Clear the previous output
#     print(f"Status: {retrieved_job.status}")
#     print(f"Total requests: {retrieved_job.total_requests}")
#     print(f"Failed requests: {retrieved_job.failed_requests}")
#     print(f"Successful requests: {retrieved_job.succeeded_requests}")
#     print(
#         f"Percent done: {round((retrieved_job.succeeded_requests + retrieved_job.failed_requests) / retrieved_job.total_requests, 4) * 100}%"
#     )
#     time.sleep(2)

Status: SUCCESS
Total requests: 10
Failed requests: 0
Successful requests: 10
Percent done: 100.0%


In [8]:
retrieved_job.output_file

'a6f1a923-5989-479f-a00e-57be9472479b'

In [2]:
client.files.download(file_id = 'a6f1a923-5989-479f-a00e-57be9472479b')

<Response [200 OK]>

In [5]:
output_file = client.files.download(file_id = '8709905a-84a4-4758-bc7d-b18ade0d85b3')

with open('file.jsonl', 'wb') as f:
    f.write(output_file.read())