In [1]:
%cd ..
%pwd
# %cd code
# should be /code
# 
%load_ext autoreload
%autoreload 2

/home/oli/Research/Joe/agent-goals/code


In [2]:
from pdg import PDG
from pdg.rv import Variable as Var, Unit
from pdg.dist import CPT, RawJointDist as RJD
from pdg.alg import interior_pt as ip

import numpy as np

In [3]:
M = PDG()
X = Var.alph("X", 7)
M += X
p = CPT.make_random(Unit, X)
q = CPT.make_random(Unit, X)
M += 'p', p, 'q', q

In [4]:
np.set_printoptions(precision=3)
# ip.cvx_opt_joint(M)
pdat = p.to_numpy().reshape(-1)
qdat = q.to_numpy().reshape(-1)

print('p: \t', pdat)
print('q: \t', qdat)
print('pq: \t', M.factor_product().data)


p: 	 [0.169 0.147 0.184 0.156 0.062 0.108 0.174]
q: 	 [0.095 0.228 0.254 0.146 0.075 0.163 0.039]
pq: 	 [0.109 0.227 0.314 0.154 0.031 0.119 0.046]


In [6]:
def renyi_div(pdata, qdata, alpha):
    return np.log((pdata ** alpha * qdata **(1-alpha)).sum()) / (alpha -1) / np.log(2)


The below shows that the interior point method produces the renyi divergences as one would expect. 

In [31]:

for b in  [1E-6, 1E-4, 1E-2, 0.5, 1, 2, 10, 100, 1E4, 1E6, 1E6]:
    M.set_beta('p', b)
    rjd = ip.cvx_opt_joint(M, also_idef=False)
    alpha = b / (b+1)
    print(
        f"b={b:<12e} \t "
        f"renyi ({ f'α={alpha:.3}':^9}) = {renyi_div(pdat, qdat, alpha):.3}  \t"
        f"Inc: {M.Inc(rjd).real:<10.3}"
        # f"M.betas: {[*M.edges('l,beta')]}"
    )


b=1.000000e-06 	 renyi ( α=1e-06 ) = 2.01e-07  	Inc: 2.02e-07  
b=1.000000e-04 	 renyi (α=0.0001 ) = 2.01e-05  	Inc: 2.01e-05  
b=1.000000e-02 	 renyi (α=0.0099 ) = 0.002  	Inc: 0.002     
b=5.000000e-01 	 renyi ( α=0.333 ) = 0.074  	Inc: 0.074     
b=1.000000e+00 	 renyi (  α=0.5  ) = 0.116  	Inc: 0.116     
b=2.000000e+00 	 renyi ( α=0.667 ) = 0.163  	Inc: 0.163     
b=1.000000e+01 	 renyi ( α=0.909 ) = 0.238  	Inc: 0.238     
b=1.000000e+02 	 renyi ( α=0.99  ) = 0.265  	Inc: 0.265     
b=1.000000e+04 	 renyi (  α=1.0  ) = 0.268  	Inc: 0.268     
b=1.000000e+06 	 renyi (  α=1.0  ) = 0.268  	Inc: 0.268     
b=1.000000e+06 	 renyi (  α=1.0  ) = 0.268  	Inc: 0.268     
