In [None]:
import numpy as np
import time
import torch
import torch.nn as nn
from torch.distributions import Beta
from tqdm import tqdm
#
from utils import *

In [None]:
def beta_plot1d(x, figsize=(10, 1)):
    x = np.array(x).squeeze()
    assert len(x.shape) == 1
    plt.figure(figsize=figsize)
    plt.scatter(x, np.ones(len(x)))
    plt.plot([0, 1], [1, 1])
    plt.xlim([-0.2, 1.2])
    plt.show()


def plot_beta_pdf(dist, title=None):
    xx = torch.linspace(0, 1, 200)[1:-1]
    plt.plot(xx, torch.exp(dist.log_prob(xx)))
    a, b = float(dist.concentration0), float(dist.concentration1)
    if title is not None:
        plt.title("{} \n a={:.3f}, beta={:.3f}".format(
            title, a, b))
    else:
        plt.title("a={:.3f}, beta={:.3f}".format(a, b))
    plt.show()


def beta_params(X):
    mu = X.mean()
    var = X.var()
    #
    a = ((mu * (1 - mu)) / var - 1) * mu
    b = ((mu * (1 - mu)) / var - 1) * (1 - mu)
    return a, b

def beta_params2(X):
    mu = X.mean()
    var = X.var()
    #
    a = ((1 - mu) / var - (1 / mu)) * mu**2
    b = a * (1 / mu - 1)
    return a, b

def kl_beta_beta_pt(p, q):
    sum_params_p = p.concentration1 + p.concentration0
    sum_params_q = q.concentration1 + q.concentration0
    t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + (sum_params_p).lgamma()
    t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + (sum_params_q).lgamma()
    t3 = (p.concentration1 - q.concentration1) * torch.digamma(p.concentration1)
    t4 = (p.concentration0 - q.concentration0) * torch.digamma(p.concentration0)
    t5 = (sum_params_q - sum_params_p) * torch.digamma(sum_params_p)
    return t1 - t2 + t3 + t4 + t5

def kl_beta_beta(ab_aprx, ab_true, forward=True):
    """
    Calculates either:
        Forward KL: D_kl(P||Q)
        Reverse KL: D_kl(Q||P)
    where:
        P ... True distribution
        Q ... Approximation
    Forward:
        - Mean seeking
        - Where pdf(P) is high, pdf(Q) must be high
    Reverse:
        - Mode seeking
        - where pdf(Q) is high, pdf(P) must be high
    """
    if forward:
        p_a, p_b = ab_aprx
        q_a, q_b = ab_true
    else:
        p_a, p_b = ab_true
        q_a, q_b = ab_aprx
    #
    sum_pab = p_a + p_b
    sum_qab = q_a + q_b
    #
    t1 = q_b.lgamma() + q_a.lgamma() + (sum_pab).lgamma()
    t2 = p_b.lgamma() + p_a.lgamma() + (sum_qab).lgamma()
    t3 = (p_b - q_b) * torch.digamma(p_b)
    t4 = (p_a - q_a) * torch.digamma(p_a)
    t5 = (sum_qab - sum_pab) * torch.digamma(sum_pab)
    return t1 - t2 + t3 + t4 + t5

In [None]:
a1, b1 = torch.Tensor([8]), torch.Tensor([2])
a2, b2 = 1/a1, 1/b1
#
d1 = Beta(a1, b1)
d2 = Beta(a2, b2)
#
n_samples = 100
x1 = d1.sample((n_samples, ))
x2 = d2.sample((n_samples, ))
#
beta_plot1d(x1)
beta_plot1d(x2)

In [None]:
print(kl_beta_beta_pt(d2, d1))
print(kl_beta_beta_pt(d1, d2))

In [None]:
print(kl_beta_beta(ab_aprx=(a2, b2), ab_true=(a1, b1), forward=True))
print(kl_beta_beta(ab_aprx=(a2, b2), ab_true=(a1, b1), forward=False))

In [None]:
n_samples = 10000
#
a_true, b_true = 0.2, 0.5
d_true = Beta(a_true, b_true)
#
d_aprx = Beta(1/b_true, 1/a_true)
x_aprx = d_aprx.sample((n_samples,))
#
a_aprx, b_aprx = beta_params(x_aprx)
a_aprx, b_aprx = 1/b_aprx, 1/a_aprx
#
print("P = Beta({:.3f},{:.3f})".format(a_true, b_true))
print("Q = Beta({:.3f},{:.3f})".format(a_aprx, b_aprx))

In [None]:
plot_beta_pdf(d_true, title="True")
plot_beta_pdf(d_aprx, title="Approx")
plot_beta_pdf(d_aprx, title="Estimated")

In [None]:
a1, b1 = torch.Tensor([0.2]), torch.Tensor([0.8])
#a2, b2 = 1/a1, 1/b1
a3, b3 = 1/b1, 1/a1

d1 = Beta(a1, b1)
#d2 = Beta(a2, b2)
d3 = Beta(a3, b3)
#
plot_beta_pdf(d1)
#plot_beta_pdf(d2, title="GW")
plot_beta_pdf(d3, title="GW")

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, d_in, d_out, batch_norm):
        super().__init__()
        if batch_norm:
            self.ff = nn.Sequential(
                nn.Linear(d_in, d_out),
                nn.BatchNorm1d(d_out),
                nn.ReLU(inplace=True)
            )
        else:
            self.ff = nn.Sequential(
                nn.Linear(d_in, d_out),
                nn.ReLU(inplace=True)
            )
    def forward(self, x):
        return self.ff(x)

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
    
class Net(nn.Module):
    def __init__(self, d_in, d_hid, n_hid, d_out, batch_norm=False):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.n_hid = n_hid
        self.d_hid = d_hid
        
        self.dims = [d_in] + [d_hid] * n_hid
        
        layers = []
        for idx in range(len(self.dims) - 1):
            layers.append(BasicBlock(self.dims[idx], self.dims[idx + 1], batch_norm))
        layers.append(nn.Linear(self.dims[-1], d_out))
        layers.append(nn.Sigmoid())
        self.backbone = nn.Sequential(*layers)
    def forward(self, x):
        return self.backbone(x)
    
def beta_params(X):
    mu = X.mean()
    var = X.var()
    #
    a = ((mu * (1 - mu)) / var - 1) * mu
    b = ((mu * (1 - mu)) / var - 1) * (1 - mu)
    return a, b

def kl_beta_beta(ab_aprx, ab_true, forward=True):
    """
    Calculates either:
        Forward KL: D_kl(P||Q)
        Reverse KL: D_kl(Q||P)
    where:
        P ... True distribution
        Q ... Approximation
    Forward:
        - Mean seeking
        - Where pdf(P) is high, pdf(Q) must be high
    Reverse:
        - Mode seeking
        - where pdf(Q) is high, pdf(P) must be high
    """
    if forward:
        p_a, p_b = ab_aprx
        q_a, q_b = ab_true
    else:
        p_a, p_b = ab_true
        q_a, q_b = ab_aprx
    #
    sum_pab = p_a + p_b
    sum_qab = q_a + q_b
    #
    t1 = q_b.lgamma() + q_a.lgamma() + (sum_pab).lgamma()
    t2 = p_b.lgamma() + p_a.lgamma() + (sum_qab).lgamma()
    t3 = (p_b - q_b) * torch.digamma(p_b)
    t4 = (p_a - q_a) * torch.digamma(p_a)
    t5 = (sum_qab - sum_pab) * torch.digamma(sum_pab)
    return t1 - t2 + t3 + t4 + t5

In [None]:
d_in = 1
d_hid = 16
d_out = 1
n_hid = 4
batch_norm = True
#
a_true, b_true = torch.Tensor([0.8, 0.8])
#
a_true_prox = 1 / b_true
b_true_prox = 1 / a_true
#
dist_true = Beta(a_true, b_true)
dist_true_prox = Beta(a_true_prox, b_true_prox)
#
plot_beta_pdf(dist_true, "True")
plot_beta_pdf(dist_true_prox, "Prox")
#
dist_in = torch.distributions.Uniform(0, 1)
x = dist_in.sample((100,))

In [None]:
model = Net(d_in=d_in, d_hid=d_hid, n_hid=n_hid, d_out=d_out, batch_norm=batch_norm)
model.apply(init_weights)
x = torch.rand((512, d_in))

model.eval()
with torch.no_grad():
    z = model(x)
#
beta_plot1d(z.detach().numpy())
beta_plot1d(x.numpy())
print(z.min(), z.max())

In [None]:
model = Net(d_in=d_in, d_hid=d_hid, n_hid=n_hid, d_out=d_out, batch_norm=batch_norm)
#optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
batch_size = 1024
num_steps = 200
num_epochs = 10
for epoch_idx in range(1, num_epochs + 1, 1):
    model.train()
    desc = "Epoch [{:3}/{:3}] {}:".format(epoch_idx, num_epochs, 'train')
    pbar = tqdm(range(num_steps), bar_format= desc + '{bar:10}{r_bar}{bar:-10b}')
    for step in pbar:
        x = dist_in.sample((batch_size, )).view((-1, 1))
        # ##########
        # TRAIN
        # ##########
        for param in model.parameters():
            param.grad = None
        z = model(x)
        a_z, b_z = beta_params(z)
        a_aprx_prox = 1 / b_z
        b_aprx_prox = 1 / a_z
        loss =  kl_beta_beta((a_aprx_prox,b_aprx_prox),
                                 (a_true_prox,b_true_prox),forward=True)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        #print(loss.item())
        pbar.set_postfix({'loss': loss.item(), 'a': a_z.item(), 'b': b_z.item()})
        time.sleep(0.01)
    ############
    # EVAL
    ############
    model.eval()
    with torch.no_grad():
        x = dist_in.sample((batch_size, )).view((-1, 1))
        z = model(x)
        #
        a_z, b_z = beta_params(z)
        dist_aprx = Beta(a_z, b_z)
        plot_beta_pdf(dist_aprx, "aprx")
        time.sleep(0.1)

In [None]:
model.eval()
n_samples = 64
with torch.no_grad():
    x = dist_in.sample((n_samples, )).view((-1, 1))
    z = model(x)
beta_plot1d(z.numpy())

In [None]:
n_samples = 100000
with torch.no_grad():
    x = dist_in.sample((n_samples, )).view((-1, 1))
    z = model(x)

#
a_z, b_z = beta_params(z)
#
dist_aprx = Beta(a_z, b_z)
plot_beta_pdf(dist_aprx, "aprx")

In [None]:
a_z, b_z

In [None]:
a_true, b_true