From f002a65791cc54bade6d473a94fe03edc3e55745 Mon Sep 17 00:00:00 2001 From: Lukas Blecher Date: Wed, 20 Apr 2022 18:10:07 +0200 Subject: [PATCH 1/3] add api --- pix2tex/api/app.py | 44 ++++++++++++++++++++++++++++++++++++++++ pix2tex/api/run.py | 21 +++++++++++++++++++ pix2tex/api/streamlit.py | 33 ++++++++++++++++++++++++++++++ setup.py | 21 +++++++++++++------ 4 files changed, 113 insertions(+), 6 deletions(-) create mode 100644 pix2tex/api/app.py create mode 100644 pix2tex/api/run.py create mode 100644 pix2tex/api/streamlit.py diff --git a/pix2tex/api/app.py b/pix2tex/api/app.py new file mode 100644 index 0000000..b79941f --- /dev/null +++ b/pix2tex/api/app.py @@ -0,0 +1,44 @@ +from http import HTTPStatus +from fastapi import FastAPI, File, UploadFile +from PIL import Image +from io import BytesIO +from pix2tex.cli import initialize, call_model + +model = None +app = FastAPI(title='pix2tex API') + + +def read_imagefile(file) -> Image.Image: + image = Image.open(BytesIO(file)) + return image + + +@app.on_event('startup') +async def load_model(): + global model + if model is None: + model = initialize() + + +@app.get('/') +def root(): + '''Health check.''' + response = { + 'message': HTTPStatus.OK.phrase, + 'status-code': HTTPStatus.OK, + 'data': {}, + } + return response + + +@app.post('/predict/') +async def predict(file: UploadFile = File(...)): + global model + image = Image.open(file.file) + pred = call_model(*model, img=image) + response = { + 'message': HTTPStatus.OK.phrase, + 'status-code': HTTPStatus.OK, + 'data': pred, + } + return response diff --git a/pix2tex/api/run.py b/pix2tex/api/run.py new file mode 100644 index 0000000..e265b2f --- /dev/null +++ b/pix2tex/api/run.py @@ -0,0 +1,21 @@ +from multiprocessing import Process +import subprocess +import os + + +def start_api(path='.'): + subprocess.call(['uvicorn', 'app:app'], cwd=path) + + +def start_frontend(path='.'): + subprocess.call(['streamlit', 'run', 'streamlit.py'], cwd=path) + + +if __name__ == '__main__': + path = os.path.realpath(os.path.dirname(__file__)) + api = Process(target=start_api, kwargs={'path': path}) + api.start() + frontend = Process(target=start_frontend, kwargs={'path': path}) + frontend.start() + api.join() + frontend.join() diff --git a/pix2tex/api/streamlit.py b/pix2tex/api/streamlit.py new file mode 100644 index 0000000..44735d6 --- /dev/null +++ b/pix2tex/api/streamlit.py @@ -0,0 +1,33 @@ +from msilib.schema import Icon +import requests +from PIL import Image +import streamlit + +if __name__ == '__main__': + streamlit.set_page_config(page_title='LaTeX-OCR') + streamlit.title('LaTeX OCR') + streamlit.markdown('Convert images of equations to corresponding LaTeX code.\n\nThis is based on the `pix2tex` module. Check it out [![github](https://img.shields.io/badge/LaTeX--OCR-visit-a?style=social&logo=github)](https://github.com/lukas-blecher/LaTeX-OCR)') + + uploaded_file = streamlit.file_uploader( + 'Upload an image an equation', + type=['png', 'jpg'], + ) + + if uploaded_file is not None: + image = Image.open(uploaded_file) + streamlit.image(image) + else: + streamlit.text('\n') + + if streamlit.button('Convert'): + if uploaded_file is not None and image is not None: + with streamlit.spinner('Computing'): + response = requests.post('http://127.0.0.1:8000/predict/', files={'file': uploaded_file.getvalue()}) + if response.ok: + latex_code = response.json()['data'] + streamlit.code(latex_code, language='latex') + streamlit.markdown(f'$\\displaystyle {latex_code}$') + else: + streamlit.error(response.text) + else: + streamlit.error('Please upload an image.') diff --git a/setup.py b/setup.py index 689f857..4fb6213 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,18 @@ this_directory = Path(__file__).parent long_description = (this_directory / "README.md").read_text() +gui = [ + "PyQt5", + "PyQtWebEngine", + "pynput", + "screeninfo", +] +api = [ + "streamlit>=1.8.1", + "fastapi>=0.75.2", + "uvicorn[standard]" +] + setuptools.setup( name='pix2tex', version='0.0.12', @@ -53,12 +65,9 @@ "python-Levenshtein>=0.12.2", ], extras_require={ - "gui": [ - "PyQt5", - "PyQtWebEngine", - "pynput", - "screeninfo", - ] + "all": gui+api, + "gui": gui, + "api": api }, entry_points={ 'console_scripts': [ From 5c63c3b68b0b1d560d70a76b013b5817cafe1cb1 Mon Sep 17 00:00:00 2001 From: Lukas Blecher Date: Thu, 21 Apr 2022 11:15:21 +0200 Subject: [PATCH 2/3] update --- setup.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/setup.py b/setup.py index 8f37a8a..41a5d79 100644 --- a/setup.py +++ b/setup.py @@ -8,20 +8,21 @@ long_description = (this_directory / 'README.md').read_text() gui = [ - "PyQt5", - "PyQtWebEngine", - "pynput", - "screeninfo", + 'PyQt5', + 'PyQtWebEngine', + 'pynput', + 'screeninfo', ] api = [ - "streamlit>=1.8.1", - "fastapi>=0.75.2", - "uvicorn[standard]" + 'streamlit>=1.8.1', + 'fastapi>=0.75.2', + 'uvicorn[standard]', + 'python-multipart' ] setuptools.setup( name='pix2tex', - version='0.0.14', + version='0.0.15', description='pix2tex: Using a ViT to convert images of equations into LaTeX code.', long_description=long_description, long_description_content_type='text/markdown', @@ -64,9 +65,9 @@ 'imagesize>=1.2.0', ], extras_require={ - "all": gui+api, - "gui": gui, - "api": api + 'all': gui+api, + 'gui': gui, + 'api': api }, entry_points={ 'console_scripts': [ From c160ea0a126f44b53e009d2bc9db7c6b4c48ce1b Mon Sep 17 00:00:00 2001 From: Lukas Blecher Date: Fri, 22 Apr 2022 17:07:50 +0200 Subject: [PATCH 3/3] better inference with model class --- notebooks/LaTeX_OCR_test.ipynb | 4 +- pix2tex/api/app.py | 25 +++-- pix2tex/api/streamlit.py | 2 +- pix2tex/cli.py | 165 ++++++++++++++++----------------- pix2tex/gui.py | 26 ++---- 5 files changed, 109 insertions(+), 113 deletions(-) diff --git a/notebooks/LaTeX_OCR_test.ipynb b/notebooks/LaTeX_OCR_test.ipynb index 52846e6..a1c800f 100644 --- a/notebooks/LaTeX_OCR_test.ipynb +++ b/notebooks/LaTeX_OCR_test.ipynb @@ -61,7 +61,7 @@ "\n", "from pix2tex import cli as pix2tex\n", "from PIL import Image\n", - "args = pix2tex.initialize()\n", + "model = pix2tex.LatexOCR()\n", "\n", "from IPython.display import HTML, Math\n", "display(HTML(\"