In [None]:
import itertools
from collections import Counter
from functools import reduce
from sklearn.preprocessing import LabelEncoder

from src.tree_parser import *
from src.tree_sequencing import *
from src.model import *

def iter_trees():
    with open("tmp/im2latex_formulas1.tree", "rb") as lines1:
        with open("tmp/im2latex_formulas2.tree", "rb") as lines2:
            for elements in iter_tree_lines(itertools.chain(lines1, lines2)):
                try:
                    yield list(iter_elements(elements)) if elements else []
                except Exception as err:
                    print("WTF")
                    print(b''.join(elements).decode())
                    raise


class TreeMatcherVisitor(Visitor):
    def __init__(self, with_children_filter=lambda _: True, symbol_filter=lambda _: True):
        self.is_ok = True
        self.with_children_filter = with_children_filter
        self.symbol_filter = symbol_filter

    def visit_with_children(self, with_children):
        if not self.with_children_filter(with_children.header):
            self.is_ok = False

    def visit_symbol(self, symbol):
        if not self.symbol_filter(symbol.symbol):
            self.is_ok = False
            
def tree_matches(elements, **kwargs):
    visitor = TreeMatcherVisitor(**kwargs)
    AllVisitor(visitor).visit_list(elements)
    return visitor.is_ok

class TotalAndDocumentFrequencyCounter:
    def __init__(self):
        self.total = Counter()
        self.document_frequency = Counter()

    def __call__(self, key, count):
        self.total[key] += count
        self.document_frequency[key] += 1
        return self

class StripInvisibleVisitor(Visitor):
    def visit_symbol(self, symbol):
        return symbol_is_invisible(symbol.symbol)

def strip_invisible(elements):
    return [
        element
        for element in elements
        if not element.visit(StripInvisibleVisitor())
    ]

def symbol_is_invisible(symbol):
    return\
        symbol.startswith(b'\\write1{\\newlabel{')\
            or symbol in [
                b'\\kern 0.0\n',
                b'\\kern0.0\n',
                b'\\glue 0.0\n',
            ]

def default_with_children_filter(header):
    if header.startswith((b'\\vbox', b'\\hbox')):
        return False
    return True

            
def default_symbol_filter(symbol):
    if symbol_is_invisible(symbol):
        return False
    if symbol.startswith((b'\\hbox', b'\\vbox', b'\\kern')):
        return False
    elif symbol.startswith(b'\\glue') and symbol not in [
            b'\\glue(\\mskip) 3.0mu\n',
            b'\\glue 3.33333 plus 1.66666 minus 1.11111\n',
            b'\\glue(\\mskip) 5.0mu plus 5.0mu\n',
            b'\\glue 10.00002\n',
            b'\\glue 20.00003\n',
            b'\\glue(\\mskip) -3.0mu\n',
            b'\\glue(\\mskip) 4.0mu plus 2.0mu minus 4.0mu\n',
            b'\\glue 0.0 plus 1.0fil\n',
            b'\\glue 28.45274\n',
            b'\\glue 14.22636\n',
            b'\\glue 56.9055\n',
    ]:
        return False
    return True

def base_tree_filter(tree):
    return tree_matches(
        tree,
        with_children_filter=default_with_children_filter,
        symbol_filter=default_symbol_filter,
    )

class SymbolCounterVisitor(Visitor):
    def __init__(self):
        self.counter = Counter()

    def visit_symbol(self, symbol):
        self.counter[symbol.symbol] += 1

def count_symbols(counter, elements):
    visitor = SymbolCounterVisitor()
    AllVisitor(visitor).visit_list(elements)
    for key, count in visitor.counter.items():
        counter(key, count)
    return counter

class FrequencyFilter:
    def __init__(self, counter, limit):
        self.counter = counter
        self.limit = limit
    
    def __call__(self, key):
        return self.counter[key] > self.limit

In [None]:
symbol_counter = reduce(count_symbols, map(strip_invisible, iter_trees()), TotalAndDocumentFrequencyCounter())

In [None]:
def tree_is_sequenceable(tree):
    return base_tree_filter(tree) \
        and tree_matches(
            tree,
            symbol_filter=FrequencyFilter(symbol_counter.document_frequency, 1000),
        ) \
        and len(elements_to_sequence(tree)) < 100

In [None]:
with open('tmp/image_ids.lst') as image_ids:
    sequence_with_image_id = [
        (image_id[:-1] + ".png", elements_to_sequence(tree))
        for image_id, tree in zip(image_ids, map(strip_invisible, iter_trees()))
        if tree_is_sequenceable(tree)
    ]
token_encoder = LabelEncoder()
token_encoder.fit(['EOS', 'SOS'] + list({
    token
    for image_id, sequence in sequence_with_image_id
    for token in sequence
}))

In [None]:
flow = TrainingSequence(sequence_with_image_id, token_encoder)
max_sequence_length = flow.max_sequence_length

In [None]:
images_input = keras.layers.Input(shape=(256, 256, 1), dtype='uint8')
seq = images_seq(images_input)
seq_mean = keras.layers.Lambda(
    keras.backend.mean,
    arguments={'axis': 1},
    output_shape=(seq.shape.dims[0], seq.shape.dims[2]),
)(seq)
h0 = keras.layers.Dense(512, activation='tanh')(seq_mean)
o0 = keras.layers.Dense(512, activation='tanh')(seq_mean)
vocab_size = len(token_encoder.classes_)

# 1. get token embeddings
dim = 80
sequences_input = keras.Input((max_sequence_length,), dtype='int32')
tok_embeddings = keras.layers.Embedding(vocab_size, dim, dtype='float32')(sequences_input)

# 2. add the special <sos> token embedding at the beggining of every formula

# 3. decode
attn_cell = keras.layers.RNN(
    AttentionCell(keras.layers.LSTMCell(512), seq, vocab_size),
    return_sequences=True,
)
model = keras.models.Model(
    [images_input, sequences_input],
    attn_cell(tok_embeddings),
)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(flow)