# Plot Overlay


In [None]:
# Jinko specifics imports & initialization
import concurrent.futures
import os
import jinko_helpers as jinko
from google import genai
from google.genai import types

jinko.initialize()

In [None]:
# Cookbook specifics imports
import cv2
import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import display
from overlay_utils import (
    _detect_axes,
    fetch_images_from_source_sid,
    get_calibration,
    normalize_axis_if_needed,
    normalize_result_to_series,
    process_image,
    show_extracted_series_overlay,
)
import re
import subprocess
import tempfile

## Step 0: Configure inputs


In [None]:
# The extract short id you performed your extraction from
EXTRACT_SID = "as-74EU-QHiY"

# The data table created by Kohai
DATA_TABLE_SID = "dt-U9iA-WNom"

ORG_SID = "ad10916c-f4ac-4b53-8865-9ccde303e3bf"

In [None]:
graphql_url = os.getenv("JINKO_BASE_URL") + "/_api"

CLIENT = genai.Client()
MODEL_NAME = "gemini-3-flash-preview"

In [None]:
# Jinko Auth via jk-sso-login

cmd = [
    "nix", "develop",
    "git+ssh://git@git.novadiscovery.net/jinko/dorayaki/jinko-seeder#interactive",
    "--command",
    "jk-sso-login",
    "-o", ORG_SID,
    "-u", graphql_url,
]

with tempfile.TemporaryDirectory() as tmpdir:
    res = subprocess.run(
        cmd, check=True, capture_output=True, text=True, cwd=tmpdir)
out = res.stdout.strip()

# Adjust parsing if your command prints extra text.
# This tries to extract a JWT-like token first, else uses last non-empty line.
m = re.search(r"eyJ[A-Za-z0-9_\-\.]+", out)
token = m.group(0) if m else [l for l in out.splitlines() if l.strip()][-1]

token = 'jinko-session=' + token

## Step 1: Axis calibration tool


In [None]:
def axis_properties_tool(ocr_results, image, axis_info):
    """Normalize LLM axis info and build x/y calibration objects."""
    y_axis = axis_info["y_axis"]
    x_axis = axis_info["x_axis"]

    axis_info["x_axis"] = normalize_axis_if_needed(x_axis)
    axis_info["y_axis"] = normalize_axis_if_needed(y_axis)

    return get_calibration(
        image_path=image,
        llm_axis_info=axis_info,
        res=ocr_results,
    )

## Step 2: LLM calibration prompt and schema


In [None]:
CALIBRATION_SYSTEM_PROMPT = """
You are an axis calibration specialist.

Your sole task is to infer accurate axis properties from the chart image.

### Your task:
- Identify X and Y axis properties
- Determine:
  - axis type (linear, logK with K being an integer so 10, 2, 5, ...)
  - visible tick values
  - approximate numeric range
  - presence of breaks or discontinuities (if any)

### Axis breaks:
- If an axis has a break, describe breaks using a separate `breaks` field.
- Each break must specify the numeric values it separates.
  Example:
  {
    "ticks": [0, 10, 50, 100],
    "breaks": [[10, 50], [50, 100]]
  }

### Rules:
- Always call `axis_properties_tool`.
- Provide the most complete AXIS_INFO possible.
- Do NOT output commentary.
- Do NOT repeat calibration after it has been done.
"""

CALIBRATION_USER_PROMPT = """
Analyze the chart image and extract full axis calibration information.
Infer axis ticks and numeric range as precisely as possible.
"""


def create_function_declaration(tool_dict):
    return types.Tool(
        function_declarations=[
            types.FunctionDeclaration(
                name=tool_dict["name"],
                description=tool_dict["description"],
                parameters=types.Schema(**tool_dict["parameters"]),
            )
        ]
    )


axis_properties_tool_def = {
    "name": "axis_properties_tool",
    "description": "Store axis properties and produce calibrated AxisCal objects.",
    "parameters": {
        "type": "OBJECT",
        "properties": {
            "axis_info": {
                "type": "OBJECT",
                "properties": {
                    "x_axis": {
                        "type": "OBJECT",
                        "properties": {
                            "type": {"type": "STRING"},
                            "ticks": {"type": "ARRAY", "items": {"type": "NUMBER"}},
                            "break": {"type": "ARRAY", "items": {"type": "NUMBER"}},
                        },
                        "required": ["type", "ticks"],
                    },
                    "y_axis": {
                        "type": "OBJECT",
                        "properties": {
                            "type": {"type": "STRING"},
                            "ticks": {"type": "ARRAY", "items": {"type": "NUMBER"}},
                            "break": {"type": "ARRAY", "items": {"type": "NUMBER"}},
                        },
                        "required": ["type", "ticks"],
                    },
                },
                "required": ["x_axis", "y_axis"],
            }
        },
        "required": ["axis_info"],
    },
}

## Step 3: Run extraction and display overlay


In [None]:
# Step 1: Fetch extracted series rows from the source data table
print('[1/8] Fetching extracted series rows from Jinko...')
core_id = jinko.get_core_item_id(shortId=DATA_TABLE_SID)
result = jinko.make_request(
    f"/core/v2/data_table_manager/data_table/{core_id['id']}/snapshots/{core_id['snapshotId']}/export",
    method="POST",
).json()

# Drop rowId immediately (not used downstream, not shown to users)
result = [{k: v for k, v in row.items() if k != 'rowId'} for row in result]
print(f"Fetched {len(result)} rows (without rowId).")

# Step 2: Download chart image used for calibration and overlay (in memory only)
print('[2/8] Downloading source chart image (in memory)...')
img, image_bytes, image_mime = fetch_images_from_source_sid(
    extract_sid=EXTRACT_SID,
    token=token,
    graphql_url=graphql_url,
)

# Step 3: Show fetched inputs early (table + image)
print('[3/8] Displaying fetched inputs...')
print('Extracted source rows (without rowId):')
display(pd.DataFrame(result))

print('Fetched source image (used for OCR/calibration):')
plt.figure(figsize=(10, 6))
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.title('Fetched source image')
plt.axis('off')
plt.show()

# Step 4: Build image payload for LLM and detect axis anchors from pixels
print('[4/8] Preparing image payload and detecting axes...')
image_part = types.Part.from_bytes(data=image_bytes, mime_type=image_mime)
x_axis_y, y_axis_x, _, _ = _detect_axes(img)
print(f'Detected axis anchors: x_axis_y={x_axis_y}, y_axis_x={y_axis_x}')

# Step 5: Define OCR and LLM jobs
print('[5/8] Defining OCR and LLM tasks...')


def run_ocr_job():
    print('  -> Running OCR tick extraction...')
    out = process_image(image_path=img, x_axis_y=x_axis_y, y_axis_x=y_axis_x)
    print('  -> OCR tick extraction done.')
    return out


def run_llm_job():
    print('  -> Running LLM axis interpretation...')
    response = CLIENT.models.generate_content(
        model=MODEL_NAME,

        contents=[image_part, CALIBRATION_USER_PROMPT],
        config=types.GenerateContentConfig(
            system_instruction=CALIBRATION_SYSTEM_PROMPT,
            thinking_config=types.ThinkingConfig(thinking_level="low"),
            tools=[create_function_declaration(axis_properties_tool_def)],
        ),
    )

    for call in getattr(response, "function_calls", []) or []:
        if call.name == "axis_properties_tool":
            print('  -> LLM axis interpretation done.')
            return dict(call.args["axis_info"])

    raise RuntimeError("LLM did not call axis_properties_tool")


# Step 6: Run OCR and LLM in parallel
print('[6/8] Executing OCR and LLM in parallel...')
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as ex:
    f_ocr = ex.submit(run_ocr_job)
    f_llm = ex.submit(run_llm_job)
    ocr_results = f_ocr.result(timeout=120)
    print('  -> OCR future resolved.')
    axis_info = f_llm.result(timeout=120)
    print('  -> LLM future resolved.')
print('OCR and LLM completed.')

# Step 7: Build calibrated axes from OCR + LLM outputs
print('[7/8] Building calibrated axes...')
cal_x, cal_y = axis_properties_tool(
    ocr_results=ocr_results,
    image=img,
    axis_info=axis_info,
)

print("=== FINAL AXIS CALIBRATION ===")
print("Technical Note: pixel_value = a * value +b")
print(f"X axis -> a={cal_x.a:.6g}, b={cal_x.b:.6g}, mode={cal_x.mode}")
print(f"Y axis -> a={cal_y.a:.6g}, b={cal_y.b:.6g}, mode={cal_y.mode}")

assert cal_x.a != 0, "Invalid X-axis calibration"
assert cal_y.a != 0, "Invalid Y-axis calibration"

if cal_x.mode.startswith("log"):
    assert cal_x.a > 0, "Log X-axis must increase left->right"
if cal_y.mode.startswith("log"):
    assert cal_y.a < 0, "Log Y-axis must increase upward"

# Step 8: Normalize extracted rows and display overlay in notebook
print('[8/8] Rendering overlay in notebook...')
normalized_result = normalize_result_to_series(result)
show_extracted_series_overlay(
    result=normalized_result,
    image_path=img,
    x_cal=cal_x,
    y_cal=cal_y,
    axis_info=axis_info,
)