# Utilizing Custom ONNX Models Stored in Hugging Face within HSSM
This guide demonstrates how to use a custom ONNX model stored in Hugging Face to generate a log-likelihood in HSSM.

## Colab Instructions

If you would like to run this tutorial on Google colab, please click this [link](https://github.com/lnccbrown/HSSM/blob/main/docs/tutorial_notebooks/no_execute/getting_started.ipynb). 

Once you are *in the colab*, follow the *installation instructions below* and then **restart your runtime**. 

Just **uncomment the code in the next code cell** and run it!

**NOTE**:

You may want to *switch your runtime* to have a GPU or TPU. To do so, go to *Runtime* > *Change runtime type* and select the desired hardware accelerator.

Note that if you switch your runtime you have to follow the installation instructions again.

In [None]:
# !pip install numpy==1.23.4
# !pip install git+https://github.com/lnccbrown/hssm@main
# !pip install git+https://github.com/brown-ccv/hddm-wfpt@main
# !pip install numpyro

## Load Modules

In [6]:
import pandas as pd
import arviz as az
import numpy as np
import pandas as pd
import pytensor

import hssm
import ssms.basic_simulators

pytensor.config.floatX = "float32"

## Simulating Data
Start by simulating some data:

In [27]:
# Simulate some data
v_true, a_true, z_true, alpha_true, t_true = [0.5, 1.5, 0.5, 0.5, 0.3]
obs_angle = ssms.basic_simulators.simulator(
    [v_true, a_true, z_true, alpha_true, t_true], model="levy", n_samples=1000
)
obs_angle = np.column_stack([obs_angle["rts"][:, 0], obs_angle["choices"][:, 0]])
dataset_lan = pd.DataFrame(obs_angle, columns=["rt", "response"])
dataset_lan

Unnamed: 0,rt,response
0,0.367000,1.0
1,1.430997,1.0
2,0.424000,1.0
3,1.593004,1.0
4,0.560000,-1.0
...,...,...
995,1.707010,1.0
996,0.654999,1.0
997,1.201992,1.0
998,4.502878,1.0


## Loading ONNX Model from Hugging Face Repository. 
If `lexy.onnx` stored in [`hugging face repository`](https://huggingface.co/franklab/HSSM), you can specify it in HSSM model as shown below:. 

In [38]:
my_hssm = hssm.HSSM(
    data=dataset_lan,
    model="custom",
    model_config={
        "list_params": ["v", "a", "z", "alpha", "t"],
        "backend": "jax",
        "bounds": {
            "v": (-3.0, 3.0),
            "a": (0.3, 3.0),
            "z": (0.1, 0.9),
            "alpha": (1.0, 2.0),
            "t": (1e-3, 2.0),
        },
    },
    loglik_kind="approx_differentiable",
    loglik="angle.onnx",
)

This creates an HSSM object my_hssm using the custom ONNX model levy.onnx from the Hugging Face repository.

In [39]:
my_hssm.sample(cores=2, draws=500, tune=500, mp_ctx="forkserver")

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [alpha, z, t, a, v]


Sampling 2 chains for 500 tune and 500 draw iterations (1_000 + 1_000 draws total) took 35 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics


## Uploading ONNX Files to a Hugging Face Repository
If your ONNX file is not currently housed in your Hugging Face repository, you can include it by adhering to the steps delineated below:

1. Import the HfApi module from huggingface_hub:

In [2]:
from huggingface_hub import HfApi

2. Upload the ONNX file using the upload_file method:

```python
api = HfApi()
api.upload_file(
    path_or_fileobj="test.onnx",
    path_in_repo="test.onnx",
    repo_id="franklab/HSSM",
    repo_type="model",
    create_pr=True,
)
```

The execution of these steps will generate a Pull Request (PR) on Hugging Face, which will subsequently be evaluated by a member of our team.

## Creating a Pull Request and a New ONNX Model

1. **Creating a Pull Request on Hugging Face**

   Navigate to the following link: [Hugging Face PR](https://huggingface.co/franklab/HSSM/blob/refs%2Fpr%2F1/test.onnx)

   By doing so, you will **generate a Pull Request on Hugging Face**, which will be reviewed by our team members.

2. **Creating a Custom ONNX Model**

   ### Establish a Network Config and State Dictionary Files in PyTorch

   To construct a custom model and save it as an ONNX file, you must create a network configuration file and a state dictionary file in PyTorch. Refer to the instructions outlined in the README of the [LANFactory package](LINK_TO_LANFACTORY_PACKAGE).

   ### Convert Network Config and State Dictionary Files to ONNX

   Once you've generated the network configuration and state dictionary files, you will need to **convert these files into an ONNX format**.


![onnx conversion](../images/onnx.png)