In [1]:
# -*- coding:utf-8 -*-
# create: @time: 10/8/23 11:47
import argparse

import torch
from PIL import Image
from transformers import VisionEncoderDecoderModel
from transformers.models.nougat import NougatTokenizerFast
from nougat_latex.util import process_raw_latex_code
from nougat_latex import NougatLaTexProcessor


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def parse_option():
    parser = argparse.ArgumentParser(prog="nougat inference config", description="model archiver")
    parser.add_argument("--pretrained_model_name_or_path", default="Norm/nougat-latex-base")
    parser.add_argument("--img_path", help="path to latex image segment", required=True)
    parser.add_argument("--device", default="gpu")
    return parser.parse_args()


def run_nougat_latex():
    device = torch.device("cuda:0")
    model_name = 'Norm/nougat-latex-base'
    imge_path = 'image/test01.png'

    # init model
    model = VisionEncoderDecoderModel.from_pretrained(model_name).to(device)

    # init processor
    tokenizer = NougatTokenizerFast.from_pretrained(model_name)
    latex_processor = NougatLaTexProcessor.from_pretrained(model_name)

    # run test
    image = Image.open(imge_path)
    if not image.mode == "RGB":
        image = image.convert('RGB')

    pixel_values = latex_processor(image, return_tensors="pt").pixel_values
    task_prompt = tokenizer.bos_token
    decoder_input_ids = tokenizer(task_prompt, add_special_tokens=False,
                                  return_tensors="pt").input_ids
    with torch.no_grad():
        outputs = model.generate(
            pixel_values.to(device),
            decoder_input_ids=decoder_input_ids.to(device),
            max_length=model.decoder.config.max_length,
            early_stopping=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            use_cache=True,
            num_beams=1,
            bad_words_ids=[[tokenizer.unk_token_id]],
            return_dict_in_generate=True,
        )
    sequence = tokenizer.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(tokenizer.eos_token, "").replace(tokenizer.pad_token, "").replace(tokenizer.bos_token,
                                                                                                  "")
    sequence = process_raw_latex_code(sequence)
    print(sequence)


if __name__ == '__main__':
    run_nougat_latex()


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


\delta^{L}=\left[\delta_{1}^{L},\delta_{2}^{L},\cdots,\delta_{j}^{L},\cdots\right]^{T}=\left[\frac{\partial{\cal C}}{\partial a_{1}^{L}}\cdot\sigma^{\prime}(z_{1}^{L}),\frac{\partial{\cal C}}{\partial a_{2}^{L}}\cdot\sigma^{\prime}(z_{2}^{L}),\cdots\frac{\partial{\cal C}}{\partial a_{j}^{L}}\cdot\sigma^{\prime}(z_{j}^{L}),\cdots\right]^{T}
