In [None]:
import sys
sys.path.append("..")
from plot_utils import *

In [None]:
import numpy as np
import time
import torch
import torch.nn as nn
from torch.distributions import Beta
from torch.distributions.dirichlet import Dirichlet
from tqdm import tqdm
#
from utils import *
from functions import *

In [None]:
a, b = torch.Tensor([.1]), torch.Tensor([.1])
d = Beta(a, b)
n_samples = 1000
bn = nn.BatchNorm1d(1, affine=False)
x1 = d.sample((n_samples, ))
x2 = d.sample((n_samples, ))
simplex_plot(x1)
simplex_plot(x2)
plot_beta_pdf(d, "True")

# No COrrelation

In [None]:
def cc_matrix(x1, x2):
    x1 = reshape_(x1)
    x2 = reshape_(x2)
    c = x1.T @ x2
    c.div_(x1.shape[0])
    return c

def reshape_(x):
    if len(x.shape) == 1:
        x = x.reshape((-1, 1))
    return x

In [None]:
x1 = torch.rand((1000,)) * 2 - 1
x2 = torch.rand((1000,)) * 2 - 1
#
x1 = (x1 - x1.mean()) / x1.std()
x2 = (x2 - x2.mean()) / x2.std()

In [None]:
cc_matrix(x1, x2)

# SINUS CROSS CORRELATION

In [None]:
def mean_var(x, axis=0):
    x = np.array(x)
    print("mu={:.3f} var={:.3f}".format(float(x.mean(axis=0)),
                                        float(x.var(axis=0))))

def cc_matrix(x1, x2):
    x1 = reshape_(x1)
    x2 = reshape_(x2)
    c = x1.T @ x2
    c.div_(x1.shape[0])
    return c

def reshape_(x):
    if len(x.shape) == 1:
        x = x.reshape((-1, 1))
    return x

def cc_bn(x1, x2, debug=False, eps=1e-5):
    x1 = reshape_(x1)
    x2 = reshape_(x2)
    
    bn = torch.nn.BatchNorm1d(x1.shape[1], affine=False, eps=eps)
    x1 = bn(x1)
    bn = torch.nn.BatchNorm1d(x1.shape[1], affine=False, eps=eps)
    x2 = bn(x2)
    if debug:
        print("bn(X1)", x1.mean(axis=0), x1.var(axis=0))
        print("bn(X2)", x2.mean(axis=0), x1.var(axis=0))
    #
    return cc_matrix(x1, x2)

def cc_norm(x1, x2, debug=False, eps=1e-5):
    x1 = reshape_(x1)
    x2 = reshape_(x2)
    
    # recenter
    #x1 = x1 - x1.mean()
    #x2 = x2 - x2.mean()
    
    # unit variance
    if eps > 0:
        x1 = (x1 - x1.mean(axis=0)) / torch.sqrt(x1.var(axis=0) + eps)
        x2 = (x2 - x2.mean(axis=0)) / torch.sqrt(x2.var(axis=0) + eps)
    else:
        x1 = (x1 - x1.mean(axis=0)) / x1.std()
        x2 = (x2 - x2.mean(axis=0)) / x2.std()
    
    if debug:
        print("mv(X1)", x1.mean(axis=0), x1.var(axis=0))
        print("mv(X2)", x2.mean(axis=0), x1.var(axis=0))
    #
    return cc_matrix(x1, x2)

def cc_norm2(x1, x2, debug=False, eps=1e-5):
    x1 = reshape_(x1)
    x2 = reshape_(x2)
    
    # recenter
    x1 = x1 - x1.mean(axis=0)
    x2 = x2 - x2.mean(axis=0)
    
    # unit variance
    if eps > 0:
        x1 = x1 / torch.sqrt(x1.var(axis=0) + eps)
        x2 = x2 / torch.sqrt(x2.var(axis=0) + eps)
    else:
        x1 = x1 / x1.std()
        x2 = x2 / x2.std()
    
    if debug:
        print("mv(X1)", x1.mean(axis=0), x1.var(axis=0))
        print("mv(X2)", x2.mean(axis=0), x1.var(axis=0))
    #
    return cc_matrix(x1, x2)

In [None]:
n_samples = 300
xx = torch.linspace(0, 4*np.pi, n_samples)
x_sin_1 = torch.sin(xx + 0.5) + 0.1
x_sin_2 = 2 * torch.sin(xx - 0.2) - 1
x_ran_1 = 2 * torch.rand((n_samples,)) + 1
#
scatter([x_sin_1, x_sin_2, x_ran_1])
#
mean_var(x_sin_1)
mean_var(x_sin_2)
mean_var(x_ran_1)

In [None]:
# PERFECTLY CORRELATED
x1 = x_sin_2
x2 = x_sin_2
scatter([x1, x2])
eps = 1e-9

c_bn = cc_bn(x1, x2, True, eps)
print(c_bn)
#
c_norm = cc_norm(x1, x2, True,eps)
print(c_norm)

In [None]:
# PERFECTLY NEGATIVLY CORRELATED
x1 = x_sin_2
x2 = -x_sin_2
scatter([x1, x2])
c_bn = cc_bn(x1, x2, True)
print(c_bn)
#
c_norm = cc_norm(x1, x2, True)
print(c_norm)

In [None]:
# SOMEHOW CORRELATED
x1 = x_sin_1
x2 = x_sin_2
scatter([x1, x2])
c_bn = cc_bn(x1, x2, True)
print(c_bn)
#
c_norm = cc_norm(x1, x2, True)
print(c_norm)

In [None]:
# NOT CORRELATED
x1 = x_sin_2
x2 = x_ran_1
scatter([x1, x2])
c_bn = cc_bn(x1, x2, True)
print(c_bn)
#
c_norm = cc_norm(x1, x2, True)
print(c_norm)

In [None]:
# constant values
n_samples  = 10000
x1 = torch.rand((n_samples,)) * 2 + 10
x2 = torch.rand((n_samples,)) / 2 - 3
#
scatter([x1, x2])
#
c_norm = cc_norm(x1, x2, False)
print(c_norm)

# CC

In [None]:
n_samples = 300
xx = torch.linspace(0, 4*np.pi, n_samples)
x_sin_1 = torch.sin(xx + 0.5) + 0.1
x_sin_2 = 2 * torch.sin(xx - 0.2) - 1
x_ran_1 = 2 * torch.rand((n_samples,)) + 1
#
scatter([x_sin_1, x_sin_2, x_ran_1])
#
mean_var(x_sin_1)
mean_var(x_sin_2)
mean_var(x_ran_1)

In [None]:
n_samples = 300
#
xx = torch.linspace(0, 4*np.pi, n_samples)
#
x_const = torch.ones((n_samples,))
x_zero = torch.zeros((n_samples, ))
x_rand = torch.rand((n_samples,))
#
x1 = torch.sin(xx + 0.5) + 0.1
x2 = torch.sin(xx + 0.8) - 1
x3 = torch.cos(xx + 3) * 2 + 2
#
scatter([x1,x2,x3, x1+x2])
#
print(cc_norm(x1, x2))
print(cc_norm(x1, x_rand))
print(cc_norm(x2, x_rand))
print(cc_norm(x1 + x2, x_rand))

In [None]:
z = torch.Tensor([[0.2, 0.4, 0.8, 2, 0.1]])
torch.nn.Softmax(dim=1)(z)

# Beta

In [None]:
a, b = torch.Tensor([.1]), torch.Tensor([.9])
dist = Beta(a, b)
n_samples = 512
x = dist.sample((n_samples, ))
simplex_plot(x)
plot_beta_pdf(dist)

In [None]:
n_samples = 10000
x1 = dist.sample((n_samples, ))
x2 = dist.sample((n_samples, ))

print(dist.mean, dist.variance)
mean_var(x1)
mean_var(x2)

In [None]:
xn1 = (x1 - dist.mean) / torch.sqrt(dist.variance)
xn2 = (x2 - dist.mean) / torch.sqrt(dist.variance)

In [None]:
cc_matrix(xn1, xn2)

In [None]:
cc_bn(x1, x2)

In [None]:
cc_norm(x1, x2)

In [None]:
n_samples = [50, 100, 500, 1000]
n_simulations = 50000

errs = []
pbar = tqdm(total=len(n_samples) * n_simulations)
for n in n_samples:
    for _ in range(n_simulations):
        x1 = dist.sample((n, ))
        x2 = dist.sample((n, ))
        #
        xn1 = (x1 - d.mean) / torch.sqrt(d.variance)
        xn2 = (x2 - d.mean) / torch.sqrt(d.variance)
        #
        cc = cc_matrix(xn1, xn2)
        if cc > 1 or cc < -1:
            errs.append((cc.item(), n))
        #time.sleep(0.1)
        pbar.update(1)
pbar.close()

In [None]:
pbar.close()

In [None]:
for cc, n in errs:
    print("cc={:.2f}, n={}".format(cc, n))

In [None]:
ds = list(range(2, 11, 1)) + [128, 256, 512, 1024]
for d in ds:
    n_on_diag = d
    n_of_diag = d**2 - d
    #
    scale_on = 1 / n_on_diag
    scale_of = 1 / n_of_diag
    #
    f_on = n_on_diag * scale_on
    f_of = n_of_diag * scale_of
    #
    #
    s = d / (d**2 - d)
    #
    print("{:>4d}:{:>8d}: {:.4f} {:.4f}: ({}:{}) {:.4f}".format(d, d**2 - d, scale_on, scale_of, f_on, f_of, s))

In [None]:
72 * 0.125

# OLD

In [None]:
c = (x - d.mean).T @ (x - d.mean)
c.div_(x.shape[0])
c.div_(torch.sqrt(d.variance)**2)
print(c)

In [None]:
n = 10000
x1 = torch.rand(n) * 2 - 1
x2 = torch.rand(n) * 2 - 1
#
mean_var(x1)
mean_var(x2)

In [None]:
x1 = torch.rand(n) * 2 - 1
x2 = torch.rand(n) * 2 - 1
x2 = x1

c = (x1 - x1.mean()).T @ (x2 - x2.mean())
c.div_(x1.shape[0])
c.div_(x1.std() * x2.std()) # in (0, 1)
c.add(-1).pow(2)

In [None]:
c - 1

In [None]:
ab1 = [
    torch.Tensor([0.1, 0.2, 0.5, 0.9, 1, 5]),
    torch.Tensor([0.9, 0.8, 0.5, 0.1, 5, 1])
]
ab2 = torch.Tensor([0.1, 0.9])
print(kl_beta_beta(ab1, ab2))
print(kl_beta_beta(ab2, ab1))
print(0.5 * (kl_beta_beta(ab1, ab2) + kl_beta_beta(ab2, ab1)))

In [None]:
def sample_ab(min_val, max_val):
    ab = torch.rand(2)
    if torch.rand(1) > 0.25:
        ab[0] *= max_val
    if torch.rand(1) > 0.25:
        ab[1] *= max_val
    ab = torch.maximum(ab, torch.Tensor([min_val]))
    return ab

def dJD(ab1, ab2):
    return 0.5 * (kl_beta_beta(ab1, ab2) + kl_beta_beta(ab2, ab1))

In [None]:
n_simulations  = 100
min_ab = 1e-3
max_ab = 50
dkl = []
djd = []
for idx in range(n_simulations):
    ab1 = sample_ab(min_ab, max_ab)
    ab2 = sample_ab(min_ab, max_ab)
    dkl.append(kl_beta_beta(ab1, ab2))
    djd.append(dJD(ab1, ab2))

In [None]:
print(min(dkl), max(dkl))
print(min(djd), max(djd))