In [1]:
import torch
import torch.distributions as torchdist

In [2]:
def toy_poly():
    
    x = 5 * torch.rand(100, 1) 
    linear_op = -3 - 4*x + 1*x**2 
    y = torchdist.Normal(linear_op, 1).sample()
    return x, y

x_train, y_train = toy_poly()

In [3]:
def log_joint_prob(w0, w1, w2, x, y):
    
    prior_w0 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.))
    prior_w1 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.))
    prior_w2 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.))

    linear = w0 + w1*x + w2*x**2
    likelihood = torchdist.Normal(linear, torch.ones_like(linear))
    
    return (
        prior_w0.log_prob(w0) +
        prior_w1.log_prob(w1) +
        prior_w2.log_prob(w2) +
        likelihood.log_prob(y).sum()
    )

In [4]:
variational_params = {
    "w0_loc": torch.nn.Parameter(torch.tensor(0.)),
    "w0_scale_log": torch.nn.Parameter(torch.tensor(0.)),
    "w1_loc": torch.nn.Parameter(torch.tensor(0.)),
    "w1_scale_log": torch.nn.Parameter(torch.tensor(0.)),
    "w2_loc": torch.nn.Parameter(torch.tensor(0.)),
    "w2_scale_log": torch.nn.Parameter(torch.tensor(0.)),
}

def variational_model(variational_params):
    """
    Variational model q(w; eta)
    arg: variational parameters "eta"
    return: w ~ q(w; eta)
    """
    w0_q = torchdist.Normal(
        variational_params["w0_loc"],
        torch.exp(variational_params["w0_scale_log"]),
    )
    
    w1_q = torchdist.Normal(
        variational_params["w1_loc"],
        torch.exp(variational_params["w1_scale_log"]),
    )
    
    w2_q = torchdist.Normal(
        variational_params["w2_loc"],
        torch.exp(variational_params["w2_scale_log"]),
    )
    
    return w0_q, w1_q, w2_q

In [5]:
def kl_divergence(variational_params, x, y):
    w0_q, w1_q, w2_q = variational_model(variational_params)
    
    w0_sample = w0_q.rsample()
    w1_sample = w1_q.rsample()    
    w2_sample = w2_q.rsample()
    
    log_joint_prob_value = log_joint_prob(w0_sample, w1_sample, w2_sample, x, y)
    log_variational_prob_value = (
        w0_q.log_prob(w0_sample) +
        w1_q.log_prob(w1_sample) +
        w2_q.log_prob(w2_sample)
    )
    
    return log_variational_prob_value - log_joint_prob_value

In [6]:
optimizer = torch.optim.SGD(params=variational_params.values(), lr=1e-4)

for i in range(9000):
    optimizer.zero_grad()
    loss_value =kl_divergence(variational_params, x_train, y_train)
    loss_value.backward()
    optimizer.step()
    
    if (i+1) % 300 == 0 or (i==0):
        print(loss_value.detach().numpy())

2010.088
216.45891
151.53426
170.11266
169.32898
149.42986
216.23782
154.24156
156.65796
157.64983
155.47226
169.41145
160.99359
190.32835
200.18675
157.82788
169.57982
158.5978
160.85092
161.54916
158.56343
166.40855
155.73099
169.33304
154.00766
150.02861
162.02048
166.96443
155.15277
155.5115
155.15538


In [7]:
variational_params

{'w0_loc': Parameter containing:
 tensor(-3.2885, requires_grad=True), 'w0_scale_log': Parameter containing:
 tensor(-2.2080, requires_grad=True), 'w1_loc': Parameter containing:
 tensor(-3.9548, requires_grad=True), 'w1_scale_log': Parameter containing:
 tensor(-3.3054, requires_grad=True), 'w2_loc': Parameter containing:
 tensor(1.0232, requires_grad=True), 'w2_scale_log': Parameter containing:
 tensor(-4.7245, requires_grad=True)}