# 混合ロジスティックで混合比を一定にした場合の推定
+ $p(y=1|x,w) = \sum_{k=1}^K \frac{1}{K} r(b_k^T x)$の推定を行う
    + 特に$K$を大きくしたとき、どのようになるか調べる

In [2]:
%matplotlib inline

In [12]:
import sys
sys.path.append("../lib")

import numpy as np
import matplotlib.pyplot as plt
from scipy.special import expit

from util import logcosh

In [3]:
### problem setting
n = 1000
N = 1000
M = 2
X_domain = (-10, 10)
data_seed = 20191123
true_func = lambda x: (x * np.sin(x)).sum(axis = 1)

In [4]:
test = np.random.normal(size = (2,3))

In [5]:
### data generation
np.random.seed(data_seed)
def data_generation(n:int):
    ret_X = np.zeros((n, M))
    base_X = np.random.uniform(low = X_domain[0], high = X_domain[1], size = n)
    for j in range(M):
        ret_X[:,j] = base_X**j
    ret_func = true_func(ret_X)
    ret_prob = expit(ret_func)
    ret_Y = np.random.binomial(n = 1, p = ret_prob, size = n)

    return (ret_X, ret_Y, ret_func, ret_prob)
    
(train_X, train_Y, train_func, train_prob) = data_generation(n)
(test_X, test_Y, test_func, test_prob) = data_generation(N)

In [47]:
### learning setting
learning_seed = 20181123
iteration = 1000
K = 10
pri_beta = 0.0001

In [36]:
### initial learning
np.random.seed(learning_seed)
est_u_xi = np.random.dirichlet(alpha = np.ones(K), size = n)
est_g_eta = np.abs(np.random.normal(size = (n,K)))
est_v_eta = -est_u_xi*np.tanh(np.sqrt(est_g_eta)/2)/(4*np.sqrt(est_g_eta))

in_out_matrix = np.repeat((train_Y - 0.5),M).reshape(n,M) * train_X

In [37]:
sum([np.linalg.slogdet(est_beta[:,:,k])[1] for k in range(K)])

IndexError: index 10 is out of bounds for axis 2 with size 10

In [46]:
### iteration
for ite in range(iteration):
    ### update param posterior
    est_beta = np.repeat(pri_beta * np.eye(M), K).reshape(M,M,K)
    for i in range(M):
        for j in range(M):
            est_beta[i,j,:] += train_X[:,i] * train_X[:,j] @ (-2*est_v_eta)
    est_inv_beta = np.array([np.linalg.inv(est_beta[:,:, k]) for k in range(K)]).transpose((1,2,0))
    est_b = np.zeros((M,K))
    for j in range(M):
        est_b[j,:] = (est_inv_beta[j,:,:] * (in_out_matrix.T @ est_u_xi)).sum(axis = 0)
    
    ### update g_eta
    est_g_eta = np.zeros((n,K))
    for i in range(M):
        for j in range(M):
            est_g_eta += np.repeat(train_X[:,i] * train_X[:,j], K).reshape((n,K)) * np.repeat(est_b[i,:] * est_b[j,:] + est_inv_beta[i,j,:], n).reshape((K,n)).T
    sq_g_eta = np.sqrt(est_g_eta)
    
    ### update h_xi
    est_h_xi = -np.log(K) + in_out_matrix @ est_b - np.log(2) - logcosh(sq_g_eta/2)
    max_est_h_xi = est_h_xi.max(axis = 1)
    norm_est_h_xi = est_h_xi - np.repeat(max_est_h_xi, K).reshape(n,K)
    est_u_xi = np.exp(norm_est_h_xi) / np.repeat(np.exp(norm_est_h_xi).sum(axis = 1), K).reshape(n,K)
    
    ### energy calculation
    energy = n*np.log(K)
    energy += -(np.log(np.exp(norm_est_h_xi).sum(axis = 1)) + max_est_h_xi).sum()
    energy += (est_u_xi * (np.log(2) + logcosh(sq_g_eta/2) )).sum()
    energy += (est_u_xi * est_h_xi).sum()
    energy += (est_v_eta * est_g_eta).sum()
    for i in range(M):
        for j in range(M):
            energy += (-est_beta[i,j,:] * est_b[i,:] * est_b[j,:]/2).sum()
        pass
    energy += sum([np.linalg.slogdet(est_beta[:,:,k])[1]/2 for k in range(K)])-M*K/2*np.log(pri_beta)
    print(energy)
    pass

In [16]:
est_h_xi.shape

(1000, 10)

In [9]:
est_g_eta.shape

(1000, 10)

array([[ 4.78286009e-01, -1.53970557e-03],
       [-1.53970557e-03,  4.95664350e-06]])

In [47]:
debug_g_eta = np.zeros((n,K))
for k in range(K):
    debug_g_eta[:,k] = np.diag(train_X @ (est_b[:,0].reshape((M,1)) @ est_b[:,0].reshape((1,M)) + est_inv_beta[:,:,k]) @ train_X.T)
    pass

In [50]:
np.allclose(est_g_eta, debug_g_eta)

False

In [51]:
est_g_eta

array([[0.53230047, 0.32348865, 0.50860542, ..., 0.38101478, 0.44958491,
        0.50704554],
       [0.65357134, 0.467749  , 0.66127278, ..., 0.51911687, 0.38043045,
        0.58622026],
       [0.56143098, 0.36105132, 0.5489754 , ..., 0.415234  , 0.406066  ,
        0.52036624],
       ...,
       [0.53110276, 0.32179304, 0.50675343, ..., 0.37955351, 0.45277256,
        0.50679397],
       [0.51958419, 0.29683962, 0.4779551 , ..., 0.36239307, 0.56338151,
        0.5213054 ],
       [0.51959236, 0.30241936, 0.48504355, ..., 0.36440399, 0.51187079,
        0.51040381]])

In [48]:
debug_g_eta

array([[0.53230047, 0.53111938, 0.52956119, ..., 0.53068484, 0.53036172,
        0.53236387],
       [0.65357134, 0.64662563, 0.6388145 , ..., 0.64885839, 0.64709297,
        0.64941506],
       [0.56143098, 0.55877346, 0.55475537, ..., 0.55796354, 0.55855264,
        0.55928558],
       ...,
       [0.53110276, 0.52998715, 0.52857988, ..., 0.52962084, 0.5291948 ,
        0.5313191 ],
       [0.51958419, 0.51937075, 0.52226115, ..., 0.52267993, 0.51752407,
        0.52482272],
       [0.51959236, 0.51920296, 0.52025937, ..., 0.52056732, 0.51782072,
        0.52254274]])

In [36]:
est_inv_beta.transpose((1,2,0)).shape

(2, 2, 10)

In [34]:
est_inv_beta.shape

(10, 2, 2)