In [12]:
import pyro
import torch
from pyro.infer import SVI, infer_discrete, TraceEnum_ELBO, TraceGraph_ELBO
from anneal.Interface import Interface
from anneal.models.MixtureGaussian import MixtureGaussian
from anneal.models.MixtureDirichlet import MixtureDirichlet
from anneal.models.MixtureGaussianDMP import MixtureGaussianDMP
from anneal.models.MixtureCategorical import MixtureCategorical
from anneal.models.HmmMIxtureRNA import HmmMixtureRNA
from anneal.models.HmmSimple import HmmSimple

import pyro
from pyro.optim import ClippedAdam
import anneal.utils 

from importlib import reload  



interface = Interface()
interface.set_model(MixtureGaussianDMP)
interface.set_loss(TraceEnum_ELBO)
interface.set_optimizer(ClippedAdam)

data_dict = anneal.utils.load_simulation_seg("anneal/data", "example1")
data_dict

{'data': tensor([[45., 36., 42.,  ..., 50., 47., 42.],
         [55., 55., 60.,  ..., 76., 88., 89.],
         [32., 14., 27.,  ..., 34., 32., 42.],
         ...,
         [43., 33., 43.,  ..., 53., 53., 51.],
         [13.,  3., 12.,  ..., 21., 19.,  9.],
         [28., 21., 40.,  ..., 45., 50., 46.]]),
 'pld': tensor([2.0000, 2.0000, 1.0000, 2.0000, 2.0000, 2.0000, 2.0000, 1.3000, 2.0000,
         2.0000, 1.0000, 1.7000, 2.0000, 2.0000, 3.0000, 2.0000, 2.0000, 2.0000,
         2.0000, 3.0000, 1.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000,
         2.0000, 2.0000, 2.0000, 1.0000, 2.0000]),
 'segments': 32,
 'mu': tensor([12, 17, 14, 13,  8, 12, 29, 11, 12, 24, 15, 15, 12, 15, 11,  9, 13,  9,
         12, 26, 20, 17, 16,  8,  8, 15, 16, 13,  9, 11,  7,  9])}

In [14]:
pyro.enable_validation(True)
interface.initialize_model(data_dict)
interface.set_model_params({'T' : 4})

loss = interface.run(300 ,MAP = True,  param_optimizer={"lr" : 0.05 }, verbose=False
                     )

{'T': 4, 'cnv_mean': 2, 'cnv_var': 0.6, 'theta_scale': 3, 'theta_rate': 1, 'batch_size': None, 'mixture': tensor([1, 1]), 'alpha': 0.01, 'gamma_multiplier': 5}
Running MixtureGaussianDMP on 1000 cells for 300 steps
..........

..........

..........

Done!


In [15]:

lr = interface.learned_parameters(posterior=False, verbose=False)
print(lr)



{'mixture_weights': array([0.0176087 , 0.30270222, 0.66903806, 0.01065105], dtype=float32), 'cnv_probs': array([[2.0059266 , 2.007425  , 0.8362006 , 1.7767787 , 2.098364  ,
        1.9205471 , 2.0340385 , 0.8012872 , 1.8802614 , 1.9742701 ,
        0.95668095, 1.7591872 , 1.755937  , 1.8347664 , 2.8318205 ,
        1.7437901 , 1.54105   , 1.6155257 , 1.7160128 , 2.9334698 ,
        0.90487796, 1.8957801 , 1.9830414 , 1.7486382 , 1.8832546 ,
        2.0762794 , 2.1394954 , 1.9307585 , 2.0251215 , 1.832886  ,
        0.93737006, 1.7570152 ],
       [1.9852037 , 1.9680603 , 1.0017264 , 1.9793115 , 2.0112545 ,
        1.9943826 , 1.9747865 , 1.9790934 , 1.9436585 , 1.9709723 ,
        0.98152304, 0.96804297, 1.9936032 , 1.9430892 , 2.943163  ,
        1.9599857 , 1.9815218 , 1.9436476 , 1.983     , 2.9485545 ,
        1.0082628 , 1.9419303 , 1.970281  , 1.9682404 , 1.9229746 ,
        1.9503103 , 1.9698972 , 1.9787732 , 1.9773796 , 1.9387294 ,
        0.9811526 , 1.9941232 ],
       [1.921

In [16]:
import pyro.poutine as poutine

guide_trace = poutine.trace(interface._guide_trained).get_trace()
model_trace = poutine.trace(poutine.replay(interface._model_trained, trace = guide_trace)).get_trace()
keys = ['obs_' + str(i) for i in range(32)]
obs_node = [model_trace.nodes.get(key) for key in keys]
res = 0

for nodes in obs_node:
 res += sum(nodes['fn'].log_prob(nodes['value']))
    
print(res)
print(model_trace.log_prob_sum())


tensor(-114115.3047, grad_fn=<AddBackward0>)
tensor(-116457.5078, grad_fn=<AddBackward0>)


[117.5926640625, 116.7628125, 115.449171875, 113.0401875, 113.28453125, 113.30265625, 112.2823671875, 111.5015078125, 111.2852421875, 111.06546875, 110.887234375, 110.9455, 110.9530859375, 110.89953125, 110.8688515625, 110.754015625, 110.5925859375, 110.4706953125, 110.43103125, 110.4429453125, 110.442421875, 110.41996875, 110.38703125, 110.32421875, 110.247328125, 110.1177109375, 109.898390625, 109.6146328125, 109.309078125, 108.99796875, 108.6815390625, 108.3839296875, 108.12365625, 107.9287578125, 107.822140625, 107.7778671875, 107.7648359375, 107.763234375, 107.7582890625, 107.734328125, 107.702671875, 107.6671953125, 107.625890625, 107.584234375, 107.5477578125, 107.517265625, 107.4806796875, 107.443390625, 107.4180625, 107.4031328125, 107.398734375, 107.401015625, 107.4078515625, 107.413125, 107.4128125, 107.4054765625, 107.39878125, 107.3878671875, 107.3764609375, 107.36828125, 107.3688125, 107.370140625, 107.3721171875, 107.3691171875, 107.36603125, 107.3654375, 107.3668125, 10