### Here, I will try to apply bayesian linear regression (http://pyro.ai/examples/bayesian_regression.html) to logistic regression

In [1]:
import numpy as np
import scipy.special as ssp
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.distributions.constraints as constraints

from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

import pyro
import pyro.distributions as dist

from pyro.nn import PyroModule
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam, SGD

pyro.enable_validation(True)
torch.set_default_dtype(torch.double) # this was necessary on the CPU

In [2]:
def build_logistic_dataset(N, p=1, noise_std=0.01):
    X = np.random.randn(N, p)
    
    w = np.random.randn(p)
    w += 2 * np.sign(w)

    y = np.round(ssp.expit(np.matmul(X, w) 
                           + np.repeat(1, N) 
                           + np.random.normal(0, noise_std, size=N)))
    y = y.reshape(N, 1)
    X, y = torch.tensor(X).type(torch.Tensor), torch.tensor(y).type(torch.Tensor)
    data = torch.cat((X, y), 1)
    assert data.shape == (N, p + 1)
    return X, y, w

In [3]:
from pyro.nn import PyroSample

class BayesianLogisticRegression(PyroModule):
    def __init__(self, in_features):
        super().__init__()
        self.linear = PyroModule[nn.Linear](in_features, 1)
        
        #set prior
        self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([1, in_features]).to_event(2))
        self.linear.bias = PyroSample(dist.Normal(0, 10.).expand([1]).to_event(1))
        
        self.sigmoid = PyroModule[nn.Sigmoid]()

    def forward(self, x, y=None):
        
        model_output = self.sigmoid(self.linear(x)).squeeze(-1)
        
        with pyro.plate("data", x.shape[0]):
                        
            obs = pyro.sample("obs", 
                              dist.Bernoulli(probs=model_output),
                              obs=y.squeeze())
        return model_output
    
class BayesianMultiLogisticRegression(PyroModule):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        
        #set prior
        self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
        self.linear.bias = PyroSample(dist.Normal(0, 10.).expand([out_features]).to_event(1))
        
        self.softmax = PyroModule[nn.Softmax]()

    def forward(self, x, y=None):
        
        model_output = self.softmax(self.linear(x)).squeeze(-1)
        
        with pyro.plate("data", x.shape[0]):
                        
            obs = pyro.sample("obs", 
                              dist.Categorical(probs=model_output),
                              obs=y.squeeze())
        return model_output

In [4]:
N, p = 5000, 3
num_iterations = 1500

X_data, y_data, w = build_logistic_dataset(N, p)

In [15]:
from pyro.infer.autoguide import AutoDiagonalNormal

model = BayesianMultiLogisticRegression(p, 2)
# model = BayesianLogisticRegression(p)
guide = AutoDiagonalNormal(model)

In [16]:
from pyro.infer import SVI, Trace_ELBO

adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO())

In [17]:
pyro.clear_param_store()
for j in range(num_iterations):
    # calculate the loss and take a gradient step
    loss = svi.step(X_data, y_data)
    if j % 100 == 0:
        print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(X_data)))

[iteration 0001] loss: 1.7682
[iteration 0101] loss: 0.1524
[iteration 0201] loss: 0.1130
[iteration 0301] loss: 0.0961
[iteration 0401] loss: 0.0867
[iteration 0501] loss: 0.0781
[iteration 0601] loss: 0.0720
[iteration 0701] loss: 0.0701
[iteration 0801] loss: 0.0665
[iteration 0901] loss: 0.0647
[iteration 1001] loss: 0.0637
[iteration 1101] loss: 0.0624
[iteration 1201] loss: 0.0617
[iteration 1301] loss: 0.0614
[iteration 1401] loss: 0.0601


In [8]:
print(w)

[-2.43097623  2.88986634 -2.95624688]


In [9]:
guide.requires_grad_(False)

for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name))

AutoDiagonalNormal.loc Parameter containing:
tensor([ 4.4457, -5.1098,  5.1007, -4.0689,  4.9862, -5.2390, -0.9785,  2.3734])
AutoDiagonalNormal.scale tensor([0.1014, 0.1020, 0.1127, 0.0987, 0.1096, 0.1087, 0.0872, 0.0903])


In [10]:
guide.quantiles([0.5])

{'linear.weight': tensor([[[ 4.4457, -5.1098,  5.1007],
          [-4.0689,  4.9862, -5.2390]]]),
 'linear.bias': tensor([[-0.9785,  2.3734]])}

In [11]:
w * float(guide.quantiles([0.5])['linear.bias'][0][0])

array([ 2.37880362, -2.82784521,  2.89280111])

In [51]:
y_data

tensor([[1.],
        [0.],
        [1.],
        ...,
        [0.],
        [1.],
        [1.]])