In [None]:
from collections import Counter
from functools import reduce

from src.tree_parser import *
from src.tree_sequencing import *

def iter_trees():
    with open("../tmp/im2latex_formulas.tree", "rb") as lines:
        for elements in iter_tree_lines(lines):
            try:
                yield list(iter_elements(elements)) if elements else []
            except:
                print(b''.join(elements).decode())
                raise


class TreeMatcherVisitor(Visitor):
    def __init__(self, with_children_filter=lambda _: True, symbol_filter=lambda _: True):
        self.is_ok = True
        self.with_children_filter = with_children_filter
        self.symbol_filter = symbol_filter

    def visit_with_children(self, with_children):
        if not self.with_children_filter(with_children.header):
            self.is_ok = False

    def visit_symbol(self, symbol):
        if not self.symbol_filter(symbol.symbol):
            self.is_ok = False
            
def tree_matches(elements, **kwargs):
    visitor = TreeMatcherVisitor(**kwargs)
    AllVisitor(visitor).visit_list(elements)
    return visitor.is_ok

class TotalAndDocumentFrequencyCounter:
    def __init__(self):
        self.total = Counter()
        self.document_frequency = Counter()

    def __call__(self, key, count):
        self.total[key] += count
        self.document_frequency[key] += 1
        return self

class StripInvisibleVisitor(Visitor):
    def visit_symbol(self, symbol):
        return symbol_is_invisible(symbol.symbol)

def strip_invisible(elements):
    return [
        element
        for element in elements
        if not element.visit(StripInvisibleVisitor())
    ]

In [None]:
with open("../tmp/im2latex_formulas.lst", "rb") as formulas:
    formulas_number = 0
    for _ in formulas:
        formulas_number += 1
formulas_number

In [None]:
def symbol_is_invisible(symbol):
    return\
        symbol.startswith(b'\\write1{\\newlabel{')\
            or symbol in [
                b'\\kern 0.0\n',
                b'\\kern0.0\n',
                b'\\glue 0.0\n',
            ]

def default_with_children_filter(header):
    if header.startswith((b'\\vbox', b'\\hbox')):
        return False
    return True

            
def default_symbol_filter(symbol):
    if symbol_is_invisible(symbol):
        return False
    if symbol.startswith((b'\\hbox', b'\\vbox', b'\\kern')):
        return False
    elif symbol.startswith(b'\\glue') and symbol not in [
            b'\\glue(\\mskip) 3.0mu\n',
            b'\\glue 3.33333 plus 1.66666 minus 1.11111\n',
            b'\\glue(\\mskip) 5.0mu plus 5.0mu\n',
            b'\\glue 10.00002\n',
            b'\\glue 20.00003\n',
            b'\\glue(\\mskip) -3.0mu\n',
            b'\\glue(\\mskip) 4.0mu plus 2.0mu minus 4.0mu\n',
            b'\\glue 0.0 plus 1.0fil\n',
            b'\\glue 28.45274\n',
            b'\\glue 14.22636\n',
            b'\\glue 56.9055\n',
    ]:
        return False
    return True

def base_tree_filter(tree):
    return tree_matches(
        tree,
        with_children_filter=default_with_children_filter,
        symbol_filter=default_symbol_filter,
    )

In [None]:
sum(1 for elements in iter_trees() if base_tree_filter(strip_invisible(elements)))

In [None]:
class SymbolCounterVisitor(Visitor):
    def __init__(self):
        self.counter = Counter()

    def visit_symbol(self, symbol):
        self.counter[symbol.symbol] += 1

def count_symbols(counter, elements):
    visitor = SymbolCounterVisitor()
    AllVisitor(visitor).visit_list(elements)
    for key, count in visitor.counter.items():
        counter(key, count)
    return counter

class FrequencyFilter:
    def __init__(self, counter, limit):
        self.counter = counter
        self.limit = limit
    
    def __call__(self, key):
        return self.counter[key] > self.limit

In [None]:
symbol_counter = reduce(count_symbols, map(strip_invisible, iter_trees()), TotalAndDocumentFrequencyCounter())

In [None]:
sorted(map(list, map(reversed, symbol_counter.document_frequency.items())), reverse=True)

In [None]:
sum(
    1 for elements in map(strip_invisible, iter_trees())
    if base_tree_filter(elements) and tree_matches(
        elements,
        symbol_filter=FrequencyFilter(symbol_counter.document_frequency, 1000),
    )
)

In [None]:
symbol_counter2 = reduce(
    count_symbols,
    [
        elements
        for elements in map(strip_invisible, iter_trees())
        if base_tree_filter(elements)
            and tree_matches(elements, symbol_filter=FrequencyFilter(symbol_counter.document_frequency, 1000))
    ],
    TotalAndDocumentFrequencyCounter(),
)

In [None]:
sorted(map(list, map(reversed, symbol_counter2.document_frequency.items())), reverse=True)

In [None]:
for tree in iter_trees():
    if base_tree_filter(tree) \
        and tree_matches(tree, symbol_filter=FrequencyFilter(symbol_counter.document_frequency, 1000)):
        assert list(elements_from_sequence(elements_to_sequence(tree))) == tree