In [1]:
import scanpy as sc
import pyro
import sys
sys.path.append("../")

## Example 4: Echidna & Pyro

In this tutorial, we show you how to use Echidna more flexibly by accessing the underlying data in the model. This will be useful if you have prior familiarity with Pyro and aim to perform more custom work.

In [2]:
import echidna as ec

As in the previous example, we load our AnnData from our previous run. Keep in mind that the copy number data is now saved in `.obs`.

In [3]:
adata = sc.read_h5ad("data/R310_MT_SAVE.h5ad")

Loading a model reads in a pickled model object and loads the Pyro parameter store. We can see that this is true by looking at the keys in the global parameter store.

In [4]:
echidna = ec.tl.load_model(adata)

In [5]:
pyro.get_param_store().keys()

dict_keys(['eta_mean', 'c_shape', 'scale_shape', 'scale_rate', 'corr_loc', 'corr_scale'])

If you make changes to the model object or param store, you can save those changes with the following function:

In [6]:
ec.tl.save_model(adata, echidna, overwrite=False)

2024-07-24 03:04:03,272 | INFO : Saving echidna model with run_id 20240724-030403.


The `echidna` object will contain the configuartion, the model and guide functions to run a forward pass through the model, as well as Torch tensors for the ground truth $\eta$, $c$ and $\Sigma$ from training.

In [6]:
eta, c, cov = echidna.eta_posterior, echidna.c_posterior, echidna.cov_posterior

Let's also see how we build the data tensors to do a forward pass. Echidna does this under-the-hood for you for safety purposes, but it may be helpful to do have access to the actual tensors. 

In [7]:
data = ec.tl.build_torch_tensors(adata, echidna.config)

A pass of the data through the model will return the X and W tensors passed through the model, after a full sampling of the model.

In [8]:
echidna.model(*data)

In [9]:
echidna.guide(*data)

This function uses poutine to trace the guide with the data, replay the model with that guide, and finally return nodes from the trace of the replay with the data. See for example a sampled $\eta$ compared to the ground truth of many averaged samples.

In [10]:
learned_params = ec.tl.get_learned_params(echidna, data)

In [11]:
learned_params["eta"]["value"], eta

(tensor([[0.9383, 3.1680, 0.7946,  ..., 3.2735, 1.2649, 1.5849],
         [3.6889, 3.6525, 1.0541,  ..., 2.6823, 0.2618, 0.9038],
         [2.5289, 1.5187, 1.4478,  ..., 1.8989, 2.7365, 3.3961],
         ...,
         [4.6113, 0.2153, 2.0284,  ..., 4.4695, 1.0533, 1.8157],
         [0.9125, 2.0520, 3.6108,  ..., 4.8281, 0.9271, 1.8465],
         [1.0457, 3.7517, 1.2641,  ..., 0.6974, 0.5060, 0.4976]],
        device='cuda:0', grad_fn=<AddBackward0>),
 tensor([[2.0112, 2.0578, 2.5723,  ..., 2.6592, 2.1364, 2.1962],
         [2.6046, 2.1048, 2.1410,  ..., 1.9683, 2.3005, 2.4361],
         [2.1090, 1.9782, 2.1131,  ..., 1.7650, 2.2017, 2.3996],
         ...,
         [3.0124, 2.6707, 2.0432,  ..., 2.4128, 3.1363, 3.3743],
         [2.0109, 2.2372, 2.1957,  ..., 1.9801, 2.5291, 2.2600],
         [2.0114, 2.3963, 2.1225,  ..., 2.1832, 2.5045, 2.3407]],
        device='cuda:0', requires_grad=True))

The rest of the problem, as they say, is behind the keyboard.