In [1]:
import numpy as np
from scipy.special import gammaln, digamma

In [2]:
import bnpy
reload(bnpy)
calcSummaryStats = bnpy.allocmodel.topics.HDPTopicModel.calcSummaryStats
calcXSummaryStats = bnpy.allocmodel.topics.HDPTopicModel.calcSummaryStats_expansion

In [3]:
alpha = 1.0

# Trying two splits (one on topic 2, another on 3), each growing in size.

We illustrate a split in four parts.

1. First, showing example local parameters for a sample current configuration.
2. Second, we show the proposed local parameters. *The construction method is left for later work, we only discuss the constraints the proposed parameters must abide relative to the originals.*
3. Third, we show how to (trivially) obtain the relevant sufficient statistics for the proposal, via direct calculation. This is easy and affordable for small datasets, but in batch-by-batch processing we cannot touch all documents at once.
4. Finally, we show how collecting batch-specific statistics and aggregating across batches, we can manipulate sufficient statistics to create whole-dataset proposal statistics **identical** to the direct method in step #3. 


### STEP 1: Original parameters

In [4]:
# Create dataset with 3 docs. word content doesnt matter, since we just focus on upper-level inference of topic probs
Data = bnpy.data.WordsData(vocab_size=1, word_id=np.zeros(3), word_count=np.ones(3), doc_range=np.asarray([0,1,2,3]))

Current truncation: K=3 topics.  We show the beta vector (probabilities of active topics) and document-topic assignment counts, for each of the 3 documents.

In [5]:
curDoc_beta = np.asarray([
    [0.2, 0.2, 0.2], 
    [0.4, 0.1, 0.1],
    [0.1, 0.1, 0.1],
    ])

curDocTopicCount = np.asarray([
    [10, 20, 10], 
    [50, 30, 0], 
    [5, 10, 9],
    ])

In [6]:
curDoc_betaRem = 1 - curDoc_beta.sum(axis=1)
curTheta = curDocTopicCount + alpha * curDoc_beta
curThetaRem = alpha * curDoc_betaRem
curdigammaSumTheta = digamma(np.sum(curTheta + curThetaRem, axis=1))
curLP = dict(
    DocTopicCount=curDocTopicCount,
    theta=curTheta,
    thetaRem=curThetaRem,
    resp=np.random.rand(Data.nUniqueToken, curTheta.shape[1]),
    digammaSumTheta=curdigammaSumTheta,
    )

In [7]:
curSS = calcSummaryStats(Data, curLP, doPrecompEntropy=1, doTrackTruncationGrowth=1)

### STEP 2: Create proposal for expansion of topic 2 and topic 3

In [17]:
# Split of topic 2

# Truncation grows from Korig topics,
# to K=Korig+2 new topics at doc 1,
# to K=Korig+3 at doc 2, 
# to K=Korig+4 at doc 3
K_d_2 = np.asarray([2, 3, 4])

xDoc_beta_2 = np.asarray([
    [0.09,  0.11,  0,    0],    # sums to 0.2
    [0.07, 0.01,  0.02, 0],    # sums to 0.1
    [0.02, 0.02,  0.02,  0.04], # sums to 0.1
    ])
xDocTopicCount_2 = np.asarray([
    [10, 10, 0,  0],   # sums to 20
    [5,   0, 25, 0],   # sums to 30
    [0,   3,  3, 4],    # sums to 10
    ])
assert np.allclose(xDoc_beta_2.sum(axis=1), curDoc_beta[:, 1])
assert np.allclose(xDocTopicCount_2.sum(axis=1), curDocTopicCount[:, 1])

In [18]:
# Split of topic 3

# Truncation grows from Korig topics,
# to K=Korig+2 new topics at doc 1,
# to K=Korig+2 at doc 2, 
# to K=Korig+4 at doc 3
K_d_3 = np.asarray([2, 2, 4])

xDoc_beta_3 = np.asarray([
    [0.1,  0.10,  0,    0],    # sums to 0.2
    [0.07, 0.03,  0,    0],    # sums to 0.1
    [0.02, 0.02,  0.03,  0.03], # sums to 0.1
    ])

xDocTopicCount_3 = np.asarray([
    [3, 7, 0,  0],   # sums to 10
    [0,   0, 0, 0],   # sums to 0
    [0,   3,  3, 3],    # sums to 9
    ])
assert np.allclose(xDoc_beta_3.sum(axis=1), curDoc_beta[:, 2])
assert np.allclose(xDocTopicCount_3.sum(axis=1), curDocTopicCount[:, 2])

### Step 3: Create suff stats for *combined* proposal, expanding topic 2 and topic 3 directly

In [23]:
uid_propTotalOrder = np.asarray([0, 200, 201, 300, 301, 202, 203, 302, 303])

uids_2 = [200,201,202,203]
uids_3 = [300,301,302,303]

# Loop over all docs
for d in range(Data.nDoc):
    # Grab the single document
    Data_b = Data.select_subset_by_mask([d])
    K_d = 1 + K_d_2[d] + K_d_3[d]
    
    # Do local and summary step for full combination of both splits
    propDocTopicCount_d = np.hstack([curDocTopicCount[d, :1], 
                                    xDocTopicCount_2[d, :K_d_2[d]],
                                    xDocTopicCount_3[d, :K_d_3[d]],
                                    ])[np.newaxis,:]
    propDoc_beta_d = np.hstack([curDoc_beta[d, :1],
                               xDoc_beta_2[d, :K_d_2[d]],
                               xDoc_beta_3[d, :K_d_3[d]], 
                               ])[np.newaxis,:]
    propLP_2and3_b = dict(
        DocTopicCount=propDocTopicCount_d,
        theta =propDocTopicCount_d+propDoc_beta_d,
        resp=np.random.rand(1, K_d), # doesnt matter
        thetaRem=curThetaRem[d],
        )
    propSS_2and3_b = calcSummaryStats(Data_b, propLP_2and3_b, doPrecompEntropy=1, doTrackTruncationGrowth=1)
    propSS_2and3_b.setUIDs([0] + uids_2[:K_d_2[d]] + uids_3[:K_d_3[d]])
    propSS_2and3_b.reorderComps(uids=uid_propTotalOrder[:K_d], fieldsToIgnore=['sumLogPiRemVec'])
    if d == 0:
        propSS_2and3 = propSS_2and3_b.copy()
    else:        
        Kextra = propSS_2and3_b.K - propSS_2and3.K
        if Kextra > 0:
            propSS_2and3.insertEmptyComps(Kextra, uid_propTotalOrder[K_d-Kextra:K_d])            
        propSS_2and3 += propSS_2and3_b

    # Do local and summary step using only expansion terms from state 2
    propLP_newonly_2b = dict(
        DocTopicCount=xDocTopicCount_2[d,:K_d_2[d]][np.newaxis,:],
        theta=xDocTopicCount_2[d, :K_d_2[d]][np.newaxis,:] + alpha * xDoc_beta_2[d,:K_d_2[d]][np.newaxis,:],
        resp=np.random.rand(1, K_d_2[d]), # doesnt matter
        thetaRem=curThetaRem[d],
        )
    propLP_newonly_2b['digammaSumTheta'] = np.asarray([curLP['digammaSumTheta'][d]])
    propSS_newonly_2b = calcXSummaryStats(Data_b, propLP_newonly_2b, uids=uids_2[:K_d_2[d]],
                                          doPrecompEntropy=1, doTrackTruncationGrowth=1)
    # Do local and summary step using only expansion terms from state 3
    propLP_newonly_3b = dict(
        DocTopicCount=xDocTopicCount_3[d,:K_d_3[d]][np.newaxis,:],
        theta=xDocTopicCount_3[d,:K_d_3[d]][np.newaxis,:] + alpha * xDoc_beta_3[d,:K_d_3[d]][np.newaxis,:],
        resp=np.random.rand(1, K_d_3[d]), # doesnt matter
        thetaRem=curThetaRem[d],
        )
    propLP_newonly_3b['digammaSumTheta'] = np.asarray([curLP['digammaSumTheta'][d]])
    propSS_newonly_3b = calcXSummaryStats(Data_b, propLP_newonly_3b,  uids=uids_3[:K_d_3[d]],
                                          doPrecompEntropy=1, doTrackTruncationGrowth=1)
    if d == 0:
        propSS_newonly_2 = propSS_newonly_2b.copy()
        propSS_newonly_3 = propSS_newonly_3b.copy()
    else:
        Kextra = propSS_newonly_2b.K - propSS_newonly_2.K
        if Kextra > 0:
            propSS_newonly_2.insertEmptyComps(Kextra, newuids=uids_2[K_d_2[d-1]:K_d_2[d]])
        propSS_newonly_2 += propSS_newonly_2b
        
        Kextra = propSS_newonly_3b.K - propSS_newonly_3.K
        if Kextra > 0:
            propSS_newonly_3.insertEmptyComps(Kextra, newuids=uids_3[K_d_3[d-1]:K_d_3[d]])
        propSS_newonly_3 += propSS_newonly_3b
        
    print ''
    print 'sumLogPiRem for doc %d, K[d]=%d' % (d+1, K_d)
    print '------'
    print '2and3 : ', propSS_2and3_b.sumLogPiRemVec
    print ' only2: ', propSS_newonly_2b.sumLogPiRemVec
    print ' only3: ', propSS_newonly_3b.sumLogPiRemVec


sumLogPiRem for doc 1, K[d]=5
------
2and3 :  [ 0.    0.    0.    0.   -6.26]
 only2:  [ 0.   -6.26]
 only3:  [ 0.   -6.26]

sumLogPiRem for doc 2, K[d]=6
------
2and3 :  [ 0.    0.    0.    0.    0.   -6.95]
 only2:  [ 0.    0.   -6.95]
 only3:  [ 0.   -6.95]

sumLogPiRem for doc 3, K[d]=9
------
2and3 :  [ 0.    0.    0.    0.    0.    0.    0.    0.   -4.42]
 only2:  [ 0.    0.    0.   -4.42]
 only3:  [ 0.    0.    0.   -4.42]


In [24]:
propSS_directFromLP = propSS_2and3
print propSS_directFromLP.uids
print uid_propTotalOrder
assert np.allclose(uid_propTotalOrder, propSS_directFromLP.uids)

[  0 200 201 300 301 202 203 302 303]
[  0 200 201 300 301 202 203 302 303]


In [25]:
print propSS_2and3.sumLogPi
print propSS_2and3.sumLogPiRemVec

print propSS_newonly_2.sumLogPiRemVec

[  -3.58  -58.05 -108.66  -75.62  -42.33   -3.46   -1.93   -2.26   -2.26]
[ 0.    0.    0.    0.   -6.26 -6.95  0.    0.   -4.42]
[ 0.   -6.26 -6.95 -4.42]


### STEP 4: Create sufficient stats for combined proposal by manipulating tracked stats

In [26]:
propSS_fromXSS = curSS.copy()
propSS_fromXSS.replaceCompWithExpansion(uid=1, xSS=propSS_newonly_2)
propSS_fromXSS.replaceCompWithExpansion(uid=2, xSS=propSS_newonly_3)
propSS_fromXSS.reorderComps([0,1,2,5,6,3,4,7,8])
assert np.allclose(propSS_fromXSS.uids, uid_propTotalOrder)
assert np.allclose(propSS_directFromLP.uids, uid_propTotalOrder)

In [27]:
np.set_printoptions(linewidth=100, precision=2)
for key in ['sumLogPi', 'gammalnTheta', 'gammalnSumTheta', 'slackTheta', 'slackThetaRem']:
    print key
    if hasattr(propSS_fromXSS, key):
        arr_directFromLP = getattr(propSS_directFromLP, key)
        arr_fromXSS = getattr(propSS_fromXSS, key)
        
    elif propSS_fromXSS.hasELBOTerm(key):
        arr_directFromLP = propSS_directFromLP.getELBOTerm(key)
        arr_fromXSS = propSS_fromXSS.getELBOTerm(key)
        
    print '  %s  direct construction proposal' % (arr_directFromLP)
    print '  %s  from tracked expansion stats' % (arr_fromXSS)
    assert np.allclose(arr_fromXSS, arr_directFromLP)

sumLogPi
  [  -3.58  -58.05 -108.66  -75.62  -42.33   -3.46   -1.93   -2.26   -2.26]  direct construction proposal
  [  -3.58  -58.05 -108.66  -75.62  -42.33   -3.46   -1.93   -2.26   -2.26]  from tracked expansion stats
gammalnTheta
  [ 162.71   20.19   18.36    7.31   10.97   55.56    1.84    0.72    0.72]  direct construction proposal
  [ 162.71   20.19   18.36    7.31   10.97   55.56    1.84    0.72    0.72]  from tracked expansion stats
gammalnSumTheta
  438.778493399  direct construction proposal
  438.778493399  from tracked expansion stats
slackTheta
  [ 0.64  1.41  1.25  2.69  1.37  0.07  0.08  0.07  0.07]  direct construction proposal
  [ 0.64  1.41  1.25  2.69  1.37  0.07  0.08  0.07  0.07]  from tracked expansion stats
slackThetaRem
  8.37808027748  direct construction proposal
  8.37808027748  from tracked expansion stats


In [28]:
propSS_directFromLP.sumLogPiRemVec

array([ 0.  ,  0.  ,  0.  ,  0.  , -6.26, -6.95,  0.  ,  0.  , -4.42])

In [29]:
propSS_fromXSS.sumLogPiRemVec

array([  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  , -13.21,   0.  ,  -4.42])

There is a complication in tracking how the truncation varies across batches. It's not possible to obtain the right statistics purely by keeping an incrementally aggregated whole-dataset statistics for each split separately, and then merging.

However, summing over all batches would do the right thing.