Skip to content

Commit

Permalink
Merge pull request #140 from lukas-blecher/api
Browse files Browse the repository at this point in the history
Add API
  • Loading branch information
lukas-blecher committed Apr 27, 2022
2 parents aa4093f + 63787f5 commit 13d562a
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 109 deletions.
4 changes: 2 additions & 2 deletions notebooks/LaTeX_OCR_test.ipynb
Expand Up @@ -61,7 +61,7 @@
"\n",
"from pix2tex import cli as pix2tex\n",
"from PIL import Image\n",
"args = pix2tex.initialize()\n",
"model = pix2tex.LatexOCR()\n",
"\n",
"from IPython.display import HTML, Math\n",
"display(HTML(\"<script src='https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.3/\"\n",
Expand All @@ -76,7 +76,7 @@
"predictions = []\n",
"for name, f in imgs:\n",
" img = Image.open(f)\n",
" math = pix2tex.call_model(*args, img)\n",
" math = model(img)\n",
" print(math)\n",
" predictions.append('\\\\mathrm{%s} & \\\\displaystyle{%s}'%(name, math))\n",
"Math(table%'\\\\\\\\'.join(predictions))"
Expand Down
49 changes: 49 additions & 0 deletions pix2tex/api/app.py
@@ -0,0 +1,49 @@
# Adapted from https://github.com/kingyiusuen/image-to-latex/blob/main/api/app.py

from ctypes import resize
from http import HTTPStatus
from fastapi import FastAPI, File, UploadFile, Form
from PIL import Image
from io import BytesIO
from pix2tex.cli import LatexOCR

model = None
app = FastAPI(title='pix2tex API')


def read_imagefile(file) -> Image.Image:
image = Image.open(BytesIO(file))
return image


@app.on_event('startup')
async def load_model():
global model
if model is None:
model = LatexOCR()


@app.get('/')
def root():
'''Health check.'''
response = {
'message': HTTPStatus.OK.phrase,
'status-code': HTTPStatus.OK,
'data': {},
}
return response


@app.post('/predict/')
async def predict(file: UploadFile = File(...)):
global model
image = Image.open(file.file)
return model(image)


@app.post('/bytes/')
async def predict_from_bytes(file: bytes = File(...)): #, size: str = Form(...)
global model
#size = tuple(int(a) for a in size.split(','))
image = Image.open(BytesIO(file))
return model(image, resize=False)
21 changes: 21 additions & 0 deletions pix2tex/api/run.py
@@ -0,0 +1,21 @@
from multiprocessing import Process
import subprocess
import os


def start_api(path='.'):
subprocess.call(['uvicorn', 'app:app'], cwd=path)


def start_frontend(path='.'):
subprocess.call(['streamlit', 'run', 'streamlit.py'], cwd=path)


if __name__ == '__main__':
path = os.path.realpath(os.path.dirname(__file__))
api = Process(target=start_api, kwargs={'path': path})
api.start()
frontend = Process(target=start_frontend, kwargs={'path': path})
frontend.start()
api.join()
frontend.join()
33 changes: 33 additions & 0 deletions pix2tex/api/streamlit.py
@@ -0,0 +1,33 @@
from msilib.schema import Icon
import requests
from PIL import Image
import streamlit

if __name__ == '__main__':
streamlit.set_page_config(page_title='LaTeX-OCR')
streamlit.title('LaTeX OCR')
streamlit.markdown('Convert images of equations to corresponding LaTeX code.\n\nThis is based on the `pix2tex` module. Check it out [![github](https://img.shields.io/badge/LaTeX--OCR-visit-a?style=social&logo=github)](https://github.com/lukas-blecher/LaTeX-OCR)')

uploaded_file = streamlit.file_uploader(
'Upload an image an equation',
type=['png', 'jpg'],
)

if uploaded_file is not None:
image = Image.open(uploaded_file)
streamlit.image(image)
else:
streamlit.text('\n')

if streamlit.button('Convert'):
if uploaded_file is not None and image is not None:
with streamlit.spinner('Computing'):
response = requests.post('http://127.0.0.1:8000/predict/', files={'file': uploaded_file.getvalue()})
if response.ok:
latex_code = response.json()
streamlit.code(latex_code, language='latex')
streamlit.markdown(f'$\\displaystyle {latex_code}$')
else:
streamlit.error(response.text)
else:
streamlit.error('Please upload an image.')
165 changes: 81 additions & 84 deletions pix2tex/cli.py
Expand Up @@ -21,8 +21,6 @@
from pix2tex.utils import *
from pix2tex.model.checkpoints.get_latest_checkpoint import download_checkpoints

last_pic = None


def minmax_size(img, max_dimensions=None, min_dimensions=None):
if max_dimensions is not None:
Expand All @@ -40,79 +38,77 @@ def minmax_size(img, max_dimensions=None, min_dimensions=None):
return img


@in_model_path()
def initialize(arguments=None):
if arguments is None:
arguments = Munch({'config': 'settings/config.yaml', 'checkpoint': 'checkpoints/weights.pth', 'no_cuda': True, 'no_resize': False})
logging.getLogger().setLevel(logging.FATAL)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
with open(arguments.config, 'r') as f:
params = yaml.load(f, Loader=yaml.FullLoader)
args = parse_args(Munch(params))
args.update(**vars(arguments))
args.wandb = False
args.device = 'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu'
if not os.path.exists(args.checkpoint):
download_checkpoints()
model = get_model(args)
model.load_state_dict(torch.load(args.checkpoint, map_location=args.device))

if 'image_resizer.pth' in os.listdir(os.path.dirname(args.checkpoint)) and not arguments.no_resize:
image_resizer = ResNetV2(layers=[2, 3, 3], num_classes=max(args.max_dimensions)//32, global_pool='avg', in_chans=1, drop_rate=.05,
preact=True, stem_type='same', conv_layer=StdConv2dSame).to(args.device)
image_resizer.load_state_dict(torch.load(os.path.join(os.path.dirname(args.checkpoint), 'image_resizer.pth'), map_location=args.device))
image_resizer.eval()
else:
image_resizer = None
tokenizer = PreTrainedTokenizerFast(tokenizer_file=args.tokenizer)
return args, model, image_resizer, tokenizer


@in_model_path()
def call_model(args, model, image_resizer, tokenizer, img=None):
global last_pic
encoder, decoder = model.encoder, model.decoder
if type(img) is bool:
img = None
if img is None:
if last_pic is None:
print('Provide an image.')
return ''
class LatexOCR:
image_resizer = None
last_pic = None

@in_model_path()
def __init__(self, arguments=None):
if arguments is None:
arguments = Munch({'config': 'settings/config.yaml', 'checkpoint': 'checkpoints/weights.pth', 'no_cuda': True, 'no_resize': False})
logging.getLogger().setLevel(logging.FATAL)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
with open(arguments.config, 'r') as f:
params = yaml.load(f, Loader=yaml.FullLoader)
self.args = parse_args(Munch(params))
self.args.update(**vars(arguments))
self.args.wandb = False
self.args.device = 'cuda' if torch.cuda.is_available() and not self.args.no_cuda else 'cpu'
if not os.path.exists(self.args.checkpoint):
download_checkpoints()
self.model = get_model(self.args)
self.model.load_state_dict(torch.load(self.args.checkpoint, map_location=self.args.device))

if 'image_resizer.pth' in os.listdir(os.path.dirname(self.args.checkpoint)) and not arguments.no_resize:
self.image_resizer = ResNetV2(layers=[2, 3, 3], num_classes=max(self.args.max_dimensions)//32, global_pool='avg', in_chans=1, drop_rate=.05,
preact=True, stem_type='same', conv_layer=StdConv2dSame).to(self.args.device)
self.image_resizer.load_state_dict(torch.load(os.path.join(os.path.dirname(self.args.checkpoint), 'image_resizer.pth'), map_location=self.args.device))
self.image_resizer.eval()
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=self.args.tokenizer)

@in_model_path()
def __call__(self, img=None, resize=True):
if type(img) is bool:
img = None
if img is None:
if self.last_pic is None:
print('Provide an image.')
return ''
else:
img = self.last_pic.copy()
else:
self.last_pic = img.copy()
img = minmax_size(pad(img), self.args.max_dimensions, self.args.min_dimensions)
if (self.image_resizer is not None and not self.args.no_resize) and resize:
with torch.no_grad():
input_image = img.convert('RGB').copy()
r, w, h = 1, input_image.size[0], input_image.size[1]
for _ in range(10):
h = int(h * r) # height to resize
img = pad(minmax_size(input_image.resize((w, h), Image.BILINEAR if r > 1 else Image.LANCZOS), self.args.max_dimensions, self.args.min_dimensions))
t = test_transform(image=np.array(img.convert('RGB')))['image'][:1].unsqueeze(0)
w = (self.image_resizer(t.to(self.args.device)).argmax(-1).item()+1)*32
logging.info(r, img.size, (w, int(input_image.size[1]*r)))
if (w == img.size[0]):
break
r = w/img.size[0]
else:
img = last_pic.copy()
else:
last_pic = img.copy()
img = minmax_size(pad(img), args.max_dimensions, args.min_dimensions)
if image_resizer is not None and not args.no_resize:
img = np.array(pad(img).convert('RGB'))
t = test_transform(image=img)['image'][:1].unsqueeze(0)
im = t.to(self.args.device)

with torch.no_grad():
input_image = img.convert('RGB').copy()
r, w, h = 1, input_image.size[0], input_image.size[1]
for _ in range(10):
h = int(h * r) # height to resize
img = pad(minmax_size(input_image.resize((w, h), Image.BILINEAR if r > 1 else Image.LANCZOS), args.max_dimensions, args.min_dimensions))
t = test_transform(image=np.array(img.convert('RGB')))['image'][:1].unsqueeze(0)
w = (image_resizer(t.to(args.device)).argmax(-1).item()+1)*32
logging.info(r, img.size, (w, int(input_image.size[1]*r)))
if (w == img.size[0]):
break
r = w/img.size[0]
else:
img = np.array(pad(img).convert('RGB'))
t = test_transform(image=img)['image'][:1].unsqueeze(0)
im = t.to(args.device)

with torch.no_grad():
model.eval()
device = args.device
encoded = encoder(im.to(device))
dec = decoder.generate(torch.LongTensor([args.bos_token])[:, None].to(device), args.max_seq_len,
eos_token=args.eos_token, context=encoded.detach(), temperature=args.get('temperature', .25))
pred = post_process(token2str(dec, tokenizer)[0])
try:
clipboard.copy(pred)
except:
pass
return pred
self.model.eval()
device = self.args.device
encoded = self.model.encoder(im.to(device))
dec = self.model.decoder.generate(torch.LongTensor([self.args.bos_token])[:, None].to(device), self.args.max_seq_len,
eos_token=self.args.eos_token, context=encoded.detach(), temperature=self.args.get('temperature', .25))
pred = post_process(token2str(dec, self.tokenizer)[0])
try:
clipboard.copy(pred)
except:
pass
return pred


def output_prediction(pred, args):
Expand Down Expand Up @@ -144,7 +140,8 @@ def main():
parser.add_argument('--no-resize', action='store_true', help='Resize the image beforehand')
arguments = parser.parse_args()
with in_model_path():
args, *objs = initialize(arguments)
model = LatexOCR(arguments)
file = None
while True:
instructions = input('Predict LaTeX code for image ("?"/"h" for help). ')
possible_file = instructions.strip()
Expand Down Expand Up @@ -176,32 +173,32 @@ def main():
''')
continue
elif ins in ['show', 'katex', 'no_resize']:
setattr(args, ins, not getattr(args, ins, False))
print('set %s to %s' % (ins, getattr(args, ins)))
setattr(model.args, ins, not getattr(model.args, ins, False))
print('set %s to %s' % (ins, getattr(model.args, ins)))
continue
elif os.path.isfile(os.path.realpath(possible_file)):
args.file = possible_file
file = possible_file
else:
t = re.match(r't=([\.\d]+)', ins)
if t is not None:
t = t.groups()[0]
args.temperature = float(t)+1e-8
print('new temperature: T=%.3f' % args.temperature)
model.args.temperature = float(t)+1e-8
print('new temperature: T=%.3f' % model.args.temperature)
continue
try:
img = None
if args.file:
img = Image.open(args.file)
if file:
img = Image.open(file)
else:
try:
img = ImageGrab.grabclipboard()
except:
pass
pred = call_model(args, *objs, img=img)
output_prediction(pred, args)
pred = model(img)
output_prediction(pred, model.args)
except KeyboardInterrupt:
pass
args.file = None
file = None


if __name__ == "__main__":
Expand Down

0 comments on commit 13d562a

Please sign in to comment.