In [None]:
!pip install torchsummary

In [7]:
import cv2
import pandas.io.clipboard as clipboard
from PIL import ImageGrab
from PIL import Image
import os
import sys
import argparse
import logging
import yaml
import re

import numpy as np
import torch
from torchvision import transforms
from munch import Munch
from transformers import PreTrainedTokenizerFast
from timm.models.resnetv2 import ResNetV2
from timm.models.layers import StdConv2dSame
from dataset.dataset import test_transform

from dataset.latex2png import tex2pil
from models import get_model
from utils import *

last_pic = None

In [20]:
import os
os.chdir("/home/lap14784/Downloads/LaTeX_OCR")
os.getcwd()

'/home/lap14784/Downloads/LaTeX_OCR'

## Define model

In [64]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# from x_transformers import *
from x_transformers import TransformerWrapper, Decoder
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper, top_k, top_p, entmax, ENTMAX_ALPHA
from timm.models.vision_transformer import VisionTransformer
from timm.models.vision_transformer_hybrid import HybridEmbed
from timm.models.resnetv2 import ResNetV2
from timm.models.layers import StdConv2dSame
from einops import rearrange, repeat


class CustomARWrapper(AutoregressiveWrapper):
    def __init__(self, *args, **kwargs):
        super(CustomARWrapper, self).__init__(*args, **kwargs)

    @torch.no_grad()
    def generate(self, start_tokens, seq_len, eos_token=None, temperature=1., filter_logits_fn=top_k, filter_thres=0.9, **kwargs):
        device = start_tokens.device
        was_training = self.net.training
        num_dims = len(start_tokens.shape)

        if num_dims == 1:
            start_tokens = start_tokens[None, :]

        b, t = start_tokens.shape

        self.net.eval()
        out = start_tokens
        mask = kwargs.pop('mask', None)
        if mask is None:
            mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)

        for _ in range(seq_len):
            x = out[:, -self.max_seq_len:]
            mask = mask[:, -self.max_seq_len:]
            # print('arw:',out.shape)
            logits = self.net(x, mask=mask, **kwargs)[:, -1, :]

            if filter_logits_fn in {top_k, top_p}:
                filtered_logits = filter_logits_fn(logits, thres=filter_thres)
                probs = F.softmax(filtered_logits / temperature, dim=-1)

            elif filter_logits_fn is entmax:
                probs = entmax(logits / temperature, alpha=ENTMAX_ALPHA, dim=-1)

            sample = torch.multinomial(probs, 1)

            out = torch.cat((out, sample), dim=-1)
            mask = F.pad(mask, (0, 1), value=True)

            if eos_token is not None and (torch.cumsum(out == eos_token, 1)[:, -1] >= 1).all():
                break

        out = out[:, t:]

        if num_dims == 1:
            out = out.squeeze(0)

        self.net.train(was_training)
        return out


class CustomVisionTransformer(VisionTransformer):
    def __init__(self, img_size=224, patch_size=16, *args, **kwargs):
        super(CustomVisionTransformer, self).__init__(img_size=img_size, patch_size=patch_size, *args, **kwargs)
        self.height, self.width = img_size
        self.patch_size = patch_size

    def forward_features(self, x):
        print(np.shape(x))
        B, c, h, w = x.shape
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)
        h, w = h//self.patch_size, w//self.patch_size
        pos_emb_ind = repeat(torch.arange(h)*(self.width//self.patch_size-w), 'h -> (h w)', w=w)+torch.arange(h*w)
        pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long()
        x += self.pos_embed[:, pos_emb_ind]
        #x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        return x


class Model(nn.Module):
    def __init__(self, encoder: CustomVisionTransformer, decoder: CustomARWrapper, args, temp: float = .333):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.bos_token = args.bos_token
        self.eos_token = args.eos_token
        self.max_seq_len = args.max_seq_len
        self.temperature = temp

    @torch.no_grad()
    def forward(self, x: torch.Tensor):
        print("forward", x)
        device = x.device
        encoded = self.encoder(x.to(device))
        dec = self.decoder.generate(torch.LongTensor([self.bos_token]*len(x))[:, None].to(device), self.max_seq_len,
                                    eos_token=self.eos_token, context=encoded, temperature=self.temperature)
        return dec


def get_model(args, training=False):
    backbone = ResNetV2(
        layers=args.backbone_layers, num_classes=0, global_pool='', in_chans=args.channels,
        preact=False, stem_type='same', conv_layer=StdConv2dSame)
    min_patch_size = 2**(len(args.backbone_layers)+1)

    def embed_layer(**x):
        ps = x.pop('patch_size', min_patch_size)
        assert ps % min_patch_size == 0 and ps >= min_patch_size, 'patch_size needs to be multiple of %i with current backbone configuration' % min_patch_size
        return HybridEmbed(**x, patch_size=ps//min_patch_size, backbone=backbone)

    encoder = CustomVisionTransformer(img_size=(args.max_height, args.max_width),
                                      patch_size=args.patch_size,
                                      in_chans=args.channels,
                                      num_classes=0,
                                      embed_dim=args.dim,
                                      depth=args.encoder_depth,
                                      num_heads=args.heads,
                                      embed_layer=embed_layer
                                      ).to(args.device)

    decoder = CustomARWrapper(
        TransformerWrapper(
            num_tokens=args.num_tokens,
            max_seq_len=args.max_seq_len,
            attn_layers=Decoder(
                dim=args.dim,
                depth=args.num_layers,
                heads=args.heads,
                **args.decoder_args
            )),
        pad_value=args.pad_token
    ).to(args.device)

    model = Model(encoder, decoder, args)
#     if training:
#         # check if largest batch can be handled by system
#         im = torch.empty(args.batchsize, args.channels, args.max_height, args.min_height, device=args.device).float()
#         seq = torch.randint(0, args.num_tokens, (args.batchsize, args.max_seq_len), device=args.device).long()
#         decoder(seq, context=encoder(im)).sum().backward()
#         model.zero_grad()
#         torch.cuda.empty_cache() 
#         del im, seq
    return model


In [26]:
from torchsummary import summary
from dataset.dataset import test_transform
import cv2
import pandas.io.clipboard as clipboard
from PIL import ImageGrab
from PIL import Image
import os
import sys
import argparse
import logging
import yaml
import re

import numpy as np
import torch
from torchvision import transforms
from munch import Munch
from transformers import PreTrainedTokenizerFast
from timm.models.resnetv2 import ResNetV2
from timm.models.layers import StdConv2dSame

from dataset.latex2png import tex2pil
from models import get_model
from utils import *
last_pic = None


In [65]:
# if arguments is None:
# arguments = Munch({'config': 'settings/config.yaml', 'checkpoint': 'checkpoints/weights.pth', 'no_cuda': True, 'no_resize': False})
arguments = Munch({'epoch': 0, 'backbone_layers': [2, 3, 7], 'betas': [0.9, 0.999], 'batchsize': 10, 'bos_token': 1, 'channels': 1, 'data': 'dataset/data/train.pkl', 'debug': False, 'decoder_args': {'attn_on_attn': True, 'cross_attend': True, 'ff_glu': True, 'rel_pos_bias': False, 'use_scalenorm': False}, 'dim': 256, 'encoder_depth': 4, 'eos_token': 2, 'epochs': 10, 'gamma': 0.9995, 'heads': 8, 'id': None, 'load_chkpt': None, 'lr': 0.001, 'lr_step': 30, 'max_height': 192, 'max_seq_len': 512, 'max_width': 672, 'min_height': 32, 'min_width': 32, 'model_path': 'checkpoints', 'name': 'pix2tex', 'num_layers': 4, 'num_tokens': 8000, 'optimizer': 'Adam', 'output_path': 'outputs', 'pad': False, 'pad_token': 0, 'patch_size': 16, 'sample_freq': 3000, 'save_freq': 5, 'scheduler': 'StepLR', 'seed': 42, 'temperature': 0.2, 'test_samples': 5, 'testbatchsize': 20, 'tokenizer': 'dataset/tokenizer.json', 'valbatches': 100, 'valdata': 'dataset/data/val.pkl', 'wandb': False, 'device': 'cpu', 'max_dimensions': [672, 192], 'min_dimensions': [32, 32], 'out_path': 'checkpoints/pix2tex', 'config': 'settings/config.yaml', 'checkpoint': 'checkpoints/weights.pth', 'no_cuda': False, 'no_resize': False})
# logging.getLogger().setLevel(logging.FATAL)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
with open(arguments.config, 'r') as f:
    params = yaml.load(f, Loader=yaml.FullLoader)
args = parse_args(Munch(params))
args.update(**vars(arguments))
# args.device = "cpu"
args.device = 'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu'

model = get_model(args)
# summary(model, (1, 1, 64, 352))
# model.load_state_dict(torch.load(args.checkpoint, map_location=args.device))
# summary(model, (1, 1, 64, 352))

# if 'image_resizer.pth' in os.listdir(os.path.dirname(args.checkpoint)) and not arguments.no_resize:
#     image_resizer = ResNetV2(layers=[2, 3, 3], num_classes=max(args.max_dimensions)//32, global_pool='avg', in_chans=1, drop_rate=.05,
#                              preact=True, stem_type='same', conv_layer=StdConv2dSame).to(args.device)
#     image_resizer.load_state_dict(torch.load(os.path.join(os.path.dirname(args.checkpoint), 'image_resizer.pth'), map_location=args.device))
#     image_resizer.eval()
# else:
#     image_resizer = None
# tokenizer = PreTrainedTokenizerFast(tokenizer_file=args.tokenizer)
# return args, model, image_resizer, tokenizer

forward tensor([[[[[0.6461, 0.9796, 0.5770,  ..., 0.0884, 0.3716, 0.7028],
           [0.8204, 0.4483, 0.2467,  ..., 0.2225, 0.6587, 0.7426],
           [0.6052, 0.4622, 0.7299,  ..., 0.4954, 0.2070, 0.9071],
           ...,
           [0.0564, 0.4330, 0.6895,  ..., 0.0263, 0.9023, 0.3936],
           [0.2982, 0.2558, 0.8864,  ..., 0.4704, 0.9282, 0.7937],
           [0.4343, 0.1507, 0.2430,  ..., 0.8863, 0.6885, 0.9142]]]],



        [[[[0.6531, 0.7680, 0.2409,  ..., 0.6531, 0.3143, 0.2850],
           [0.4735, 0.0383, 0.5654,  ..., 0.0679, 0.4456, 0.6908],
           [0.7435, 0.8656, 0.9616,  ..., 0.0239, 0.5107, 0.7927],
           ...,
           [0.8872, 0.7160, 0.3235,  ..., 0.5862, 0.6248, 0.2348],
           [0.8961, 0.3489, 0.6081,  ..., 0.1254, 0.4765, 0.9102],
           [0.1613, 0.3081, 0.9532,  ..., 0.6353, 0.5779, 0.8079]]]]])
torch.Size([2, 1, 1, 64, 352])


ValueError: too many values to unpack (expected 4)

In [31]:
from PIL import Image
img = Image.open("./dataset/sample/1000a29807.png")

In [52]:
torch.Size([1, 1, 64, 352])

torch.Size([1, 1, 64, 352])

In [34]:
encoder, decoder = model.encoder, model.decoder
if type(img) is bool:
    img = None
if img is None:
    if last_pic is None:
        print('Provide an image.')
    else:
        img = last_pic.copy()
else:
    last_pic = img.copy()
img = minmax_size(pad(img), args.max_dimensions, args.min_dimensions)
if image_resizer is not None and not args.no_resize:
    with torch.no_grad():
        input_image = img.convert('RGB').copy()
        r, w, h = 1, input_image.size[0], input_image.size[1]
        for _ in range(10):
            h = int(h * r)  # height to resize
            img = pad(minmax_size(input_image.resize((w, h), Image.BILINEAR if r > 1 else Image.LANCZOS), args.max_dimensions, args.min_dimensions))
            t = test_transform(image=np.array(img.convert('RGB')))['image'][:1].unsqueeze(0)
            w = (image_resizer(t.to(args.device)).argmax(-1).item()+1)*32
            logging.info(r, img.size, (w, int(input_image.size[1]*r)))
            if (w == img.size[0]):
                break
            r = w/img.size[0]
else:
    img = np.array(pad(img).convert('RGB'))
    t = test_transform(image=img)['image'][:1].unsqueeze(0)
im = t.to(args.device)

with torch.no_grad():
    model.eval()
    device = args.device
    encoded = encoder(im.to(device))
    dec = decoder.generate(torch.LongTensor([args.bos_token])[:, None].to(device), args.max_seq_len,
                           eos_token=args.eos_token, context=encoded.detach(), temperature=args.get('temperature', .25))
    pred = post_process(token2str(dec, tokenizer)[0])
try:
    clipboard.copy(pred)
except:
    pass

In [35]:
pred

'\\left\\{\\begin{array}{r c l}{{\\delta_{\\epsilon}B}}&{{\\sim}}&{{\\epsilon F\\,,}}\\\\ {{\\delta_{\\epsilon}F}}&{{\\sim}}&{{\\partial\\epsilon+\\epsilon B\\,,}}\\end{array}\\right.'

In [38]:
prediction = pred.replace('<', '\\lt ').replace('>', '\\gt ')
prediction

'\\left\\{\\begin{array}{r c l}{{\\delta_{\\epsilon}B}}&{{\\sim}}&{{\\epsilon F\\,,}}\\\\ {{\\delta_{\\epsilon}F}}&{{\\sim}}&{{\\partial\\epsilon+\\epsilon B\\,,}}\\end{array}\\right.'

In [42]:
html = str('\\left\\{\\begin{array}{r c l}{{\\delta_{\\epsilon}B}}&{{\\sim}}&{{\\epsilon F\\,,}}\\\\ {{\\delta_{\\epsilon}F}}&{{\\sim}}&{{\\partial\\epsilon+\\epsilon B\\,,}}\\end{array}\\right.')
html

'\\left\\{\\begin{array}{r c l}{{\\delta_{\\epsilon}B}}&{{\\sim}}&{{\\epsilon F\\,,}}\\\\ {{\\delta_{\\epsilon}F}}&{{\\sim}}&{{\\partial\\epsilon+\\epsilon B\\,,}}\\end{array}\\right.'

In [43]:
from bs4 import BeautifulSoup
soup = BeautifulSoup(html)
print(soup.get_text())

\left\{\begin{array}{r c l}{{\delta_{\epsilon}B}}&{{\sim}}&{{\epsilon F\,,}}\\ {{\delta_{\epsilon}F}}&{{\sim}}&{{\partial\epsilon+\epsilon B\,,}}\end{array}\right.


In [45]:
pageSource = """
        <html>
        <head><script id="MathJax-script" src="qrc:MathJax.js"></script>
        <script>
        MathJax.Hub.Config({messageStyle: 'none',tex2jax: {preview: 'none'}});
        MathJax.Hub.Queue(
            function () {
                document.getElementById("equation").style.visibility = "";
            }
            );
        </script>
        </head> """ + """
        <body>
        <div id="equation" style="font-size:1em; visibility:hidden">$${equation}$$</div>
        </body>
        </html>
            """.format(equation=prediction)

In [46]:
from IPython.core.display import display, HTML
display(HTML(pageSource))