## Infer

In [1]:
import torch
import torch.nn as nn


In [2]:
path = 'ckpt/test/epoch21-0.8867.pth'
ckpt = torch.load(path)

In [None]:
class Net(nn.Module):
    """Network inference
    """
    def __init__(self):
        super().__init__()
        self.i_proj = nn.Linear(8, 64, False)
        self.alpha = nn.Parameter(torch.rand(64))
        self.h0 = nn.Parameter(torch.rand(1, 64))
        self.o_proj = nn.Linear(64, 1, False)
        self.act = nn.ReLU6()
        self.load()

    @torch.no_grad()
    def load(self):
        path = 'ckpt/test/epoch21-0.8867.pth'
        ckpt = torch.load(path)

        self.i_proj.weight.copy_(ckpt['0.weight'])
        self.alpha.copy_(ckpt['3.alpha'].squeeze([1, 2]))
        self.h0.copy_(ckpt['3.x0'] * (1-self.alpha))
        self.o_proj.weight.copy_(ckpt['6.weight'])

    @torch.no_grad()
    def forward(self, x, hx=None):
        g = self.alpha
        act = self.act
        i = act(self.i_proj(x))
        if hx is None:
            hx = self.h0
        h = g * hx + (1 - g) * i
        y = self.o_proj(act(h))
        return y, h


In [9]:
net = Net()

In [13]:
x = torch.tensor([[4.882802, 0.4116477, 0.2582662, 1.1432298, 0.34995735, 0.15397726, 0.07350865, 0.]])

In [None]:
y, h = net(x)
y

In [17]:
x = torch.tensor([[3.9512436 , 0.4121284 , 0.57202524, 0.9942352 , 0.38036877, 0.11640994, 0.07208236, 0.08725858]])

In [18]:
y, h = net(x, h)
y

tensor([[2.5762]])

## ONNX

In [26]:
dummy = (torch.randn(1, 8), torch.randn(1, 64))  # 替换成你的真实输入尺寸

torch.onnx.export(
    net,
    dummy, 'test/model.onnx',
    input_names=['input', 'hx'],
    output_names=['output', 'hy'],
    opset_version=13
)

  torch.onnx.export(


In [28]:
import onnx, onnxruntime as ort, numpy as np

# onnx_model = onnx.load('test/model.onnx')
# onnx.checker.check_model(onnx_model)

sess = ort.InferenceSession('test/model.onnx', providers=["CPUExecutionProvider"])
out = sess.run(None, {'input': dummy[0].numpy(), 'hx': dummy[1].numpy()})
# 与 PyTorch 输出比较
# pt_out = net(dummy).detach().numpy()
# print(np.max(np.abs(out - pt_out)))

In [30]:
net(*dummy)

(tensor([[0.2246]]),
 tensor([[-0.0215, -0.8248,  0.0657,  0.0301,  0.4198, -0.3752,  0.4744,  0.0853,
          -0.4335, -0.0017,  1.3332,  0.3892, -0.1745,  0.1359, -0.9661, -0.2100,
           0.0559,  0.4735, -0.8220,  0.1195,  0.0377, -0.1647,  0.8174,  0.0791,
          -0.4159, -0.3635, -0.0192,  0.1987, -0.8256,  0.2248, -0.2913,  0.5854,
           0.4149, -0.6019,  0.0531, -0.1081,  0.2635,  0.0284,  1.6606,  0.0615,
           0.1233, -0.0058, -0.1948,  0.9033,  0.0801,  0.0900,  0.1275,  0.4002,
           0.0910, -0.4605,  1.2778, -0.3346, -1.0528,  0.0038, -1.6449,  0.4506,
           0.3029, -0.7454, -0.0880, -0.0716,  1.0241, -0.1311,  0.0194, -0.6278]]))

In [29]:
out

[array([[0.22462046]], dtype=float32),
 array([[-0.02149462, -0.8247864 ,  0.065715  ,  0.03011797,  0.41984096,
         -0.37517273,  0.47437143,  0.08528044, -0.4335284 , -0.0017328 ,
          1.333152  ,  0.38923585, -0.17446381,  0.13587855, -0.96608984,
         -0.21001492,  0.05589665,  0.47351813, -0.8220247 ,  0.11953185,
          0.03770047, -0.16471067,  0.81736815,  0.07914691, -0.4159286 ,
         -0.36345395, -0.0192393 ,  0.1987017 , -0.82560915,  0.22475633,
         -0.29132307,  0.5853786 ,  0.41486594, -0.6019309 ,  0.05312075,
         -0.10809331,  0.26349932,  0.02837995,  1.660605  ,  0.06152598,
          0.12331674, -0.00579717, -0.1947561 ,  0.9032636 ,  0.08008309,
          0.09000917,  0.12745585,  0.4002468 ,  0.09102561, -0.46048966,
          1.2777627 , -0.3346157 , -1.0528045 ,  0.00380314, -1.6448549 ,
          0.45057708,  0.30287167, -0.74543256, -0.08795726, -0.07163236,
          1.0241157 , -0.13109833,  0.01939382, -0.6278176 ]],
       dty

In [23]:
out

array([[1.3314825]], dtype=float32)