In [1]:
# import packages
import numpy as np
import seaborn as sns
from scipy.stats import wishart, dirichlet, expon, norm
import scipy.special as sc
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D 

In [2]:
def para_x(u, s2, tao):
    lambdax = np.dot(s2.T, u).reshape(d,d)
    sx = np.linalg.inv(lambdax) # variance of x
    taox = np.dot(tao.T, u) # sum_j(u_j * tao_j)
    mux = np.dot(sx, taox) # mu(x)
    return sx, mux 
def document_generator(a, rho, T, s2, tao, N):
    '''
    Given the corpus, generate more documents.
    All corpus-level parameters are given.
    N: the number of documents.
    
    output: 
    X: N*d, X[i] = document[i]
    Y: Y[i] = label[i]
    G: membership
    U: transformed membership
    '''

    nlabel = len(T) # number of y
    d = len(tao[0]) # dim(x)
    
    Y = np.random.choice(list(range(nlabel)),N) # labels
    G = np.random.dirichlet(a*rho,N)
    U = np.array([np.dot(T[Y[i]], G[i]) for i in range(N)])

    X = []
    
    for i in range(N):
        sx, mux = para_x(U[i], s2, tao)
        X.append(np.random.multivariate_normal(mux, sx))
    X = np.array(X)

    return X, Y, G, U

def check_p(a, rho, mu, s_inv):
    '''
    Input: the parameters used in the data
    1. Check whether a and rho will generate extremely small g
    2. See distributions of different topics
    '''
    ntopic = len(mu)
    g = np.random.dirichlet(a*rho, 100)
#     print("The minimum component of g is",g.min())
    
    print("Distribution of pure types")
    s = [np.linalg.inv(i) for i in s_inv]
    pX = np.concatenate(tuple([np.random.multivariate_normal(mu[i], s[i], 100) for i in range(ntopic)]))
    pY = [[i]*100 for i in range(ntopic)]
    fig_pure_type = plt.figure()
    ax = Axes3D(fig_pure_type)
    ax.scatter(pX[:, 0], pX[:, 1], pX[:, 2], c = pY)
    plt.show()

In [16]:
def simplex_proposal(g, sigma):
    '''
    g is in simplex
    z[i] = log g[i]/g[-1]
    draw newz~N(z, sigma*I)
    newz -> newg
    '''
    z = np.log(g[:-1]/g[-1])
    newz = np.random.multivariate_normal(z, sigma*np.eye(len(z)))
    newg = np.exp(np.append(newz,0))
    newg /= newg.sum()
    return newg

def px(x, u, s2, tao):
    # return 0.5*log|lambda_x|-0.5(x-mux)^T*lambda_x*(x-mux)
    
    lambdax = np.dot(s2.T, u).reshape(d,d)
    sx = np.linalg.inv(lambdax) # variance of x
    taox = np.dot(tao.T, u) # sum_j(u_j * tao_j)
    mux = np.dot(sx, taox) # mu(x)
    
    return 0.5*np.log(np.linalg.det(lambdax))-0.5*np.dot(np.dot((x-mux).T, lambdax), x-mux)

def predict_prob(x,y,a,rho,s2,tao,T,nconverge, nskip, nsave,sigmag):
    nlabel = len(T)
    dg = len(rho)
    ntopic = len(T[0])
    d = len(x)
    infinity = 10**(-5) # To avoid overflow, if we get g or rho smaller than this
    
    prob = []
    oldg = np.random.dirichlet(np.ones(dg))
    oldp = ((a*rho-1)*np.log(oldg)).sum() + px(x, np.dot(T[y], oldg), s2, tao)
    
    for t in range(nconverge):    
        newg = simplex_proposal(oldg, sigmag)
        if newg.min()<infinity: newg = (newg + infinity*np.ones(dg))/(1+infinity)
        newp = ((a*rho-1)*np.log(newg)).sum() + px(x, np.dot(T[y], newg), s2, tao)
        ag = min(1, np.exp(min(1,newp-oldp))) # avoid overflow when newp>>oldp
        if np.random.uniform() < ag:
            oldg = newg
            oldp = newp
            
        
    for step in range(nsave):
        for t in range(nskip):
            newg = simplex_proposal(oldg, sigmag)
            if newg.min()<infinity: newg = (newg + infinity*np.ones(dg))/(1+infinity)
            newp = ((a*rho-1)*np.log(newg)).sum() + px(x, np.dot(T[y], newg), s2, tao)
            ag = min(1, np.exp(min(1,newp-oldp))) # avoid overflow when newp>>oldp
            if np.random.uniform() < ag:
                oldg = newg
                oldp = newp  
        prob.append(newp) 
    
    return sum(prob)/nsave
# sss=[1,2,3]
# print(mean(sss))

In [13]:
d = 30
k0, k1 = 2,5
nlabel = 3
alpha = np.ones(7)
b = 0.1
para_topic = [np.zeros(d),0.1,np.eye(d)/3/d, 3*d]
N, sigmaa, sigmarho, sigmag, nconverge, nskip, nsave =  50, 1,1,1,10, 10,500
# define T
dg = k0 + k1 
ntopic = nlabel*k0+k1
T = []
for i in range(nlabel):
    tem = np.block([
        [np.zeros((k0*i,k0+k1))],
        [np.eye(k0), np.zeros((k0, k1))],
        [np.zeros((k0*(nlabel-i-1),k0+k1))],
        [np.zeros((k1,k0)), np.eye(k1)]
    ])
    T.append(tem)
# draw corpus-level parameters
rho = np.random.dirichlet(alpha, 1)[0]
a = np.random.exponential(1/b,1)[0]

mu0, lambda0, W, nu = para_topic
d = len(mu0)
s_inv = wishart.rvs(df = nu, scale = W, size=ntopic) # sigma inverse
mu = np.array([np.random.multivariate_normal(mu0, 1/lambda0*np.linalg.inv(i)) for i in s_inv])    
s2 = np.array([i.flatten() for i in s_inv]) # flatten s_inv, s2[i].reshape(d,d) = s_inv[i]
tao = np.array([np.dot(s_inv[i], mu[i]) for i in range(ntopic)]) # sigma^-1 mu    
# generate data
X, Y, G, U = document_generator(a, rho, T, s2, tao, N)
print("the value of a:", a)

the value of a: 13.24487022040325


In [17]:
prediction = np.zeros((N,nlabel))
for i in range(N):
    for j in range(nlabel):
        prediction[i][j] = predict_prob(X[i],j,a,rho,s2,tao,T,nconverge=500, nskip=10, nsave=100,sigmag=0.05)

In [21]:
pred = np.exp(prediction)
for i in range(N): pred[i]/=pred[i].sum()

In [29]:
trueY = np.zeros((N,nlabel))
# print(Y)
for i in range(N): trueY[i][Y[i]] = 1
# print(trueY)
# print(pred)
error = trueY-pred
error.sum()

5.999026145633857e-17

In [30]:
prediction = np.zeros((N,nlabel))
for i in range(N):
    for j in range(nlabel):
        prediction[i][j] = predict_prob(X[i],j,5,rho,s2,tao,T,nconverge=500, nskip=10, nsave=100,sigmag=0.05)
pred = np.exp(prediction)
for i in range(N): pred[i]/=pred[i].sum()
trueY = np.zeros((N,nlabel))
for i in range(N): trueY[i][Y[i]] = 1
error = trueY-pred
error.sum()

-4.673930909012429e-16

In [31]:
prediction = np.zeros((N,nlabel))
for i in range(N):
    for j in range(nlabel):
        prediction[i][j] = predict_prob(X[i],j,1,rho,s2,tao,T,nconverge=500, nskip=10, nsave=100,sigmag=0.05)
pred = np.exp(prediction)
for i in range(N): pred[i]/=pred[i].sum()
trueY = np.zeros((N,nlabel))
for i in range(N): trueY[i][Y[i]] = 1
error = trueY-pred
error.sum()

-1.1102230246251565e-16