# Hyrax Model Export

In this getting started notebook we'll create an instance of a Hyrax object, train a builtin model on the CiFAR training dataset, and then show how to export the model's weights for inspection or evaluation.

It is recommended that you use Hyrax's infer verb for batch evaluation.

## Train a Hyrax model

We will configure it to use the builtin `HyraxAutoencoder` model, and immediately run training on the sample CIFAR dataset. Using 
the `prepare` and `infer` verbs we will also save the input dataset as well as the latent space representation post-training for
future exploration.

In [None]:
import hyrax
import torch

h = hyrax.Hyrax()
h.config["model"]["name"] = "HyraxAutoencoder"

dataset = h.prepare()
model = h.train()
latent_space = h.infer()

## Inspect the model

The return value from `train` is the torch model with its weights set by training. 

We can run individual data through the model in order to see the output. Note that whenever we pass a single tensor to our model, we must do `torch.stack([<our data>])`. This is because torch module functions accept and return batches of data rather than individual items.

In [None]:
test_batch = torch.stack([dataset[0]["image"]])
encoded = model.forward(test_batch)
encoded[0]

`HyraxAutoencoder` has private members `_eval_encoder` and `_eval_decoder` in addition to the normal `forward` function required of a model class. We can call these to see the decoded version of the model's latent space.

In [None]:
decoded = model._eval_decoder(encoded)
decoded

## Export the model

The model is already exported in the most recent results directory in two forms:

1. A pytorch weights file `example_model.pth` 
2. An onnx weights file `example_model_opset_##.onnx`

This directory is visible in the output from training; however, we can also list it programattically so you can see the files:

In [None]:
import os

results_dir = hyrax.config_utils.find_most_recent_results_dir(h.config, "train")
print(results_dir)
os.listdir(results_dir)

## Running your trained model outside Hyrax
If you want to run your trained model without importing hyrax, we highly recommend using the onnx runtime. This is because using the pytorch model weights file has significant drawbacks we will address in the next section.

Evaluating a model using Onnx in python is quite simple.

Further information on using Onnx can be found in the [Onnx Documentation](https://onnxruntime.ai/docs/).

In [None]:
# Get the filename from the results directory
onnx_model_filename = [filename for filename in os.listdir(results_dir) if filename[-4:] == "onnx"][0]
onnx_model_path = results_dir / onnx_model_filename
print(f"Onnx model filename: {onnx_model_path}")

# Run our single datum with ONNX
import onnxruntime as ort

test_batch = torch.stack([dataset[0]["image"]])
ort_sess = ort.InferenceSession(onnx_model_path)
outputs = ort_sess.run(None, {"input": test_batch.numpy()})
outputs

## Running your trained model with pytorch 
### (not recommended)

In order to load a pytorch file with weights the exact class structure of the pytorch model must be the same at load and save time. This means that you need a full copy of your python model class up to date and available in the program where you load the weights. You will also need to ensure that Python and PyTorch are on the exact same versions.

If these things are true, loading the model is relatively straightforward:

In [None]:
# Get the filename from the results directory
pth_model_filename = [filename for filename in os.listdir(results_dir) if filename[-3:] == "pth"][0]
pth_model_path = results_dir / pth_model_filename
print(f"Pytorch module filename: {pth_model_path}")


from hyrax.models.hyrax_autoencoder import HyraxAutoencoder

test_batch = torch.stack([dataset[0]["image"]])
imported_model = HyraxAutoencoder(dataset=dataset, config=h.config)
imported_model.load(pth_model_path)
imported_model.to(device="cpu")

encoded_from_import = imported_model.forward(test_batch)
encoded_from_import[0]