Skip to content

Commit

Permalink
Merge branch 'main' into api
Browse files Browse the repository at this point in the history
  • Loading branch information
lukas-blecher committed Apr 27, 2022
2 parents c160ea0 + aa4093f commit 63787f5
Show file tree
Hide file tree
Showing 13 changed files with 210 additions and 93 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dist/
downloads/
eggs/
.eggs/
lib/
# lib/
lib64/
parts/
sdist/
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ Don't forget to update the path to the tokenizer in the config file and set `num
The model consist of a ViT [[1](#References)] encoder with a ResNet backbone and a Transformer [[2](#References)] decoder.

### Performance
| BLEU score | normed edit distance |
| ---------- | -------------------- |
| 0.88 | 0.10 |
| BLEU score | normed edit distance | token accuracy |
| ---------- | -------------------- | -------------- |
| 0.88 | 0.10 | 0.60 |

## Data
We need paired data for the network to learn. Luckily there is a lot of LaTeX code on the internet, e.g. [wikipedia](https://www.wikipedia.org), [arXiv](https://www.arxiv.org). We also use the formulae from the [im2latex-100k](https://zenodo.org/record/56198#.V2px0jXT6eA) [[3](#References)] dataset.
Expand Down
2 changes: 2 additions & 0 deletions pix2tex/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
import os
os.environ['FOR_DISABLE_CONSOLE_CTRL_HANDLER'] = '1'
6 changes: 0 additions & 6 deletions pix2tex/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +0,0 @@
import pix2tex.dataset.arxiv
import pix2tex.dataset.extract_latex
import pix2tex.dataset.latex2png
import pix2tex.dataset.render
import pix2tex.dataset.scraping
import pix2tex.dataset.dataset
113 changes: 69 additions & 44 deletions pix2tex/dataset/arxiv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# modified from https://github.com/soskek/arxiv_leaks

import argparse
import json
import subprocess
import os
import glob
import re
Expand All @@ -10,7 +10,6 @@
import logging
import tarfile
import tempfile
import chardet
import logging
import requests
import urllib.request
Expand All @@ -22,7 +21,7 @@

# logging.getLogger().setLevel(logging.INFO)
arxiv_id = re.compile(r'(?<!\d)(\d{4}\.\d{5})(?!\d)')
arxiv_base = 'https://arxiv.org/e-print/'
arxiv_base = 'https://export.arxiv.org/e-print/'


def get_all_arxiv_ids(text):
Expand All @@ -48,7 +47,7 @@ def download(url, dir_path='./'):
return 0


def read_tex_files(file_path):
def read_tex_files(file_path, demacro=False):
tex = ''
try:
with tempfile.TemporaryDirectory() as tempdir:
Expand All @@ -59,50 +58,59 @@ def read_tex_files(file_path):
texfiles = [os.path.abspath(x) for x in glob.glob(os.path.join(tempdir, '**', '*.tex'), recursive=True)]
except tarfile.ReadError as e:
texfiles = [file_path] # [os.path.join(tempdir, file_path+'.tex')]
if demacro:
ret = subprocess.run(['de-macro', *texfiles], cwd=tempdir, capture_output=True)
if ret.returncode == 0:
texfiles = glob.glob(os.path.join(tempdir, '**', '*-clean.tex'), recursive=True)
for texfile in texfiles:
try:
tex += open(texfile, 'r', encoding=chardet.detect(open(texfile, 'br').readline())['encoding']).read()
except UnicodeDecodeError:
ct = open(texfile, 'r', encoding='utf-8').read()
tex += ct
except UnicodeDecodeError as e:
logging.debug(e)
pass
tex = unfold(convert(tex))
except Exception as e:
logging.debug('Could not read %s: %s' % (file_path, str(e)))
pass
# remove comments
return re.sub(r'(?<!\\)%.*\n', '', tex)
raise e
tex = pydemacro(tex)
return tex


def download_paper(arxiv_id, dir_path='./'):
url = arxiv_base + arxiv_id
return download(url, dir_path)


def read_paper(targz_path, delete=True):
def read_paper(targz_path, delete=False, demacro=False):
paper = ''
if targz_path != 0:
paper = read_tex_files(targz_path)
paper = read_tex_files(targz_path, demacro=demacro)
if delete:
os.remove(targz_path)
return paper


def parse_arxiv(id):
tempdir = tempfile.gettempdir()
text = read_paper(download_paper(id, tempdir))
#print(text, file=open('paper.tex', 'w'))
#linked = list(set([l for l in re.findall(arxiv_id, text)]))
def parse_arxiv(id, save=None, demacro=True):
if save is None:
dir = tempfile.gettempdir()
else:
dir = save
text = read_paper(download_paper(id, dir), delete=save is None, demacro=demacro)

return find_math(text, wiki=False), []


if __name__ == '__main__':
# logging.getLogger().setLevel(logging.DEBUG)
parser = argparse.ArgumentParser(description='Extract math from arxiv')
parser.add_argument('-m', '--mode', default='top100', choices=['top100', 'ids', 'dir'],
help='Where to extract code from. top100: current 100 arxiv papers, id: specific arxiv ids. \
Usage: `python arxiv.py -m id id001 id002`, dir: a folder full of .tar.gz files. Usage: `python arxiv.py -m dir directory`')
parser.add_argument('-m', '--mode', default='top100', choices=['top', 'ids', 'dirs'],
help='Where to extract code from. top: current 100 arxiv papers (-m top int for any other number of papers), id: specific arxiv ids. \
Usage: `python arxiv.py -m id id001 id002`, dirs: a folder full of .tar.gz files. Usage: `python arxiv.py -m dir directory`')
parser.add_argument(nargs='*', dest='args', default=[])
parser.add_argument('-o', '--out', default=os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data'), help='output directory')
parser.add_argument('-d', '--demacro', dest='demacro', action='store_true',
help='Deprecated - Use de-macro (Slows down extraction, may but improves quality). Install https://www.ctan.org/pkg/de-macro')
parser.add_argument('-s', '--save', default=None, type=str, help='When downloading files from arxiv. Where to save the .tar.gz files. Default: Only temporary')
args = parser.parse_args()
if '.' in args.out:
args.out = os.path.dirname(args.out)
Expand All @@ -111,30 +119,47 @@ def parse_arxiv(id):
skip = open(skips, 'r', encoding='utf-8').read().split('\n')
else:
skip = []
if args.mode == 'ids':
visited, math = recursive_search(parse_arxiv, args.args, skip=skip, unit='paper')
elif args.mode == 'top100':
url = 'https://arxiv.org/list/physics/pastweek?skip=0&show=100' #'https://arxiv.org/list/hep-th/2203?skip=0&show=100'
ids = get_all_arxiv_ids(requests.get(url).text)
math, visited = [], ids
for id in tqdm(ids):
m, _ = parse_arxiv(id)
math.extend(m)
elif args.mode == 'dir':
dirs = os.listdir(args.args[0])
math, visited = [], []
for f in tqdm(dirs):
try:
text = read_paper(os.path.join(args.args[0], f), False)
math.extend(find_math(text, wiki=False))
visited.append(os.path.basename(f))
except Exception as e:
logging.debug(e)
pass
else:
raise NotImplementedError
print('\n'.join(math))
sys.exit(0)
if args.save is not None:
os.makedirs(args.save, exist_ok=True)
try:
if args.mode == 'ids':
visited, math = recursive_search(parse_arxiv, args.args, skip=skip, unit='paper', save=args.save, demacro=args.demacro)
elif args.mode == 'top':
num = 100 if len(args.args) == 0 else int(args.args[0])
url = 'https://arxiv.org/list/physics/pastweek?skip=0&show=%i' % num # 'https://arxiv.org/list/hep-th/2203?skip=0&show=100'
ids = get_all_arxiv_ids(requests.get(url).text)
math, visited = [], ids
for id in tqdm(ids):
try:
m, _ = parse_arxiv(id, save=args.save, demacro=args.demacro)
math.extend(m)
except ValueError:
pass
elif args.mode == 'dirs':
files = []
for folder in args.args:
files.extend([os.path.join(folder, p) for p in os.listdir(folder)])
math, visited = [], []
for f in tqdm(files):
try:
text = read_paper(f, delete=False, demacro=args.demacro)
math.extend(find_math(text, wiki=False))
visited.append(os.path.basename(f))
except DemacroError as e:
logging.debug(f + str(e))
pass
except KeyboardInterrupt:
break
except Exception as e:
logging.debug(e)
raise e
else:
raise NotImplementedError
except KeyboardInterrupt:
pass
print('Found %i instances of math latex code' % len(math))
# print('\n'.join(math))
# sys.exit(0)
for l, name in zip([visited, math], ['visited_arxiv.txt', 'math_arxiv.txt']):
f = os.path.join(args.out, name)
if not os.path.exists(f):
Expand Down
2 changes: 1 addition & 1 deletion pix2tex/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def generate_tokenizer(equations, output, vocab_size):
tokenizer = Tokenizer(BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
trainer = BpeTrainer(special_tokens=["[PAD]", "[BOS]", "[EOS]"], vocab_size=vocab_size, show_progress=True)
tokenizer.train(trainer, equations)
tokenizer.train(equations, trainer)
tokenizer.save(path=output, pretty=False)


Expand Down
86 changes: 62 additions & 24 deletions pix2tex/dataset/demacro.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@

import argparse
import re
import logging
from collections import Counter
import time
from pix2tex.dataset.extract_latex import remove_labels


class DemacroError(Exception):
pass


def main():
args = parse_command_line()
data = read(args.input)
data = convert(data)
data = unfold(data)
data = pydemacro(data)
if args.output is not None:
write(args.output, data)
else:
Expand All @@ -28,16 +34,6 @@ def read(path):
return handle.read()


def convert(data):
return re.sub(
r'((?:\\(?:expandafter|global|long|outer|protected)'
r'(?: +|\r?\n *)?)*)?'
r'\\def *(\\[a-zA-Z]+) *(?:#+([0-9]))*\{',
replace,
data,
)


def bracket_replace(string: str) -> str:
'''
replaces all layered brackets with special symbols
Expand Down Expand Up @@ -66,7 +62,9 @@ def sweep(t, cmds):
nargs = int(c[1][1]) if c[1] != r'' else 0
optional = c[2] != r''
if nargs == 0:
t = re.sub(r'\\%s([\W_^\d])' % c[0], r'%s\1' % c[-1].replace('\\', r'\\'), t)
num_matches += len(re.findall(r'\\%s([\W_^\dĊ])' % c[0], t))
if num_matches > 0:
t = re.sub(r'\\%s([\W_^\dĊ])' % c[0], r'%s\1' % c[-1].replace('\\', r'\\'), t)
else:
matches = re.findall(r'(\\%s(?:\[(.+?)\])?' % c[0]+r'{(.+?)}'*(nargs-(1 if optional else 0))+r')', t)
num_matches += len(matches)
Expand All @@ -81,18 +79,49 @@ def sweep(t, cmds):


def unfold(t):
t = remove_labels(t).replace('\n', 'Ċ')

cmds = re.findall(r'\\(?:re)?newcommand\*?{\\(.+?)}\s*(\[\d\])?(\[.+?\])?{(.+?)}Ċ', t)
#t = queue.get()
t = t.replace('\n', 'Ċ')
t = bracket_replace(t)
commands_pattern = r'\\(?:re)?newcommand\*?{\\(.+?)}[\sĊ]*(\[\d\])?[\sĊ]*(\[.+?\])?[\sĊ]*{(.*?)}\s*(?:Ċ|\\)'
cmds = re.findall(commands_pattern, t)
t = re.sub(r'(?<!\\)'+commands_pattern, 'Ċ', t)
cmds = sorted(cmds, key=lambda x: len(x[0]))
for _ in range(10):
# check for up to 10 nested commands
t = bracket_replace(t)
t, N = sweep(t, cmds)
t = undo_bracket_replace(t)
if N == 0:
break
return t.replace('Ċ', '\n')
cmd_names = Counter([c[0] for c in cmds])
for i in reversed(range(len(cmds))):
if cmd_names[cmds[i][0]] > 1:
# something went wrong here. No multiple definitions allowed
del cmds[i]
elif '\\newcommand' in cmds[i][-1]:
logging.debug("Command recognition pattern didn't work properly. %s" % (undo_bracket_replace(cmds[i][-1])))
del cmds[i]
start = time.time()
try:
for i in range(10):
# check for up to 10 nested commands
if i > 0:
t = bracket_replace(t)
t, N = sweep(t, cmds)
if time.time()-start > 5: # not optimal. more sophisticated methods didnt work or are slow
raise TimeoutError
t = undo_bracket_replace(t)
if N == 0 or i == 9:
#print("Needed %i iterations to demacro" % (i+1))
break
elif N > 4000:
raise ValueError("Too many matches. Processing would take too long.")
except ValueError:
pass
except TimeoutError:
pass
except re.error as e:
raise DemacroError(e)
t = remove_labels(t.replace('Ċ', '\n'))
# queue.put(t)
return t


def pydemacro(t):
return unfold(convert(re.sub('\n+', '\n', re.sub(r'(?<!\\)%.*\n', '\n', t))))


def replace(match):
Expand Down Expand Up @@ -120,6 +149,15 @@ def replace(match):
return result


def convert(data):
data = re.sub(
r'((?:\\(?:expandafter|global|long|outer|protected)(?:\s+|\r?\n\s*)?)*)?\\def\s*(\\[a-zA-Z]+)\s*(?:#+([0-9]))*\{',
replace,
data,
)
return re.sub(r'\\let\s*(\\[a-zA-Z]+)\s*=?\s*(\\?\w+)*', r'\\newcommand*{\1}{\2}\n', data)


def write(path, data):
with open(path, mode='w') as handle:
handle.write(data)
Expand Down
2 changes: 1 addition & 1 deletion pix2tex/dataset/extract_latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
displaymath = re.compile(r'(\\displaystyle)(.{%i,%i}?)(\}(?:<|"))' % (1, MAX_CHARS))
outer_whitespace = re.compile(
r'^\\,|\\,$|^~|~$|^\\ |\\ $|^\\thinspace|\\thinspace$|^\\!|\\!$|^\\:|\\:$|^\\;|\\;$|^\\enspace|\\enspace$|^\\quad|\\quad$|^\\qquad|\\qquad$|^\\hspace{[a-zA-Z0-9]+}|\\hspace{[a-zA-Z0-9]+}$|^\\hfill|\\hfill$')
label_names = [re.compile(r'\\%s\s?\{(.*?)\}' % s) for s in ['ref', 'cite', 'label', 'caption', 'eqref']]
label_names = [re.compile(r'\\%s\s?\{(.*?)\}' % s) for s in ['ref', 'cite', 'label', 'eqref']]

def check_brackets(s):
a = []
Expand Down
4 changes: 3 additions & 1 deletion pix2tex/dataset/preprocessing/preprocess_latex.js
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ groupTypes.array = function(group, options) {
groupTypes.sqrt = function(group, options) {
var node;
if (group.value.index) {
norm_str = norm_str + "\\sqrt [ " + group.value.index + " ] ";
norm_str = norm_str + "\\sqrt [ ";
buildExpression(group.value.index.value, options);
norm_str = norm_str + "] ";
buildGroup(group.value.body, options);
} else {
norm_str = norm_str + "\\sqrt ";
Expand Down

0 comments on commit 63787f5

Please sign in to comment.