In [1]:
import sys
sys.path.append("../")

import numpy as np
import random
import math
from inference import zstates_old_method as zs_old
from utils import model
from utils import hyperparameters

pi_init = 0.1
mu_init = 0.0
sigmabg_init = 0.001
sigma_init = 0.03
tau_init = 1 / (0.005 * 0.005)

x, y, csnps, v = model.simulate(pi = pi_init,
                             mu = mu_init,
                             sigma = sigma_init,
                             sigmabg = sigmabg_init,
                             tau = tau_init)

In [2]:
nhyp = 5
nvar = x.shape[0]
nsample = x.shape[1]
params = np.array([0.1, mu_init, sigma_init, sigmabg_init, tau_init])
scaledparams = hyperparameters.scale(params)
zstates = zs_old.create(scaledparams, x, y, 1, nvar, 0.98)

In [48]:
params = np.array([0.005, mu_init, sigma_init, sigmabg_init, tau_init / 900])

In [49]:
import time
start_time = time.time()

#def log_marginal_likelihood(params, x, y, zstates):
pi = params[0]
mu = params[1]
sigma = params[2]
sigmabg = params[3]
tau = params[4]
nvar = x.shape[0]
nsample = x.shape[1]
constpi = math.pi
marglik_k = 0
lmlzlist = list()
logk = 0

sigma2 = sigma * sigma
sigmabg2 = sigmabg * sigmabg

for i, z in enumerate(zstates):
    nz = len(z)
    muz = np.zeros(nvar)
    muz[z] = mu
    #sigz = np.repeat(sigmabg2, nvar)
    #sigz[z] = sigma2
    
    sigz = np.array([sigmabg2] * nvar)
    sigz[z] = sigma2

    # P(z | theta)
    log_probz = nz * np.log(pi) + (nvar - nz) * np.log(1 - pi)

    # N(y | Mz, Sz)
    #Sz = np.einsum('ki, k, kj -> ij', x, sigz, x)
    Sz = np.dot(np.dot(x.T, np.diag(sigz)), x)
    Sz[np.diag_indices_from(Sz)] += 1/tau
    Mz = np.einsum('ij, i -> j', x, muz)
    logdetS = np.linalg.slogdet(Sz)[1]
    invS = np.linalg.inv(Sz)
    y_minus_M = y - Mz
    nterm = np.einsum('i, ij, j', y_minus_M, invS, y_minus_M)
    log_normz = - 0.5 * (logdetS + (nsample * np.log(2 * constpi)) + nterm)
    
    lmlzlist.append(log_probz + log_normz)
    
lmlzarr = np.array(lmlzlist)
logk = np.max(lmlzarr)
marglik_k = np.sum(np.exp(lmlzarr - logk))
logmarglik = logk + np.log(marglik_k)


print("--- {:f} seconds ---\n".format(time.time() - start_time))
print("logML = {:f}\n".format(-logmarglik))
#print("\n".join(["{:f}".format(x) for x in lmlzlist]))

--- 0.915421 seconds ---

logML = -141.226272



In [7]:
start_time = time.time()

def mat3mul(A, B, C):
    return np.dot(A, np.dot(B, C))

#def log_marginal_likelihood(params, x, y, zstates):
pi = params[0]
mu = params[1]
sigma = params[2]
sigmabg = params[3]
tau = params[4]
nvar = x.shape[0]
nsample = x.shape[1]
constpi = math.pi
kmarglik = 0
lmlzlist = list()
BZinvlist = list()
Sinvlist = list()
logk = 0

sigma2 = sigma * sigma
sigmabg2 = sigmabg * sigmabg
h = 1/sigma2 - 1/sigmabg2

# calculate for zstate = [[]]
nz = 0
sigz0 = np.repeat(sigmabg2, nvar)

log_probz = nz * np.log(pi) + (nvar - nz) * np.log(1 - pi)

B0 = np.linalg.inv(np.diag(sigz0)) + tau * np.dot(x, x.T)
B0inv = np.linalg.inv(B0)
#Sinv = - tau * tau * np.einsum('li, lk, kj -> ij', x, B0inv, x)
Sinv = - tau * tau * np.dot(np.dot(x.T, B0inv), x)
Sinv[np.diag_indices_from(Sinv)] += tau

logB0det = np.linalg.slogdet(B0)[1]
logdetS = - nsample * np.log(tau) + (nvar - nz) * np.log(sigmabg2) + nz * np.log(sigma2) + logB0det

y_minus_M = y
nterm = np.einsum('i, ij, j', y_minus_M, Sinv, y_minus_M)
log_normz = - 0.5 * (logdetS + (nsample * np.log(2 * constpi)) + nterm)

#logk = - log_probz - log_normz
#lmlz = log_probz + log_normz + logk
#kmarglik += np.exp(lmlz)
#lmlzlist.append(lmlz - logk)
lmlzlist.append(log_probz + log_normz)
BZinvlist.append(B0inv)
Sinvlist.append(Sinv)

for i, z in enumerate(zstates[1:]):
    nz = len(z)


    # Start from B0, and update the BZinv and logBZdet sequentially
    base_BZinv = B0inv
    base_logBZdet = logB0det
    Mz = np.zeros(nsample)
    for zpos in z:
        mod = h / (1 + h * base_BZinv[zpos, zpos])
        BZinv = base_BZinv - mod * np.einsum('i, j -> ij', base_BZinv[:,zpos], base_BZinv[zpos,:])
        logBZdet = base_logBZdet + np.log(1 + h * base_BZinv[zpos, zpos])
        base_BZinv = BZinv
        base_logBZdet = logBZdet
        #y_minus_M -= mu * x[zpos, :]
        Mz += mu * x[zpos, :]

    #muz = np.zeros(nvar)
    #muz[z] = mu
    #Mz = np.einsum('ij, i -> j', x, muz)
    y_minus_M = y - Mz

    # P(z | theta)
    log_probz = nz * np.log(pi) + (nvar - nz) * np.log(1 - pi)

    # N(y | Mz, Sz)
    Sinv = tau * np.identity(nsample) - tau * tau * mat3mul(x.T, BZinv, x)
    #Sinv[np.diag_indices_from(Sinv)] += tau

    logdetS = - nsample * np.log(tau) + (nvar - nz) * np.log(sigmabg2) + nz * np.log(sigma2) + logBZdet

    nterm = np.einsum('i, ij, j', y_minus_M, Sinv, y_minus_M)
    log_normz = - 0.5 * (logdetS + (nsample * np.log(2 * constpi)) + nterm)

    # Combine the log values
    lmlzlist.append(log_probz + log_normz)
    BZinvlist.append(BZinv)
    Sinvlist.append(Sinv)
    
    #if i > len(zstates) - 10:
    #    sigz = np.array([sigmabg2] * nvar)
    #    sigz[z] = sigma2
    #    Sz = np.identity(nsample) / tau + np.dot(np.dot(x.T, np.diag(sigz)), x)
    #    logdetS_fromSz = np.linalg.slogdet(Sz)[1]
    #    lambdaz = np.diag(1 / sigz)
    #    lambdazinv = np.diag(sigz)
    #    Bz = lambdaz + tau * np.dot(x, x.T)
    #    logdetS_byBz = -nsample * np.log(tau) + np.linalg.slogdet(lambdazinv)[1] + np.linalg.slogdet(Bz)[1]
    #    print(logdetS_fromSz - logdetS_byBz)


lmlzarr = np.array(lmlzlist)
logk = np.max(lmlzarr)
marglik_k = np.sum(np.exp(lmlzarr - logk))
logmarglik = logk + np.log(marglik_k)

print("--- {:f} seconds ---\n".format(time.time() - start_time))
print("logML = {:f}\n".format(-logmarglik))

--- 0.411265 seconds ---

logML = 6087.048823

