In [1]:
!pip install onnxruntime
!pip install jaxwt

[0m

In [2]:
from models.jax_vae import VAE
from dataset.load_tfrecords import SimpleDataloader
import jaxwt as jwt
import jax.numpy as jnp
import jax.random as random
import onnxruntime as ort
import numpy as np
import orbax
from flax.training import orbax_utils
import os

In [8]:
dataset_path = '/code/tfrecords'
dataloader = SimpleDataloader(tfrecord_pattern=os.path.join(dataset_path, "*.tfrecord"), batch_size=16)
jax_ds = dataloader.get_jax_iterator(shuffle=True)

imgs = next(jax_ds)['features']
single_img = imgs[0, ...][jnp.newaxis]
transformed = jwt.wavedec2(imgs, "haar", level=1, mode="reflect", axes=(1,2))
waves = jnp.concatenate([transformed[0], transformed[1][1], transformed[1][0], transformed[1][2]], axis=-1)

single_example = waves[0, ...][jnp.newaxis]
print(single_example.shape)

vae = VAE(base_features=32, latent_dim=128)
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
raw_restored = orbax_checkpointer.restore('/code/checkpoints/vae')

@jax.jit
def run_inference(sample):
    (x_recon, x_waves, mu, log_var) = vae.apply(
        {'params': raw_restored['model']},
        sample,
        training=False,
        key=random.key(0)
    )
    return x_recon, x_waves

# # Load model
session = ort.InferenceSession("/code/models/vae.onnx")
print("Model input info:")
for inp in session.get_inputs():
    print(f"  {inp.name}: {inp.shape}")
# # Run inference
inp = np.array(single_img.transpose(0, 3, 1, 2))
print(inp.shape)
output_onnx = session.run(None, {"input": inp})[1][0, ...]
output_jax = np.array(run_inference(single_example)[1])[0, ...]
print(output_jax.shape)
print(output_onnx.shape)

np.testing.assert_almost_equal(output_jax, output_onnx.transpose(1, 2, 0), decimal=2)

output_onnx = session.run(None, {"input": inp})[2][0, ...]
output_jax = np.array(waves)[0, ...]
print(output_jax.shape)
print(output_onnx.shape)

np.testing.assert_almost_equal(output_jax, output_onnx.transpose(1, 2, 0), decimal=2)


Found 32 TFRecord files
(1, 128, 128, 4)
Model input info:
  input: [1, 1, 256, 256]
(1, 1, 256, 256)
(128, 128, 4)
(4, 128, 128)
(128, 128, 4)
(4, 128, 128)
