In [1]:
import logging

import pickle

import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from torch.distributions import constraints
from torch import nn
import pyro
import pyro.distributions as dist
import pyro.optim as optim
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.infer import Predictive
import seaborn as sns
import torch.nn.functional as F
from pyro import poutine
from sklearn import metrics
from pyro.infer.autoguide import AutoMultivariateNormal, AutoLowRankMultivariateNormal, init_to_mean,init_to_feasible,AutoNormal

In [2]:
pyro.set_rng_seed(10)
with open('data_all.pickle', 'rb') as handle:
    data = pickle.load(handle)
print(data.shape)

(1127, 5237)


In [3]:

class PMF_clusters(nn.Module):
    #bayesian non parametrics - dirichlet process

    #with multivariate gammas that are "somehow?" related through pyro's dependent dimension setting
    #how to define their covariance?
    def __init__(self, train, dim):
        super().__init__()
        """Build the Probabilistic Matrix Factorization model using pymc3.



        """
        self.dim = dim   
        self.data = train.copy()
        self.n, self.m = self.data.shape
        self.map = None
        self.bounds = (0,1)
        self.losses = None
        self.predictions = None


        # Perform mean value imputation
    
        
        # Low precision reflects uncertainty; prevents overfitting.
        # Set to the mean variance across users and items.
        self.alpha_u = (np.mean(self.data, axis=1).mean())**2 / np.std(self.data, axis=1).mean()
        self.alpha_v = (np.mean(self.data, axis=0).mean())**2 / np.std(self.data, axis=0).mean()

        self.beta_u = (np.mean(self.data, axis=1).mean()) / np.std(self.data, axis=1).mean()
        self.beta_v = (np.mean(self.data, axis=0).mean()) / np.std(self.data, axis=0).mean()
       
        self.bias = self.data.mean()
        self.num_clusters_drugs = 5
        self.num_clusters_se =  10
    def mix_weights(self,beta):
            beta1m_cumprod = (1 - beta).cumprod(-1)
            return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)

    def model(self, data):
            alpha = 0.1
            with pyro.plate("beta_drugs_plate", self.num_clusters_drugs-1):
                beta_drugs = pyro.sample("beta_drugs", dist.Beta(1, alpha))

            with pyro.plate("mu_drugs_plate", self.num_clusters_drugs):
                mu_drugs = pyro.sample("mu_drugs", dist.MultivariateNormal(torch.zeros(self.dim), 0.5 * torch.eye(self.dim)))

            with pyro.plate("data_drugs", self.n):
                z_d = pyro.sample("z_drugs", dist.Categorical(self.mix_weights(beta_drugs)))
                UA = pyro.sample("UA", dist.MultivariateNormal(mu_drugs[z_d], torch.eye(self.dim)))

            with pyro.plate("beta_se_plate", self.num_clusters_se-1):
                beta_se = pyro.sample("beta_se", dist.Beta(1, alpha))

            with pyro.plate("mu_plate_se", self.num_clusters_se):
                mu_se = pyro.sample("mu_se", dist.MultivariateNormal(torch.zeros(self.dim), 0.5 * torch.eye(self.dim)))

            with pyro.plate("data_sideeffects", self.m):
                z_se = pyro.sample("z_se", dist.Categorical(self.mix_weights(beta_se)))
                VA = pyro.sample("VA", dist.MultivariateNormal(mu_se[z_se], torch.eye(self.dim)))
            
            u2_plate = pyro.plate("u2_plate", self.n, dim=-2)
            se2_plate = pyro.plate("se2_plate", self.m, dim=-1)

            with se2_plate, u2_plate: 
                Y = pyro.sample("target", dist.Poisson(torch.abs(UA@VA.T)), obs=data ) 
                return Y

    def guide(self,data=None):
            kappa = pyro.param('kappa_d', lambda: dist.Uniform(0, 2).sample([self.num_clusters_drugs-1]), constraint=constraints.positive)
            tau = pyro.param('tau_d', lambda: dist.MultivariateNormal(torch.zeros(self.dim), 0.5 * torch.eye(self.dim)).sample([self.num_clusters_drugs]))
            phi = pyro.param('phi_d', lambda: dist.Dirichlet(1/self.num_clusters_drugs * torch.ones(self.num_clusters_drugs)).sample([self.n]), constraint=constraints.simplex)

            with pyro.plate("beta_plate", self.num_clusters_drugs-1):
                beta_drugs = pyro.sample("beta_drugs", dist.Beta(torch.ones(self.num_clusters_drugs-1), kappa))

            with pyro.plate("mu_plate_drug", self.num_clusters_drugs):
                mu_drugs = pyro.sample("mu_drugs", dist.MultivariateNormal(tau, torch.eye(self.dim)))

            with pyro.plate("data_drug", self.n):
                z_d = pyro.sample("z_drugs", dist.Categorical(phi))
                UA = pyro.sample("UA", dist.MultivariateNormal(mu_drugs[z_d], torch.eye(self.dim)))

            
            kappa_s = pyro.param('kappa_s', lambda: dist.Uniform(0, 2).sample([self.num_clusters_se-1]), constraint=constraints.positive)
            tau_s = pyro.param('tau_s', lambda: dist.MultivariateNormal(torch.zeros(self.dim), 0.5 * torch.eye(self.dim)).sample([self.num_clusters_se]))
            phi_s = pyro.param('phi_s', lambda: dist.Dirichlet(1/self.num_clusters_se * torch.ones(self.num_clusters_se)).sample([self.m]), constraint=constraints.simplex)

            with pyro.plate("beta_se_plate", self.num_clusters_se-1):
                beta_se = pyro.sample("beta_se", dist.Beta(torch.ones(self.num_clusters_se-1), kappa_s))

            with pyro.plate("mu_plate_se", self.num_clusters_se):
                mu_se = pyro.sample("mu_se", dist.MultivariateNormal(tau_s, torch.eye(self.dim)))

            with pyro.plate("data_sideeffects", self.m):
                z_se = pyro.sample("z_se", dist.Categorical(phi_s))
                VA = pyro.sample("VA", dist.MultivariateNormal(mu_se[z_se], torch.eye(self.dim)))

    def train_SVI(self,train, nsteps=250, lr = 0.05, lrd = 1):
                logging.basicConfig(format='%(message)s', level=logging.INFO)
                svi = SVI(self.model,
                self.guide,
                optim.ClippedAdam({"lr": lr, "lrd": lrd}),
                loss=Trace_ELBO())
                losses = []
                for step in range(nsteps):
                    elbo = svi.step(torch.from_numpy(train).float())
                    losses.append(elbo)
                    if step % 10 == 0:
                        print("Elbo loss: {}".format(elbo))
                self.losses = losses
                #constrained_params = list(pyro.get_param_store().values())
                #PARAMS = [p.unconstrained() for p in constrained_params]
                #print(PARAMS)
                return losses

In [4]:
with open('data_all.pickle', 'rb') as handle:
    data = pickle.load(handle)

test = PMF_clusters(train=data, dim=100)
test.train_SVI(data)



Elbo loss: 170333202.72097492
Elbo loss: 171507606.7302152
Elbo loss: 165512936.2486418
Elbo loss: 169319632.4801597
Elbo loss: 162791561.27348012
Elbo loss: 164676929.1861047
Elbo loss: 164732843.07234585
Elbo loss: 162971541.6796273
Elbo loss: 164328151.93174267
Elbo loss: 159778966.43127453
Elbo loss: 161964433.89461613
Elbo loss: 158638836.18453467
Elbo loss: 161577749.12965137
Elbo loss: 162302984.18946004
Elbo loss: 158828284.97283804
Elbo loss: 163584730.48053694
Elbo loss: 163630120.66321015
Elbo loss: 158342146.59430647
Elbo loss: 166061768.09017003
Elbo loss: 162117679.09715545
Elbo loss: 160387243.80076563
Elbo loss: 164987436.0388283
Elbo loss: 157676786.8433683
Elbo loss: 160296463.0647326
Elbo loss: 163174325.3161726


[170333202.72097492,
 162292778.97869027,
 174552258.24824154,
 167429912.5915593,
 175191001.1995377,
 173403596.81074893,
 172334883.4760805,
 160031743.01762122,
 168377044.31872618,
 169331772.11819598,
 171507606.7302152,
 165704237.78437135,
 169407054.28472775,
 168554952.51679182,
 165937697.5300961,
 176425321.42185506,
 169783934.69481742,
 174450578.8925225,
 170786236.03162226,
 166968625.76068553,
 165512936.2486418,
 171885907.66150492,
 163180210.55908933,
 165949605.9802571,
 164210752.69084615,
 166332874.15098742,
 169706746.3094123,
 169146982.28950354,
 171235492.6609692,
 168086929.63003618,
 169319632.4801597,
 165892642.65986976,
 164604346.3941257,
 171689093.3460443,
 164034149.39280748,
 168924616.91286182,
 173896953.26908392,
 161910000.63772595,
 168543931.66266608,
 164839965.51895392,
 162791561.27348012,
 167677793.66685736,
 163397423.86693367,
 164430262.54720592,
 164264145.5830592,
 164890951.30644602,
 163418107.1546222,
 163779939.88256,
 165549376

In [5]:

predictive_svi = Predictive(test.model, guide=test.guide, num_samples=700)(None )
for k, v in predictive_svi.items():
    print(f"{k}: {tuple(v.shape)}")
table = predictive_svi["target"].numpy()
print(table)


beta_drugs: (700, 1, 4)
mu_drugs: (700, 1, 5, 100)
z_drugs: (700, 1, 1127)
UA: (700, 1, 1127, 100)
beta_se: (700, 1, 9)
mu_se: (700, 1, 10, 100)
z_se: (700, 1, 5237)
VA: (700, 1, 5237, 100)
target: (700, 1127, 5237)
[[[24. 14. 38. ...  4. 35. 27.]
  [20. 27.  8. ... 10. 14.  9.]
  [ 2.  3. 11. ...  0.  8. 16.]
  ...
  [ 1.  1.  4. ... 11.  4. 12.]
  [35. 22. 30. ... 38. 15. 12.]
  [26. 27.  0. ... 41. 13. 12.]]

 [[28. 44. 14. ... 13. 23. 24.]
  [29.  2.  3. ... 16. 13. 59.]
  [21.  2. 22. ... 14. 36.  3.]
  ...
  [13.  3. 26. ... 28.  4. 34.]
  [ 8. 35. 34. ... 26. 24. 19.]
  [28.  0. 69. ...  2. 14. 22.]]

 [[ 5. 12.  6. ...  8. 27.  4.]
  [26. 52.  0. ...  9. 58. 30.]
  [26.  3.  5. ...  5.  8. 29.]
  ...
  [ 0. 13.  2. ... 10.  0. 35.]
  [ 1. 11.  5. ...  4.  7. 43.]
  [20.  2. 11. ... 15. 23. 23.]]

 ...

 [[ 4. 34. 58. ... 47.  7. 30.]
  [ 6.  3. 25. ... 44. 49.  0.]
  [ 9.  1.  2. ... 45. 15. 43.]
  ...
  [ 5.  7.  8. ...  4. 15. 22.]
  [15.  5. 43. ... 24.  6.  8.]
  [ 0. 24. 1

In [14]:
def model(self, data):
            alpha = 0.1
            with pyro.plate("beta_drugs_plate", self.num_clusters_drugs-1):
                beta_drugs = pyro.sample("beta_drugs", dist.Beta(1, alpha))

            with pyro.plate("mu_drugs_plate", self.num_clusters_drugs):
                mu_drugs = pyro.sample("mu_drugs", dist.MultivariateNormal(torch.zeros(self.dim), 0.5 * torch.eye(self.dim)))

            with pyro.plate("data_drugs", self.n):
                z_d = pyro.sample("z_drugs", dist.Categorical(self.mix_weights(beta_drugs)))
                UA = pyro.sample("UA", dist.MultivariateNormal(mu_drugs[z_d], torch.eye(self.dim)))

            with pyro.plate("beta_se_plate", self.num_clusters_se-1):
                beta_se = pyro.sample("beta_se", dist.Beta(1, alpha))

            with pyro.plate("mu_plate_se", self.num_clusters_se):
                mu_se = pyro.sample("mu_se", dist.MultivariateNormal(torch.zeros(self.dim), 0.5 * torch.eye(self.dim)))

            with pyro.plate("data_sideeffects", self.m):
                z_se = pyro.sample("z_se", dist.Categorical(self.mix_weights(beta_se)))
                VA = pyro.sample("VA", dist.MultivariateNormal(mu_se[z_se], torch.eye(self.dim)))
            
            u2_plate = pyro.plate("u2_plate", self.n, dim=-2)
            se2_plate = pyro.plate("se2_plate", self.m, dim=-1)

            with se2_plate, u2_plate: 
                Y = pyro.sample("target", dist.Poisson(torch.abs(UA@VA.T)), obs=data ) 
                return Y

[[ 1  0  0 ...  1  0  0]
 [ 0  0  0 ...  0  0  0]
 [ 1  0  0 ...  1  8  0]
 ...
 [ 8  0  0 ... 10 12  0]
 [ 1  0  0 ...  4 25  0]
 [ 0  0  0 ...  0  0  0]]


In [7]:
class PMF(nn.Module):
    # by default our latent space is 50-dimensional
    # and we use 400 hidden units
    def __init__(self, train, dim):
        super().__init__()
        """Build the Probabilistic Matrix Factorization model using pymc3.



        """
        self.dim = dim   
        self.data = train.copy()
        self.n, self.m = self.data.shape
        self.map = None
        self.bounds = (0,1)
        self.losses = None
        self.predictions = None
        self.returned = None


        # Perform mean value imputation
    
        
        # Low precision reflects uncertainty; prevents overfitting.
        # Set to the mean variance across users and items.
        self.alpha_u = np.mean(self.data, axis=1).mean()
        self.alpha_v =  np.std(self.data, axis=1).mean()
        
        self.beta_u = np.mean(self.data, axis=0).mean() 
        self.beta_v =  np.std(self.data, axis=0).mean()
        self.bias = self.data.mean()


    def model(self, train):

        drug_plate = pyro.plate("drug_latents", self.n, dim= -1) #independent users
        sideeffect_plate = pyro.plate("sideeffect_latents", self.m, dim= -1) #independent items

        with drug_plate: 
            UA = pyro.sample("UA", dist.Normal(self.alpha_u, self.beta_u).expand([self.dim]).to_event(1))
            #UA_int = pyro.sample("UAint", dist.Normal(0., 1.))
        
        with sideeffect_plate:
            VA = pyro.sample("VA", dist.Normal(self.alpha_v, self.beta_v).expand([self.dim]).to_event(1))
            #possibly add intercepts VA_int = pyro.sample("VA", dist.Normal(0., 1.).to_event(1))
       
        u2_plate = pyro.plate("u2_plate", self.n, dim=-2)

        with sideeffect_plate, u2_plate: 
         
             Y = pyro.sample("target", dist.Poisson(torch.abs(UA@VA.T)), obs=train ) 
             return Y
        

    def guide(self, train=None, mask=None):

        d_alpha = pyro.param('d_alpha', torch.ones(self.n,self.dim), constraint=constraints.positive)#*self.user_mean)
        d_beta = pyro.param('d_beta', 0.5*torch.ones(self.n,self.dim), constraint=constraints.positive)
       # int_mean = pyro.param('int_mean', torch.tensor(1.)*self.user_mean)
       # mov_cov = pyro.param('mov_cov', torch.tensor(1.)*0.1,
          #                  constraint=constraints.positive)

        s_alpha = pyro.param('s_alpha', torch.ones(self.m,self.dim), constraint=constraints.positive)#*self.item_mean)
        s_beta = pyro.param('s_beta', 0.5*torch.ones(self.m,self.dim), constraint=constraints.positive)
        drug_plate = pyro.plate("drug_latents", self.n, dim= -1) #independent users
        sideeffect_plate = pyro.plate("sideeffect_latents", self.m, dim= -1) #independent items

        with drug_plate: 
            UA = pyro.sample("UA", dist.Normal(d_alpha, d_beta).to_event(1))
           # UA_int = pyro.sample("UAint", dist.Normal(int_mean, mov_cov).to_event(1))
        with sideeffect_plate: 
            VA = pyro.sample("VA", dist.Normal(s_alpha, s_beta).to_event(1))
    
    def train_SVI(self,train,nsteps=250, lr = 0.05, lrd = 1):
        logging.basicConfig(format='%(message)s', level=logging.INFO)
        svi = SVI(self.model,
        self.guide,
        optim.ClippedAdam({"lr": lr, "lrd": lrd}),
        loss=Trace_ELBO())
        losses = []
        for step in range(nsteps):
            elbo = svi.step(torch.from_numpy(train).float())
            losses.append(elbo)
            if step % 10 == 0:
                print("Elbo loss: {}".format(elbo))
        self.losses = losses
        #constrained_params = list(pyro.get_param_store().values())
        #PARAMS = [p.unconstrained() for p in constrained_params]
        #print(PARAMS)
        return losses
    
    def sample_predict(self, nsamples=500 , verbose=True):
    
        predictive_svi = Predictive(self.model, guide=self.guide, num_samples=nsamples)(None )
        if (verbose):
            for k, v in predictive_svi.items():
                print(f"{k}: {tuple(v.shape)}")
        table = predictive_svi["target"].numpy()
        print(table)
        self.returned = table
        mc_table = table.mean(axis = 0)
        mc_table_std = table.std(axis = 0)
        mc_table[mc_table < self.bounds[1]] = self.bounds[0]
        mc_table[mc_table >= self.bounds[1]] = self.bounds[1]
        self.predictions = mc_table
        
    
    def rmse(self,test):
        low, high = self.bounds
        test_data = test.copy()
        test_data[test_data < high] = low
        test_data[test_data >= high] = high
        sqerror = abs(test_data - self.predictions) ** 2  # squared error array
        mse = sqerror.sum()/(test_data.shape[0]*test_data.shape[1])
        print("PMF MAP training RMSE: %.5f" % np.sqrt(mse))
        fpr, tpr, thresholds = metrics.roc_curve(test_data.astype(int).flatten(),  self.predictions.astype(int).flatten(), pos_label=1)
        metrics.auc(fpr, tpr)
        print("AUC: %.5f" % metrics.auc(fpr, tpr))
        return np.sqrt(mse) , metrics.auc(fpr, tpr)

    def get_predictions(self):
        return (self.returned,self.predictions)

    
   
       

In [14]:
test = PMF(train=data, dim=100)
test.train_SVI(data)


Elbo loss: 11992020.8125
Elbo loss: 12044234.3125
Elbo loss: 11983409.609375
Elbo loss: 11980527.859375
Elbo loss: 11985525.8125
Elbo loss: 11964293.296875
Elbo loss: 11938020.515625
Elbo loss: 11923634.890625
Elbo loss: 11889246.625
Elbo loss: 11873478.9140625
Elbo loss: 11884241.9765625
Elbo loss: 11905652.8515625
Elbo loss: 11870566.46875
Elbo loss: 11875177.05859375
Elbo loss: 11843510.703125
Elbo loss: 11835140.0234375
Elbo loss: 11831297.0078125
Elbo loss: 11780705.17578125
Elbo loss: 11779127.166015625
Elbo loss: 11772117.734375
Elbo loss: 11767533.673828125
Elbo loss: 11775437.494140625
Elbo loss: 11745992.875
Elbo loss: 11768591.358398438
Elbo loss: 11722052.244140625


[11992020.8125,
 12008567.4375,
 12019911.5,
 12021809.34375,
 12042411.921875,
 12043122.7890625,
 12019932.7890625,
 12022563.4765625,
 12047215.4375,
 12036885.328125,
 12044234.3125,
 12045118.453125,
 12025997.5078125,
 12062301.4296875,
 12042643.2578125,
 12005705.15625,
 12015076.265625,
 12016554.7734375,
 12018766.203125,
 12001827.21875,
 11983409.609375,
 12008835.546875,
 12002125.0625,
 12003066.84375,
 11996001.84375,
 11994370.296875,
 12012471.53125,
 12007147.046875,
 11998281.1640625,
 11991081.0,
 11980527.859375,
 11990359.2734375,
 11991332.7734375,
 11987789.7890625,
 11989869.546875,
 12001762.6015625,
 11994243.28125,
 11997264.1953125,
 11981955.1875,
 11967238.015625,
 11985525.8125,
 11970913.109375,
 11988881.0546875,
 11961965.5625,
 11947875.8125,
 11958736.6015625,
 11962320.9921875,
 11963257.1484375,
 11965998.734375,
 11948119.3203125,
 11964293.296875,
 11950115.4453125,
 11963924.0625,
 11937081.0390625,
 11946569.828125,
 11947193.4296875,
 1192643

In [15]:
test.sample_predict(1000)


UA: (1000, 1, 1127, 100)
VA: (1000, 1, 5237, 100)
target: (1000, 1127, 5237)
[[[ 6.  1.  1. ...  0.  0.  0.]
  [ 0.  0.  0. ...  0.  0.  0.]
  [ 0.  0.  0. ...  0.  4.  0.]
  ...
  [ 8.  1.  0. ... 13.  5.  0.]
  [ 2.  0.  0. ...  1. 11.  0.]
  [ 1.  0.  0. ...  0.  0.  0.]]

 [[ 2.  0.  0. ...  1.  0.  4.]
  [ 1.  0.  0. ...  0.  2.  1.]
  [ 0.  0.  0. ...  0. 10.  0.]
  ...
  [10.  0.  0. ...  8.  6.  0.]
  [ 3.  0.  0. ...  1. 22.  0.]
  [ 0.  0.  0. ...  0.  2.  0.]]

 [[ 0.  0.  0. ...  0.  0.  0.]
  [ 2.  0.  1. ...  0.  0.  0.]
  [ 2.  0.  0. ...  0.  2.  2.]
  ...
  [ 5.  0.  0. ... 11. 17.  0.]
  [ 1.  0.  0. ...  2. 18.  0.]
  [ 1.  0.  0. ...  0.  0.  0.]]

 ...

 [[ 5.  1.  2. ...  1.  2.  1.]
  [ 0.  0.  0. ...  0.  1.  0.]
  [ 5.  0.  0. ...  6.  8.  0.]
  ...
  [ 7.  0.  0. ... 12. 13.  0.]
  [ 1.  0.  0. ...  2. 16.  0.]
  [ 0.  0.  1. ...  1.  0.  0.]]

 [[ 1.  0.  0. ...  0.  3.  0.]
  [ 2.  0.  0. ...  1.  3.  1.]
  [ 1.  0.  1. ...  1.  6.  2.]
  ...
  [10.  0.  0. 

In [16]:
test.rmse(data)
print(test.get_predictions())

PMF MAP training RMSE: 0.32909
AUC: 0.83589
(array([[[ 6.,  1.,  1., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  4.,  0.],
        ...,
        [ 8.,  1.,  0., ..., 13.,  5.,  0.],
        [ 2.,  0.,  0., ...,  1., 11.,  0.],
        [ 1.,  0.,  0., ...,  0.,  0.,  0.]],

       [[ 2.,  0.,  0., ...,  1.,  0.,  4.],
        [ 1.,  0.,  0., ...,  0.,  2.,  1.],
        [ 0.,  0.,  0., ...,  0., 10.,  0.],
        ...,
        [10.,  0.,  0., ...,  8.,  6.,  0.],
        [ 3.,  0.,  0., ...,  1., 22.,  0.],
        [ 0.,  0.,  0., ...,  0.,  2.,  0.]],

       [[ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 2.,  0.,  1., ...,  0.,  0.,  0.],
        [ 2.,  0.,  0., ...,  0.,  2.,  2.],
        ...,
        [ 5.,  0.,  0., ..., 11., 17.,  0.],
        [ 1.,  0.,  0., ...,  2., 18.,  0.],
        [ 1.,  0.,  0., ...,  0.,  0.,  0.]],

       ...,

       [[ 5.,  1.,  2., ...,  1.,  2.,  1.],
        [ 0.,  0.,  0., ...,  0.,  1.,  0.