In [1]:
import torch
import torch.nn as nn
import torch.onnx
import numpy as np
import onnx
import onnxruntime

In [2]:
class Encoder(nn.Module):
    def __init__(self, embedding_tensor):
        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(embedding_tensor, freeze=True, padding_idx=0)
        self.rnn = nn.LSTM(embedding_tensor.shape[1], embedding_tensor.shape[1], batch_first=True, bidirectional=True)

    def forward(self, x):
        # x = torch.flip(x, [-1])
        x = self.embedding(x)
        output, hc = self.rnn(x)
        return output, hc

class Decoder(nn.Module) :
    def __init__(self, embedding_tensor):
        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(embedding_tensor, freeze=True, padding_idx=0)
        self.rnn = nn.LSTM(embedding_tensor.shape[1] * 3, embedding_tensor.shape[1], batch_first=True, bidirectional=True)
        self.f = nn.Linear(embedding_tensor.shape[1] * 4, embedding_tensor.shape[0])
        self.encoder_h_context = None

    def forward(self, encoder_output, encoder_hc, t = None) :
        encoder_h_forward = encoder_hc[0][0:1,:,:]
        encoder_h_backward = encoder_hc[0][1:2,:,:]
        self.encoder_h_context = torch.concat([encoder_h_forward,encoder_h_backward], dim = -1).transpose(0,1)
        batch_size = encoder_output.shape[0]
        decoder_input = torch.zeros(batch_size, 1).type(torch.long).to(encoder_output.device)
        decoder_hc = encoder_hc
        decoder_output_list = []

        for i in range(4) :
            decoder_output, decoder_hc = self.forward_sub(decoder_input, decoder_hc)
            decoder_output_list.append(decoder_output)

            if t is None :
                decoder_input = decoder_output.argmax(dim = -1).detach()
            else :
                decoder_input = t[:, i].unsqueeze(-1)

        decoder_output_list = torch.cat(decoder_output_list, dim=1)
        return decoder_output_list, decoder_hc, None

    def forward_sub(self, x, h) :
        x = self.embedding(x)
        x = torch.concat([self.encoder_h_context, x], dim = -1)
        output, hc = self.rnn(x, h)
        output = torch.concat([self.encoder_h_context, output], dim = -1)
        output = self.f(output)
        return output, hc

In [3]:
encoder = torch.load("num_encoder.pt", weights_only=False)
print(encoder)

Encoder(
  (embedding): Embedding(12, 11, padding_idx=0)
  (rnn): LSTM(11, 11, batch_first=True, bidirectional=True)
)


In [4]:
x = torch.randint(0,12,size = (1,7)).type(torch.long)
dynamic_axes = {"input" : {0 : "b"}, "output" : {0 : "b"}, "h" : {1 : "b"}, "c" : {1 : "b"}}

encoder.eval()
torch.onnx.export(
    encoder,
    x,
    "num_encoder.onnx",
    export_params=True,
    opset_version=10,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output", "h", "c"],
    dynamic_axes=dynamic_axes
)



In [5]:
onnx_encoder = onnx.load("num_encoder.onnx")
onnx.checker.check_model(onnx_encoder)

In [6]:
np_x = np.random.randint(0,12,size=(1,7)).astype(np.int64)
tensor_x = torch.tensor(np_x, dtype = torch.long)

onnx_encoder = onnxruntime.InferenceSession("num_encoder.onnx", providers=["CPUExecutionProvider"])
inputs = {onnx_encoder.get_inputs()[0].name : np_x}
np_y = onnx_encoder.run(None, inputs)

y, hc = encoder(tensor_x)

np.testing.assert_allclose(y.detach().numpy(), np_y[0], rtol=1e-03, atol=1e-05)
np.testing.assert_allclose(hc[0].detach().numpy(), np_y[1], rtol=1e-03, atol=1e-05)
np.testing.assert_allclose(hc[1].detach().numpy(), np_y[2], rtol=1e-03, atol=1e-05)

In [7]:
decoder = torch.load("num_decoder.pt", weights_only=False)
print(decoder)

Decoder(
  (embedding): Embedding(12, 11, padding_idx=0)
  (rnn): LSTM(33, 11, batch_first=True, bidirectional=True)
  (f): Linear(in_features=44, out_features=12, bias=True)
)


In [8]:
encoder_output = torch.randn(1,7,22)
encoder_h = torch.randn(2,1,11)
encoder_c = torch.randn(2,1,11)
encoder_hc = (encoder_h, encoder_c)
dynamic_axes = {"encoder_output" : {0 : 'b'}, "encoder_h" : {1 : 'b'}, "encoder_c" : {1 : 'b'}, "output" : {0 : 'b'}, "h" : {1 : 'b'}, "c" : {1 : 'b'}}

decoder.eval()
torch.onnx.export(
    decoder,
    (encoder_output, encoder_hc, None),
    "num_decoder.onnx",
    export_params=True,
    opset_version=10,
    do_constant_folding=True,
    input_names=["encoder_output", "encoder_h", "encoder_c"],
    output_names=["output","h","c"],
    dynamic_axes=dynamic_axes
)

In [9]:
onnx_decoder = onnx.load("num_decoder.onnx")
onnx.checker.check_model(onnx_decoder)

In [11]:
np_encoder_input = np.random.randn(1,7,22).astype(np.float32)
np_encoder_h = np.random.randn(2,1,11).astype(np.float32)
np_encoder_c = np.random.randn(2,1,11).astype(np.float32)

onnx_decoder = onnxruntime.InferenceSession("num_decoder.onnx", providers=["CPUExecutionProvider"])
inputs = {onnx_decoder.get_inputs()[0].name : np_encoder_input, onnx_decoder.get_inputs()[1].name : np_encoder_h, onnx_decoder.get_inputs()[2].name : np_encoder_c}
np_y = onnx_decoder.run(None, inputs)

tensor_encoder_input = torch.tensor(np_encoder_input, dtype = torch.float)
tensor_encoder_h = torch.tensor(np_encoder_h, dtype = torch.float)
tensor_encoder_c = torch.tensor(np_encoder_c, dtype = torch.float)

y, hc, _ = decoder(tensor_encoder_input, (tensor_encoder_h, tensor_encoder_c), None)

np.testing.assert_allclose(y.detach().numpy(), np_y[0], rtol=1e-03, atol=1e-05)
np.testing.assert_allclose(hc[0].detach().numpy(), np_y[1], rtol=1e-03, atol=1e-05)
np.testing.assert_allclose(hc[1].detach().numpy(), np_y[2], rtol=1e-03, atol=1e-05)