# Convert from pytorch to C++/Matlab
Convert the net so it can be loaded by pytorch c++ inference. 

In [None]:
import torch
from utils import *

model_path = 'data/local/best_mse.pth'

net = LiuqeNet()
test_network_io(verbose=False)
print(f'loading {model_path}')
net.load_state_dict(torch.load(model_path, map_location='cpu'), strict=True)
net.eval()

torch.jit.script(net).save('net.pt')
print('net.pt saved')

In [None]:
# create an example small ds and test inference on it to verify the conversion later
import numpy as np
from scipy.io import savemat
ds = LiuqeDataset(EVAL_DS_PATH, verbose=False)

ni = 5 # number of samples

xs, _, rs, zs = ds[np.random.randint(0, len(ds), ni)] # inputs
ys = net(xs, rs, zs) # outputs

print(f'torch -> xs: {xs.shape}, ys: {ys.shape}, rs: {rs.shape}, zs: {zs.shape}')

xs, ys, rs, zs = xs.numpy(), ys.detach().numpy().reshape(-1,NGR*NGZ), rs.numpy(), zs.numpy()
print(f'numpy -> xs: {xs.shape}, ys: {ys.shape}, rs: {rs.shape}, zs: {zs.shape}')

# print the values to a txt files to compare them later
with open(f'{TEST_DIR}/test_inference.txt', 'w') as f:
    f.write('xs:\n')
    for x in xs: 
        f.write(' [')
        for xi in x: f.write(f' {xi:+.4f}')
        f.write(' ]\n')
    f.write('ys:\n')
    for y in ys: 
        f.write(' [')
        for i in range(NGR*NGZ): 
            if i%NGR == 0 and i!=0: f.write(' ] [')
            f.write(f' {y[i]:+.4f}')
        f.write(' ]\n')
    f.write('rs:\n')
    for r in rs: 
        f.write(' [')
        for ri in r: f.write(f' {ri:+.4f}')
        f.write(' ]\n')
    f.write('zs:\n')
    for z in zs: 
        f.write(' [')
        for zi in z: f.write(f' {zi:+.4f}')
        f.write(' ]\n')

# print the first ns samples of each vector
ns = 5
np.set_printoptions(precision=4, suppress=True, sign='+')
print(f'x -> {xs[0,:ns]}')
print(f'y -> {ys[0,:ns]}')
print(f'r -> {rs[0,:ns]}')
print(f'z -> {zs[0,:ns]}')

# save the data to a .mat file for later comparison
savemat(f'{TEST_DIR}/test_inference.mat', {'xs': xs, 'ys': ys, 'rs': rs, 'zs': zs})