Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hello, the Chinese onnx model inference seems not right. #9

Closed
lucasjinreal opened this issue Sep 29, 2022 · 12 comments
Closed

hello, the Chinese onnx model inference seems not right. #9

lucasjinreal opened this issue Sep 29, 2022 · 12 comments

Comments

@lucasjinreal
Copy link

HI. I have using this script export wenet model:

python export_onnx.py --bpe_model weights/icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char/ --pretrained_model weights/icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/pretrained_epoch_10_avg_2.pt

this weighs using pytorch inference result OK.

but when using it inference on onnx, I get result:

wav: ../../../data/test_data/16k16bit.wav search: greedy
loading ../../../weights/onnx/encoder.onnx
loading ../../../weights/onnx/decoder.onnx
loading ../../../weights/onnx/joiner.onnx
loading ../../../weights/onnx/joiner_encoder_proj.onnx
loading ../../../weights/onnx/joiner_decoder_proj.onnx
../../../data/test_data/16k16bit.wav
reading ../../../data/test_data/16k16bit.wav
wav shape: (1,112400)
Elapsed: 0.332 seconds
rtf: 0.0474286
Hyps:
3938-2261-519-3938-376-657-657-4250-4745-637-1449-1449-59-3978-376-249-5294-2480-1449-657-3786-519-249-519-249-519-249-519-249-519-249-1449-2480-1872-657-519-1713-3201-4186-3938-3938-4745-519-3938-376-519-249-3669-3938-4877-5525-3938-4745-519-3938-376-657-657-519-1713-3938-2261-1449-59-3978-376-249-5294-4745-1449-1449-3655-1449-3710-1449-1449-59-3978-376-249-5294-2480-1449-657-3786-519-249-519-249-519-249-519-249-519-249-519-249-519-249-519-249-519-249-519-249-519-249-519-249-1449-2480-3281-519-249-519-249-519-249-519-249-519-249-519-249-519-249-519-249-519-249-519-249-1449-2480-637-1449-1449-3978-858-1852-1574-249-2368-2480-1923-1449-2169-361-2480-1449-657-3786-519-249-519-249-519-249-519-249-1449-2480-1501-657-3786-3669-3938-4877-657-3938-4745-519-3938-376-|
诋犹估诋亲谅谅惬颔扑殊殊就孵亲务邗竖殊谅椎估务估务估务估务估务殊竖瞭谅估顿尬蜕诋诋颔估诋亲估务殓诋蹰蚡诋颔估诋亲谅谅估顿诋犹殊就孵亲务邗颔殊殊挎殊跺殊殊就孵亲务邗竖殊谅椎估务估务估务估务估务估务估务估务估务估务估务估务殊竖撇估务估务估务估务估务估务估务估务估务估务殊竖扑殊殊孵螃园径务咽竖脖殊液立竖殊谅椎估务估务估务估务殊竖迈谅椎殓诋蹰谅诋颔估诋亲

I checked the tokens seems normal.

What's could be missed?

@lucasjinreal
Copy link
Author

@csukuangfj @EmreOzkose Hello, can u guys help me out on this issue?

@csukuangfj
Copy link
Collaborator

@EmreOzkose
Is this repo ready for recognition?

@csukuangfj
Copy link
Collaborator

@jinfagang

Could you fix use
https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py
to check that the exported models are working when used in Python?

(Note: You need to make some changes to onnx_pretrained.py since it is for English models with bpe.model)

@EmreOzkose
Copy link
Collaborator

@csukuangfj yes, the repo can decode for English samples, but I didn't test it in Chinese.

@jinfagang sorry for late reply, I was mostly AFK for a few days. Can you check models with onnx_check.py?
You can use https://github.com/EmreOzkose/icefall/blob/onnx_proj_exports/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py
and also what is the output of onnx_pretrained.py?

@lucasjinreal
Copy link
Author

thank u all. I will try it

@lucasjinreal
Copy link
Author

@EmreOzkose @csukuangfj Hello. I have tried the onnx_pretrained.py`. got error:

/onnx_pretrained.py", line 177, in greedy_search
    logits = joiner.run(
  File "/Users/xx/miniforge3/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 192, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid rank for input: encoder_out Got: 2 Expected: 4 Please fix either the inputs or the model.

Can u make it more consistent since the onnx_pretrained are just have joiner without projector. While in c++ it needs projector.

The code can not make me have a e2e inference result. What have I miss here? (I just want make it get final ASR result)

@lucasjinreal
Copy link
Author

I got assertion error, when try onnx_check.py:

python check_onnx.py --jit-filename weights/icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/cpu_jit_epoch_10_avg_2_torch_1.11.0.pt --onnx-encoder-filename weights/onnx/encoder.onnx --onnx-decoder-filename weights/onnx/decoder.onnx --onnx-joiner-filename weights/onnx/joiner.onnx --onnx-joiner-encoder-proj-filename weights/onnx/joiner_encoder_proj.onnx --onnx-joiner-decoder-proj-filename weights/onnx/joiner_decoder_proj.onnx
2022-10-04 20:13:40,783 INFO [check_onnx.py:201] {'jit_filename': 'weights/icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/cpu_jit_epoch_10_avg_2_torch_1.11.0.pt', 'onnx_encoder_filename': 'weights/onnx/encoder.onnx', 'onnx_decoder_filename': 'weights/onnx/decoder.onnx', 'onnx_joiner_filename': 'weights/onnx/joiner.onnx', 'onnx_joiner_encoder_proj_filename': 'weights/onnx/joiner_encoder_proj.onnx', 'onnx_joiner_decoder_proj_filename': 'weights/onnx/joiner_decoder_proj.onnx'}
2022-10-04 20:13:40,966 INFO [check_onnx.py:209] Test encoder
N, T 1 12
Traceback (most recent call last):
  File "/check_onnx.py", line 252, in <module>
    main()
  File "/Users/xx/miniforge3/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "\
    test_encoder(model, encoder_session)
  File "//check_onnx.py", line 90, in test_encoder
    assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
AssertionError: tensor(2.7750)

@csukuangfj
Copy link
Collaborator

How did you get the onnx models? Are you using the same branch to get the onnx models and check_onnx.py?

@EmreOzkose
Copy link
Collaborator

@jinfagang I will debug with Chinese model today.

@lucasjinreal
Copy link
Author

@EmreOzkose thank u! I just using onnx model via same export script.

like this:


def get_parser():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument("-p", "--pretrained_model", type=str, help="pretrained model")
    parser.add_argument(
        "--bpe_model",
        type=str,
        default="data/lang_bpe_500/bpe.model",
        help="Path to the BPE model",
    )
    parser.add_argument(
        "--jit",
        type=str2bool,
        default=False,
        help="""True to save a model after applying torch.jit.script.
        It will generate 4 files:
         - encoder_jit_script.pt
         - decoder_jit_script.pt
         - joiner_jit_script.pt
         - cpu_jit.pt (which combines the above 3 files)
        Check ./jit_pretrained.py for how to use them.
        """,
    )
    parser.add_argument(
        "--jit-trace",
        type=str2bool,
        default=False,
        help="""True to save a model after applying torch.jit.trace.
        It will generate 3 files:
         - encoder_jit_trace.pt
         - decoder_jit_trace.pt
         - joiner_jit_trace.pt
        Check ./jit_pretrained.py for how to use them.
        """,
    )
    parser.add_argument(
        "--onnx",
        type=str2bool,
        default=True,
        help="""If True, --jit is ignored and it exports the model
        to onnx format. Three files will be generated:
            - encoder.onnx
            - decoder.onnx
            - joiner.onnx
        Check ./onnx_check.py and ./onnx_pretrained.py for how to use them.
        """,
    )
    parser.add_argument(
        "--context-size",
        type=int,
        default=2,
        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
    )
    parser.add_argument(
        "--streaming-model",
        type=str2bool,
        default=False,
        help="""Whether to export a streaming model, if the models in exp-dir
        are streaming model, this should be True.
        """,
    )
    add_model_arguments(parser)
    return parser


def export_encoder_model_jit_script(
    encoder_model: nn.Module,
    encoder_filename: str,
) -> None:
    """Export the given encoder model with torch.jit.script()
    Args:
      encoder_model:
        The input encoder model
      encoder_filename:
        The filename to save the exported model.
    """
    script_model = torch.jit.script(encoder_model)
    script_model.save(encoder_filename)
    logging.info(f"Saved to {encoder_filename}")


def export_decoder_model_jit_script(
    decoder_model: nn.Module,
    decoder_filename: str,
) -> None:
    """Export the given decoder model with torch.jit.script()
    Args:
      decoder_model:
        The input decoder model
      decoder_filename:
        The filename to save the exported model.
    """
    script_model = torch.jit.script(decoder_model)
    script_model.save(decoder_filename)
    logging.info(f"Saved to {decoder_filename}")


def export_joiner_model_jit_script(
    joiner_model: nn.Module,
    joiner_filename: str,
) -> None:
    """Export the given joiner model with torch.jit.trace()
    Args:
      joiner_model:
        The input joiner model
      joiner_filename:
        The filename to save the exported model.
    """
    script_model = torch.jit.script(joiner_model)
    script_model.save(joiner_filename)
    logging.info(f"Saved to {joiner_filename}")


def export_encoder_model_jit_trace(
    encoder_model: nn.Module,
    encoder_filename: str,
) -> None:
    """Export the given encoder model with torch.jit.trace()
    Note: The warmup argument is fixed to 1.
    Args:
      encoder_model:
        The input encoder model
      encoder_filename:
        The filename to save the exported model.
    """
    x = torch.zeros(1, 100, 80, dtype=torch.float32)
    x_lens = torch.tensor([100], dtype=torch.int64)

    traced_model = torch.jit.trace(encoder_model, (x, x_lens))
    traced_model.save(encoder_filename)
    logging.info(f"Saved to {encoder_filename}")


def export_decoder_model_jit_trace(
    decoder_model: nn.Module,
    decoder_filename: str,
) -> None:
    """Export the given decoder model with torch.jit.trace()
    Note: The argument need_pad is fixed to False.
    Args:
      decoder_model:
        The input decoder model
      decoder_filename:
        The filename to save the exported model.
    """
    y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
    need_pad = torch.tensor([False])

    traced_model = torch.jit.trace(decoder_model, (y, need_pad))
    traced_model.save(decoder_filename)
    logging.info(f"Saved to {decoder_filename}")


def export_joiner_model_jit_trace(
    joiner_model: nn.Module,
    joiner_filename: str,
) -> None:
    """Export the given joiner model with torch.jit.trace()
    Note: The argument project_input is fixed to True. A user should not
    project the encoder_out/decoder_out by himself/herself. The exported joiner
    will do that for the user.
    Args:
      joiner_model:
        The input joiner model
      joiner_filename:
        The filename to save the exported model.
    """
    encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
    decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
    encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
    decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)

    traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
    traced_model.save(joiner_filename)
    logging.info(f"Saved to {joiner_filename}")


def export_encoder_model_onnx(
    encoder_model: nn.Module,
    encoder_filename: str,
    opset_version: int = 11,
) -> None:
    """Export the given encoder model to ONNX format.
    The exported model has two inputs:
        - x, a tensor of shape (N, T, C); dtype is torch.float32
        - x_lens, a tensor of shape (N,); dtype is torch.int64
    and it has two outputs:
        - encoder_out, a tensor of shape (N, T, C)
        - encoder_out_lens, a tensor of shape (N,)
    Note: The warmup argument is fixed to 1.
    Args:
      encoder_model:
        The input encoder model
      encoder_filename:
        The filename to save the exported ONNX model.
      opset_version:
        The opset version to use.
    """
    x = torch.zeros(1, 100, 80, dtype=torch.float32)
    x_lens = torch.tensor([100], dtype=torch.int64)

    #  encoder_model = torch.jit.script(encoder_model)
    # It throws the following error for the above statement
    #
    # RuntimeError: Exporting the operator __is_ to ONNX opset version
    # 11 is not supported. Please feel free to request support or
    # submit a pull request on PyTorch GitHub.
    #
    # I cannot find which statement causes the above error.
    # torch.onnx.export() will use torch.jit.trace() internally, which
    # works well for the current reworked model
    warmup = 1.0
    torch.onnx.export(
        encoder_model,
        (x, x_lens, warmup),
        encoder_filename,
        verbose=False,
        opset_version=opset_version,
        input_names=["x", "x_lens", "warmup"],
        output_names=["encoder_out", "encoder_out_lens"],
        dynamic_axes={
            "x": {0: "N", 1: "T"},
            "x_lens": {0: "N"},
            "encoder_out": {0: "N", 1: "T"},
            "encoder_out_lens": {0: "N"},
        },
    )
    logging.info(f"Saved to {encoder_filename}")


def export_decoder_model_onnx(
    decoder_model: nn.Module,
    decoder_filename: str,
    opset_version: int = 11,
) -> None:
    """Export the decoder model to ONNX format.
    The exported model has one input:
        - y: a torch.int64 tensor of shape (N, decoder_model.context_size)
    and has one output:
        - decoder_out: a torch.float32 tensor of shape (N, 1, C)
    Note: The argument need_pad is fixed to False.
    Args:
      decoder_model:
        The decoder model to be exported.
      decoder_filename:
        Filename to save the exported ONNX model.
      opset_version:
        The opset version to use.
    """
    y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
    need_pad = False  # Always False, so we can use torch.jit.trace() here
    # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
    # in this case
    torch.onnx.export(
        decoder_model,
        (y, need_pad),
        decoder_filename,
        verbose=False,
        opset_version=opset_version,
        input_names=["y", "need_pad"],
        output_names=["decoder_out"],
        dynamic_axes={
            "y": {0: "N"},
            "decoder_out": {0: "N"},
        },
    )
    logging.info(f"Saved to {decoder_filename}")


def export_joiner_model_onnx(
    joiner_model: nn.Module,
    joiner_filename: str,
    opset_version: int = 11,
) -> None:
    """Export the joiner model to ONNX format.
    The exported model has two inputs:
        - encoder_out: a tensor of shape (N, encoder_out_dim)
        - decoder_out: a tensor of shape (N, decoder_out_dim)
    and has one output:
        - joiner_out: a tensor of shape (N, vocab_size)
    Note: The argument project_input is fixed to True. A user should not
    project the encoder_out/decoder_out by himself/herself. The exported joiner
    will do that for the user.
    """
    encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
    decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
    encoder_out = torch.rand(1, 1, 1, encoder_out_dim, dtype=torch.float32)
    decoder_out = torch.rand(1, 1, 1, decoder_out_dim, dtype=torch.float32)

    project_input = True
    # Note: It uses torch.jit.trace() internally
    torch.onnx.export(
        joiner_model,
        (encoder_out, decoder_out, project_input),
        joiner_filename,
        verbose=False,
        opset_version=opset_version,
        input_names=["encoder_out", "decoder_out", "project_input"],
        output_names=["logit"],
        dynamic_axes={
            "encoder_out": {0: "N"},
            "decoder_out": {0: "N"},
            "logit": {0: "N"},
        },
    )
    torch.onnx.export(
        joiner_model.encoder_proj,
        (encoder_out.squeeze(0).squeeze(0)),
        str(joiner_filename).replace(".onnx", "_encoder_proj.onnx"),
        verbose=False,
        opset_version=opset_version,
        input_names=["encoder_out"],
        output_names=["encoder_proj"],
        dynamic_axes={
            "encoder_out": {0: "N"},
            "encoder_proj": {0: "N"},
        },
    )
    torch.onnx.export(
        joiner_model.decoder_proj,
        (decoder_out.squeeze(0).squeeze(0)),
        str(joiner_filename).replace(".onnx", "_decoder_proj.onnx"),
        verbose=False,
        opset_version=opset_version,
        input_names=["decoder_out"],
        output_names=["decoder_proj"],
        dynamic_axes={
            "decoder_out": {0: "N"},
            "decoder_proj": {0: "N"},
        },
    )
    logging.info(f"Saved to {joiner_filename}")


def export_all_in_one_onnx(
    encoder_filename: str,
    decoder_filename: str,
    joiner_filename: str,
    all_in_one_filename: str,
):
    encoder_onnx = onnx.load(encoder_filename)
    decoder_onnx = onnx.load(decoder_filename)
    joiner_onnx = onnx.load(joiner_filename)

    encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
    decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
    joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")

    combined_model = onnx.compose.merge_models(encoder_onnx, decoder_onnx, io_map={})
    combined_model = onnx.compose.merge_models(combined_model, joiner_onnx, io_map={})
    onnx.save(combined_model, all_in_one_filename)
    logging.info(f"Saved to {all_in_one_filename}")


@torch.no_grad()
def main():
    args = get_parser().parse_args()

    params = get_default_params()
    params.update({"exp_dir": "weights/onnx"})
    params.exp_dir = Path(params.exp_dir)
    params.update(vars(args))

    os.makedirs(params.exp_dir, exist_ok=True)

    device = torch.device("cpu")
    if torch.cuda.is_available():
        device = torch.device("cuda", 0)

    logging.info(f"device: {device}")

    if "bpe" in args.bpe_model:
        sp = spm.SentencePieceProcessor()
        sp.load(args.bpe_model)
        # <blk> is defined in local/train_bpe_model.py
        params.blank_id = sp.piece_to_id("<blk>")
        params.vocab_size = sp.get_piece_size()
    else:
        lexions = Lexicon(args.bpe_model)
        # vc  = max(token_table) + 1
        # <blk> is defined in local/train_bpe_model.py
        params.blank_id = lexions.token_table["<blk>"]
        params.vocab_size = max(lexions.tokens) + 1
        logging.info("Reading Lexions...")

        sp = lexions

    if params.streaming_model:
        assert params.causal_convolution

    logging.info(params)

    logging.info("About to create model")
    # model = get_transducer_model(params, enable_giga=False)
    model = build_conformer_transducer_model(sp, params)
    model.to(device)

    model.load_state_dict(
        torch.load(args.pretrained_model, "cpu"),
        strict=False,
    )
    model.to("cpu")
    model.eval()

    if params.onnx is True:
        convert_scaled_to_non_scaled(model, inplace=True)
        opset_version = 11
        logging.info("Exporting to onnx format")
        encoder_filename = params.exp_dir / "encoder.onnx"
        export_encoder_model_onnx(
            model.encoder,
            encoder_filename,
            opset_version=opset_version,
        )

        decoder_filename = params.exp_dir / "decoder.onnx"
        export_decoder_model_onnx(
            model.decoder,
            decoder_filename,
            opset_version=opset_version,
        )

        joiner_filename = params.exp_dir / "joiner.onnx"
        export_joiner_model_onnx(
            model.joiner,
            joiner_filename,
            opset_version=opset_version,
        )

        all_in_one_filename = params.exp_dir / "all_in_one.onnx"
        export_all_in_one_onnx(
            encoder_filename,
            decoder_filename,
            joiner_filename,
            all_in_one_filename,
        )
    elif params.jit is True:
        convert_scaled_to_non_scaled(model, inplace=True)
        logging.info("Using torch.jit.script()")
        # We won't use the forward() method of the model in C++, so just ignore
        # it here.
        # Otherwise, one of its arguments is a ragged tensor and is not
        # torch scriptabe.
        model.__class__.forward = torch.jit.ignore(model.__class__.forward)
        logging.info("Using torch.jit.script")
        model = torch.jit.script(model)
        filename = params.exp_dir / "cpu_jit.pt"
        model.save(str(filename))
        logging.info(f"Saved to {filename}")

        # Also export encoder/decoder/joiner separately
        encoder_filename = params.exp_dir / "encoder_jit_script.pt"
        export_encoder_model_jit_script(model.encoder, encoder_filename)

        decoder_filename = params.exp_dir / "decoder_jit_script.pt"
        export_decoder_model_jit_script(model.decoder, decoder_filename)

        joiner_filename = params.exp_dir / "joiner_jit_script.pt"
        export_joiner_model_jit_script(model.joiner, joiner_filename)

    elif params.jit_trace is True:
        convert_scaled_to_non_scaled(model, inplace=True)
        logging.info("Using torch.jit.trace()")
        encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
        export_encoder_model_jit_trace(model.encoder, encoder_filename)

        decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
        export_decoder_model_jit_trace(model.decoder, decoder_filename)

        joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
        export_joiner_model_jit_trace(model.joiner, joiner_filename)


if __name__ == "__main__":
    main()

@lucasjinreal
Copy link
Author

@EmreOzkose hello, does Chines model able to get a right result?

@csukuangfj
Copy link
Collaborator

Please use the latest master.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants