In [2]:
import logging

import pickle

import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from torch.distributions import constraints
from torch import nn
import pyro
import pyro.distributions as dist
import pyro.optim as optim
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.infer import Predictive
import seaborn as sns
from pyro import poutine
from sklearn import metrics

In [3]:
pyro.set_rng_seed(10)

In [4]:
class PMF(nn.Module):
    # by default our latent space is 50-dimensional
    # and we use 400 hidden units
    def __init__(self, train, dim):
        super().__init__()
        """Build the Probabilistic Matrix Factorization model using pymc3.



        """
        self.dim = dim   
        self.data = train.copy()
        self.n, self.m = self.data.shape
        self.map = None
        self.bounds = (0,1)
        self.losses = None
        self.predictions = None
        self.returned = None


        # Perform mean value imputation
    
        
        # Low precision reflects uncertainty; prevents overfitting.
        # Set to the mean variance across users and items.
        self.alpha_u = (np.mean(self.data, axis=1).mean())**2 / np.std(self.data, axis=1).mean()
        self.alpha_v = (np.mean(self.data, axis=0).mean())**2 / np.std(self.data, axis=0).mean()

        self.beta_u = (np.mean(self.data, axis=1).mean()) / np.std(self.data, axis=1).mean()
        self.beta_v = (np.mean(self.data, axis=0).mean()) / np.std(self.data, axis=0).mean()
       
        self.bias = self.data.mean()


    def model(self, train):

        drug_plate = pyro.plate("drug_latents", self.n, dim= -1) #independent users
        sideeffect_plate = pyro.plate("sideeffect_latents", self.m, dim= -1) #independent items

        with drug_plate: 
            UA = pyro.sample("UA", dist.Gamma(self.alpha_u, self.beta_u).expand([self.dim]).to_event(1))
            #UA_int = pyro.sample("UAint", dist.Normal(0., 1.))
        
        with sideeffect_plate:
            VA = pyro.sample("VA", dist.Gamma(self.alpha_v, self.beta_v).expand([self.dim]).to_event(1))
            #possibly add intercepts VA_int = pyro.sample("VA", dist.Normal(0., 1.).to_event(1))
       
        u2_plate = pyro.plate("u2_plate", self.n, dim=-2)

        with sideeffect_plate, u2_plate: 
            Y = pyro.sample("target", dist.Poisson(UA@VA.T), obs=train ) 
            return Y
        

    def guide(self, train=None, mask=None):

        d_alpha = pyro.param('d_alpha', torch.ones(self.n,self.dim), constraint=constraints.positive)#*self.user_mean)
        d_beta = pyro.param('d_beta', 0.5*torch.ones(self.n,self.dim), constraint=constraints.positive)
       # int_mean = pyro.param('int_mean', torch.tensor(1.)*self.user_mean)
       # mov_cov = pyro.param('mov_cov', torch.tensor(1.)*0.1,
          #                  constraint=constraints.positive)

        s_alpha = pyro.param('s_alpha', torch.ones(self.m,self.dim), constraint=constraints.positive)#*self.item_mean)
        s_beta = pyro.param('s_beta', 0.5*torch.ones(self.m,self.dim), constraint=constraints.positive)
        drug_plate = pyro.plate("drug_latents", self.n, dim= -1) #independent users
        sideeffect_plate = pyro.plate("sideeffect_latents", self.m, dim= -1) #independent items

        with drug_plate: 
            UA = pyro.sample("UA", dist.Gamma(d_alpha, d_beta).to_event(1))
           # UA_int = pyro.sample("UAint", dist.Normal(int_mean, mov_cov).to_event(1))
        with sideeffect_plate: 
            VA = pyro.sample("VA", dist.Gamma(s_alpha, s_beta).to_event(1))
    
    def train_SVI(self,train, nsteps=250, lr = 0.05, lrd = 1):
        logging.basicConfig(format='%(message)s', level=logging.INFO)
        svi = SVI(self.model,
        self.guide,
        optim.ClippedAdam({"lr": lr, "lrd": lrd}),
        loss=Trace_ELBO())
        losses = []
        for step in range(nsteps):
            elbo = svi.step(torch.from_numpy(train).float())
            losses.append(elbo)
            if step % 10 == 0:
                print("Elbo loss: {}".format(elbo))
        self.losses = losses
        #constrained_params = list(pyro.get_param_store().values())
        #PARAMS = [p.unconstrained() for p in constrained_params]
        #print(PARAMS)
        return losses
    
    def sample_predict(self, nsamples=500 , verbose=True):
        predictive_svi = Predictive(self.model, guide=self.guide, num_samples=nsamples)( None)
        if (verbose):
            for k, v in predictive_svi.items():
                print(f"{k}: {tuple(v.shape)}")
        table = predictive_svi["target"].numpy()
        print(table)
        self.returned = table
        mc_table = table.mean(axis = 0)
        mc_table_std = table.std(axis = 0)
        mc_table[mc_table < self.bounds[1]] = self.bounds[0]
        mc_table[mc_table >= self.bounds[1]] = self.bounds[1]
        self.predictions = mc_table
        
    
    def rmse(self,test):
        low, high = self.bounds
        test_data = test.copy()
        test_data[test_data < high] = low
        test_data[test_data >= high] = high
        sqerror = abs(test_data - self.predictions) ** 2  # squared error array
        mse = sqerror.sum()/(test_data.shape[0]*test_data.shape[1])
        print("PMF MAP training RMSE: %.5f" % np.sqrt(mse))
        fpr, tpr, thresholds = metrics.roc_curve(test_data.astype(int).flatten(),  self.predictions.astype(int).flatten(), pos_label=1)
        metrics.auc(fpr, tpr)
        print("AUC: %.5f" % metrics.auc(fpr, tpr))
        return np.sqrt(mse) , metrics.auc(fpr, tpr)

    def get_predictions(self):
        return (self.returned,self.predictions)

    
   
       

In [5]:

with open('data_all.pickle', 'rb') as handle:
    data = pickle.load(handle)
print(data.shape)

(1127, 5237)


In [6]:
test = PMF(train=data, dim=100)
test.train_SVI(data)


Elbo loss: 2336314634.109375
Elbo loss: 335497992.1171875
Elbo loss: 88186257.00390625
Elbo loss: 47303800.537109375
Elbo loss: 36057065.96875
Elbo loss: 32626540.515625
Elbo loss: 31529046.171875
Elbo loss: 30088244.25
Elbo loss: 27043836.09375
Elbo loss: 23925259.828125
Elbo loss: 21573661.546875
Elbo loss: 19923352.796875
Elbo loss: 19120000.640625
Elbo loss: 18205186.515625
Elbo loss: 17605883.859375
Elbo loss: 17287293.515625
Elbo loss: 16694075.484375
Elbo loss: 16594215.015625
Elbo loss: 16351819.015625
Elbo loss: 16263272.578125
Elbo loss: 15928297.859375
Elbo loss: 15722939.921875
Elbo loss: 15567714.6875
Elbo loss: 15421303.71875
Elbo loss: 15315426.65625


[2336314634.109375,
 1891504406.84375,
 1541580388.1875,
 1278080593.546875,
 1044348330.421875,
 857057030.40625,
 703238010.921875,
 582897138.8515625,
 480593259.6875,
 400234952.96875,
 335497992.1171875,
 279226896.578125,
 235646270.578125,
 201481908.36132812,
 173939746.76367188,
 151126704.34960938,
 132999739.81835938,
 118129568.57421875,
 106375419.33984375,
 96390129.375,
 88186257.00390625,
 81378336.90625,
 75930065.2109375,
 70438025.33007812,
 65844248.7265625,
 62455479.556640625,
 58810562.779052734,
 55688136.95703125,
 52644793.8515625,
 49810326.36328125,
 47303800.537109375,
 44933465.35546875,
 42958675.328125,
 41306958.78515625,
 39694746.2890625,
 38325722.5390625,
 37455367.51171875,
 36669173.4765625,
 36400280.1015625,
 36218159.6171875,
 36057065.96875,
 35973082.9765625,
 35646414.4453125,
 35419720.0234375,
 34836985.59375,
 34383767.0625,
 33893274.828125,
 33536042.2734375,
 33031913.734375,
 32834952.140625,
 32626540.515625,
 32519938.75,
 32464942.

In [7]:
test.sample_predict(1000)




UA: (1000, 1, 1127, 100)
VA: (1000, 1, 5237, 100)
target: (1000, 1127, 5237)
[[[ 2.  0.  2. ...  0.  2.  1.]
  [ 2.  0.  0. ...  1.  4.  0.]
  [ 1.  0.  0. ...  0.  4.  0.]
  ...
  [ 5.  0.  1. ... 17. 11.  0.]
  [ 2.  0.  0. ...  6. 23.  0.]
  [ 0.  0.  0. ...  1.  0.  0.]]

 [[ 0.  0.  0. ...  0.  2.  0.]
  [ 0.  0.  0. ...  0.  2.  3.]
  [ 1.  0.  0. ...  1.  6.  0.]
  ...
  [14.  0.  0. ... 60. 24.  1.]
  [ 1.  1.  0. ...  3.  5.  0.]
  [ 0.  0.  0. ...  0.  1.  0.]]

 [[ 0.  0.  2. ...  0.  1.  0.]
  [ 0.  0.  0. ...  0.  0.  1.]
  [ 1.  0.  0. ...  1.  2.  0.]
  ...
  [ 5.  0.  0. ... 14. 13.  1.]
  [ 5.  0.  0. ... 13. 17.  0.]
  [ 0.  0.  0. ...  0.  3.  0.]]

 ...

 [[ 1.  0.  0. ...  0.  0.  0.]
  [ 0.  0.  0. ...  0.  1.  0.]
  [ 1.  0.  0. ...  0.  3.  3.]
  ...
  [ 8.  0.  0. ... 19. 12.  0.]
  [ 3.  0.  0. ...  5. 25.  0.]
  [ 0.  0.  0. ...  2.  0.  0.]]

 [[ 1.  0.  0. ...  1.  3.  0.]
  [ 1.  0.  0. ...  0.  0.  0.]
  [ 1.  0.  0. ... 12.  3.  0.]
  ...
  [13.  0.  0. 

In [7]:
test.rmse(data)
print(test.get_predictions())
print(data)

PMF MAP training RMSE: 0.36516
AUC: 0.83883
[[1. 0. 0. ... 0. 1. 0.]
 [0. 0. 0. ... 0. 1. 0.]
 [1. 0. 0. ... 1. 1. 0.]
 ...
 [1. 0. 0. ... 1. 1. 0.]
 [1. 0. 0. ... 1. 1. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[[ 1  0  0 ...  1  0  0]
 [ 0  0  0 ...  0  0  0]
 [ 1  0  0 ...  1  8  0]
 ...
 [ 8  0  0 ... 10 12  0]
 [ 1  0  0 ...  4 25  0]
 [ 0  0  0 ...  0  0  0]]


In [18]:
from sklearn import metrics
low, high = (1,1)
test_data = data.copy()
test_data[test_data < low] = 0
test_data[test_data >= high] = 1
preds = test.get_predictions()
preds[preds<low] = 0
preds[preds>=high] = 1
print(test_data.astype(int))
print(preds.astype(int))
fpr, tpr, thresholds = metrics.roc_curve(test_data.astype(int).flatten(), preds.astype(int).flatten(), pos_label=1)
metrics.auc(fpr, tpr)


[[1 0 0 ... 1 0 0]
 [0 0 0 ... 0 0 0]
 [1 0 0 ... 1 1 0]
 ...
 [1 0 0 ... 1 1 0]
 [1 0 0 ... 1 1 0]
 [0 0 0 ... 0 0 0]]
[[0 0 0 ... 1 1 0]
 [0 0 0 ... 0 0 0]
 [1 0 0 ... 1 1 1]
 ...
 [1 0 0 ... 1 1 0]
 [1 0 0 ... 1 1 0]
 [0 0 0 ... 0 0 0]]


0.8267106141796051

In [4]:
class PMF_intercepts(nn.Module):
    # by default our latent space is 50-dimensional
    # and we use 400 hidden units
    def __init__(self, train, dim):
        super().__init__()
        """Build the Probabilistic Matrix Factorization model using pymc3.



        """
        self.dim = dim   
        self.data = train.copy()
        self.n, self.m = self.data.shape
        self.map = None
        self.bounds = (0,1)
        self.losses = None
        self.predictions = None


        # Perform mean value imputation
    
        
        # Low precision reflects uncertainty; prevents overfitting.
        # Set to the mean variance across users and items.
        self.alpha_u = (np.mean(self.data, axis=1).mean())**2 / np.std(self.data, axis=1).mean()
        self.alpha_v = (np.mean(self.data, axis=0).mean())**2 / np.std(self.data, axis=0).mean()

        self.beta_u = (np.mean(self.data, axis=1).mean()) / np.std(self.data, axis=1).mean()
        self.beta_v = (np.mean(self.data, axis=0).mean()) / np.std(self.data, axis=0).mean()
       
        self.bias = self.data.mean()


    def model(self, train):

        drug_plate = pyro.plate("drug_latents", self.n, dim= -1) #independent users
        sideeffect_plate = pyro.plate("sideeffect_latents", self.m, dim= -1) #independent items

        with drug_plate: 
            UA = pyro.sample("UA", dist.Gamma(self.alpha_u, self.beta_u).expand([self.dim]).to_event(1))
            drug_intercept = pyro.sample("drug_int", dist.HalfNormal(0.5))
            #UA_int = pyro.sample("UAint", dist.Normal(0., 1.))
        
        with sideeffect_plate:
            sideeffect_intercept = pyro.sample("sf_int", dist.HalfNormal(0.5))
            VA = pyro.sample("VA", dist.Gamma(self.alpha_v, self.beta_v).expand([self.dim]).to_event(1))
            #possibly add intercepts VA_int = pyro.sample("VA", dist.Normal(0., 1.).to_event(1))
       
        u2_plate = pyro.plate("u2_plate", self.n, dim=-2)
       #with u2_plate:
           
            
        with sideeffect_plate, u2_plate: 
            #Y = pyro.sample("target", dist.Poisson(UA@VA.T  +sideeffect_intercept.T), obs=train ) z[:, np.newaxis] + x
            Y = pyro.sample("target", dist.Poisson(UA@VA.T  + (drug_intercept[:, np.newaxis] + sideeffect_intercept.T)), obs=train ) 
            return Y
        

    def guide(self, train=None, mask=None):

        d_alpha = pyro.param('d_alpha', torch.ones(self.n,self.dim), constraint=constraints.positive)#*self.user_mean)
        d_beta = pyro.param('d_beta', 0.5*torch.ones(self.n,self.dim), constraint=constraints.positive)
       # int_mean = pyro.param('int_mean', torch.tensor(1.)*self.user_mean)
       # mov_cov = pyro.param('mov_cov', torch.tensor(1.)*0.1,
          #                  constraint=constraints.positive)

        s_alpha = pyro.param('s_alpha', torch.ones(self.m,self.dim), constraint=constraints.positive)#*self.item_mean)
        s_beta = pyro.param('s_beta', 0.5*torch.ones(self.m,self.dim), constraint=constraints.positive)
        drug_plate = pyro.plate("drug_latents", self.n, dim= -1) #independent users
        sideeffect_plate = pyro.plate("sideeffect_latents", self.m, dim= -1) #independent items
        se_t = pyro.param("sef_int", 0.25*torch.ones(self.m), constraint=constraints.positive)
        drug_t = pyro.param("drug_int_p", 0.25*torch.ones(self.n), constraint=constraints.positive)
        u2_plate = pyro.plate("u2_plate", self.n, dim=-2)

       # with u2_plate:
          #  drug_intercept = pyro.sample("drug_int", dist.HalfNormal(drug_t).to_event(1))

        with drug_plate: 
            UA = pyro.sample("UA", dist.Gamma(d_alpha, d_beta).to_event(1))
            UA_int = pyro.sample("UAint", dist.HalfNormal(drug_t))
        with sideeffect_plate: 
            VA = pyro.sample("VA", dist.Gamma(s_alpha, s_beta).to_event(1))
            sideeffect_intercept =  pyro.sample("sf_int", dist.HalfNormal(se_t))
    
    def train_SVI(self,train, nsteps=250, lr = 0.05, lrd = 1):
        logging.basicConfig(format='%(message)s', level=logging.INFO)
        svi = SVI(self.model,
        self.guide,
        optim.ClippedAdam({"lr": lr, "lrd": lrd}),
        loss=Trace_ELBO())
        losses = []
        for step in range(nsteps):
            elbo = svi.step(torch.from_numpy(train).float())
            losses.append(elbo)
            if step % 10 == 0:
                print("Elbo loss: {}".format(elbo))
        self.losses = losses
        #constrained_params = list(pyro.get_param_store().values())
        #PARAMS = [p.unconstrained() for p in constrained_params]
        #print(PARAMS)
        return losses
    
    def sample_predict(self, nsamples=500 , verbose=True):
        predictive_svi = Predictive(self.model, guide=self.guide, num_samples=nsamples)( None)
        if (verbose):
            for k, v in predictive_svi.items():
                print(f"{k}: {tuple(v.shape)}")
        table = predictive_svi["target"].numpy()
        mc_table = table.mean(axis = 0)
        mc_table_std = table.std(axis = 0)
        mc_table[mc_table < self.bounds[1]] = self.bounds[0]
        mc_table[mc_table >= self.bounds[1]] = self.bounds[1]
        self.predictions = mc_table
        
    
    def rmse(self,test):
        low, high = self.bounds
        test_data = test.copy()
        test_data[test_data < high] = low
        test_data[test_data >= high] = high
        sqerror = abs(test_data - self.predictions) ** 2  # squared error array
        mse = sqerror.sum()/(test_data.shape[0]*test_data.shape[1])
        print("PMF MAP training RMSE: %.5f" % np.sqrt(mse))
        fpr, tpr, thresholds = metrics.roc_curve(test_data.astype(int).flatten(),  self.predictions.astype(int).flatten(), pos_label=1)
        metrics.auc(fpr, tpr)
        print("AUC: %.5f" % metrics.auc(fpr, tpr))
        return np.sqrt(mse) , metrics.auc(fpr, tpr)

    def get_predictions(self):
        return self.predictions

    
   
test = PMF_intercepts(data, 100)       

In [6]:

test.train_SVI(data)


Elbo loss: 16507036.471118927
Elbo loss: 16882178.70968628
Elbo loss: 16316151.067085266
Elbo loss: 16288115.850841522
Elbo loss: 16257655.711566925
Elbo loss: 16270344.331485748
Elbo loss: 16342345.294166565
Elbo loss: 15919929.547546387
Elbo loss: 16114426.330440521
Elbo loss: 15449073.024482727
Elbo loss: 15926813.614624977
Elbo loss: 15586750.751041412
Elbo loss: 15827920.701267242
Elbo loss: 15714445.02986908
Elbo loss: 15497542.100240707
Elbo loss: 15426100.48765564
Elbo loss: 15437820.758758545
Elbo loss: 15585684.149702072
Elbo loss: 15870064.738471985
Elbo loss: 15594255.186466217
Elbo loss: 15578450.936626434
Elbo loss: 15589500.542766571
Elbo loss: 15133441.503730774
Elbo loss: 15588512.186912537
Elbo loss: 15487589.70847702


[16507036.471118927,
 16496506.882324219,
 16725043.243679047,
 16757541.569824219,
 16918150.519577026,
 16833062.85477066,
 16704408.156257153,
 16910661.563728333,
 16673162.07164669,
 17017526.848858833,
 16882178.70968628,
 16528738.827140808,
 16472368.863552094,
 16625404.072410583,
 16564866.971096039,
 16570515.512786865,
 16734607.416477203,
 16479054.055652618,
 16483060.28685379,
 16532277.746879578,
 16316151.067085266,
 16505781.601371765,
 16345559.011214256,
 16548068.182533264,
 16689426.65473175,
 16392662.219470024,
 16446702.082932472,
 16294120.535728455,
 16434014.2137146,
 16128583.033908844,
 16288115.850841522,
 16416165.300584793,
 16276428.832092285,
 16078770.043344498,
 16182646.042268753,
 16126891.797328949,
 16419918.408088684,
 16246004.408210754,
 16403620.063171387,
 16164927.168022156,
 16257655.711566925,
 16272669.908325195,
 15973345.710792542,
 16108965.42982483,
 15875432.667793274,
 16125924.253845215,
 16135030.776966095,
 16107974.334237576,


In [7]:
test.sample_predict(1000)



UA: (1000, 1, 1127, 100)
drug_int: (1000, 1, 1127)
sf_int: (1000, 1, 5237)
VA: (1000, 1, 5237, 100)
target: (1000, 1127, 5237)


In [8]:
test.rmse(data)
print(test.get_predictions())
print(data)


PMF MAP training RMSE: 0.36189
AUC: 0.84249
[[1. 1. 0. ... 1. 1. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 1. 1. 0.]
 ...
 [1. 0. 0. ... 1. 1. 0.]
 [1. 0. 0. ... 1. 1. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[[ 1  0  0 ...  1  0  0]
 [ 0  0  0 ...  0  0  0]
 [ 1  0  0 ...  1  8  0]
 ...
 [ 8  0  0 ... 10 12  0]
 [ 1  0  0 ...  4 25  0]
 [ 0  0  0 ...  0  0  0]]


In [12]:
#Testing!!!

n,m = data.shape
dim=10
alpha_u = 4
alpha_v=5
beta_u=1
beta_v=1
def model():

        drug_plate = pyro.plate("drug_latents", n, dim= -1) #independent users
        sideeffect_plate = pyro.plate("sideeffect_latents", m, dim= -1) #independent items
   

        with drug_plate:
            
            UA = pyro.sample("UA2", dist.Gamma(alpha_u, beta_u).expand([dim]).to_event(1))
            #UA_int = pyro.sample("UAint", dist.Normal(0., 1.))
        
        with sideeffect_plate:
            sideeffect_intercept = pyro.sample("sf_int", dist.HalfNormal(0.5))

            VA = pyro.sample("VA2", dist.Gamma(alpha_v, beta_v).expand([dim]).to_event(1))
        
        u2_plate = pyro.plate("u2_plate", n, dim=-2)
        with u2_plate:
            drug_intercept = pyro.sample("drug_int", dist.HalfNormal(0.5))

        with sideeffect_plate, u2_plate: 
            Y = pyro.sample("target", dist.Poisson( UA@VA.T+sideeffect_intercept.T ) ) 
            return Y

def guide9():

        d_alpha = pyro.param('d_alpha', 5*torch.ones(n,dim), constraint=constraints.positive)#*self.user_mean)
        d_beta = pyro.param('d_beta', 0.05*torch.ones(n,dim), constraint=constraints.positive)
        # int_mean = pyro.param('int_mean', torch.tensor(1.)*self.user_mean)
        # mov_cov = pyro.param('mov_cov', torch.tensor(1.)*0.1,
            #                  constraint=constraints.positive)
        s_alpha = pyro.param('s_alpha', 5*torch.ones(m,dim), constraint=constraints.positive)#*self.item_mean)
        s_beta = pyro.param('s_beta', 0.05*torch.ones(m,dim), constraint=constraints.positive)
        drug_plate = pyro.plate("drug_latents",n, dim= -1) #independent users
        sideeffect_plate = pyro.plate("sideeffect_latents",m, dim= -1) #independent items
        se_t = pyro.param("sef_int", 0.25*torch.ones(m), constraint=constraints.positive)
        drug_t = pyro.param("drug_int_p", 0.25*torch.ones(n), constraint=constraints.positive)
        u2_plate = pyro.plate("u2_plate", n, dim=-2)
        with u2_plate:
            drug_intercept = pyro.sample("drug_int", dist.HalfNormal(drug_t).to_event(1))

        
        with drug_plate: 
            UA = pyro.sample("UA2", dist.Gamma(d_alpha, d_beta).to_event(1))
            # UA_int = pyro.sample("UAint", dist.Normal(int_mean, mov_cov).to_event(1))
     
        
        with sideeffect_plate: 
            VA = pyro.sample("VA2", dist.Gamma(s_alpha, s_beta).to_event(1))
            sideeffect_intercept =  pyro.sample("sf_int", dist.HalfNormal(se_t))


In [13]:
trace=poutine.trace(model).get_trace()
trace.compute_log_prob()
print(trace.format_shapes())

trace=poutine.trace(guide9).get_trace()
trace.compute_log_prob()
print(trace.format_shapes())

          Trace Shapes:               
           Param Sites:               
          Sample Sites:               
      drug_latents dist           |   
                  value      1127 |   
               log_prob           |   
sideeffect_latents dist           |   
                  value      5237 |   
               log_prob           |   
               UA2 dist      1127 | 10
                  value      1127 | 10
               log_prob      1127 |   
            sf_int dist      5237 |   
                  value      5237 |   
               log_prob      5237 |   
               VA2 dist      5237 | 10
                  value      5237 | 10
               log_prob      5237 |   
          u2_plate dist           |   
                  value      1127 |   
               log_prob           |   
          drug_int dist 1127    1 |   
                  value 1127    1 |   
               log_prob 1127    1 |   
            target dist 1127 5237 |   
                  value 1

In [17]:
a = torch.ones(n,m)
print(a.shape)
z = torch.ones(n)
x = torch.ones(m)
b = z[:, np.newaxis] + x
print(b.shape)

torch.Size([1127, 5237])
torch.Size([1127, 5237])
