In [1]:
import logging

import pickle

import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from torch.distributions import constraints
from torch import nn
import pyro
import pyro.distributions as dist
import pyro.optim as optim
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.infer import Predictive
import seaborn as sns
import torch.nn.functional as F
from pyro import poutine
from sklearn import metrics
from pyro.infer.autoguide import AutoMultivariateNormal, AutoLowRankMultivariateNormal, init_to_mean,init_to_feasible,AutoNormal

In [2]:
pyro.set_rng_seed(10)
with open('data_all.pickle', 'rb') as handle:
    data = pickle.load(handle)
print(data.shape)

(1127, 5237)


In [3]:

class PMF_clusters(nn.Module):
    #bayesian non parametrics - dirichlet process

    #with multivariate gammas that are "somehow?" related through pyro's dependent dimension setting
    #how to define their covariance?
    def __init__(self, train, dim,p1,p2):
        super().__init__()
        """Build the Probabilistic Matrix Factorization model using pymc3.
        """
        self.dim = dim   
        self.data = train.copy()
        self.n, self.m = self.data.shape
        self.map = None
        self.bounds = (0,1)
        self.losses = None
        self.predictions = None
        self.t_drug = p1[0]
        self.phi_drug = p1[1]
        self.t_s = p2[0]
        self.phi_s = p2[1]


        # Perform mean value imputation
    
        
        # Low precision reflects uncertainty; prevents overfitting.
        # Set to the mean variance across users and items.
        self.alpha_u = (np.mean(self.data, axis=1).mean())**2 / np.std(self.data, axis=1).mean()
        self.alpha_v = (np.mean(self.data, axis=0).mean())**2 / np.std(self.data, axis=0).mean()

        self.beta_u = (np.mean(self.data, axis=1).mean()) / np.std(self.data, axis=1).mean()
        self.beta_v = (np.mean(self.data, axis=0).mean()) / np.std(self.data, axis=0).mean()
       
        self.bias = self.data.mean()
        self.num_clusters_drugs = 73
        self.num_clusters_se =  77
    def mix_weights(self,beta):
            beta1m_cumprod = (1 - beta).cumprod(-1)
            return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)

    def model(self, data):
            alpha = 10
            with pyro.plate("beta_drugs_plate", self.num_clusters_drugs-1):
                beta_drugs = pyro.sample("beta_drugs", dist.Beta(1, alpha))

            with pyro.plate("mu_drugs_plate", self.num_clusters_drugs):
                mu_drugs = pyro.sample("mu_drugs", dist.MultivariateNormal(torch.zeros(self.dim), 0.1 * torch.eye(self.dim)))
            #     cov = pyro.sample("cov", dist.Gamma(torch.ones(1), 2*torch.ones(1)).expand([self.dim]).to_event(1) )

            # cov_matrix = torch.ones(self.num_clusters_drugs,self.dim,self.dim)
            # for i in range(0,self.num_clusters_drugs):
            #     cov_matrix[i] = cov[i]*torch.eye(self.dim)

            with pyro.plate("data_drugs", self.n):
                z_d = pyro.sample("z_drugs", dist.Categorical(self.mix_weights(beta_drugs)))
                UA = pyro.sample("UA", dist.MultivariateNormal(mu_drugs[z_d],  0.05*torch.eye(self.dim)))

            with pyro.plate("beta_se_plate", self.num_clusters_se-1):
                beta_se = pyro.sample("beta_se", dist.Beta(1, alpha))

            with pyro.plate("mu_plate_se", self.num_clusters_se):
                mu_se = pyro.sample("mu_se", dist.MultivariateNormal(torch.zeros(self.dim), 0.1 * torch.eye(self.dim)))
            #     cov2 = pyro.sample("cov2", dist.Gamma(torch.ones(1), 2*torch.ones(1)).expand([self.dim]).to_event(1) )

            # cov_matrix2 = torch.ones(self.num_clusters_se,self.dim,self.dim)
            # for i in range(0,self.num_clusters_se):
            #     cov_matrix2[i] = cov2[i]*torch.eye(self.dim)

            with pyro.plate("data_sideeffects", self.m):
                z_se = pyro.sample("z_se", dist.Categorical(self.mix_weights(beta_se)))
                VA = pyro.sample("VA", dist.MultivariateNormal(mu_se[z_se], 0.05*torch.eye(self.dim)))#cov2[z_se]))
            
            u2_plate = pyro.plate("u2_plate", self.n, dim=-2)
            se2_plate = pyro.plate("se2_plate", self.m, dim=-1)

            with se2_plate, u2_plate: 
                Y = pyro.sample("target", dist.Poisson(torch.abs(UA@VA.T)), obs=data ) 
                return Y

    def guide(self,data=None):
            kappa = pyro.param('kappa_d', lambda: dist.Uniform(0, 2).sample([self.num_clusters_drugs-1]), constraint=constraints.positive)
            tau = pyro.param('tau_d', self.t_drug)
            phi = pyro.param('phi_d', self.phi_drug, constraint=constraints.simplex)
            # d_alpha = pyro.param('d_alpha', torch.ones(self.n,self.dim), constraint=constraints.positive)#*self.user_mean)
            # d_beta = pyro.param('d_beta', 0.5*torch.ones(self.n,self.dim), constraint=constraints.positive)


            with pyro.plate("beta_plate", self.num_clusters_drugs-1):
                beta_drugs = pyro.sample("beta_drugs", dist.Beta(torch.ones(self.num_clusters_drugs-1), kappa))

            with pyro.plate("mu_plate_drug", self.num_clusters_drugs):
                mu_drugs = pyro.sample("mu_drugs", dist.MultivariateNormal(tau, 0.01*torch.eye(self.dim)))
                

            with pyro.plate("data_drug", self.n):
                z_d = pyro.sample("z_drugs", dist.Categorical(phi))
                UA = pyro.sample("UA", dist.MultivariateNormal(mu_drugs[z_d], 0.1*torch.eye(self.dim)))

            
            kappa_s = pyro.param('kappa_s', lambda: dist.Uniform(0, 2).sample([self.num_clusters_se-1]), constraint=constraints.positive)
            tau_s = pyro.param('tau_s', self.t_s)
            phi_s = pyro.param('phi_s', self.phi_s, constraint=constraints.simplex)

            with pyro.plate("beta_se_plate", self.num_clusters_se-1):
                beta_se = pyro.sample("beta_se", dist.Beta(torch.ones(self.num_clusters_se-1), kappa_s))

            with pyro.plate("mu_plate_se", self.num_clusters_se):
                mu_se = pyro.sample("mu_se", dist.MultivariateNormal(tau_s, 0.01*torch.eye(self.dim)))

            with pyro.plate("data_sideeffects", self.m):
                z_se = pyro.sample("z_se", dist.Categorical(phi_s))
                VA = pyro.sample("VA", dist.MultivariateNormal(mu_se[z_se], 0.1*torch.eye(self.dim)))

    def train_SVI(self,train, nsteps=250, lr = 0.05, lrd = 1):
                logging.basicConfig(format='%(message)s', level=logging.INFO)
                svi = SVI(self.model,
                self.guide,
                optim.ClippedAdam({"lr": lr, "lrd": lrd}),
                loss=Trace_ELBO())
                losses = []
                for step in range(nsteps):
                    elbo = svi.step(torch.from_numpy(train).float())
                    losses.append(elbo)
                    if step % 10 == 0:
                        print("Elbo loss: {}".format(elbo))
                self.losses = losses
                #constrained_params = list(pyro.get_param_store().values())
                #PARAMS = [p.unconstrained() for p in constrained_params]
                #print(PARAMS)
                return losses

In [4]:
with open('data_all.pickle', 'rb') as handle:
    data = pickle.load(handle)
with open('drug_param.pickle', 'rb') as handle:
     param_drugs = pickle.load(handle)

with open('se_param.pickle', 'rb') as handle:
     param_se = pickle.load(handle)

test = PMF_clusters(train=data, dim=100, p1=param_drugs, p2=param_se)
test.train_SVI(data)



Elbo loss: 123953260.4874115
Elbo loss: 134619347.82818985


[E thread_pool.cpp:113] Exception in thread pool task: mutex lock failed: Invalid argument


KeyboardInterrupt: 

In [7]:

predictive_svi = Predictive(test.model, guide=test.guide, num_samples=700)(None )
for k, v in predictive_svi.items():
    print(f"{k}: {tuple(v.shape)}")
table = predictive_svi["target"].numpy()
print(table)


beta_drugs: (700, 1, 72)
mu_drugs: (700, 1, 73, 100)
z_drugs: (700, 1, 1127)
UA: (700, 1, 1127, 100)
beta_se: (700, 1, 76)
mu_se: (700, 1, 77, 100)
z_se: (700, 1, 5237)
VA: (700, 1, 5237, 100)
target: (700, 1127, 5237)
[[[ 8.  7.  4. ...  1. 31.  6.]
  [10. 11.  9. ... 12. 20.  2.]
  [22. 14. 18. ... 30.  0. 21.]
  ...
  [ 0.  1.  3. ... 10.  7.  8.]
  [ 0.  0.  2. ...  4.  9.  5.]
  [13.  7.  1. ... 17. 11.  6.]]

 [[ 2.  6.  3. ...  7.  2.  3.]
  [ 5.  6.  3. ...  9.  3.  8.]
  [ 9.  4.  4. ...  7. 11.  1.]
  ...
  [24. 14.  6. ...  5.  0. 14.]
  [ 3.  9. 40. ...  6.  0. 11.]
  [ 7.  2. 25. ... 22.  3. 13.]]

 [[14. 16. 26. ...  1.  9.  9.]
  [28. 10.  2. ... 14.  6.  4.]
  [ 5.  0. 10. ... 22. 15. 40.]
  ...
  [ 0.  1. 12. ...  0.  9. 16.]
  [14.  0.  6. ...  7.  1. 13.]
  [14.  4.  3. ... 26.  6. 19.]]

 ...

 [[20.  2.  9. ... 15.  6.  6.]
  [15. 19.  2. ... 15.  7. 15.]
  [ 9.  3. 13. ... 14. 10. 20.]
  ...
  [ 4. 14. 21. ... 17.  2. 19.]
  [ 1.  9. 15. ...  4.  7.  5.]
  [18.  7

In [7]:
table2 = predictive_svi["UA"].numpy()
print(table2)

[[[[-0.28535977  0.5407753  -0.11483517 ...  0.75336343  1.2187921
    -1.168079  ]
   [ 0.27519417  0.6142986  -0.06168953 ... -0.5462328   1.0791954
    -1.0290871 ]
   [-2.520281    0.49413767  1.2901305  ...  1.2930528  -0.40262035
    -0.7137808 ]
   ...
   [ 1.1134045   2.2058878  -0.10623908 ...  2.0770118  -1.7445196
    -0.13941166]
   [ 0.7838261   1.2989461  -1.6670617  ...  1.0509782  -0.63851213
    -1.2264577 ]
   [-2.2912352  -0.24184676  1.016104   ...  0.6750027   0.00472934
    -0.5812001 ]]]


 [[[-0.90443933 -0.8370635  -0.6256679  ...  1.0380534   0.35315022
    -0.15691854]
   [-0.79911983  0.4583026   0.43753034 ... -2.6858876   0.54812634
     0.34935674]
   [ 1.8308849  -0.90615886 -0.8976115  ...  0.7617015   0.5277555
     0.78012705]
   ...
   [ 1.0805224   1.0855103   0.46282822 ... -0.17598367 -1.0061806
    -0.9500762 ]
   [-1.1788591  -1.5199696   0.0466184  ...  2.2012267   2.1054943
     1.5023819 ]
   [-0.6371386   0.69335437  0.45942116 ... -1.665205

In [14]:
class PM_test_normals(nn.Module):
    # by default our latent space is 50-dimensional
    # and we use 400 hidden units
    def __init__(self, train, dim):
        super().__init__()
        """Build the Probabilistic Matrix Factorization model using pymc3.



        """
        self.dim = dim   
        self.data = train.copy()
        self.n, self.m = self.data.shape
        self.map = None
        self.bounds = (0,1)
        self.losses = None
        self.predictions = None
        self.returned = None


        # Perform mean value imputation
    
        
        # Low precision reflects uncertainty; prevents overfitting.
        # Set to the mean variance across users and items.
        self.alpha_u = np.mean(self.data, axis=1).mean()
        self.alpha_v =  np.std(self.data, axis=1).mean()
        
        self.beta_u = np.mean(self.data, axis=0).mean() 
        self.beta_v =  np.std(self.data, axis=0).mean()
        self.bias = self.data.mean()


    def model(self, train):

        drug_plate = pyro.plate("drug_latents", self.n, dim= -1) #independent users
        sideeffect_plate = pyro.plate("sideeffect_latents", self.m, dim= -1) #independent items

        with drug_plate: 
            UA = pyro.sample("UA", dist.Normal(self.alpha_u, self.beta_u).expand([self.dim]).to_event(1))
            #UA_int = pyro.sample("UAint", dist.Normal(0., 1.))
        
        with sideeffect_plate:
            VA = pyro.sample("VA", dist.Normal(self.alpha_v, self.beta_v).expand([self.dim]).to_event(1))
            #possibly add intercepts VA_int = pyro.sample("VA", dist.Normal(0., 1.).to_event(1))
       
        u2_plate = pyro.plate("u2_plate", self.n, dim=-2)

        with sideeffect_plate, u2_plate: 
         
             Y = pyro.sample("target", dist.Poisson(torch.abs(UA@VA.T)), obs=train ) 
             return Y
        

    def guide(self, train=None, mask=None):

        d_alpha = pyro.param('d_alpha', torch.ones(self.n,self.dim), constraint=constraints.positive)#*self.user_mean)
        d_beta = pyro.param('d_beta', 0.5*torch.ones(self.n,self.dim), constraint=constraints.positive)
       # int_mean = pyro.param('int_mean', torch.tensor(1.)*self.user_mean)
       # mov_cov = pyro.param('mov_cov', torch.tensor(1.)*0.1,
          #                  constraint=constraints.positive)

        s_alpha = pyro.param('s_alpha', torch.ones(self.m,self.dim), constraint=constraints.positive)#*self.item_mean)
        s_beta = pyro.param('s_beta', 0.5*torch.ones(self.m,self.dim), constraint=constraints.positive)
        drug_plate = pyro.plate("drug_latents", self.n, dim= -1) #independent users
        sideeffect_plate = pyro.plate("sideeffect_latents", self.m, dim= -1) #independent items

        with drug_plate: 
            UA = pyro.sample("UA", dist.Normal(d_alpha, d_beta).to_event(1))
           # UA_int = pyro.sample("UAint", dist.Normal(int_mean, mov_cov).to_event(1))
        with sideeffect_plate: 
            VA = pyro.sample("VA", dist.Normal(s_alpha, s_beta).to_event(1))
    
    def train_SVI(self,train,nsteps=250, lr = 0.05, lrd = 1):
        logging.basicConfig(format='%(message)s', level=logging.INFO)
        svi = SVI(self.model,
        self.guide,
        optim.ClippedAdam({"lr": lr, "lrd": lrd}),
        loss=Trace_ELBO())
        losses = []
        for step in range(nsteps):
            elbo = svi.step(torch.from_numpy(train).float())
            losses.append(elbo)
            if step % 10 == 0:
                print("Elbo loss: {}".format(elbo))
        self.losses = losses
        #constrained_params = list(pyro.get_param_store().values())
        #PARAMS = [p.unconstrained() for p in constrained_params]
        #print(PARAMS)
        return losses
    
    def sample_predict(self, nsamples=500 , verbose=True):
    
        predictive_svi = Predictive(self.model, guide=self.guide, num_samples=nsamples)(None )
        if (verbose):
            for k, v in predictive_svi.items():
                print(f"{k}: {tuple(v.shape)}")
        table = predictive_svi["target"].numpy()
        print(predictive_svi["UA"].numpy())
        print(table)
        self.returned = table
        mc_table = table.mean(axis = 0)
        mc_table_std = table.std(axis = 0)
        mc_table[mc_table < self.bounds[1]] = self.bounds[0]
        mc_table[mc_table >= self.bounds[1]] = self.bounds[1]
        self.predictions = mc_table
        
    
    def rmse(self,test):
        low, high = self.bounds
        test_data = test.copy()
        test_data[test_data < high] = low
        test_data[test_data >= high] = high
        sqerror = abs(test_data - self.predictions) ** 2  # squared error array
        mse = sqerror.sum()/(test_data.shape[0]*test_data.shape[1])
        print("PMF MAP training RMSE: %.5f" % np.sqrt(mse))
        fpr, tpr, thresholds = metrics.roc_curve(test_data.astype(int).flatten(),  self.predictions.astype(int).flatten(), pos_label=1)
        metrics.auc(fpr, tpr)
        print("AUC: %.5f" % metrics.auc(fpr, tpr))
        return np.sqrt(mse) , metrics.auc(fpr, tpr)

    def get_predictions(self):
        return (self.returned,self.predictions)

    
   
       

In [15]:
test = PM_test_normals(train=data, dim=100)
test.train_SVI(data)


Elbo loss: 591432476.8125
Elbo loss: 239103504.328125
Elbo loss: 116336450.234375
Elbo loss: 71069838.8984375
Elbo loss: 54330414.9921875
Elbo loss: 48362075.265625
Elbo loss: 45127749.2421875
Elbo loss: 43086197.15625
Elbo loss: 41732170.1171875
Elbo loss: 40314587.0625
Elbo loss: 39100278.40234375
Elbo loss: 37828989.0625
Elbo loss: 36709630.87890625
Elbo loss: 35421977.07421875
Elbo loss: 34185446.689453125
Elbo loss: 32923683.490234375
Elbo loss: 31648543.49609375
Elbo loss: 30293576.508789062
Elbo loss: 28964532.186523438
Elbo loss: 27587156.774902344
Elbo loss: 26430231.645263672
Elbo loss: 25193279.890625
Elbo loss: 24073525.458007812
Elbo loss: 23086155.828125
Elbo loss: 22172239.625


[591432476.8125,
 538834471.9375,
 488149915.3125,
 445523456.9765625,
 405506890.4765625,
 370707173.2578125,
 339093177.0390625,
 309681960.4375,
 284761439.375,
 260882675.9921875,
 239103504.328125,
 220628402.0234375,
 203509699.3671875,
 188415758.09375,
 173408269.0859375,
 161949297.359375,
 151487045.09375,
 140488381.3671875,
 131264792.8125,
 123462614.265625,
 116336450.234375,
 109566729.765625,
 102879449.4140625,
 97921070.40625,
 92545796.765625,
 87987672.2734375,
 83922075.40625,
 79997450.0625,
 76707045.796875,
 73702940.015625,
 71069838.8984375,
 68147953.3828125,
 66099763.75,
 64343995.2890625,
 62471817.765625,
 60501447.90625,
 59014716.46875,
 57949754.40625,
 56490802.9140625,
 55351322.5625,
 54330414.9921875,
 53132290.3828125,
 52497240.875,
 51740164.78125,
 51044806.859375,
 50319237.0859375,
 49870891.40625,
 49192331.0859375,
 48854842.734375,
 48421415.71875,
 48362075.265625,
 47505831.703125,
 47250801.5,
 47188658.3671875,
 46754873.3515625,
 4653

In [16]:
test.sample_predict(500)


UA: (500, 1, 1127, 100)
VA: (500, 1, 5237, 100)
target: (500, 1127, 5237)
[[[[ 6.40348867e-02 -4.75348890e-01 -7.94213340e-02 ... -2.59275168e-01
    -1.20288119e-01 -1.15572318e-01]
   [ 1.50117487e-01  1.90185070e-01  1.92879945e-01 ... -8.30131471e-02
     1.13322705e-01 -1.53372899e-01]
   [-1.21509857e-01 -1.11562379e-01  1.96243718e-01 ...  1.07064366e-01
     8.40026513e-02 -3.93198356e-02]
   ...
   [-6.50865138e-02 -2.03364015e-01  1.97081864e-01 ...  2.16839805e-01
    -1.65045746e-02 -1.98119804e-01]
   [ 9.77682173e-02  9.53350216e-02 -5.48548065e-03 ...  1.16732687e-01
     2.93995813e-03  3.96088883e-02]
   [-6.41719177e-02  1.12862371e-01 -4.03884985e-02 ... -5.06552532e-02
     1.90641165e-01 -8.10519084e-02]]]


 [[[ 9.96897370e-03  1.70645386e-01 -4.58481945e-02 ...  1.15541600e-01
     3.68814558e-01 -1.71744451e-01]
   [ 3.07305992e-01 -9.03509706e-02 -6.04870729e-05 ...  2.50956416e-02
    -1.87679842e-01  9.78012979e-02]
   [ 5.49563289e-01  2.13760942e-01 -9.4285

In [17]:
test.rmse(data)
print(test.get_predictions())

PMF MAP training RMSE: 0.46792
AUC: 0.80388
(array([[[ 0.,  2.,  0., ...,  1.,  4.,  2.],
        [ 4.,  1.,  1., ...,  3.,  7.,  9.],
        [ 1.,  0.,  0., ...,  0.,  7.,  6.],
        ...,
        [ 7.,  0.,  0., ...,  2., 26.,  0.],
        [ 1.,  1.,  0., ...,  9., 20.,  5.],
        [ 3.,  0.,  0., ...,  1.,  1.,  0.]],

       [[ 5.,  1.,  0., ...,  1.,  0.,  1.],
        [ 2.,  1.,  0., ...,  1., 19.,  5.],
        [ 2.,  1.,  1., ...,  2.,  5.,  5.],
        ...,
        [12.,  1.,  0., ...,  4., 32.,  2.],
        [12.,  1.,  0., ...,  8.,  8.,  6.],
        [ 1.,  0.,  0., ...,  0.,  3.,  2.]],

       [[ 0.,  0.,  0., ...,  0.,  2.,  0.],
        [ 1.,  1.,  0., ...,  0.,  5.,  5.],
        [ 1.,  0.,  2., ...,  9., 10.,  0.],
        ...,
        [17.,  0.,  0., ..., 14., 20.,  3.],
        [ 6.,  0.,  1., ...,  4., 41.,  3.],
        [ 1.,  0.,  0., ...,  0.,  7.,  3.]],

       ...,

       [[ 2.,  1.,  0., ...,  2.,  3.,  2.],
        [ 0.,  1.,  0., ...,  1.,  5.,  0.

In [4]:

class PMF_clusters_try_LJK_prior_on_covariance(nn.Module):
    #bayesian non parametrics - dirichlet process

    #with multivariate gammas that are "somehow?" related through pyro's dependent dimension setting
    #how to define their covariance?
    def __init__(self, train, dim):
        super().__init__()
        """Build the Probabilistic Matrix Factorization model using pymc3.



        """
        self.dim = dim   
        self.data = train.copy()
        self.n, self.m = self.data.shape
        self.map = None
        self.bounds = (0,1)
        self.losses = None
        self.predictions = None


        # Perform mean value imputation
    
        
        # Low precision reflects uncertainty; prevents overfitting.
        # Set to the mean variance across users and items.
        self.alpha_u = (np.mean(self.data, axis=1).mean())**2 / np.std(self.data, axis=1).mean()
        self.alpha_v = (np.mean(self.data, axis=0).mean())**2 / np.std(self.data, axis=0).mean()

        self.beta_u = (np.mean(self.data, axis=1).mean()) / np.std(self.data, axis=1).mean()
        self.beta_v = (np.mean(self.data, axis=0).mean()) / np.std(self.data, axis=0).mean()
       
        self.bias = self.data.mean()
        self.num_clusters_drugs = 1000
        self.num_clusters_se =  5000
    def mix_weights(self,beta):
            beta1m_cumprod = (1 - beta).cumprod(-1)
            return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)

    def model(self, data):
            alpha = 0.1
            with pyro.plate("beta_drugs_plate", self.num_clusters_drugs-1):
                beta_drugs = pyro.sample("beta_drugs", dist.Beta(1, alpha))

            with pyro.plate("mu_drugs_plate", self.num_clusters_drugs):
                mu_drugs = pyro.sample("mu_drugs", dist.MultivariateNormal(torch.zeros(self.dim), 0.5 * torch.eye(self.dim)))
                theta = pyro.sample("theta", dist.HalfCauchy(torch.ones( self.dim)).to_event(1))
                # Lower cholesky factor of a correlation matrix
                concentration = torch.ones(
                    (),
                )  # Implies a uniform distribution over correlation matrices
                L_omega = pyro.sample("L_omega", dist.LKJCholesky( self.dim, concentration))
            # Lower cholesky factor of the covariance matrix
                
    # For inference with SVI, one might prefer to use torch.bmm(theta.sqrt().diag_embed(), L_omega)
            with pyro.plate("data_drugs",  self.n):
                z_d = pyro.sample("z_drugs", dist.Categorical( self.mix_weights(beta_drugs)))
                L_Omega = torch.bmm(theta[z_d].sqrt().diag_embed(), L_omega[z_d])# torch.mm(torch.diag(theta.sqrt()), L_omega)
                UA = pyro.sample("UA", dist.MultivariateNormal(mu_drugs[z_d], scale_tril=L_Omega))

            with pyro.plate("beta_se_plate", self.num_clusters_se-1):
                beta_se = pyro.sample("beta_se", dist.Beta(1, alpha))

            with pyro.plate("mu_plate_se", self.num_clusters_se):
                mu_se = pyro.sample("mu_se", dist.MultivariateNormal(torch.zeros(self.dim), 0.5 * torch.eye(self.dim)))

            with pyro.plate("data_sideeffects", self.m):
                z_se = pyro.sample("z_se", dist.Categorical(self.mix_weights(beta_se)))
                VA = pyro.sample("VA", dist.MultivariateNormal(mu_se[z_se], torch.eye(self.dim)))
            
            u2_plate = pyro.plate("u2_plate", self.n, dim=-2)
            se2_plate = pyro.plate("se2_plate", self.m, dim=-1)

            with se2_plate, u2_plate: 
                Y = pyro.sample("target", dist.Poisson(torch.abs(UA@VA.T)), obs=data ) 
                return Y

    def guide(self,data=None):
            kappa = pyro.param('kappa_d', lambda: dist.Uniform(0, 2).sample([self.num_clusters_drugs-1]), constraint=constraints.positive)
            tau = pyro.param('tau_d', lambda: dist.MultivariateNormal(torch.zeros(self.dim), 0.5 * torch.eye(self.dim)).sample([self.num_clusters_drugs]))
            phi = pyro.param('phi_d', lambda: dist.Dirichlet(1/self.num_clusters_drugs * torch.ones(self.num_clusters_drugs)).sample([self.n]), constraint=constraints.simplex)
            cov1 =  pyro.param('cov1', 10*torch.ones(self.num_clusters_drugs,self.dim), constraint=constraints.positive)

            with pyro.plate("beta_plate", self.num_clusters_drugs-1):
                beta_drugs = pyro.sample("beta_drugs", dist.Beta(torch.ones(self.num_clusters_drugs-1), kappa))

            with pyro.plate("mu_plate_drug", self.num_clusters_drugs):
                mu_drugs = pyro.sample("mu_drugs", dist.MultivariateNormal(tau, torch.eye(self.dim)))
                theta = pyro.sample("theta", dist.HalfCauchy(cov1).to_event(1))
            # Lower cholesky factor of a correlation matrix
                concentration = torch.ones(
                (),
            )  # Implies a uniform distribution over correlation matrices
                L_omega = pyro.sample("L_omega", dist.LKJCholesky( self.dim, concentration))
            with pyro.plate("data_drug", self.n):
                z_d = pyro.sample("z_drugs", dist.Categorical(phi))
                L_Omega = torch.bmm(theta[z_d].sqrt().diag_embed(), L_omega[z_d])

                UA = pyro.sample("UA", dist.MultivariateNormal(mu_drugs[z_d],scale_tril= L_Omega))

            
            kappa_s = pyro.param('kappa_s', lambda: dist.Uniform(0, 2).sample([self.num_clusters_se-1]), constraint=constraints.positive)
            tau_s = pyro.param('tau_s', lambda: dist.MultivariateNormal(torch.zeros(self.dim), 0.5 * torch.eye(self.dim)).sample([self.num_clusters_se]))
            phi_s = pyro.param('phi_s', lambda: dist.Dirichlet(1/self.num_clusters_se * torch.ones(self.num_clusters_se)).sample([self.m]), constraint=constraints.simplex)

            with pyro.plate("beta_se_plate", self.num_clusters_se-1):
                beta_se = pyro.sample("beta_se", dist.Beta(torch.ones(self.num_clusters_se-1), kappa_s))

            with pyro.plate("mu_plate_se", self.num_clusters_se):
                mu_se = pyro.sample("mu_se", dist.MultivariateNormal(tau_s, torch.eye(self.dim)))

            with pyro.plate("data_sideeffects", self.m):
                z_se = pyro.sample("z_se", dist.Categorical(phi_s))
                VA = pyro.sample("VA", dist.MultivariateNormal(mu_se[z_se], torch.eye(self.dim)))

    def train_SVI(self,train, nsteps=250, lr = 0.05, lrd = 1):
                logging.basicConfig(format='%(message)s', level=logging.INFO)
                svi = SVI(self.model,
                self.guide,
                optim.ClippedAdam({"lr": lr, "lrd": lrd}),
                loss=Trace_ELBO())
                losses = []
                for step in range(nsteps):
                    elbo = svi.step(torch.from_numpy(train).float())
                    losses.append(elbo)
                    if step % 10 == 0:
                        print("Elbo loss: {}".format(elbo))
                self.losses = losses
                #constrained_params = list(pyro.get_param_store().values())
                #PARAMS = [p.unconstrained() for p in constrained_params]
                #print(PARAMS)
                return losses

In [5]:
with open('data_all.pickle', 'rb') as handle:
    data = pickle.load(handle)

test = PMF_clusters2(train=data, dim=99)
test.train_SVI(data)



Elbo loss: 522898932.63165283
Elbo loss: 472267225.5957184
Elbo loss: 394533633.65270996
Elbo loss: 369998733.354126
Elbo loss: 336892462.5012207
Elbo loss: 311479740.1257324
Elbo loss: 287125368.5058594
Elbo loss: 270561637.4732666
Elbo loss: 255268516.51623535
Elbo loss: 254280572.35302734


KeyboardInterrupt: 

In [1]:

predictive_svi = Predictive(test.model, guide=test.guide, num_samples=700)(None )
for k, v in predictive_svi.items():
    print(f"{k}: {tuple(v.shape)}")
table = predictive_svi["target"].numpy()
print(table)

NameError: name 'Predictive' is not defined