Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# InkSight: Oﬄine-to-Online Handwriting Conversion by Teaching Vision-Language Models to Read and Write
<p align="center">
  <a href="https://research.google/blog/a-return-to-hand-written-notes-by-learning-to-read-write/">
    <img src="https://img.shields.io/badge/Google_Research_Blog-333333?&logo=google&logoColor=white" alt="Google Research Blog">
  </a>
  <a href="https://arxiv.org/abs/2402.05804">
    <img src="https://img.shields.io/badge/Read_the_Paper-4CAF50?&logo=arxiv&logoColor=white" alt="Read the Paper">
  </a>
  <a href="https://huggingface.co/spaces/Derendering/Model-Output-Playground">
    <img src="https://img.shields.io/badge/Output_Playground-007acc?&logo=huggingface&logoColor=white" alt="Try Demo on Hugging Face">
  </a>
    <a href="https://charlieleee.github.io/publication/inksight/">
    <img src="https://img.shields.io/badge/🔗_Project_Page-FFA500?&logo=link&logoColor=white" alt="Project Page">
  </a>
  <a href="https://huggingface.co/datasets/Derendering/InkSight-Derenderings">
    <img src="https://img.shields.io/badge/Dataset-InkSight-40AF40?&logo=huggingface&logoColor=white" alt="Hugging Face Dataset">
  </a>
</p>


In [None]:
# @title Dependencies
from IPython.display import Markdown
import time
import os

display(Markdown("## 📦 Installing required packages...\nThis may take a minute. Please wait..."))

!sudo apt -qq install tesseract-ocr
!uv pip install -q --system "tensorflow[and-cuda]==2.17.0" "tensorflow-text==2.17.0" pytesseract "tf-keras==2.17.0" "python-doctr[tf,viz]==0.10.0"
display(Markdown("✅ **Installation complete!**"))
time.sleep(1)

display(Markdown("""
---
### 🔄 Restarting Runtime
To finalize the installation, we need to restart the Colab runtime.

> **Why?** TensorFlow and system-level packages need a restart to properly initialize with new dependencies.

⏳ Restarting in 3 seconds...

✅ Restartng done! **Please continue to the following cells**
"""))

time.sleep(3)

# Kill the current process to force a runtime restart
os.kill(os.getpid(), 9)

## 📦 Installing required packages...
This may take a minute. Please wait...

In [None]:
# @title docTR Preparation

from doctr.io import DocumentFile
from doctr.models import ocr_predictor
predictor = ocr_predictor(pretrained=True)
print("doctr predictor loaded.")

In [None]:
# @title ####Utils
from IPython.display import SVG, display
import json
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from xml.dom import minidom
import gdown
import os
from PIL import Image, ImageEnhance, ImageDraw
import matplotlib.animation as animation
import copy
import colorsys
import matplotlib.colors as mcolors
from matplotlib.collections import LineCollection
from matplotlib.patheffects import withStroke
import random
import IPython
import warnings
warnings.filterwarnings("ignore")
import datetime
from pathlib import Path
import json
from matplotlib.figure import Figure
from io import BytesIO
from matplotlib.animation import FuncAnimation, FFMpegWriter, PillowWriter
import requests
import zipfile
import base64


def get_svg_content(svg_path):
    with open(svg_path, "r") as file:
        return file.read()


def download_file(url, filename):
    response = requests.get(url)
    with open(filename, "wb") as f:
        f.write(response.content)


def unzip_file(filename, extract_to="."):
    with zipfile.ZipFile(filename, "r") as zip_ref:
        zip_ref.extractall(extract_to)


def get_base64_encoded_gif(gif_path):
    with open(gif_path, "rb") as gif_file:
        return base64.b64encode(gif_file.read()).decode("utf-8")


def load_and_pad_img_dir(file_dir):
    image_path = os.path.join(file_dir)
    image = Image.open(image_path)
    width, height = image.size
    ratio = min(224 / width, 224 / height)
    image = image.resize((int(width * ratio), int(height * ratio)))
    width, height = image.size
    if height < 224:
        # If width is shorter than height pad top and bottom.
        top_padding = (224 - height) // 2
        bottom_padding = 224 - height - top_padding
        padded_image = Image.new("RGB", (width, 224), (255, 255, 255))
        padded_image.paste(image, (0, top_padding))
    else:
        # Otherwise pad left and right.
        left_padding = (224 - width) // 2
        right_padding = 224 - width - left_padding
        padded_image = Image.new("RGB", (224, height), (255, 255, 255))
        padded_image.paste(image, (left_padding, 0))
    return padded_image


def plot_ink(ink, ax, lw=1.8, input_image=None, with_path=True, path_color="white"):
    if input_image is not None:
        img = copy.deepcopy(input_image)
        enhancer = ImageEnhance.Brightness(img)
        img = enhancer.enhance(0.45)
        ax.imshow(img)

    base_colors = plt.cm.get_cmap("rainbow", len(ink.strokes))

    for i, stroke in enumerate(ink.strokes):
        x, y = np.array(stroke.x), np.array(stroke.y)

        base_color = base_colors(len(ink.strokes) - 1 - i)
        hsv_color = colorsys.rgb_to_hsv(*base_color[:3])

        darker_color = colorsys.hsv_to_rgb(
            hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65)
        )
        colors = [
            mcolors.to_rgba(darker_color, alpha=1 - (0.5 * j / len(x)))
            for j in range(len(x))
        ]

        points = np.array([x, y]).T.reshape(-1, 1, 2)
        segments = np.concatenate([points[:-1], points[1:]], axis=1)

        lc = LineCollection(segments, colors=colors, linewidth=lw)
        if with_path:
            lc.set_path_effects(
                [withStroke(linewidth=lw * 1.25, foreground=path_color)]
            )
        ax.add_collection(lc)

    ax.set_xlim(0, 224)
    ax.set_ylim(0, 224)
    ax.invert_yaxis()


def plot_ink_to_video(
    ink, output_name, lw=1.8, input_image=None, path_color="white", fps=30
):
    fig, ax = plt.subplots(figsize=(4, 4), dpi=150)

    if input_image is not None:
        img = copy.deepcopy(input_image)
        enhancer = ImageEnhance.Brightness(img)
        img = enhancer.enhance(0.45)
        ax.imshow(img)

    ax.set_xlim(0, 224)
    ax.set_ylim(0, 224)
    ax.invert_yaxis()
    ax.axis("off")

    base_colors = plt.cm.get_cmap("rainbow", len(ink.strokes))
    all_points = sum([len(stroke.x) for stroke in ink.strokes], 0)

    def update(frame):
        ax.clear()
        if input_image is not None:
            ax.imshow(img)
        ax.set_xlim(0, 224)
        ax.set_ylim(0, 224)
        ax.invert_yaxis()
        ax.axis("off")

        points_drawn = 0
        for stroke_index, stroke in enumerate(ink.strokes):
            x, y = np.array(stroke.x), np.array(stroke.y)
            points = np.array([x, y]).T.reshape(-1, 1, 2)
            segments = np.concatenate([points[:-1], points[1:]], axis=1)

            base_color = base_colors(len(ink.strokes) - 1 - stroke_index)
            hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
            darker_color = colorsys.hsv_to_rgb(
                hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65)
            )
            visible_segments = (
                segments[: frame - points_drawn]
                if frame - points_drawn < len(segments)
                else segments
            )
            colors = [
                mcolors.to_rgba(
                    darker_color, alpha=1 - (0.5 * j / len(visible_segments))
                )
                for j in range(len(visible_segments))
            ]

            if len(visible_segments) > 0:
                lc = LineCollection(visible_segments, colors=colors, linewidth=lw)
                lc.set_path_effects(
                    [withStroke(linewidth=lw * 1.25, foreground=path_color)]
                )
                ax.add_collection(lc)

            points_drawn += len(segments)
            if points_drawn >= frame:
                break

    ani = FuncAnimation(fig, update, frames=all_points + 1, blit=False)
    Writer = FFMpegWriter(fps=fps)
    plt.tight_layout()
    ani.save(output_name, writer=Writer)
    plt.close(fig)


class Stroke:
    def __init__(self, list_of_coordinates=None) -> None:
        self.x = []
        self.y = []
        if list_of_coordinates:
            for point in list_of_coordinates:
                self.x.append(point[0])
                self.y.append(point[1])

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        return (self.x[index], self.y[index])


class Ink:
    def __init__(self, list_of_strokes=None) -> None:
        self.strokes = []
        if list_of_strokes:
            self.strokes = list_of_strokes

    def __len__(self):
        return len(self.strokes)

    def __getitem__(self, index):
        return self.strokes[index]


def inkml_to_ink(inkml_file):
    """Convert inkml file to Ink"""
    tree = ET.parse(inkml_file)
    root = tree.getroot()

    inkml_namespace = {"inkml": "http://www.w3.org/2003/InkML"}

    strokes = []

    for trace in root.findall("inkml:trace", inkml_namespace):
        points = trace.text.strip().split()
        stroke_points = []

        for point in points:
            x, y = point.split(",")
            stroke_points.append((float(x), float(y)))
        strokes.append(Stroke(stroke_points))
    return Ink(strokes)


def parse_inkml_annotations(inkml_file):
    tree = ET.parse(inkml_file)
    root = tree.getroot()

    annotations = root.findall(".//{http://www.w3.org/2003/InkML}annotation")

    annotation_dict = {}

    for annotation in annotations:
        annotation_type = annotation.get("type")
        annotation_text = annotation.text

        annotation_dict[annotation_type] = annotation_text

    return annotation_dict


def pregenerate_videos(video_cache_dir):
    datasets = ["IAM", "IMGUR5K", "HierText"]
    models = ["Small-i", "Large-i", "Small-p"]
    query_modes = ["d+t", "r+d", "vanilla"]
    for Dataset in datasets:
        for Model in models:
            inkml_path_base = f"./derendering_supp/{Model.lower()}_{Dataset}_inkml"
            for mode in query_modes:
                path = f"./derendering_supp/{Dataset}/images_sample"
                if not os.path.exists(path):
                    continue
                samples = os.listdir(path)
                for name in tqdm(
                    samples, desc=f"Generating {Model}-{Dataset}-{mode} videos"
                ):
                    example_id = name.strip(".png")
                    inkml_file = os.path.join(
                        inkml_path_base, mode, f"{example_id}.inkml"
                    )
                    if not os.path.exists(inkml_file):
                        continue
                    video_filename = f"{Model}_{Dataset}_{mode}_{example_id}.mp4"
                    video_filepath = video_cache_dir / video_filename
                    if not video_filepath.exists():
                        img_path = os.path.join(path, name)
                        img = load_and_pad_img_dir(img_path)
                        ink = inkml_to_ink(inkml_file)
                        plot_ink_to_video(ink, str(video_filepath), input_image=img)



def show_system():
    display(SVG('derendering_supp/derender_diagram.svg'))


def load_and_pad_img_dir(file_dir):
    image_path = os.path.join(file_dir)
    image = Image.open(image_path)
    width, height = image.size
    ratio = min(224 / width, 224 / height)
    image = image.resize((int(width * ratio), int(height * ratio)))
    width, height = image.size
    if height < 224:
        # If width is shorter than height pad top and bottom.
        top_padding = (224 - height) // 2
        bottom_padding = 224 - height - top_padding
        padded_image = Image.new('RGB', (width, 224), (255, 255, 255))
        padded_image.paste(image, (0, top_padding))
    else:
        # Otherwise pad left and right.
        left_padding = (224 - width) // 2
        right_padding = 224 - width - left_padding
        padded_image = Image.new('RGB', (224, height), (255, 255, 255))
        padded_image.paste(image, (left_padding, 0))
    return padded_image

def plot_ink(ink, ax, lw=1.8, input_image=None, with_path=True, path_color='white'):
  if input_image is not None:
    img = copy.deepcopy(input_image)
    enhancer = ImageEnhance.Brightness(img)
    img = enhancer.enhance(0.45)
    ax.imshow(img)

  base_colors = plt.cm.get_cmap('rainbow', len(ink.strokes))

  for i, stroke in enumerate(ink.strokes):
    x, y = np.array(stroke.x), np.array(stroke.y)

    base_color = base_colors(len(ink.strokes) - 1 - i)
    hsv_color = colorsys.rgb_to_hsv(*base_color[:3])

    darker_color = colorsys.hsv_to_rgb(hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65))
    colors = [mcolors.to_rgba(darker_color, alpha=1 - (0.5 * j / len(x))) for j in range(len(x))]

    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)

    lc = LineCollection(segments, colors=colors, linewidth=lw)
    if with_path:
      lc.set_path_effects([withStroke(linewidth=lw*1.25, foreground=path_color)])
    ax.add_collection(lc)

  ax.set_xlim(0, 224)
  ax.set_ylim(0, 224)
  ax.invert_yaxis()

def plot_ink_to_gif(ink, output_filename, lw=1.8, input_image=None, path_color='white', fps=30):
    fig, ax = plt.subplots(figsize=(4, 4), dpi=150)

    if input_image is not None:
        img = copy.deepcopy(input_image)
        enhancer = ImageEnhance.Brightness(img)
        img = enhancer.enhance(0.45)
        ax.imshow(img)

    base_colors = plt.cm.get_cmap('rainbow', len(ink.strokes))

    def get_segments(stroke):
        x, y = np.array(stroke.x), np.array(stroke.y)
        points = np.array([x, y]).T.reshape(-1, 1, 2)
        return np.concatenate([points[:-1], points[1:]], axis=1)

    all_segments = [get_segments(stroke) for stroke in ink.strokes]
    max_frames = sum(len(segments) for segments in all_segments)

    def update(frame):
        current_frame = 0
        for i, segments in enumerate(all_segments):
            if current_frame + len(segments) > frame:
                segment_index = frame - current_frame
                base_color = base_colors(len(ink.strokes) - 1 - i)
                hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
                darker_color = colorsys.hsv_to_rgb(
                    hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65))
                colors = [mcolors.to_rgba(
                    darker_color, alpha=1 - (0.5 * j / len(segments))) for j in range(len(segments))]

                lc = LineCollection(
                    segments[:segment_index+1], colors=colors[:segment_index+1], linewidth=lw)
                if path_color:
                    lc.set_path_effects(
                        [withStroke(linewidth=lw*1.25, foreground=path_color)])

                ax.add_collection(lc)
                break

            current_frame += len(segments)

        return ax.collections

    ax.set_xlim(0, 224)
    ax.set_ylim(0, 224)
    ax.invert_yaxis()
    plt.tight_layout()
    ax.axis('off')
    ani = animation.FuncAnimation(fig, update, frames=max_frames, blit=True)
    ani.save(output_filename, writer='imagemagick', fps=fps)

    plt.close(fig)


class Stroke:
    def __init__(self, list_of_coordinates=None) -> None:
        self.x = []
        self.y = []
        if list_of_coordinates:
            for point in list_of_coordinates:
                self.x.append(point[0])
                self.y.append(point[1])

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        return (self.x[index], self.y[index])


class Ink:
    def __init__(self, list_of_strokes=None) -> None:
        self.strokes = []
        if list_of_strokes:
            self.strokes = list_of_strokes

    def __len__(self):
        return len(self.strokes)

    def __getitem__(self, index):
        return self.strokes[index]

def inkml_to_ink(inkml_file):
    """ Convert inkml file to Ink"""
    tree = ET.parse(inkml_file)
    root = tree.getroot()

    inkml_namespace = {'inkml': 'http://www.w3.org/2003/InkML'}

    strokes = []

    for trace in root.findall('inkml:trace', inkml_namespace):
        points = trace.text.strip().split()
        stroke_points = []

        for point in points:
            x, y = point.split(',')
            stroke_points.append((float(x), float(y)))
        strokes.append(Stroke(stroke_points))
    return Ink(strokes)

def parse_inkml_annotations(inkml_file):
  tree = ET.parse(inkml_file)
  root = tree.getroot()

  annotations = root.findall('.//{http://www.w3.org/2003/InkML}annotation')

  annotation_dict = {}

  for annotation in annotations:
    annotation_type = annotation.get('type')
    annotation_text = annotation.text

    annotation_dict[annotation_type] = annotation_text

  return annotation_dict

def plot_ink_to_video(
    ink, output_name, lw=1.8, input_image=None, path_color="white", fps=30
):
    fig, ax = plt.subplots(figsize=(4, 4), dpi=150)

    if input_image is not None:
        img = copy.deepcopy(input_image)
        enhancer = ImageEnhance.Brightness(img)
        img = enhancer.enhance(0.45)
        ax.imshow(img)

    ax.set_xlim(0, 224)
    ax.set_ylim(0, 224)
    ax.invert_yaxis()
    ax.axis("off")

    base_colors = plt.cm.get_cmap("rainbow", len(ink.strokes))
    all_points = sum([len(stroke.x) for stroke in ink.strokes], 0)

    def update(frame):
        ax.clear()
        if input_image is not None:
            ax.imshow(img)
        ax.set_xlim(0, 224)
        ax.set_ylim(0, 224)
        ax.invert_yaxis()
        ax.axis("off")

        points_drawn = 0
        for stroke_index, stroke in enumerate(ink.strokes):
            x, y = np.array(stroke.x), np.array(stroke.y)
            points = np.array([x, y]).T.reshape(-1, 1, 2)
            segments = np.concatenate([points[:-1], points[1:]], axis=1)

            base_color = base_colors(len(ink.strokes) - 1 - stroke_index)
            hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
            darker_color = colorsys.hsv_to_rgb(
                hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65)
            )
            visible_segments = (
                segments[: frame - points_drawn]
                if frame - points_drawn < len(segments)
                else segments
            )
            colors = [
                mcolors.to_rgba(
                    darker_color, alpha=1 - (0.5 * j / len(visible_segments))
                )
                for j in range(len(visible_segments))
            ]

            if len(visible_segments) > 0:
                lc = LineCollection(visible_segments, colors=colors, linewidth=lw)
                lc.set_path_effects(
                    [withStroke(linewidth=lw * 1.25, foreground=path_color)]
                )
                ax.add_collection(lc)

            points_drawn += len(segments)
            if points_drawn >= frame:
                break

    ani = animation.FuncAnimation(fig, update, frames=all_points + 1, blit=False)
    Writer = animation.FFMpegWriter(fps=fps)
    ani.save(output_name, writer=Writer)
    plt.close(fig)

In [None]:
# @title #### Preparation
!rm -rf derendering_supp/ derendering_supp.zip __MACOSX
!wget https://storage.googleapis.com/derendering_model/derendering_supp.zip
!unzip -q derendering_supp.zip

In [None]:
# @title # InkSight Overview
show_system()

# Visualize the generated Digital Ink (Pre-saved Model Outputs)

We provide the pre-saved model outputs from three variants (**Small-i**, **Small-p**, **Large-i**) of our models as described in the paper with the three inference modes corresponding to each column below (Derender with Text, Recognize and Derender, Vanilla Derender)

# Inference with the Public Small-p Model

This section demonstrates inference examples using the Small-p model from our paper, both with the fullpage InkSight pipeline (with open source Tesseract OCR or docTR) and at the word-level.


In [None]:
# @title Notice
from IPython.display import HTML
display(HTML('<p style="font-size:20px; font-weight:bold; color:red; background-color:lightgray; padding:10px; width:50%">Notice on Model Release</p>'))
display(HTML('<p style="font-size:16px; width:50%">Model download will be available once the release process is complete. For optimal performance, please use a T4 GPU runtime in colab. </p>'))


In [None]:
# @title Utils
import tensorflow as tf
import tensorflow_text
import json
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from xml.dom import minidom
import gdown
import os
import matplotlib.animation as animation
import copy
from PIL import ImageEnhance, Image, ImageDraw
import colorsys
import matplotlib.colors as mcolors
from matplotlib.collections import LineCollection
from matplotlib.patheffects import withStroke
import random
import warnings
import re
import time
import io
import pytesseract
from tqdm import tqdm
from copy import deepcopy
from doctr.io import DocumentFile
from doctr.models import ocr_predictor


def get_box(data_idx, data):
    min_x = left = data['left'][data_idx]
    min_y = top = data['top'][data_idx]
    width = data['width'][data_idx]
    angle = 0
    height = data['height'][data_idx]
    angle = angle / 180.0 * np.pi
    s_x = left + np.cos(angle) * width
    s_y = top - np.sin(angle) * width
    f_x = (
        left + np.sin(angle) * height
    )
    f_y = top + np.cos(angle) * height
    max_x = (
        left
        + np.cos(angle) * width
        + np.sin(angle) * height
    )
    max_y = (
        top
        - np.sin(angle) * width
        + np.cos(angle) * height
    )
    return min_x, min_y, s_x, s_y, f_x, f_y, max_x, max_y

def rotate_crop_scale_and_pad(original, data_idx, data, pad_black=True):
    angle = 0
    height = data['height'][data_idx]
    width = data['width'][data_idx]
    min_x, min_y, s_x, s_y, f_x, f_y, _, _ = get_box(data_idx, data)
    max_x = min_x + width
    max_y = min_y + height

    output = original.rotate(angle, center=(min_x, min_y))
    crop = output.crop((min_x, min_y, max_x, max_y))

    ratio = min(224 / crop.width, 224 / crop.height)
    new_crop = crop.resize((int(crop.width * ratio), int(crop.height * ratio)))
    new_crop_np = np.array(new_crop)

    pixel_1 = new_crop_np[1, 1]
    pixel_2 = new_crop_np[1, new_crop_np.shape[-1] - 1]
    pixel_3 = new_crop_np[new_crop_np.shape[0] - 1, 1]
    pixel_4 = new_crop_np[new_crop_np.shape[0] - 1, new_crop_np.shape[-1] - 1]
    avg = np.rint(np.mean([pixel_1, pixel_2, pixel_3, pixel_4], axis=0)).astype(
        np.uint8
    )

    color = tuple(avg) if not pad_black else (0, 0, 0)
    new_image = Image.new(new_crop.mode, (224, 224), color)
    dx = (224 - new_crop.width) // 2
    dy = (224 - new_crop.height) // 2
    new_image.paste(new_crop, (dx, dy))
    return new_image, ratio, dx, dy, min_x, min_y, angle, crop


def extract_fullpage(img_source, option="tesseract"):
    ret_imgs = []
    img_info = []
    img_bbox = []
    if isinstance(img_source, (bytes, bytearray)):
        input_image = Image.open(io.BytesIO(img_source))
    elif isinstance(img_source, str):
        input_image = Image.open(img_source)
    elif isinstance(img_source, Image.Image):
        input_image = img_source
    else:
        raise TypeError("img_source must be bytes, str, or PIL.Image")
    if option == "tesseract":
        data = pytesseract.image_to_data(input_image, output_type=pytesseract.Output.DICT)
        for i in tqdm(range(len(data['text']))):
            if data['text'][i].strip() != '':  # Filters out empty text results
                new_image, ratio, dx, dy, min_x, min_y, angle, _ = (
                    rotate_crop_scale_and_pad(input_image, i, data, pad_black=True)
                )
                x, y, w, h = data['left'][i], data['top'][i], data['width'][i], data['height'][i]
                ret_imgs.append(new_image)
                img_info.append((ratio, dx, dy, min_x, min_y, angle))
                img_bbox.append((x, y, w, h))
    elif option == "doctr":
        doc = DocumentFile.from_images(img_source)
        predictor = ocr_predictor(pretrained=True)
        print("doctr predictor loaded.")
        result = predictor(doc)

        for page in result.pages:
            for block in page.blocks:
                for line in block.lines:
                    for word in line.words:
                        if word.value.strip() != '':
                            coords = word.geometry
                            x0, y0 = int(coords[0][0] * input_image.width), int(coords[0][1] * input_image.height)
                            x1, y1 = int(coords[1][0] * input_image.width), int(coords[1][1] * input_image.height)
                            w, h = x1 - x0, y1 - y0

                            w_expand = w * 0.1
                            h_expand = h * 0.1

                            x0 = max(0, x0 - w_expand)
                            y0 = max(0, y0 - h_expand)
                            x1 = min(input_image.width, x1 + w_expand)
                            y1 = min(input_image.height, y1 + h_expand)

                            w = x1 - x0
                            h = y1 - y0

                            x0, y0, w, h = map(int, [x0, y0, w, h])

                            # Create a mock data dictionary similar to tesseract's output
                            mock_data = {
                                'left': [x0],
                                'top': [y0],
                                'width': [w],
                                'height': [h],
                                'conf': [1.0],  # doctr doesn't provide confidence scores in the same way
                                'text': [word.value]
                            }

                            # Use the same processing function as tesseract
                            new_image, ratio, dx, dy, min_x, min_y, angle, crop = (
                                rotate_crop_scale_and_pad(input_image, 0, mock_data, pad_black=True)
                            )

                            ret_imgs.append(new_image)
                            img_info.append((ratio, dx, dy, min_x, min_y, angle))
                            img_bbox.append((x0, y0, w, h))

    print('\nFinal length: ', len(ret_imgs))

    print('\nFinal length: ', len(ret_imgs))

    # Draw the bboxes
    image = deepcopy(input_image)
    draw = ImageDraw.Draw(image)
    for bx in img_bbox:
        x, y, w, h = bx
        draw.rectangle([x, y, x + w, y + h], outline='red', width=2)

    return ret_imgs, img_info, image


warnings.filterwarnings("ignore")

def text_to_tokens(text) -> list[int]:
    pattern = r"<ink_token_(\d+)>"
    matches = re.findall(pattern, text)
    return [int(tok) for tok in matches]

def detokenize(tokens: list[int]) -> list[list[tuple[float, float]]]:
    coordinate_length = 224
    num_token_per_dimension = coordinate_length + 1
    vocabulary_size = num_token_per_dimension * 2 + 1
    start_token = num_token_per_dimension * 2

    if any([t < 0 or t >= vocabulary_size for t in tokens]):
        raise ValueError(
            f"Ink token indices should be between 0 and {vocabulary_size}"
        )
    idx = 0
    res = []
    current_stroke_tokens = []

    while idx < len(tokens):
        token = tokens[idx]
        if token == start_token:
            if current_stroke_tokens:
                res.append(current_stroke_tokens)
            current_stroke_tokens = []
            idx += 1
        else:
            if idx + 1 < len(tokens) and (tokens[idx + 1] != start_token):
                # Read in x and y coordinates.
                x = tokens[idx]
                y = tokens[idx + 1] - num_token_per_dimension
                # If the coordinates are valid, add them to detokenization ink.
                if (0 <= x <= coordinate_length) and (0 <= y <= coordinate_length):
                    current_stroke_tokens.append([x, y])
                idx += 2
            # If y doesn't exist or y is start_token, then skip this x.
            else:
                idx += 1
    if current_stroke_tokens:
        res.append(current_stroke_tokens)

    strokes = []
    for stroke in res:
        stroke_points = []
        for point in stroke:
            x, y = point
            stroke_points.append((x, y))
        strokes.append(Stroke(stroke_points))
    return Ink(strokes)

def load_and_pad_img(image):
    width, height = image.size
    ratio = min(224 / width, 224 / height)
    image = image.resize((int(width * ratio), int(height * ratio)))
    width, height = image.size
    if height < 224:
        # If width is shorter than height pad top and bottom.
        top_padding = (224 - height) // 2
        bottom_padding = 224 - height - top_padding
        padded_image = Image.new('RGB', (width, 224), (255, 255, 255))
        padded_image.paste(image, (0, top_padding))
    else:
        # Otherwise pad left and right.
        left_padding = (224 - width) // 2
        right_padding = 224 - width - left_padding
        padded_image = Image.new('RGB', (224, height), (255, 255, 255))
        padded_image.paste(image, (left_padding, 0))
    return padded_image

def scale_and_pad(original, pad_black=True):
    ratio = min(224 / original.width, 224 / original.height)
    original_np = np.array(original)
    new_crop = original.resize((int(original.width * ratio), int(original.height * ratio)))
    pixel_1 = original_np[1, 1]
    pixel_2 = original_np[1, original_np.shape[-1]-1]
    pixel_3 = original_np[original_np.shape[0]-1, 1]
    pixel_4 = original_np[original_np.shape[0]-1, original_np.shape[-1]-1]
    avg = np.rint(np.mean([pixel_1, pixel_2, pixel_3, pixel_4], axis=0)).astype(np.uint8)

    color = tuple(avg) if not pad_black else (0, 0, 0)
    new_image = Image.new(new_crop.mode, (224, 224), color)
    dx = (224 - new_crop.width) // 2
    dy = (224 - new_crop.height) // 2
    new_image.paste(new_crop, (dx, dy))
    return new_image, ratio, dx, dy, new_crop

def encode_images_in_batches(images, batch_size=32):
    def encode_image(image):
        image_np = np.array(image)[:, :, :3]
        encoded_jpeg = tf.io.encode_jpeg(image_np)
        return tf.reshape(encoded_jpeg, (1,)), image_np

    encoded_batches = []
    original_batches = []

    num_batches = len(images) // batch_size + (1 if len(images) % batch_size != 0 else 0)

    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = min(start_idx + batch_size, len(images))
        current_batch = images[start_idx:end_idx]

        encoded_batch = []
        original_batch = []
        for image in current_batch:
            encoded, original = encode_image(image)
            encoded_batch.append(encoded)
            original_batch.append(original)

        encoded_batches.append(tf.stack(encoded_batch))
        original_batches.append(np.stack(original_batch))

    return encoded_batches, original_batches

def unpad_unscale_unrotate_uncrop(ink, ratio, dx, dy, min_x, min_y, angle):
    transformed_strokes = []

    for stroke in ink:
        transformed_points = []
        for point in stroke:
            x_transformed = (point[0] - dx) / ratio
            y_transformed = (point[1] - dy) / ratio

            x_final = x_transformed + min_x
            y_final = y_transformed + min_y

            transformed_points.append((x_final, y_final))

        transformed_strokes.append(Stroke(transformed_points))

    transformed_ink = Ink(transformed_strokes)
    return transformed_ink

In [None]:
# @title Model Preparation (hugging face model)

# !wget https://storage.googleapis.com/derendering_model/small-p-cpu.zip
# !unzip small-p-cpu.zip
# model = tf.saved_model.load('small-p-cpu')

from huggingface_hub import from_pretrained_keras

model = from_pretrained_keras("Derendering/InkSight-Small-p")

In [None]:
# @title Word level Inference
# @markdown We use recognize and derender as default inference for demo and vanilla derender as fallback inference, check `Use_custom` to use your own **word-level** image
from PIL import Image
try:
    from google.colab import files
    in_colab = True
except ImportError:
    in_colab = False

Use_custom = False # @param {type:"boolean"}

if in_colab:
    if Use_custom:
        uploaded = files.upload()
        input_image = Image.open(io.BytesIO(uploaded[list(uploaded.keys())[0]]))
    else:
        url = "https://github.com/google-research/inksight/raw/main/test_inputs/word.jpg"
        response = requests.get(url)
        input_image = Image.open(BytesIO(response.content)).convert("RGB")
else:
    file_path = 'test_inputs/word.jpg'
    input_image = Image.open(file_path)

image, _, _, _, _ = scale_and_pad(input_image)

image

In [None]:
# @title Word-level inference
model = from_pretrained_keras("Derendering/InkSight-Small-p")
cf = model.signatures['serving_default']
demo_prompt = "Recognize and derender."
fall_back_prompt = "Derender the ink."

input_text = tf.constant([demo_prompt], dtype=tf.string)
image_encoded = tf.reshape(tf.io.encode_jpeg(np.array(image)[:, :, :3]), (1, 1))
output = cf(**{'input_text': input_text, 'image/encoded': image_encoded})
output_text = output["output_0"].numpy()[0][0].decode()
output_ink = detokenize(text_to_tokens(output_text))

# Check if the ink is empty and try fallback prompt
if len(output_ink.strokes) == 0:
    print('Empty output, trying fallback prompt')
    retry_input_text = tf.constant([fall_back_prompt], dtype=tf.string)
    retry_output = cf(**{'input_text': retry_input_text, 'image/encoded': image_encoded})
    retry_text = retry_output["output_0"].numpy()[0][0].decode()
    output_ink = detokenize(text_to_tokens(retry_text))

In [None]:
# @title Result Visualization
output_ink = detokenize(text_to_tokens(output['output_0'].numpy()[0][0].decode()))
fig, ax = plt.subplots()
plot_ink(output_ink, ax, input_image=load_and_pad_img(input_image))
plt.show()

In [None]:
# @title Full page pipeline with Tesseract or Doctr
# @markdown To reproduce the results presented in the paper, we recommend using the Google Cloud Vision API. However, for free alternatives, we provide guidance on achieving similar outcomes using Tesseract or Doctr.
# @markdown Check `Use_custom` to use your own **full-page** image.
from PIL import Image
import requests
from io import BytesIO
import io

try:
    from google.colab import files
    in_colab = True
except ImportError:
    in_colab = False

Use_custom = False  # @param {type:"boolean"}

if in_colab:
    if Use_custom:
        uploaded = files.upload()
        input_image = Image.open(io.BytesIO(uploaded[list(uploaded.keys())[0]]))
        file_path = list(uploaded.keys())[0]
    else:
        url = "https://github.com/google-research/inksight/raw/main/test_inputs/page.jpg"
        response = requests.get(url)
        response.raise_for_status()

        file_path = "/content/page.jpg"
        with open(file_path, "wb") as f:
            f.write(response.content)
        input_image = Image.open(file_path)

else:
    file_path = 'test_inputs/page.jpg'
    input_image = Image.open(file_path)

input_image

In [None]:
# @title Select word segmentor
Segmentor = "doctr" # @param ["tesseract", "doctr"]
word_imgs, word_info, bbox_img = extract_fullpage(file_path, option=Segmentor)
bbox_img

In [None]:
# @title Full Page Batch Inference
batchsize=32
output_inks=[]
model = from_pretrained_keras("Derendering/InkSight-Small-p")
cf = model.signatures['serving_default']
demo_prompt = "Recognize and derender."
fall_back_prompt = "Derender the ink."

input_text = tf.constant([demo_prompt], dtype=tf.string)
encode_word_imgs, original_word_imgs = encode_images_in_batches(word_imgs, batch_size=batchsize)
output_batches = []
t1 = time.perf_counter()
for batch_img in tqdm(encode_word_imgs):
    num_imgs_in_batch = batch_img.shape[0]

    input_text = tf.constant([demo_prompt] * num_imgs_in_batch, dtype=tf.string)
    output = cf(**{'input_text': input_text, 'image/encoded': batch_img})
    output_batches.append(output)

    # Process each image in the batch
    for idx in range(output_batches[-1]["output_0"].shape[0]):
        output_text = output_batches[-1]["output_0"].numpy()[idx][0].decode()
        output_ink = detokenize(text_to_tokens(output_text))

        # Check if the ink is empty
        if len(output_ink.strokes) == 0:
            retry_input_text = tf.constant([fall_back_prompt], dtype=tf.string)
            retry_img = tf.expand_dims(batch_img[idx], 0)
            retry_output = cf(**{'input_text': retry_input_text, 'image/encoded': retry_img})
            retry_text = retry_output["output_0"].numpy()[0][0].decode()
            output_ink = detokenize(text_to_tokens(retry_text))


        output_inks.append(output_ink)
t2 = time.perf_counter()

In [None]:
# @title Result Visualization

from PIL import Image

def calculate_adaptive_line_width(word_info):
    min_sides = []
    for ratio, dx, dy, min_x, min_y, angle in word_info:
        original_height = (224 - 2*dy) / ratio
        original_width = (224 - 2*dx) / ratio

        min_side = min(original_height, original_width)
        min_sides.append(min_side)

    mean_min_side = np.mean(min_sides)
    lw = mean_min_side / 150

    return lw

max_side = max(input_image.size)
lw = calculate_adaptive_line_width(word_info)
# Set maximum dimensions for the figure
MAX_FIG_WIDTH = 80
MAX_FIG_HEIGHT = 80

# Set path color for the ink
path_color="white"

# Calculate scaled dimensions while maintaining aspect ratio
img_width, img_height = input_image.size
aspect_ratio = img_height / img_width

if aspect_ratio < 0.5:
    fig_height = 5
    fig_width = min(fig_height / aspect_ratio, MAX_FIG_WIDTH)
elif aspect_ratio > 2:
    fig_width = min(MAX_FIG_WIDTH, 15)
    fig_height = min(fig_width * aspect_ratio, MAX_FIG_HEIGHT)
else:
    fig_width = min(img_width / 100, MAX_FIG_WIDTH)  # Scale down by 100
    fig_height = fig_width * aspect_ratio


fig, ax = plt.subplots(3, 1, figsize=(fig_width, fig_height))



def resize_image_if_needed(img, max_size=2000):
    w, h = img.size
    if max(w, h) > max_size:
        scale = max_size / max(w, h)
        new_size = (int(w * scale), int(h * scale))
        return img.resize(new_size, Image.LANCZOS)
    return img

def scale_ink_coordinates(ink, original_size, new_size):
    scale_x = new_size[0] / original_size[0]
    scale_y = new_size[1] / original_size[1]

    scaled_ink = copy.deepcopy(ink)
    for stroke in scaled_ink.strokes:
        stroke.x = [x * scale_x for x in stroke.x]
        stroke.y = [y * scale_y for y in stroke.y]
    return scaled_ink

# Resize images for display
original_size = input_image.size
display_input = resize_image_if_needed(input_image)
display_bbox = resize_image_if_needed(bbox_img)
new_size = display_input.size

all_inks = []
for ink, (ratio, dx, dy, min_x, min_y, angle) in zip(output_inks, word_info):
    recover_ink = unpad_unscale_unrotate_uncrop(ink, ratio, dx, dy, min_x, min_y, angle)
    scaled_ink = scale_ink_coordinates(recover_ink, original_size, new_size)
    all_inks.append(scaled_ink)

ax[0].imshow(display_input)
ax[0].set_title('Input Full page image')
ax[1].imshow(display_bbox)
ax[1].set_title("Bounding boxes")
ax[2].set_title("InkSight Result (Public Small-p)")

# Create darkened background
enhancer = ImageEnhance.Brightness(display_input)
dark_img = enhancer.enhance(0.45)
ax[2].imshow(dark_img)

# Optimize stroke drawing
for ink in tqdm(all_inks, desc="Drawing inks", total=len(all_inks)):
    base_colors = plt.cm.get_cmap('rainbow', max(len(ink.strokes), 1))
    for i, stroke in enumerate(ink.strokes):
        x, y = np.array(stroke.x), np.array(stroke.y)

        # Skip if too few points
        if len(x) < 2:
            continue

        base_color = base_colors(len(ink.strokes) - 1 - i)
        hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
        darker_color = colorsys.hsv_to_rgb(hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65))

        # Simplify color array creation
        alpha_values = np.linspace(0.5, 1.0, len(x))
        colors = [mcolors.to_rgba(darker_color, alpha=a) for a in alpha_values]

        points = np.array([x, y]).T.reshape(-1, 1, 2)
        segments = np.concatenate([points[:-1], points[1:]], axis=1)

        lc = LineCollection(segments, colors=colors, linewidth=lw)
        lc_two = LineCollection(segments, colors=colors, linewidth=lw)

        # Simplified path effects
        lc.set_path_effects([withStroke(linewidth=lw*1.25, foreground=path_color)])
        lc_two.set_path_effects([withStroke(linewidth=lw*1.8)])

        ax[2].add_collection(lc)
        ax[1].add_collection(lc_two)

plt.tight_layout()

# Save with optimized settings
fig.savefig('result.jpg', dpi=300, bbox_inches='tight')
plt.close()


from IPython.display import Image
Image(filename='result.jpg')

# Dataset Release: Generated Dataset Visualization

In this section we show the generated digital inks of our model on 100 randomly selected test samples of publicly available IMGUR5K, IAM, HierText datasets, and also compare our models to "golden" human traced data.

Structures of the Supplementary Materials for each public dataset (example below for IMGUR5k dataset, similar structure for HierText and IAM datasets):
```
├── IMGUR5k
│   └── images_sample
│       ├── 0rMi6_45.png
│       ├── 0wxvqTL_23.png
│       ├── ...
├── large-i_IMGUR5K_inkml
│   └── d+t
│       ├── 0rMi6_45.inkml
│       ├── 0wxvqTL_23.inkml
│       ├── ...
│   └── vanilla
│   └── r+d
├── small-i_IMGUR5K_inkml
│   └── d+t
│   └── vanilla
│   └── r+d
└── small-p_IMGUR5K_inkml
│   └── d+t
│   └── vanilla
│   └── r+d
```
We store the raw input images in the folder with name that corresponds to each public dataset, and corresponding `.inkml` files in the folders with the naming convention `<model_name>_<dataset_name>_inkml`.


Under each inkml folder there are three subfolders `d+t`, `vanilla`, and `r+d` corresponding to the data generated with the inference mode `Derender with Text`, `Vanilla Derendering`, and `Recognized and Derender`.


In [None]:
# @title Notice
from IPython.display import display, HTML
display(HTML('<p style="font-size:20px; font-weight:bold; color:red; background-color:lightgray; padding:10px; width:50%">Licence and Terms of Use</p>'))
display(HTML('<p style="font-size:16px; width:50%">Results of model inference on public datasets, the results of the human tracing, and the model itself are available under Apache V2 license for research, non-commerical usecases only, as the derivatives of non-commerical research datasets. </p>'))


In [None]:
# @title Comparison between Inference Tasks
from PIL import Image
Dataset = "HierText" # @param ["IMGUR5K", "IAM", "HierText"]
Num_samples = 3 # @param {type:"integer"}
Model = "Small-i" # @param ["Small-i", "Large-i", "Small-p"]
inkml_path = None
if Model == "Small-i":
    inkml_path = f"./derendering_supp/small-i_{Dataset}_inkml"
elif Model == "Small-p":
    inkml_path = f"./derendering_supp/small-p_{Dataset}_inkml"
elif Model == "Large-i":
    inkml_path = f"./derendering_supp/large-i_{Dataset}_inkml"
else:
    raise ValueError('Now only supports Small-i, Small-p, Large-i.')
path = f"./derendering_supp/{Dataset}/images_sample"
samples = os.listdir(path)
picked_samples = random.sample(samples, Num_samples)

plot_title = {
    "r+d": "Recognized: ",
    "d+t": "OCR Input: ",
    "vanilla": "Vanilla"
}
query_modes = ["d+t", "r+d", "vanilla"]

for name in picked_samples:
    fig, ax = plt.subplots(1, 1+len(query_modes), figsize=(6*len(query_modes), 4))
    img = load_and_pad_img_dir(os.path.join(path, name))
    ax[0].set_xticks([])
    ax[0].set_yticks([])
    ax[0].imshow(img)
    ax[0].set_title('Input')
    for i, mode in enumerate(query_modes):
        example_id = name.strip('.png')
        inkml_file = os.path.join(inkml_path, mode, example_id + '.inkml')
        ink = inkml_to_ink(inkml_file)
        text_field = parse_inkml_annotations(inkml_file)['textField']

        plot_ink(ink, ax[1+i], input_image=img, lw=1.8)
        ax[1+i].set_xticks([])
        ax[1+i].set_yticks([])
        ax[1+i].set_title(f'{plot_title[mode]}{text_field}')

    plt.show()

In [None]:
# @title Comparison between Models using Derendering with Text
from PIL import Image
Dataset = "IMGUR5K" # @param ["IMGUR5K", "IAM", "HierText"]
Num_samples = 3 # @param {type:"integer"}
model_selections = ["Small-p", "Small-i", "Large-i"]

path = f"./derendering_supp/{Dataset}/images_sample"
samples = os.listdir(path)
picked_samples = random.sample(list(samples), Num_samples)
mode = 'd+t'

for name in picked_samples:
    fig, ax = plt.subplots(1, 1+len(query_modes), figsize=(4*(1+len(query_modes)), 4))
    img = load_and_pad_img_dir(os.path.join(path, name))
    ax[0].set_xticks([])
    ax[0].set_yticks([])
    ax[0].imshow(img)
    ax[0].set_title('Input')
    for i, model in enumerate(model_selections):
        example_id = name.strip('.png')
        inkml_path = f"./derendering_supp/{model.lower()}_{Dataset}_inkml"
        inkml_file = os.path.join(inkml_path, mode, example_id + '.inkml')
        ink = inkml_to_ink(inkml_file)
        text_field = parse_inkml_annotations(inkml_file)['textField']

        plot_ink(ink, ax[1+i], input_image=img, lw=1.8)
        ax[1+i].set_xticks([])
        ax[1+i].set_yticks([])
        ax[1+i].set_title( model + " | OCR Input: " + text_field + ' ')

    plt.show()

In [None]:
from PIL import Image
human_ink = np.load('derendering_supp/human_tracing_hash_to_ink.npy', allow_pickle=True).item()
all_samples = human_ink.keys()

Dataset = "HierText"
Num_samples = 3 # @param {type:"integer"}
model_selections = ["Small-p", "Small-i", "Large-i"]

path = f"./derendering_supp/{Dataset}/images_sample"
samples = os.listdir(path)
picked_samples = random.sample(list(all_samples), Num_samples)
mode = 'd+t'

for name in picked_samples:
    fig, ax = plt.subplots(1, 1+len(query_modes)+1, figsize=(4*(1+len(query_modes)+1), 4))
    img = load_and_pad_img_dir(os.path.join(path, name + '.png'))
    ax[0].set_xticks([])
    ax[0].set_yticks([])
    ax[0].imshow(img)
    ax[0].set_title('Input')
    for i, model in enumerate(model_selections):
        example_id = name.strip('.png')
        inkml_path = f"./derendering_supp/{model.lower()}_{Dataset}_inkml"
        inkml_file = os.path.join(inkml_path, mode, example_id + '.inkml')
        ink = inkml_to_ink(inkml_file)
        text_field = parse_inkml_annotations(inkml_file)['textField']

        plot_ink(ink, ax[1+i], input_image=img, lw=1.8)
        ax[1+i].set_xticks([])
        ax[1+i].set_yticks([])
        ax[1+i].set_title("OCR Input: " + text_field + ' ' + model)
        plot_ink(human_ink[example_id], ax[-1], input_image=img)
        ax[-1].set_xticks([])
        ax[-1].set_yticks([])
        ax[-1].set_title('Human Traced')

In [None]:
human_ink = np.load('derendering_supp/human_tracing_hash_to_ink.npy', allow_pickle=True).item()
all_samples = human_ink.keys()

Dataset = "HierText"
Num_samples = 3 # @param {type:"integer"}
model_selections = ["Small-p", "Small-i", "Large-i"]

path = f"./derendering_supp/{Dataset}/images_sample"
samples = os.listdir(path)
picked_samples = random.sample(list(all_samples), Num_samples)
mode = 'd+t'

for name in picked_samples:
    fig, ax = plt.subplots(1, 1+len(query_modes)+1, figsize=(4*(1+len(query_modes)+1), 4))
    img = load_and_pad_img_dir(os.path.join(path, name + '.png'))
    ax[0].set_xticks([])
    ax[0].set_yticks([])
    ax[0].imshow(img)
    ax[0].set_title('Input')
    for i, model in enumerate(model_selections):
        example_id = name.strip('.png')
        inkml_path = f"./derendering_supp/{model.lower()}_{Dataset}_inkml"
        inkml_file = os.path.join(inkml_path, mode, example_id + '.inkml')
        ink = inkml_to_ink(inkml_file)
        text_field = parse_inkml_annotations(inkml_file)['textField']

        plot_ink(ink, ax[1+i], input_image=img, lw=1.8)
        ax[1+i].set_xticks([])
        ax[1+i].set_yticks([])
        ax[1+i].set_title("OCR Input: " + text_field + ' ' + model)
        plot_ink(human_ink[example_id], ax[-1], input_image=img)
        ax[-1].set_xticks([])
        ax[-1].set_yticks([])
        ax[-1].set_title('Human Traced')