# Plot data extraction demo

This notebook demonstrates how to extract data points from a plot using the plot data extraction endpoint.

The currrent version only supports line plots without a legend, similar to the example.

In [None]:
from jupyter_innotater import *
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
import numpy as np
import os
from cognite.client import CogniteClient
import json
from getpass import getpass
import pprint
pp = pprint.PrettyPrinter(indent=2)

## Create a client

To get access to your project, replace "yourproject" with your project name in the next cell. 

When you create the CogniteClient below, getpass will ask for your API key in an extra password field. Simply paste it in and press shift+enter.

In [None]:
project = 'your_project'
api_key = getpass("Please enter API key: ")
client = CogniteClient(project=project,
                       api_key=api_key,
                       client_name="dshub",
                      )

## Load an plot example image from google drive
Load an image which includes a plot, draw a bounding box coinciding with the axes of the plot and copy the value (4 numeric values separated by commas) of the bounding box to the cell below.

In [None]:
!pip install gdown
!gdown https://drive.google.com/uc?id=1nVmRoXM5ZYYRTAQAiJbwH09bZcurZuBk

In [None]:
image = "./choke_flow_curves.jpg"
annotations = np.zeros((1, 4))

Innotater(
    ImageInnotation([image], path="./"),
    BoundingBoxInnotation(annotations),
)

## Crop the plot indicated by the bounding box from image 

In [None]:
# Copy the values of the above boundingbox here
x1, y1, w, h = 195, 914, 329, 264
# Please input the scale of x-axis and y-axis below, otherwise the default values will be applied
x_min=0
x_max=100
y_min=0
y_max=40
# Please specify the number of curves in this plot, the default example has 3 curves
num_curves=3
plot_image = Image.open(image).crop((x1, y1, x1 + w, y1 + h)).convert("RGB")
plot_image

## Convert the image to base64 string

In [None]:
from io import BytesIO
import base64
def image_to_base64_str(image: Image) -> str:
    """Convert image to base64 string
    Parameters
    ----------
    image : image

    Returns: base64 string of the image
    -------

    """
    im_file = BytesIO()
    image.save(im_file, format="JPEG")
    im_bytes = im_file.getvalue()
    return str(base64.b64encode(im_bytes), "utf-8")

image_string = image_to_base64_str(plot_image)

## Plot data extraction endpoint

### post a request

In [None]:
api_url = f"/api/playground/projects/{project}/context/plotextractor"
extractdata_url = f"{api_url}/extractdata"

post_body = {
    "plotImage": image_string,
    "plotAxes": {"xMin": x_min, "xMax": x_max, "yMin": y_min, "yMax": y_max},
    "numCurves": num_curves
}

res = client.post(url = extractdata_url, json = post_body)
job_id = json.loads(res.text)["jobId"]
print("jobId:", job_id)

### get a job

In [None]:
import time
res = client.get(url = f"{api_url}/{job_id}")
while json.loads(res.text)["status"] != "Completed":
        time.sleep(2)
        res = client.get(url = f"{api_url}/{job_id}")

result = json.loads(res.text)["items"]
print(f"{job_id}  {json.loads(res.text)['status']}")

### Print the extracted points

In [None]:
res.text

### Plot curves using the extracted points
This plot below is reconstructed using the extracted points

In [None]:
extracted_curves=json.loads(res.text)['items']
for curve in extracted_curves:
    plt.plot(curve.get("xValues"),curve.get("yValues"))

### Plot the extracted points on the original plot
You can verify the extracted points by drawing them on the original plot.

In [None]:
def image_with_prediction(
    image,
    curves,
    point_size: int = 1,
    colors: list = ["blue", "green", "magenta", "cyan", "orange"],
    include_image: bool = True,
    noise_color=None,
):
    def plot_points(curve, color):
        for x, y in zip(curve.get("xPositions"), curve.get("yPositions")):
            draw.ellipse((x - point_size, y - point_size, x + point_size, y + point_size), fill=color)
        
    if curves is None:
        raise NoPredictionException("No points are extracted.")

    if isinstance(colors, str):
        colors = [colors] * len(curves)

    if include_image:
        image_with_prediction = image.copy().convert("RGB")
    else:
        image_with_prediction = Image.new("RGB", image.size, (255, 255, 255))
    draw = ImageDraw.Draw(image_with_prediction)
    for curve, curve_color in zip(curves, colors):
        plot_points(curve, curve_color)
    if noise_color is not None:
        plot_points(self._noise, noise_color)
        
    return image_with_prediction

In [None]:
image_with_prediction(plot_image, extracted_curves)