In [1]:
#%%capture
#!wget "https://repo1.maven.org/maven2/org/apache/spark/spark-sql-kafka-0-10_2.12/3.5.0/spark-sql-kafka-0-10_2.12-3.5.0.jar"
#!wget "https://repo1.maven.org/maven2/org/apache/spark/spark-streaming-kafka-0-10_2.12/3.5.0/spark-streaming-kafka-0-10_2.12-3.5.0.jar"

In [2]:
import os
os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages org.apache.spark:spark-streaming-kafka-0-10_2.12:3.5.0,org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.0 pyspark-shell'

In [3]:
import pandas as pd
import os
import pickle
import json
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

spark = SparkSession.builder.appName("OCR").getOrCreate()
spark.sparkContext.setLogLevel("WARN")

In [4]:
!pip install mistralai



In [5]:
# Initialize Mistral client with API key
from mistralai import Mistral

api_key = 'vrujwhuozfZgDYqUe3yPvpfyM0vkOTjU'
client = Mistral(api_key=api_key)

In [6]:
#Need to restart kernel once after finishing install mistral

In [7]:
# 1. Read Stream from Kafka
raw_stream = spark \
  .readStream \
  .format("kafka") \
  .option("kafka.bootstrap.servers", "192.168.1.197:8097") \
  .option("subscribe", "input") \
  .option("startingOffsets", "latest") \
  .load()

raw_stream.printSchema()

root
 |-- key: binary (nullable = true)
 |-- value: binary (nullable = true)
 |-- topic: string (nullable = true)
 |-- partition: integer (nullable = true)
 |-- offset: long (nullable = true)
 |-- timestamp: timestamp (nullable = true)
 |-- timestampType: integer (nullable = true)



In [8]:
# For debugging
if False:
    query = raw_stream \
      .selectExpr("CAST(key AS STRING)") \
      .writeStream \
      .outputMode("update") \
      .format("console") \
      .option("truncate", False) \
      .start()
    
    query.awaitTermination()

In [9]:
processed_data = raw_stream.select(
    col("value").cast("BINARY").alias("png_data"),
    col("key").cast(StringType()).alias("filename")
)

In [10]:
#import requests
from pyspark.sql import Row
from mistralai import DocumentURLChunk, ImageURLChunk, TextChunk
import json
import base64
import os
import time

# --- Configuration ---
#MISTRAL_OCR_ENDPOINT = "https://api.mistral.ai/v1/ocr" # Replace with actual API endpoint
MISTRAL_API_KEY = ""
OUTPUT_TOPIC_FOR_OCR_RESULTS = "ocr_results_topic"

# ---------------------

def call_mistral_ocr(row: Row):

    filename = row['filename']
    png_bytes = row['png_data']
    
    # Encode image as base64 for API
    encoded = base64.b64encode(png_bytes).decode()
    base64_data_url = f"data:image/jpeg;base64,{encoded}"
    
    # Process image with OCR
    image_response = client.ocr.process(
        document=ImageURLChunk(image_url=base64_data_url),
        model="mistral-ocr-latest"
    )
    
    # Convert response to JSON
    response_dict = json.loads(image_response.model_dump_json())
    json_string = json.dumps(response_dict, indent=4)
    #print(json_string)
    ocr_result = json_string
    return filename, ocr_result


def process_batch(df, batch_id):

    start_time_full = time.time()
    """
    Function executed for every micro-batch of the Structured Stream.
    Collects the rows and processes them with the OCR function.
    """
    # ⚠️ WARNING: .collect() brings ALL data to the driver program. 
    # Use for testing/low-volume only. For high-volume, consider a custom Kafka Connect Sink.
    rows = df.collect()
    
    results = []
    for row in rows:
        filename, ocr_text = call_mistral_ocr(row)
        if filename is None:
            filename = f"image_{batch_id}"
        results.append((filename, ocr_text))
        print(f"OCR Result for {filename}: {ocr_text}") # Log first 80 chars

    schema = StructType([
        StructField("filename", StringType(), False),
        StructField("ocr_text", StringType(), False) 
    ])

    if results:
        results_df = spark.createDataFrame(results, schema)

        # 3. Write results to JSON files (The Solution)
        # Use the batch_id to create a unique subdirectory for the JSON files
        batch_output_path = os.path.join("/home/jovyan/json", f"batch_{batch_id}")

        results_df.write \
            .format("json") \
            .mode("overwrite") \
            .save(batch_output_path)
        
        print(f"Successfully wrote OCR results for Batch {batch_id} to: {batch_output_path}")
        end_time_full = time.time()
        total_duration = end_time_full - start_time_full
        print(f"   > Total Function Duration: {total_duration:.4f} seconds")
    

In [12]:
query = processed_data.writeStream \
    .foreachBatch(process_batch) \
    .start() 

query.awaitTermination()

OCR Result for image_1: {
    "pages": [
        {
            "index": 0,
            "markdown": "PLACE FACE UP ON DASH\nCITY OF PALO ALTO\nNOT VALID FOR\nONSTREET PARKING\n\nExpiration Date/Time\n11:59 PM\nAUG 19, 2024\n\nPurchase Date/Time: 01:34pm Aug 19, 2024\nTotal Due: $15.00\nTotal Paid: $15.00\nTicket #: 00005883\nS/N #: 520117260957\nSetting: Permit Machines\nMach Name: Civic Center\n\n#****-1224, Visa\nDISPLAY FACE UP ON DASH\n\nPERMIT EXPIRES\nAT MIDNIGHT",
            "images": [],
            "dimensions": {
                "dpi": 200,
                "height": 1353,
                "width": 738
            }
        }
    ],
    "model": "mistral-ocr-2505-completion",
    "usage_info": {
        "pages_processed": 1,
        "doc_size_bytes": 771523
    },
    "document_annotation": null
}
Successfully wrote OCR results for Batch 1 to: /home/jovyan/json/batch_1
   > Total Function Duration: 8.0596 seconds


ERROR:root:KeyboardInterrupt while sending command.
Traceback (most recent call last):
  File "/usr/local/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/socket.py", line 706, in readinto
    return self._sock.recv_into(b)
           ^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt


KeyboardInterrupt: 

In [None]:
for q in spark.streams.active:
    q.stop()