In [4]:
from IPython.display import display
import ipywidgets as widgets
from ipycanvas import Canvas
import torch
from model import Model
from tokenizer import vocab_size, decode, special_token_ids
from PIL import Image
from torchvision import transforms

to_tensor = transforms.ToTensor()

# model
n_embed = 512
n_head = 8
n_layer = 2
model = Model(vocab_size=vocab_size, n_embed=n_embed, n_head=n_head, n_layer=n_layer)
model.load_state_dict(torch.load('models/hand_math19_finetuned.bin', map_location=torch.device('cpu')))

# canvas
scale = 2
model_w, model_h = 256, 128
canvas = Canvas(width=model_w*scale, height=model_h*scale, sync_image_data=True)
canvas.fill_style = 'white'
canvas.fill_rect(0, 0, canvas.width, canvas.height)

canvas.line_width = 8.0
canvas.line_cap = 'round'
last_x, last_y = None, None

def canvas_mouse_down(x, y):
    global last_x, last_y
    last_x, last_y =  x, y

def canvas_mouse_up(x, y):
    global last_x, last_y
    last_x, last_y = None, None

def canvas_mouse_move(x, y):
    global last_x, last_y
    if last_x is None:
        return
    canvas.begin_path()
    canvas.move_to(last_x, last_y)
    canvas.line_to(x, y)
    canvas.stroke()
    last_x, last_y = x, y

canvas.on_mouse_down(canvas_mouse_down)
canvas.on_mouse_up(canvas_mouse_up)
canvas.on_mouse_move(canvas_mouse_move)

display(canvas)

# upload
upload_widget = widgets.FileUpload(accept='image/*', multiple=False)
def upload_widget_changed(change):
    new_value = change['new']
    content = new_value[0].content
    img = widgets.Image(value=content)
    canvas.draw_image(img, 0, 0, width=model_w*scale, height=model_h*scale)

upload_widget.observe(upload_widget_changed, names='value')

# clear
clear_button = widgets.Button(description='Clear')
def clear_button_clicked(b):
    canvas.fill_style = 'white'
    canvas.fill_rect(0, 0, canvas.width, canvas.height)
clear_button.on_click(clear_button_clicked)


# recognize
output_text = widgets.Text(value='Output')
eos_token_id = special_token_ids['<eos>']
max_gen_len = 32
recognize_button = widgets.Button(description='Recognize')
def recognize_button_clicked(b):
    img = canvas.get_image_data()
    img = Image.fromarray(img)
    img = img.resize((model_w, model_h)).convert('RGB')
    img = to_tensor(img).unsqueeze(0)
    idx = torch.tensor([[special_token_ids['<begin>']]], dtype=torch.long)
    for i in range(max_gen_len):
        logits = model(idx, img)
        logits = logits[:, -1, :]
        next_id = logits.argmax(dim=-1)
        if next_id.item() == eos_token_id:
            break
        idx = torch.cat([idx, next_id.unsqueeze(0)], dim=-1)
    text = decode(idx.tolist()[0][1:])
    output_text.value = text

recognize_button.on_click(recognize_button_clicked)

bar1 = widgets.HBox([upload_widget, clear_button, recognize_button])
display(bar1)

from sympy import *
to_set_text = widgets.Text(placeholder='Value name to set')
set_button = widgets.Button(description='Set')
def set_button_clicked(b):
    globals()[to_set_text.value] = sympify(output_text.value)
set_button.on_click(set_button_clicked)
bar2 = widgets.HBox([to_set_text, output_text, set_button])
display(bar2)

Canvas(height=256, sync_image_data=True, width=512)

HBox(children=(FileUpload(value=(), accept='image/*', description='Upload'), Button(description='Clear', style…

HBox(children=(Text(value='', placeholder='Value name to set'), Text(value='Output'), Button(description='Set'…

In [2]:
x

4*H*L*u - k