In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
#
from tqdm import tqdm
from torch.distributions.dirichlet import Dirichlet
import math

In [None]:
from utils import *

# Links
- https://arxiv.org/pdf/1901.02739.pdf
- https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/Dirichlet?hl=de-deHarap

In [None]:
alphas = [0.2] * 3

In [None]:
dirdist = Dirichlet(torch.Tensor(alphas))

In [None]:
n_samples = 100
samples = dirdist.sample((n_samples, ))
dir_plot3d(samples)

In [None]:
print(dirdist.log_prob(dirdist.sample((10, ))).mean())
print(dirdist.log_prob(dirdist.sample((100, ))).mean())
print(dirdist.log_prob(dirdist.sample((512, ))).mean())

In [None]:
dirdist_ske = Dirichlet(torch.Tensor([0.9, 0.05, 0.005]))
dirdist_ide = Dirichlet(torch.Tensor([0.1, 0.1, 0.1]))
dirdist_uni = Dirichlet(torch.Tensor([1, 1, 1]))
dirdist_con = Dirichlet(torch.Tensor([10, 6, 7]))
#
dir_dists = [dirdist_ske, dirdist_ide, dirdist_uni, dirdist_con]
dir_names = ["skewed", "sparse", "uniform", "concentrated"]
n_samples = 1000
dir_samples = [dist.sample((n_samples, )) for dist in dir_dists]
#
for x in dir_samples:
    dir_plot3d((x))

In [None]:
lls = []
n_samples = 100
for dist_1 in dir_dists:
    dist_1_samples = dist_1.sample((n_samples, ))
    dist_1_lls = []
    for dist_2 in dir_dists:
        dist_2_ll = dist_2.log_prob(dist_1_samples).mean()
        dist_1_lls.append(dist_2_ll)
    lls.append(dist_1_lls)
#
x = np.array(lls)
plot_mat(x, dir_names, dir_names, title="log likelihoods dist(data)", xlabel="Dist", ylabel="Data")

# 2D Dirichlet

In [None]:
dirdist_ske = Dirichlet(torch.Tensor([0.99, 0.005]))
dirdist_ide = Dirichlet(torch.Tensor([0.01, 0.01]))
dirdist_uni = Dirichlet(torch.Tensor([1, 1]))
dirdist_con = Dirichlet(torch.Tensor([10, 8,]))
#
dir_dists = [dirdist_ske, dirdist_ide, dirdist_uni, dirdist_con]
dir_names = ["skewed", "sparse", "uniform", "concentrated"]
n_samples = 100
dir_samples = [dist.sample((n_samples, )) for dist in dir_dists]

In [None]:
n_samples = 100
for dist in dir_dists:
    x = dist.sample((n_samples,))
    dir_plot2d(x)

In [None]:
lls = []
n_samples = 1000
for dist_1 in dir_dists:
    dist_1_samples = dist_1.sample((n_samples, ))
    dist_1_lls = []
    for dist_2 in dir_dists:
        dist_2_ll = dist_2.log_prob(dist_1_samples).sum()
        dist_1_lls.append(dist_2_ll)
    lls.append(dist_1_lls)
#
x = np.array(lls)
plot_mat(x, dir_names, dir_names, title="log likelihoods dist(data)", xlabel="Dist", ylabel="Data")

In [None]:
n_samples = 10000
for idx in range(len(dir_dists)):
    dist = dir_dists[idx]
    samples = dist.sample((n_samples, ))
    mean = samples.mean(axis=0)
    axes = plt.gca()
    axes.set_ylim([0, 1])
    plt.bar(range(len(mean)), mean)
    plt.title(dir_names[idx])
    plt.show()
    print(samples.mean(axis=0), dist.mean)
    print(samples.var(axis=0), dist.variance)

# BETA

- http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
- https://math.stackexchange.com/questions/257821/kullback-liebler-divergence
- https://dibyaghosh.com/blog/probability/kldivergence.html
-

In [None]:
a1, b1 = 2, 10
a2, b2 = 10, 2
a3, b3 = 0.4, 0.5
n_samples = 100
#

x1 = np.random.beta(a1, b1, size=n_samples)
x2 = np.random.beta(a2, b2, size=n_samples)
x3 = np.random.beta(a3, b3, size=n_samples)
#
beta_plot1d(x1)
beta_plot1d(x2)
beta_plot1d(x3)

In [None]:
a, b = 0.4, 0.5
n_samples = 100
#
x = np.random.beta(a, b, size=n_samples)
beta_plot1d(x)
beta_plot1d(1 - x)

In [None]:
a, b = 2, 5
n_samples = 100
#
x = np.random.beta(a, b, size=n_samples)
beta_plot1d(x)

In [None]:
x = np.random.beta(1/a, 1/b, size=n_samples)
beta_plot1d(x)

In [None]:
from torch.distributions import Beta

def beta_params(X):
    mu = X.mean()
    var = X.var()
    #
    a = ((mu * (1 - mu)) / var - 1) * mu
    b = ((mu * (1 - mu)) / var - 1) * (1 - mu)
    return a, b

def beta_params2(X):
    mu = X.mean()
    var = X.var()
    #
    a = ((1 - mu) / var - (1 / mu)) * mu**2
    b = a * (1 / mu - 1)
    return a, b

def kl_beta_beta_pt(p, q):
    sum_params_p = p.concentration1 + p.concentration0
    sum_params_q = q.concentration1 + q.concentration0
    t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + (sum_params_p).lgamma()
    t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + (sum_params_q).lgamma()
    t3 = (p.concentration1 - q.concentration1) * torch.digamma(p.concentration1)
    t4 = (p.concentration0 - q.concentration0) * torch.digamma(p.concentration0)
    t5 = (sum_params_q - sum_params_p) * torch.digamma(sum_params_p)
    return t1 - t2 + t3 + t4 + t5

def kl_beta_beta(ab_aprx, ab_true, forward=True):
    """
    Calculates either:
        Forward KL: D_kl(P||Q)
        Reverse KL: D_kl(Q||P)
    where:
        P ... True distribution
        Q ... Approximation
    Forward:
        - Mean seeking
        - Where pdf(P) is high, pdf(Q) must be high
    Reverse:
        - Mode seeking
        - where pdf(Q) is high, pdf(P) must be high
    """
    if forward:
        p_a, p_b = ab_aprx
        q_a, q_b = ab_true
    else:
        p_a, p_b = ab_true
        q_a, q_b = ab_aprx
    #
    sum_pab = p_a + p_b
    sum_qab = q_a + q_b
    #
    t1 = q_b.lgamma() + q_a.lgamma() + (sum_pab).lgamma()
    t2 = p_b.lgamma() + p_a.lgamma() + (sum_qab).lgamma()
    t3 = (p_b - q_b) * torch.digamma(p_b)
    t4 = (p_a - q_a) * torch.digamma(p_a)
    t5 = (sum_qab - sum_pab) * torch.digamma(sum_pab)
    return t1 - t2 + t3 + t4 + t5

In [None]:
a1, b1 = torch.Tensor([8]), torch.Tensor([2])
a2, b2 = 1/a1, 1/b1
#
d1 = Beta(a1, b1)
d2 = Beta(a2, b2)
#
n_samples = 100
x1 = d1.sample((n_samples, ))
x2 = d2.sample((n_samples, ))
#
beta_plot1d(x1)
beta_plot1d(x2)

In [None]:
print(kl_beta_beta_pt(d2, d1))
print(kl_beta_beta_pt(d1, d2))

In [None]:
print(kl_beta_beta(ab_aprx=(a2, b2), ab_true=(a1, b1), forward=True))
print(kl_beta_beta(ab_aprx=(a2, b2), ab_true=(a1, b1), forward=False))
#kl_beta_beta2((a1, b1), (a2, b2))

#### Test

In [None]:
def plot_beta_pdf(dist, title=None):
    xx = torch.linspace(0, 1,200)[1:-1]
    plt.plot(xx, torch.exp(dist.log_prob(xx)))
    a, b = float(dist.concentration0), float(dist.concentration1)
    if title is not None:
        plt.title("{} \n a={:.3f}, beta={:.3f}".format(
            title, a, b))
    else:
        plt.title("a={:.3f}, beta={:.3f}".format(a, b))
    plt.show()

In [None]:
n_samples = 10000
#
a_true, b_true = 0.2, 0.5
d_true = Beta(a_true, b_true)
#
d_aprx = Beta(1/b_true, 1/a_true)
x_aprx = d_aprx.sample((n_samples,))
#
a_aprx, b_aprx = beta_params(x_aprx)
a_aprx, b_aprx = 1/b_aprx, 1/a_aprx
#
print("P = Beta({:.3f},{:.3f})".format(a_true, b_true))
print("Q = Beta({:.3f},{:.3f})".format(a_aprx, b_aprx))

In [None]:
plot_beta_pdf(d_true, title="True")
plot_beta_pdf(d_aprx, title="Approx")
plot_beta_pdf(d_aprx, title="Estimated")

In [None]:
a1, b1 = torch.Tensor([0.2]), torch.Tensor([0.8])
#a2, b2 = 1/a1, 1/b1
a3, b3 = 1/b1, 1/a1

d1 = Beta(a1, b1)
#d2 = Beta(a2, b2)
d3 = Beta(a3, b3)
#
plot_beta_pdf(d1)
#plot_beta_pdf(d2, title="GW")
plot_beta_pdf(d3, title="GW")

## Feature Loss

In [None]:
def dir_log_prob(x, alphas):
    #x = torch.stack([x, 1-x], dim=1)
    ll_unomalized = (torch.log(x) * (alphas - 1.0)).sum(-1)
    ll_normalizer = torch.lgamma(alphas.sum(-1)) - torch.lgamma(alphas).sum(-1)
    return ll_unomalized + ll_normalizer

In [None]:
alpha = 0.01
alpha_factor = 10

alphas = torch.Tensor([alpha, alpha * alpha_factor])
n_samples = 500
dist = Dirichlet(alphas)
x = dist.sample((n_samples, ))
dir_plot2d(x)

In [None]:
x1 = x
x2 = torch.stack([x[:,0], 1-x[:,0]], dim=1)
x3 = torch.stack([1 - x[:,1], x[:,1]], dim=1)
#
#print(x1.sum(axis=1))
#print(x2.sum(axis=1))
#print(x3.sum(axis=1))

In [None]:
print(dist.log_prob(x1).sum())
print(dir_log_prob(x1, alphas).sum())

In [None]:
print(dist.log_prob(x2).sum())
print(dir_log_prob(x2, alphas).sum())

In [None]:
print(dist.log_prob(x3).sum())
print(dir_log_prob(x3, alphas).sum())

In [None]:
x = torch.rand((n_samples,))
x1 = torch.stack([x, 1-x], dim=1)
x2 = torch.stack([1 - x, x], dim=1)

In [None]:
print(dist.log_prob(x1).sum())
print(dist.log_prob(x2).sum())

# Gamma

In [None]:
dgamma = torch.distributions.Gamma(alphas[0], alphas[1])
x = dgamma.sample((n_samples, ))
dir_plot2d(torch.stack([x, 1-x], dim=1))
dir_plot2d(torch.stack([1 - x, x], dim=1))

# Test

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
#
from tqdm import tqdm
from torch.distributions.dirichlet import Dirichlet
import math
import torch.nn as nn

In [None]:
def augment(x, rate=0.05):
    n, d = x.shape
    #
    x_aug = x.clone()
    #
    for d_idx in range(d - 1):
        eps = (torch.rand((n,))  * 2 - 1) * rate
        x_aug[:, d_idx] = torch.clamp(x_aug[:, d_idx] + eps, min=0, max=1)
    diff = 1 - x_aug[:, :-1].sum(axis=1) - x_aug[:, -1]
    x_aug[:, -1] += diff
    return x_aug

def calc_lambda(d):
    return 1 / ((d - 1) * 0.0244)

def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

class Net(torch.nn.Module):
    def __init__(self, d_in, d_out,
                 d_hid, n_hid,
                 d_proj, n_proj,
                 alphas,
                 w_ll,
                 w_var,
                 lambd=None):
        super().__init__()
        self.w_var = w_var
        assert n_hid >= 0
        dims = [d_in] + [d_hid] * n_hid
        self.dims = dims
        self.alphas = alphas
        self.dist = Dirichlet(alphas)
        self.w_ll = w_ll
        layers = []
        for idx in range(len(dims) - 1):
            layers.extend([
                nn.Linear(dims[idx], dims[idx + 1]),
                nn.BatchNorm1d(dims[idx + 1]),
                nn.ReLU(inplace=True)
            ])
        layers.extend([
            nn.Linear(dims[-1], d_out),
            #nn.BatchNorm1d(d_out),
            nn.Softmax(dim=1)
        ])
        self.ff = nn.Sequential(*layers)
        self.bn = nn.BatchNorm1d(d_proj, affine=False)
        if lambd is None:
            self.lambd = calc_lambda(d_out)
        else:
            self.lambd = lambd
        
        if n_proj > 0:
            proj_dims = [d_out] + [d_proj] * n_proj
            layers = []
            for i in range(len(proj_dims) - 2):
                layers.append(nn.Linear(proj_dims[i], proj_dims[i + 1], bias=False))
                layers.append(nn.BatchNorm1d(proj_dims[i + 1]))
                layers.append(nn.ReLU(inplace=True))
            layers.append(nn.Linear(proj_dims[-2], proj_dims[-1], bias=False))
            self.projector = nn.Sequential(*layers)
        else:
            self.projector =nn.Identity()
            self.bn = nn.BatchNorm1d(d_out, affine=False)
        
    def representation(self, x):
        return self.ff(x)

    def forward(self, z1, z2):
        z1 = self.ff(z1)
        z2 = self.ff(z2)
        #
        ll = - 1 * self.dist.log_prob(torch.cat([z1, z2], axis=0)).sum()
        #
        l_var = torch.nn.functional.mse_loss(torch.cat([z1, z2], axis=0).var(axis=0), z_dist.variance)
        l_mean = torch.nn.functional.mse_loss(torch.cat([z1, z2], axis=0).mean(axis=0), z_dist.mean)
        l_mom = self.w_var * (l_mean + l_var)
        #
        z1 = self.projector(z1)
        z2 = self.projector(z2)
        #
        c = self.bn(z1).T @ self.bn(z2)
        c.div_(z1.shape[0])
        #
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()    break
        off_diag = off_diagonal(c).pow_(2).sum()
        barlow_loss =  on_diag + self.lambd * off_diag
        #
        loss = barlow_loss + ll * self.w_ll + self.w_var * l_var
        return loss, barlow_loss, ll, l_var

In [None]:
d_in = 3
d_out = 3
d_hid = 32
n_hid = 3
d_proj = 3 * d_out
n_proj = 3
#
w_ll = 0.01
lambd = 100
aw1 = 0.02
aw2 = 0.02
w_var = 5000
#
alphas_z = torch.rand(d_in) * 0.1
alphas_x = torch.Tensor([8, 2, .1])
alphas_z = torch.Tensor([0.01, 0.9, 0.9])
#
x_dist = Dirichlet(alphas_x)
z_dist = Dirichlet(alphas_z)

In [None]:
x = x_dist.sample((10,))
#
x1 = augment(x, aw1)
x2 = augment(x, aw2)
#
dir_plot(x1)
dir_plot(x2)

In [None]:
if d_in == 3:
    n_samples = 1000
    for dist in [x_dist, z_dist]:
        dir_plot(dist.sample((n_samples,)))

In [None]:
model = Net(
    d_in = d_in,
    d_out = d_out,
    d_hid = d_hid,
    n_hid = n_hid,
    d_proj = d_proj,
    n_proj = n_proj,
    alphas = alphas_z,
    w_ll = w_ll,
    w_var = w_var,
    lambd=lambd
)
model

In [None]:
model.eval()
x = x_dist.sample((batch_size, ))
with torch.no_grad():
        z = model.representation(x)
z_real = z_dist.sample((batch_size, ))
dir_plot(z)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-6)
#
n_epochs = 10
n_steps = 1000
batch_size = 1024
#
for epoch_idx in range(n_epochs):
    desc = "Epoch [{:3}/{:3}] {}:".format(epoch_idx, n_epochs, 'train')
    pbar = tqdm(range(n_steps), bar_format= desc + '{bar:10}{r_bar}{bar:-10b}')
    epoch_loss = 0.
    epoch_loss_bt = 0.
    epoch_loss_ll = 0.
    epoch_loss_var = 0.
    epoch_step = 0
    for step_idx in pbar:
        model.train()
        x = x_dist.sample((batch_size, ))
        #
        x1 = augment(x, aw1)
        x2 = augment(x, aw2)
        #
        for param in model.parameters():
            param.grad = None
        loss, barlow_loss, ll_loss, l_var = model(x1, x2)
        loss.backward()
        optimizer.step()
        #
        epoch_step += 1
        epoch_loss += loss.item()
        epoch_loss_bt += barlow_loss.item()
        epoch_loss_ll += ll_loss.item()
        epoch_loss_var += l_var.item()
        #
        pbar.set_postfix({'loss': loss.item(),
                          'barlow': barlow_loss.item(),
                          'll': ll_loss.item(),
                          'lvar': l_var.item()
                         })
        
    
    if epoch_idx % 1 == 0:
        print("   Loss: {:.2f} BL: {:.2f} LL: {:.2f} VL: {:.2f}".format(
            epoch_loss / epoch_step,
            epoch_loss_bt / epoch_step,
            epoch_loss_ll / epoch_step,
            epoch_loss_var / epoch_step
        ))
        if d_out == 3:
            model.eval()
            x = x_dist.sample((batch_size, ))
            with torch.no_grad():
                z = model.representation(x)
            z_real = z_dist.sample((batch_size, ))
            #dir_plot(x)
            dir_plot(z)
            #dir_plot(z_real)

In [None]:
print(torch.nn.functional.mse_loss(x.mean(axis=0),z_dist.mean))
print(torch.nn.functional.mse_loss(x.var(axis=0), z_dist.variance))

In [None]:
x.var(axis=0), z_dist.variance

In [None]:
z_dist.mean

In [None]:
x.mean(axis=0)

In [None]:
n_samples = 1000
model.eval()
x = x_dist.sample((n_samples, ))
with torch.no_grad():
    z_pred = model.representation(x)
z_real = z_dist.sample((n_samples,))
#
dir_plot(x)
dir_plot(z_pred)
dir_plot(z_real)

In [None]:
ll_x = z_dist.log_prob(x).sum()
ll_z_pred = z_dist.log_prob(z_pred).sum()
ll_z_real = z_dist.log_prob(z_real).sum()

In [None]:
print(ll_x)

In [None]:
print(ll_z_pred)

In [None]:
print(ll_z_real)