In [1]:
import logging

import pickle

import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from torch.distributions import constraints
from torch import nn
import pyro
import pyro.distributions as dist
import pyro.optim as optim
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.infer import Predictive
import seaborn as sns
from pyro import poutine
from sklearn import metrics

In [2]:
pyro.set_rng_seed(10)

In [3]:

with open('data_all.pickle', 'rb') as handle:
    data = pickle.load(handle)
print(data.shape)

nan_mask = np.isnan(data) #when calculating the train/test set to "nan" all the examples that are for testing so that you do not train on them 
print(torch.from_numpy(nan_mask) )

(1127, 5237)
tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])


In [11]:
class PMF_zero_NB(nn.Module):

    # by default our latent space is 50-dimensional
    # and we use 400 hidden units
    def __init__(self, train, dim):
        super().__init__()
        """Build the Probabilistic Matrix Factorization model using pymc3.



        """
        self.dim = dim   
        self.data = train.copy()
        self.n, self.m = self.data.shape
        self.map = None
        self.bounds = (0,1)
        self.losses = None
        self.predictions = None
        self.returned = None


        # Perform mean value imputation
    
        
        # Low precision reflects uncertainty; prevents overfitting.
        # Set to the mean variance across users and items.
        self.alpha_u = (np.mean(self.data, axis=1).mean())**2 / np.std(self.data, axis=1).mean()
        self.alpha_v = (np.mean(self.data, axis=0).mean())**2 / np.std(self.data, axis=0).mean()
        
        self.beta_u = (np.mean(self.data, axis=1).mean()) / np.std(self.data, axis=1).mean()
        self.beta_v = (np.mean(self.data, axis=0).mean()) / np.std(self.data, axis=0).mean()
        self.bias = self.data.mean()
        self.alpha = 1


    def model(self, train, mask):
        alpha = 1
        beta = 1

        drug_plate = pyro.plate("drug_latents", self.n, dim= -1) #independent users
        sideeffect_plate = pyro.plate("sideeffect_latents", self.m, dim= -1) #independent items

        with drug_plate: 
            UA = pyro.sample("UA", dist.Gamma(self.alpha_u, self.beta_u).expand([self.dim]).to_event(1))
            #alpha = pyro.sample("alpha", dist.Poisson(self.alpha))
            #tendacy of people not reporting side effects
            p = pyro.sample("p", dist.Beta(alpha, beta))
        
        with sideeffect_plate:
            VA = pyro.sample("VA", dist.Gamma(self.alpha_v, self.beta_v).expand([self.dim]).to_event(1))
            #possibly add intercepts VA_int = pyro.sample("VA", dist.Normal(0., 1.).to_event(1))
       
        u2_plate = pyro.plate("u2_plate", self.n, dim=-2)

        with sideeffect_plate, u2_plate: 
           # with pyro.poutine.mask(mask=mask):
             Y = pyro.sample("target", dist.ZeroInflatedDistribution( base_dist= dist.NegativeBinomial(alpha, UA@VA.T/( UA@VA.T+alpha)) ,gate = p[:, np.newaxis]), obs=train ) 
             return Y
        

    def guide(self, train=None, mask=None):

        d_alpha = pyro.param('d_alpha', torch.ones(self.n,self.dim), constraint=constraints.positive)#*self.user_mean)
        d_beta = pyro.param('d_beta', 0.5*torch.ones(self.n,self.dim), constraint=constraints.positive)
        rate_alpha = pyro.param('rate_alpha', torch.ones(self.n), constraint=constraints.positive)
        rate_beta = pyro.param('rate_beta', torch.ones(self.n), constraint=constraints.positive)


        s_alpha = pyro.param('s_alpha', torch.ones(self.m,self.dim), constraint=constraints.positive)#*self.item_mean)
        s_beta = pyro.param('s_beta', 0.5*torch.ones(self.m,self.dim), constraint=constraints.positive)
        drug_plate = pyro.plate("drug_latents", self.n, dim= -1) #independent users
        sideeffect_plate = pyro.plate("sideeffect_latents", self.m, dim= -1) #independent items

        with drug_plate: 
            UA = pyro.sample("UA", dist.Gamma(d_alpha, d_beta).to_event(1))
            p = pyro.sample("p", dist.Beta(rate_beta,rate_alpha))

        with sideeffect_plate: 
            VA = pyro.sample("VA", dist.Gamma(s_alpha, s_beta).to_event(1))
    
    def train_SVI(self,train,mask, nsteps=250, lr = 0.05, lrd = 1):
        logging.basicConfig(format='%(message)s', level=logging.INFO)
        svi = SVI(self.model,
        self.guide,
        optim.ClippedAdam({"lr": lr, "lrd": lrd}),
        loss=Trace_ELBO())
        losses = []
        for step in range(nsteps):
            elbo = svi.step(torch.from_numpy(train).float(), mask)
            losses.append(elbo)
            if step % 10 == 0:
                print("Elbo loss: {}".format(elbo))
        self.losses = losses
        #constrained_params = list(pyro.get_param_store().values())
        #PARAMS = [p.unconstrained() for p in constrained_params]
        #print(PARAMS)
        return losses
    
    def sample_predict(self, nsamples=500 , verbose=True):
        unmasked =torch.ones((self.n,self.m), dtype=torch.bool)
        predictive_svi = Predictive(self.model, guide=self.guide, num_samples=nsamples)(None , unmasked)
        if (verbose):
            for k, v in predictive_svi.items():
                print(f"{k}: {tuple(v.shape)}")
        table = predictive_svi["target"].numpy()
        print(table)
        self.returned = table
        mc_table = table.mean(axis = 0)
        mc_table_std = table.std(axis = 0)
        mc_table[mc_table < self.bounds[1]] = self.bounds[0]
        mc_table[mc_table >= self.bounds[1]] = self.bounds[1]
        self.predictions = mc_table
        
    
    def rmse(self,test):
        low, high = self.bounds
        test_data = test.copy()
        test_data[test_data < high] = low
        test_data[test_data >= high] = high
        sqerror = abs(test_data - self.predictions) ** 2  # squared error array
        mse = sqerror.sum()/(test_data.shape[0]*test_data.shape[1])
        print("PMF MAP training RMSE: %.5f" % np.sqrt(mse))
        fpr, tpr, thresholds = metrics.roc_curve(test_data.astype(int).flatten(),  self.predictions.astype(int).flatten(), pos_label=1)
        metrics.auc(fpr, tpr)
        print("AUC: %.5f" % metrics.auc(fpr, tpr))
        return np.sqrt(mse) , metrics.auc(fpr, tpr)

    def get_predictions(self):
        return (self.returned,self.predictions)

    

In [15]:
nan_mask = np.isnan(data) #when calculating the train/test set to "nan" all the examples that are for testing so that you do not train on them 
print(torch.from_numpy(nan_mask) )
test = PMF_zero_NB(train=data, dim=100)
test.train_SVI(data, ~torch.from_numpy(nan_mask))


tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
Elbo loss: 6012035.456176758
Elbo loss: 6026213.563171387
Elbo loss: 5999806.37689209
Elbo loss: 6007662.76953125
Elbo loss: 5995980.518127441
Elbo loss: 5969965.423828125
Elbo loss: 6014862.355224609
Elbo loss: 5978878.61340332
Elbo loss: 5958227.224975586
Elbo loss: 5949028.366088867
Elbo loss: 5967524.559936523
Elbo loss: 5937743.723571777
Elbo loss: 5964058.8966674805
Elbo loss: 5960032.478393555
Elbo loss: 5957602.663635254
Elbo loss: 5935067.991943359
Elbo loss: 5938172.410766602
Elbo loss: 5915344.634277344
Elbo loss: 5933078.581726074
Elbo loss: 5929822.35534668
Elbo loss: 5946331.634033203
Elbo loss: 5940296.6384887695
El

[6012035.456176758,
 6012342.945617676,
 6004319.902587891,
 6017493.768127441,
 6034775.2158203125,
 6007978.318664551,
 6030311.8955078125,
 5986345.342956543,
 6015944.429016113,
 6018410.323974609,
 6026213.563171387,
 6028207.652770996,
 6013876.738891602,
 6020635.438964844,
 6039289.8212890625,
 5998328.465454102,
 6028851.170043945,
 5993492.186157227,
 6026789.848510742,
 6003249.84185791,
 5999806.37689209,
 6022679.270263672,
 6039606.224121094,
 5974590.749328613,
 5992812.397521973,
 6003235.42980957,
 6026491.765380859,
 6026796.878845215,
 6000504.59576416,
 5976955.481811523,
 6007662.76953125,
 5970413.448547363,
 5988363.527587891,
 5985468.2060546875,
 5962649.883850098,
 6008944.888000488,
 6005564.186889648,
 5974745.263000488,
 5984125.41796875,
 6015425.143432617,
 5995980.518127441,
 5986274.057678223,
 5970616.791931152,
 5985542.175231934,
 5971055.80456543,
 5971207.098693848,
 5976890.8572387695,
 5988524.072509766,
 6006659.8359375,
 5955468.06842041,
 5969

In [16]:
test.sample_predict(1000)

UA: (1000, 1, 1127, 100)
p: (1000, 1, 1127)
VA: (1000, 1, 5237, 100)
target: (1000, 1127, 5237)
[[[ 0.  0.  0. ...  0.  3.  0.]
  [ 0.  0.  0. ...  0.  0.  0.]
  [ 0.  0.  0. ...  0.  0.  0.]
  ...
  [25.  0.  1. ...  5. 14.  1.]
  [ 0.  0.  0. ...  1.  6.  1.]
  [ 0.  0.  0. ...  0.  0.  0.]]

 [[ 0.  1.  0. ...  0.  2.  0.]
  [ 0.  0.  0. ...  0.  0.  0.]
  [ 2.  0.  0. ...  0.  0.  0.]
  ...
  [ 3.  0.  0. ... 13.  7.  0.]
  [ 0.  0.  0. ...  1.  3.  0.]
  [ 0.  0.  0. ...  0.  0.  0.]]

 [[ 2.  1.  0. ...  0.  0.  0.]
  [ 0.  0.  0. ...  0.  0.  0.]
  [ 0.  1.  0. ...  7.  0.  0.]
  ...
  [ 0.  2.  0. ... 10. 49.  0.]
  [12.  0.  0. ...  3. 39.  0.]
  [ 0.  0.  0. ...  0.  0.  0.]]

 ...

 [[ 0.  0.  0. ...  0.  0.  0.]
  [ 0.  0.  0. ...  0.  0.  0.]
  [ 2.  1.  1. ...  2.  7.  0.]
  ...
  [ 0.  0.  0. ... 10.  0.  0.]
  [ 2.  0.  0. ...  0. 13.  0.]
  [ 0.  0.  0. ...  0.  0.  0.]]

 [[ 0.  0.  0. ...  0.  0.  1.]
  [ 0.  0.  0. ...  0.  0.  0.]
  [ 0.  0.  0. ...  0.  0.  1.]
  

In [17]:
test.rmse(data)
print(test.get_predictions())
print(data)

PMF MAP training RMSE: 0.34168
AUC: 0.83550
(array([[[ 0.,  0.,  0., ...,  0.,  3.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        ...,
        [25.,  0.,  1., ...,  5., 14.,  1.],
        [ 0.,  0.,  0., ...,  1.,  6.,  1.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.]],

       [[ 0.,  1.,  0., ...,  0.,  2.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 2.,  0.,  0., ...,  0.,  0.,  0.],
        ...,
        [ 3.,  0.,  0., ..., 13.,  7.,  0.],
        [ 0.,  0.,  0., ...,  1.,  3.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.]],

       [[ 2.,  1.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  1.,  0., ...,  7.,  0.,  0.],
        ...,
        [ 0.,  2.,  0., ..., 10., 49.,  0.],
        [12.,  0.,  0., ...,  3., 39.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.]],

       ...,

       [[ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.