In [None]:
import numpy as np
import time
import sys
from random import randint, sample
from datetime import datetime
import tensorflow as tf
import tensorflow_probability as tfp
from tqdm import tqdm, trange
from matplotlib import pyplot as plt
tfd, tfb = tfp.distributions, tfp.bijectors
tfm, tfla, tfr, tfs = tf.math, tf.linalg, tf.random, tf.sparse

In [None]:
def kron(A, B):
    tmp1 = A[None, None, :, :] * B[:, :, None, None]
    shape = [tf.shape(A)[0]*tf.shape(B)[0], tf.shape(A)[1]*tf.shape(B)[1]]
    return tf.reshape(tf.transpose(tmp1, [0, 2, 1, 3]), shape)

In [None]:
def updateAlpha(params, dtype=np.float64):
    
    EtaList = params['Eta']
    
    sDim = params['sDim']
    alphapwList = params['alphapw']
        
    LiWgList = params['LiWg']
    detWgList = params['detWg']
    
    nr = len(EtaList)
    AlphaList = [None] * nr
    for r, (Eta, LiWg, detWg, alphapw) in enumerate(zip(EtaList, LiWgList, detWgList, alphapwList)):
        np = Eta.shape[0]
        nf = tf.cast(tf.shape(Eta)[1], tf.int32)
        if sDim[r] > 0:
            EtaTiWEta = tf.reduce_sum(tf.matmul(LiWg, Eta)**2, axis=1)
            logLike = tfm.log(alphapw[:,1]) - 0.5*detWg - 0.5*tfla.matrix_transpose(EtaTiWEta)    
            like = tfm.exp(logLike - tf.math.reduce_logsumexp(logLike, axis=-1, keepdims=True))
            AlphaList[r] = tfr.categorical(like, 1, dtype=tf.int64)
        else:
            AlphaList[r] = tf.zeros([nf,1], tf.int64)
            
    return AlphaList

In [None]:
def updateEta(params, dtype=np.float64):
    
    sigma = params['sigma']
    
    sDim = params['sDim']
    Pi = params['Pi']

    Z = params['Z']
    Beta = params['BetaLambda']['Beta']
    EtaList = params['Eta']
    LambdaList = params['BetaLambda']['Lambda']
    AlphaList = params['Alpha']
    X = params['X']
    iWgList = params['iWg']
                         
    npVec = tf.reduce_max(Pi, 0) + 1

    nr = len(LambdaList)

    LFix = tf.matmul(X, Beta)

    LRanLevelList = [None] * nr
    for r, (Eta, Lambda) in enumerate(zip(EtaList, LambdaList)):
        LRanLevelList[r] = tf.matmul(tf.gather(Eta, Pi[:, r]), Lambda)

    iD = tf.ones_like(Z) * sigma**-2

    EtaListNew = [None] * nr

    for r, (Eta, Lambda, Alpha, iWg) in enumerate(
        zip(EtaList, LambdaList, AlphaList, iWgList)
    ):

        S = (
            Z
            - LFix
            - sum([LRanLevelList[rInd] for rInd in np.setdiff1d(np.arange(nr), r)])
        )

        nf = tf.cast(tf.shape(Lambda)[-2], tf.int64)

        LamInvSigLam = tf.scatter_nd(
            Pi[:, r, None],
            tf.einsum("hj,ij,kj->ihk", Lambda, iD, Lambda),
            tf.stack([npVec[r], nf, nf]),
        )

        mu0 = tf.scatter_nd(
            Pi[:, r, None],
            tf.matmul(iD * S, Lambda, transpose_b=True),
            tf.stack([npVec[r], nf]),
        )

        if sDim[r] > 0:
            Eta = modelSpatialFull(
                Eta, Lambda, LamInvSigLam, mu0, Alpha, iWg, npVec[r], nf
            )
        else:
            Eta = modelNonSpatial(
                Eta, Lambda, LamInvSigLam, mu0, npVec[r], nf
            )

        EtaListNew[r] = Eta

        LRanLevelList[r] = tf.matmul(
            tf.gather(EtaListNew[r], Pi[:, r]), Lambda
        )

    return EtaListNew

def modelSpatialFull(Eta, Lambda, LamInvSigLam, mu0, Alpha, iWg, np, nf, dtype=np.float64):
    iWs = tf.reshape(
        tf.transpose(
            tfla.diag(
                tf.transpose(tf.gather(iWg, tf.squeeze(Alpha, -1)), [1, 2, 0])
            ),
            [2, 0, 3, 1],
        ),
        [nf * np, nf * np],
    )
    iUEta = iWs + tf.reshape(
        tf.transpose(
            tfla.diag(tf.transpose(LamInvSigLam, [1, 2, 0])), [0, 2, 1, 3]
        ),
        [nf * np, nf * np],
    )
    LiUEta = tfla.cholesky(iUEta)
    mu1 = tfla.triangular_solve(
        LiUEta, tf.reshape(tf.transpose(mu0), [nf * np, 1])
    )
    eta = tfla.triangular_solve(
        LiUEta, mu1 + tfr.normal([nf * np, 1], dtype=dtype), adjoint=True
    )
    Eta = tf.transpose(tf.reshape(eta, [nf, np]))
    return Eta

def modelNonSpatial(Eta, Lambda, LamInvSigLam, mu0, np, nf, dtype=np.float64):
    iV = tf.eye(nf, dtype=dtype) + LamInvSigLam
    LiV = tfla.cholesky(iV + tf.eye(nf, dtype=dtype))
    mu1 = tfla.triangular_solve(LiV, tf.expand_dims(mu0, -1))
    Eta = tf.squeeze(
        tfla.triangular_solve(
            LiV, mu1 + tfr.normal([np, nf, 1], dtype=dtype), adjoint=True
        ),
        -1,
    )
    return Eta

In [None]:
def updateGammaV(params, dtype=np.float64):
    
    Beta = params['BetaLambda']['Beta']
    Gamma = params['GammaV']['Gamma']
    iV = params['GammaV']['iV']
    
    T = params['T']
    mGamma = params['mGamma']
    iUGamma = params['iUGamma']
    V0 = params['V0']
    f0 = params['f0']
    
    #nc, ns = Beta.shape
    nc = tf.shape(Beta)[0]
    ns = tf.shape(Beta)[1]
    #print([nc, ns])
    nt = Gamma.shape[-1]
    Mu = tf.matmul(Gamma, T, transpose_b=True)
    E = Beta - Mu
    A = tf.matmul(E, E, transpose_b=True)
    Vn = tfla.cholesky_solve(tfla.cholesky(A+V0), tf.eye(nc, dtype=dtype))
    LVn = tfla.cholesky(Vn)
    iV = tfp.distributions.WishartTriL(tf.cast(f0+ns, dtype), LVn).sample()

    iSigmaGamma = iUGamma + kron(tf.matmul(T, T, transpose_a=True), iV)
    L = tfla.cholesky(iSigmaGamma)
    mg0 = tf.matmul(iUGamma, mGamma[:, None]) + tf.reshape(tf.matmul(iV, tf.matmul(Beta, T)), [nc*nt, 1])
    mg1 = tfla.triangular_solve(L, mg0)
    Gamma = tf.reshape(tfla.triangular_solve(L, mg1 + 0*tfr.normal([nc*nt, 1], dtype=dtype), adjoint=True), [nc, nt])
    return {'Gamma': Gamma, 'iV': iV}

In [None]:
def updateZ(params, dtype=np.float64):
    
    Beta = params['BetaLambda']['Beta']
    EtaList = params['Eta']
    LambdaList = params['BetaLambda']['Lambda']
    sigma = params['sigma']
    
    Y = params['Y']
    X = params['X']
    Pi = params['Pi']
    distr = params['distr']
        
    ny, ns = Y.shape
    nr = len(EtaList)
    LFix = tf.matmul(X, Beta)
    LRanLevelList = [None] * nr
    for r, (Eta, Lambda) in enumerate(zip(EtaList, LambdaList)):
        LRanLevelList[r] = tf.matmul(tf.gather(Eta, Pi[:,r]), Lambda)
    L = LFix + sum(LRanLevelList)
    Yo = tfm.logical_not(tfm.is_nan(Y))
    
    # no data augmentation for normal model in columns with continious unbounded data
    indColNormal = tf.squeeze(tf.where(distr[:,0] == 1), -1)
    YN  = tf.gather(Y,  indColNormal, axis=-1)
    YoN = tf.gather(Yo, indColNormal, axis=-1)
    LN = tf.gather(L, indColNormal, axis=-1)
    sigmaN = tf.gather(sigma, indColNormal)
    ZN = tf.cast(YoN, dtype)*YN + \
        (1-tf.cast(YoN, dtype))*(LN + tfr.normal([ny, tf.size(indColNormal)], dtype=dtype)*sigmaN)

    # Albert and Chib (1993) data augemntation for probit model in columns with binary data 
    indColProbit = tf.squeeze(tf.where(distr[:,0] == 2), -1)
    YP  = tf.gather(Y,  indColProbit, axis=-1)
    YoP = tf.gather(Yo, indColProbit, axis=-1)
    LP = tf.gather(L, indColProbit, axis=-1)
    sigmaP = tf.gather(sigma, indColProbit)
    low  = tf.where(tfm.logical_or(YP == 0, tfm.logical_not(YoP)), tf.cast(-np.inf, dtype), tf.zeros_like(YP))
    high = tf.where(tfm.logical_or(YP == 1, tfm.logical_not(YoP)), tf.cast( np.inf, dtype), tf.zeros_like(YP))
    ZP = tfd.TruncatedNormal(loc=LP, scale=sigmaP, low=low, high=high, name='TruncatedNormal').sample()

    ZStack = tf.concat([ZN,ZP], -1)
    indColStack = tf.concat([indColNormal,indColProbit], 0)
    ZNew = tf.transpose(tf.scatter_nd(indColStack[:,None], tf.transpose(ZStack), Y.shape[::-1]))
    return ZNew

In [None]:
def updateBetaLambda(params, dtype=np.float64):
    
    Z = params['Z']
    Gamma = params['GammaV']['Gamma']
    iV = params['GammaV']['iV']
    EtaList = params['Eta']
    PsiList = params['PsiDelta']['Psi']
    DeltaList = params['PsiDelta']['Delta']
    sigma = params['sigma']
    X = params['X']
    T = params['T']
    Pi = params['Pi']
        
    ny, nc = X.shape
    _, ns = Z.shape
    nr = len(EtaList)
    nfVec = tf.stack([tf.shape(Eta)[-1] for Eta in EtaList])
    nfSum = tf.reduce_sum(nfVec)

    EtaListFull = [None] * nr
    for r, Eta in enumerate(EtaList):
        EtaListFull[r] = tf.gather(Eta, Pi[:,r])

    XE = tf.concat([X] + EtaListFull, axis=-1)
    GammaT = tf.matmul(Gamma, T, transpose_b=True)
    Mu = tf.concat([GammaT, tf.zeros([nfSum, ns], dtype)], axis=0)
    LambdaPriorPrec = tf.concat([Psi * tfm.cumprod(Delta, -2) for Psi, Delta in zip(PsiList, DeltaList)], axis=-2)

    iK11_op = tfla.LinearOperatorFullMatrix(iV)
    iK22_op = tfla.LinearOperatorDiag(tf.transpose(LambdaPriorPrec))
    iK = tfla.LinearOperatorBlockDiag([iK11_op, iK22_op]).to_dense()
    iU = iK + tf.matmul(XE, XE, transpose_a=True)/(sigma**2)[:, None, None]
    LiU = tfla.cholesky(iU)
    A = tf.matmul(iK, tf.transpose(Mu)[:,:,None]) + (tf.matmul(Z, XE, transpose_a=True)/(sigma**2)[:, None])[:,:,None]
    M = tfla.cholesky_solve(LiU, A)
    BetaLambda = tf.transpose(tf.squeeze(M + tfla.triangular_solve(LiU, tf.random.normal(shape=[ns, nc+nfSum, 1], dtype=dtype), adjoint=True), -1))
    BetaLambdaList = tf.split(BetaLambda, tf.concat([tf.constant([nc], tf.int32), nfVec], -1), axis=-2)
    BetaNew, LambdaListNew = BetaLambdaList[0], BetaLambdaList[1:]
    return {'Beta': BetaNew, 'Lambda': LambdaListNew}


In [None]:
def updateSigma(params, dtype=np.float64):
    
    Z = params['Z']
    Beta = params['BetaLambda']['Beta']
    EtaList = params['Eta']
    LambdaList = params['BetaLambda']['Lambda']
    sigma = params['sigma']
    
    Y = params['Y']
    X = params['X']
    Pi = params['Pi']
    distr = params['distr']
    aSigma = params['aSigma']
    bSigma = params['bSigma']
    
    nr = len(EtaList)
    indVarSigma = tf.cast(tf.equal(distr[:,1], 1), dtype)
    LFix = tf.matmul(X, Beta)
    LRanLevelList = [None] * nr
    for r, (Eta, Lambda) in enumerate(zip(EtaList, LambdaList)):
        LRanLevelList[r] = tf.matmul(tf.gather(Eta, Pi[:,r]), Lambda)

    L = LFix + sum(LRanLevelList)
    Eps = Z - L
    
    alpha = aSigma + Y.shape[0]/2.
    beta = bSigma + tf.reduce_sum(Eps**2, axis=0)/2.
    isigma2 = tfp.distributions.Gamma(concentration=alpha, rate=beta).sample()
    sigmaNew = indVarSigma*tfm.rsqrt(isigma2) + (1-indVarSigma)*sigma
    return sigmaNew

In [None]:
def updateLambdaPriors(params, dtype=np.float64):
    
    LambdaList = params['BetaLambda']['Lambda']
    DeltaList =  params['PsiDelta']['Delta']
    
    nu = params['nu']
    a1 = params['a1']
    b1 = params['b1']
    a2 = params['a2']
    b2 = params['b2']
    
    nr = len(LambdaList)
    PsiNew, DeltaNew = [None] * nr, [None] * nr
    for r, (Lambda, Delta) in enumerate(zip(LambdaList, DeltaList)):
        ns = Lambda.shape[-1]
        nf = tf.shape(Lambda)[0]
        if nf > 0:
            aDelta = tf.concat([a1[r]*tf.ones([1,1], dtype), a2[r]*tf.ones([nf-1,1], dtype)], 0)
            bDelta = tf.concat([b1[r]*tf.ones([1,1], dtype), b2[r]*tf.ones([nf-1,1], dtype)], 0)
            Lambda2 = Lambda**2
            Tau = tfm.cumprod(Delta, 0)
            aPsi = nu[r]/2. + 0.5
            bPsi = nu[r]/2. + Lambda2 * Tau
            PsiNew[r] = tf.squeeze(tfr.gamma([1], aPsi, bPsi, dtype=dtype), 0)
            M = PsiNew[r] * Lambda2
            rowSumM = tf.reduce_sum(M, 1)
            DeltaNew[r] = Delta
            for h in range(nf):
                Tau = tfm.cumprod(DeltaNew[r], 0)
                ad = aDelta[h,:] + 0.5*ns*tf.cast(nf-h,dtype)
                bd = bDelta[h,:] + 0.5*tf.reduce_sum(Tau[h:,:]*rowSumM[h:,None], 0) / DeltaNew[r][h,:]
                DeltaNew[r] = tf.tensor_scatter_nd_update(DeltaNew[r], [[h]], tfr.gamma([1], ad, bd, dtype=dtype))
        else:
            PsiNew[r] = tf.zeros([0,ns], dtype)
            DeltaNew[r] = tf.zeros([0,1], dtype)
    return {'Psi': PsiNew, 'Delta': DeltaNew}

In [None]:
def updateNf(params, dtype=np.float64):
    
    EtaList = params['Eta']
    LambdaList = params['BetaLambda']['Lambda']
    PsiList = params['PsiDelta']['Psi']
    DeltaList =  params['PsiDelta']['Delta']

    # iter ???
    iter = 1
    
    nu = params['nu']
    a2 = params['a2']
    b2 = params['b2']
    nfMin = params['nfMin']
    nfMax = params['nfMax']
        
    c0 = 1
    c1 = 0.0005
    epsilon = 1e-3 # threshold limit
    prop = 1.00 # proportion of redundant elements within columns
    prob = 1/tf.exp(c0 + c1*tf.cast(iter, dtype)) # probability of adapting
    
    nr = len(LambdaList)
    EtaNew, LambdaNew, PsiNew, DeltaNew = [[None] * nr for i in range(4)] 
    for r, (Eta, Lambda, Psi, Delta) in enumerate(zip(EtaList, LambdaList, PsiList, DeltaList)):
        if tfr.uniform([], dtype=dtype) < prob:
            nf = tf.shape(Lambda)[0]
            _, ns = Lambda.shape
            np = tf.shape(Eta)[0]
            smallLoadingProp = tf.reduce_mean(tf.cast(tfm.abs(Lambda) < epsilon, dtype=dtype), axis=1)
            indRedundant = smallLoadingProp >= prop
            numRedundant = tf.reduce_sum(tf.cast(indRedundant, dtype=dtype))
          
            if nf < nfMax[r] and iter > 20 and numRedundant == 0: #and tf.reduce_all(smallLoadingProp < 0.995):
              EtaNew[r] = tf.concat([Eta, tfr.normal([np,1], dtype=dtype)], axis=1)
              LambdaNew[r] = tf.concat([Lambda, tf.zeros([1,ns], dtype=dtype)], axis=0)
              PsiNew[r] = tf.concat([Psi, tfr.gamma([1,ns], nu[r]/2, nu[r]/2, dtype=dtype)], axis=0)
              DeltaNew[r] = tf.concat([Delta, tfr.gamma([1,1], a2[r], b2[r], dtype=dtype)], axis=0)
            elif nf > nfMin[r] and numRedundant > 0:
              indRemain = tf.cast(tf.squeeze(tf.where(tfm.logical_not(indRedundant)), -1), tf.int32)
              if tf.shape(indRemain)[0] < nfMin[r]:
                indRemain = tf.concat([indRemain, nf-1-tf.range(nfMin[r]-tf.shape(indRemain)[0])], axis=0)
              EtaNew[r] = tf.gather(Eta, indRemain, axis=1)
              LambdaNew[r] = tf.gather(Lambda, indRemain, axis=0)
              PsiNew[r] = tf.gather(Psi, indRemain, axis=0)
              DeltaNew[r] = tf.gather(Delta, indRemain, axis=0)
            else:
              EtaNew[r], LambdaNew[r], PsiNew[r], DeltaNew[r] = Eta, Lambda, Psi, Delta
        else:
          EtaNew[r], LambdaNew[r], PsiNew[r], DeltaNew[r] = Eta, Lambda, Psi, Delta
    return {'Eta': EtaNew, 'Lambda': LambdaNew, 'Psi': PsiNew, 'Delta': DeltaNew}

In [None]:
class GibbsParameter():
    
    def __init__(self, value, conditional_posterior, posterior_params = None):
        self.value = value
        self.conditional_posterior = conditional_posterior
        self.posterior_params = posterior_params
    
    def __str__(self) -> str:
        pass
    
    def __repr__(self) -> str:
        return str(self.value)
    
    def sample(self, sample_params):
        param_values = {}
        for k, v in sample_params.items():
            if isinstance(v, GibbsParameter):
                param_values[k] = v.value
            else:
                param_values[k] = v
        post_params = param_values
        self.value = self.conditional_posterior(post_params)
        return self.value

In [None]:
class GibbsSampler:
    def __init__(self, params):
        self.params = params

    def single_sample(self, param_name):    
        value = self.params[param_name].sample(self.params)
        self.params[param_name].value = value
        return value
    
    @tf.function
    def sampling_routine(self, num_samples, sample_period = 1, sample_burnin = 0, sample_thining = 1, printRetraceFlag = True):
        if printRetraceFlag:
            print("retracing")
        
        params = self.params
        history = []
        step_num = sample_burnin + num_samples*sample_thining
        for n in range(step_num):
            row = {}
            for key in list(params.keys()):
                if isinstance(params[key], GibbsParameter):
                    row[key] = self.single_sample(key)
            if ((n >= sample_burnin) & (n % sample_period == 0)):
                history.append(row)
        return history

In [10]:
import numpy as np
import time
import sys
from random import randint, sample
from datetime import datetime
import tensorflow as tf
import tensorflow_probability as tfp
from tqdm import tqdm, trange
from matplotlib import pyplot as plt
tfd, tfb = tfp.distributions, tfp.bijectors
tfm, tfla, tfr, tfs = tf.math, tf.linalg, tf.random, tf.sparse

In [18]:
from hmsc.gibbs_sampler import GibbsParameter, GibbsSampler

from hmsc.updaters.updateEta import updateEta
from hmsc.updaters.updateAlpha import updateAlpha
from hmsc.updaters.updateBetaLambda import updateBetaLambda
from hmsc.updaters.updateLambdaPriors import updateLambdaPriors
from hmsc.updaters.updateNf import updateNf
from hmsc.updaters.updateGammaV import updateGammaV
from hmsc.updaters.updateSigma import updateSigma
from hmsc.updaters.updateZ import updateZ

In [19]:
sampler_params = {
    'Z': GibbsParameter(Z, updateZ),
}

startTime = time.time()

gibbs = GibbsSampler(params = {**sampler_params, **prior_params, **random_level_params, **random_level_data_params, **model_data})

postList = [None] * nChains
for chain in range(nChains):
    postList[chain] = gibbs.sampling_routine(3)

elapsedTime = time.time() - startTime
print("\nTF decorated whole cycle elapsed %.1f" % elapsedTime)


retracing


NameError: in user code:

    File "/Users/anisjyu/Dropbox/hmsc-hpc/hmsc-hpc/hmsc/gibbs_sampler.py", line 55, in sampling_routine  *
        row[key] = self.single_sample(key)
    File "/Users/anisjyu/Dropbox/hmsc-hpc/hmsc-hpc/hmsc/gibbs_sampler.py", line 32, in single_sample  *
        value = self.params[param_name].sample(self.params)
    File "/Users/anisjyu/Dropbox/hmsc-hpc/hmsc-hpc/hmsc/gibbs_sampler.py", line 23, in sample  *
        self.value = self.conditional_posterior(post_params)
    File "/Users/anisjyu/Dropbox/hmsc-hpc/hmsc-hpc/hmsc/updaters/updateZ.py", line 19, in updateZ  *
        nr = len(EtaList)

    NameError: name 'EtaList' is not defined


In [None]:
sampler_params = {
    'Z': GibbsParameter(Z, updateZ),
    'BetaLambda': GibbsParameter({'Beta': Beta, 'Lambda': LambdaList}, updateBetaLambda),
    'GammaV': GibbsParameter({'Gamma': Gamma, 'iV': iV}, updateGammaV),
    'PsiDelta': GibbsParameter({'Psi': PsiList, 'Delta': DeltaList}, updateLambdaPriors),
    'Eta': GibbsParameter(EtaList, updateEta),    
    'sigma': GibbsParameter(sigma, updateSigma),
    'Nf': GibbsParameter({'Eta': EtaList, 'Lambda': LambdaList, 'Psi': PsiList, 'Delta': DeltaList}, updateNf),
    'Alpha': GibbsParameter(AlphaList, updateAlpha),
}

startTime = time.time()

gibbs = GibbsSampler(params = {**sampler_params, **prior_params, **random_level_params, **random_level_data_params, **model_data})

postList = [None] * nChains
for chain in range(nChains):
    postList[chain] = gibbs.sampling_routine(3)

elapsedTime = time.time() - startTime
print("\nTF decorated whole cycle elapsed %.1f" % elapsedTime)


In [None]:
len(postList[chain])

In [None]:
postList[chain][0]['Alpha']

In [None]:
def update_test(params):
    t1 = params['sDim']
    t2 = params['sDim'] + 1
    return {'t1': t1, 't2': t2}

update_test(params)

In [None]:
def update_test(params):
    t1 = params['test']['t1'] + 1
    t2 = params['test']['t2']
    return {'t1': t1, 't2': t2}

In [None]:
sampler_params = {
    'test': GibbsParameter({'t1': sDim, 't2': alphapw}, update_test),
}

gibbs = GibbsSampler(params = {**sampler_params, **prior_params, **random_level_params, **random_level_data_params, **model_data})
res = gibbs.sampling_routine(2)
res


In [None]:
sigma

In [None]:
def printFunction(i, samInd, LambdaShapeList):
    outStr = "iteration " + str(i.numpy())
    if samInd.numpy() >= 0:
        outStr += " saving " + str(samInd.numpy())
    else:
        outStr += " transient"
    outStr += " Lambda shape " + str(
        [str(LambdaShape.numpy()) for LambdaShape in LambdaShapeList]
    )
    sys.stdout.write("\r" + outStr)

In [None]:
class GibbsSampler:
    def __init__(self, nChains, samN, thinN):
        self.params = {}

        self.params["nChains"] = nChains
        self.params["samN"] = samN
        self.params["thinN"] = thinN

        self.params["transient"] = samN * thinN

    @tf.function
    def __call__(
        self,
        Z,
        Beta,
        Gamma,
        iV,
        EtaList,
        LambdaList,
        PsiList,
        DeltaList,
        AlphaList,
        sigma,
        modelData,
        priorHyperParList,
        rLParList,
        rLDataParList,
    ):
        startTime = time.time()

        postList = [None] * self.params["nChains"]
        for chain in range(self.params["nChains"]):
            postList[chain] = self.do_sampling(
                Z,
                Beta,
                Gamma,
                iV,
                EtaList,
                LambdaList,
                PsiList,
                DeltaList,
                AlphaList,
                sigma,
                modelData,
                priorHyperParList,
                rLParList,
                rLDataParList,
            )

        elapsedTime = time.time() - startTime
        print("\nTF decorated whole cycle elapsed %.1f" % elapsedTime)

        return postList

    def sample(
        self,
        Z,
        Beta,
        Gamma,
        iV,
        EtaList,
        LambdaList,
        PsiList,
        DeltaList,
        AlphaList,
        sigma,
        itInd,
        modelData,
        priorHyperParList,
        rLParList,
        rLDataParList,
        printRetraceFlag=True,
    ):
        if printRetraceFlag:
            print("retracing")

        sampler = UpdateEta(sigma, rLParList['sDim'], modelData['Pi'])
        EtaListNew = sampler(
            Z, Beta, EtaList, LambdaList, AlphaList, modelData['X'], rLDataParList['iWg']
        )

        return (
            Z,
            Beta,
            Gamma,
            iV,
            EtaListNew,
            LambdaList,
            PsiList,
            DeltaList,
            AlphaList,
            sigma,
        )

    def do_sampling(
        self,
        Z,
        Beta,
        Gamma,
        iV,
        EtaList,
        LambdaList,
        PsiList,
        DeltaList,
        AlphaList,
        sigma,
        modelData,
        priorHyperParList,
        rLParList,
        rLDataParList,
    ):

        _, ns = modelData["Y"].shape
        nr = len(LambdaList)

        samplesGamma, samplesiV, samplesBeta, samplesSigma = [
            tf.TensorArray(dtype, size=self.params["samN"]) for i in range(4)
        ]
        samplesLambdaList = [
            tf.TensorArray(dtype, size=self.params["samN"]) for i in range(nr)
        ]
        samplesEtaList = [
            tf.TensorArray(dtype, size=self.params["samN"]) for i in range(nr)
        ]
        samplesPsiList = [
            tf.TensorArray(dtype, size=self.params["samN"]) for i in range(nr)
        ]
        samplesDeltaList = [
            tf.TensorArray(dtype, size=self.params["samN"]) for i in range(nr)
        ]

        for i in tf.range(
            self.params["transient"] + self.params["samN"] * self.params["thinN"]
        ):
            tf.autograph.experimental.set_loop_options(
                shape_invariants=[
                    (EtaList, [tf.TensorShape([None, None])] * nr),
                    (Beta, tf.TensorShape([None, ns])),
                    (LambdaList, [tf.TensorShape([None, ns])] * nr),
                    (PsiList, [tf.TensorShape([None, ns])] * nr),
                    (DeltaList, [tf.TensorShape([None, 1])] * nr),
                    (AlphaList, [tf.TensorShape([None, 1])] * nr),
                ]
            )
            if i < self.params["transient"]:
                itInd = tf.cast(i, dtype)
            else:
                itInd = tf.constant(np.inf, dtype)

            (
                Z,
                Beta,
                Gamma,
                iV,
                EtaList,
                LambdaList,
                PsiList,
                DeltaList,
                AlphaList,
                sigma,
            ) = self.sample(
                Z,
                Beta,
                Gamma,
                iV,
                EtaList,
                LambdaList,
                PsiList,
                DeltaList,
                AlphaList,
                sigma,
                itInd,
                modelData,
                priorHyperParList,
                rLParList,
                rLDataParList,
            )

            samInd = tf.cast(
                (i - self.params["transient"] + 1) / self.params["thinN"] - 1, tf.int32
            )
            if i % self.params["thinN"] == 0:
                tf.py_function(
                    func=printFunction,
                    inp=[i, samInd, [tf.shape(Lambda) for Lambda in LambdaList]],
                    Tout=[],
                )

            if (
                i >= self.params["transient"]
                and (i - self.params["transient"] + 1) % self.params["thinN"] == 0
            ):
                samplesEtaList = [
                    samplesEta.write(samInd, Eta)
                    for samplesEta, Eta in zip(samplesEtaList, EtaList)
                ]
                """
                samplesGamma = samplesGamma.write(samInd, Gamma)
                samplesiV = samplesiV.write(samInd, iV)
                samplesBeta = samplesBeta.write(samInd, Beta)
                samplesSigma = samplesSigma.write(samInd, sigma)
                samplesLambdaList = [
                    samplesLambda.write(samInd, Lambda)
                    for samplesLambda, Lambda in zip(samplesLambdaList, LambdaList)
                ]
                samplesEtaList = [
                    samplesEta.write(samInd, Eta)
                    for samplesEta, Eta in zip(samplesEtaList, EtaList)
                ]
                samplesPsiList = [
                    samplesPsi.write(samInd, Psi)
                    for samplesPsi, Psi in zip(samplesPsiList, PsiList)
                ]
                samplesDeltaList = [
                    samplesDelta.write(samInd, Delta)
                    for samplesDelta, Delta in zip(samplesDeltaList, DeltaList)
                ]
                # print(samInd, samplesGamma.read(samInd))
                """
        """
        resList = [
            samples.stack()
            for samples in [samplesBeta, samplesGamma, samplesiV, samplesSigma]
        ]
        resList += [[samplesLambda.stack() for samplesLambda in samplesLambdaList]]
        resList += [[samplesEta.stack() for samplesEta in samplesEtaList]]
        resList += [[samplesPsi.stack() for samplesPsi in samplesPsiList]]
        resList += [[samplesDelta.stack() for samplesDelta in samplesDeltaList]]
        """
        resList = [[samplesEta.stack() for samplesEta in samplesEtaList]]

        return resList


sampler = GibbsSampler(nChains=2, samN=10, thinN=10)
postList = sampler(
    Z,
    Beta,
    Gamma,
    iV,
    EtaList,
    LambdaList,
    PsiList,
    DeltaList,
    AlphaList,
    sigma,
    modelData,
    priorHyperParList,
    rLParList,
    rLDataParList,
)

# print([item for item in res])

In [5]:
#path = "/Users/gtikhono/Downloads/importExport/"
path = '/users/anisjyu/Documents/demo-import/'

In [6]:
#
# Option 3. Using jsonify
#

import json

with open(path + 'obj-complete.json') as json_file:
    obj = json.load(json_file)

print(obj.keys())

dict_keys(['Y', 'XData', 'XFormula', 'X', 'XScaled', 'XRRRData', 'XRRRFormula', 'XRRRScaled', 'YScaled', 'XInterceptInd', 'studyDesign', 'ranLevels', 'ranLevelsUsed', 'dfPi', 'rL', 'Pi', 'TrData', 'TrFormula', 'Tr', 'TrScaled', 'TrInterceptInd', 'C', 'phyloTree', 'distr', 'ny', 'ns', 'nc', 'ncNRRR', 'ncRRR', 'ncORRR', 'ncsel', 'nr', 'nt', 'nf', 'ncr', 'ncs', 'np', 'spNames', 'covNames', 'trNames', 'rLNames', 'XScalePar', 'XRRRScalePar', 'YScalePar', 'TrScalePar', 'V0', 'f0', 'mGamma', 'UGamma', 'aSigma', 'bSigma', 'nu', 'a1', 'b1', 'a2', 'b2', 'rhopw', 'nuRRR', 'a1RRR', 'b1RRR', 'a2RRR', 'b2RRR', 'samples', 'transient', 'thin', 'verbose', 'adaptNf', 'initPar', 'repN', 'randSeed', 'postList', 'call', 'HmscVersion', 'repList'])


In [11]:
nChains = int(np.squeeze(len(obj['postList'])))
nChains

2

In [13]:
dtype = np.float64

ny = int(np.squeeze(obj.get('ny'))) # 50
ns = int(np.squeeze(obj.get('ns'))) # 4
nc = int(np.squeeze(obj.get('nc'))) # 3
nt = int(np.squeeze(obj.get('nt'))) # 3
nr = int(np.squeeze(obj.get('nr'))) # 2

nu = np.squeeze([obj.get('rL')[key]['nu'] for key in obj.get('rL').keys()])
a1 = np.squeeze([obj.get('rL')[key]['a1'] for key in obj.get('rL').keys()])
b1 = np.squeeze([obj.get('rL')[key]['b1'] for key in obj.get('rL').keys()])
a2 = np.squeeze([obj.get('rL')[key]['a2'] for key in obj.get('rL').keys()])
b2 = np.squeeze([obj.get('rL')[key]['b2'] for key in obj.get('rL').keys()])

nfMin = np.squeeze([obj.get('rL')[key]['nfMin'] for key in obj.get('rL').keys()])
nfMax = np.squeeze([obj.get('rL')[key]['nfMax'] for key in obj.get('rL').keys()])

sDim = np.squeeze([obj.get('rL')[key]['sDim'] for key in obj.get('rL').keys()])

#alphapw = [obj.get('rL')[key]['alphapw'] for key in obj.get('rL').keys()] # todo
#alphapw = [None, np.abs(np.random.normal(size=[101, 2]))]
alphapw = [np.abs(np.random.normal(size=[101, 2])), np.abs(np.random.normal(size=[101, 2]))]

distr = np.asarray(obj.get('distr')).astype(int)

X = np.asarray(obj.get('X'))
T = np.asarray(obj.get('Tr'))
Y = np.asarray(obj.get('Y'))

Pi = np.asarray(obj.get('Pi')).astype(int) - 1
npVec = Pi.max(axis=0) + 1

nfVec = 3 + np.arange(nr)

mGamma = np.asarray(obj.get('mGamma'))
iUGamma = np.asarray(obj.get('UGamma'))

aSigma = np.asarray(obj.get('aSigma'))
bSigma = np.asarray(obj.get('bSigma'))

V0 = np.squeeze(obj.get('V0'))
f0 = int(np.squeeze(obj.get('f0')))

WgList = [tfr.normal([101,npVec[r],npVec[r]], dtype=dtype) for r in range(nr)]
WgList = [tf.matmul(WgList[r], WgList[r], transpose_a=True) for r in range(nr)] # these MUST be SPD matrices!
iWgList = [tfla.inv(WgList[r]) for r in range(nr)]
LiWgList = [tfla.cholesky(iWgList[r]) for r in range(nr)]
detWgList = [tfr.normal([101], dtype=dtype) for r in range(nr)]

#modelDataList = [Y, X, T, Pi, distr]
#priorHyperParams = [mGamma, iUGamma, f0, V0, aSigma, bSigma]
#rLParList = [[nu[r], a1[r], b1[r], a2[r], b2[r], nfMin[r], nfMax[r], sDim[r], alphapw[r], npVec[r]] for r in range(nr)]
#rLDataParList = [[WgList[r], iWgList[r], LiWgList[r], detWgList[r]] for r in range(nr)]
#rLParList = [nu, a1, b1, a2, b2, nfMin, nfMax, sDim, alphapw]
#rLDataParList = [WgList, iWgList, LiWgList, detWgList]

modelData = {}
modelData['Y'] = Y
modelData['X'] = X
modelData['T'] = T
modelData['Pi'] = Pi
modelData['distr'] = distr

priorHyperParams = {}
priorHyperParams['mGamma'] = mGamma
priorHyperParams['iUGamma'] = iUGamma
priorHyperParams['f0'] = f0
priorHyperParams['V0'] = V0
priorHyperParams['aSigma'] = aSigma
priorHyperParams['bSigma'] = bSigma

rLDataParams = {}
rLDataParams['Wg'] = WgList
rLDataParams['iWg'] = iWgList
rLDataParams['LiWg'] = LiWgList
rLDataParams['detWg'] = detWgList

rLParams = {}

rLParams['nu'] = nu 
rLParams['a1'] = a1 
rLParams['b1'] = b1 
rLParams['a2'] = a2 
rLParams['b2'] = b2
rLParams['nfMin'] = nfMin 
rLParams['nfMax'] = nfMax 
rLParams['sDim'] = sDim
rLParams['alphapw'] = alphapw

np.random.seed(1)
tfr.set_seed(1)

aDeltaList = [tf.concat([a1[r]*tf.ones([1,1], dtype), a2[r]*tf.ones([nfVec[r]-1,1], dtype)], 0) for r in range(nr)]
bDeltaList = [tf.concat([b1[r]*tf.ones([1,1], dtype), b2[r]*tf.ones([nfVec[r]-1,1], dtype)], 0) for r in range(nr)]

Beta = tfr.normal([nc,ns], dtype=dtype)
Gamma = tfr.normal([nc,nt], dtype=dtype)
iV = tf.ones([nc,nc], dtype=dtype) + tf.eye(nc, dtype=dtype)
EtaList = [tfr.normal([npVec[r],nfVec[r]], dtype=dtype) for r in range(nr)]
PsiList = [1 + tf.abs(tfr.normal([nfVec[r],ns], dtype=dtype)) for r in range(nr)]
DeltaList = [np.random.gamma(aDeltaList[r], bDeltaList[r], size=[nfVec[r],1]) for r in range(nr)]
LambdaList = [tfr.normal([nfVec[r],ns], dtype=dtype) for r in range(nr)]
AlphaList = [tf.zeros([nfVec[r],1], dtype=tf.int64) for r in range(nr)]
Z = tf.zeros_like(Y)

sigma = tf.abs(tfr.normal([ns], dtype=dtype))*(distr[:,1]==1) + tf.ones([ns], dtype=dtype)*(distr[:,1]==0)
#sigma = tf.ones(ns, dtype=dtype)
iSigma = 1/sigma


In [14]:
model_data = modelData
prior_params = priorHyperParams
random_level_data_params = rLDataParams
random_level_params = rLParams

In [None]:
%reset

In [None]:
class UpdateEta:
    def __init__(self, params, dtype=np.float64):
        self.params = params
        self.dtype = dtype

    def sampler(params):

        sigma = self.params["sigma"]
        sDim = self.params["sDim"]
        Pi = self.params["Pi"]

        Z = self.params["Z"]
        Beta = self.params["Beta"]
        EtaList = self.params["Eta"]
        LambdaList = self.params["Lambda"]
        AlphaList = self.params["Alpha"]
        X = self.params["X"]
        iWgList = self.params["iWg"]

        npVec = tf.reduce_max(Pi, 0) + 1

        nr = len(LambdaList)

        LFix = tf.matmul(X, Beta)

        LRanLevelList = [None] * nr
        for r, (Eta, Lambda) in enumerate(zip(EtaList, LambdaList)):
            LRanLevelList[r] = tf.matmul(tf.gather(Eta, Pi[:, r]), Lambda)

        iD = tf.ones_like(Z) * sigma**-2

        EtaListNew = [None] * nr

        for r, (Eta, Lambda, Alpha, iWg) in enumerate(
            zip(EtaList, LambdaList, AlphaList, iWgList)
        ):

            S = (
                Z
                - LFix
                - sum([LRanLevelList[rInd] for rInd in np.setdiff1d(np.arange(nr), r)])
            )

            nf = tf.cast(tf.shape(Lambda)[-2], tf.int64)

            LamInvSigLam = tf.scatter_nd(
                Pi[:, r, None],
                tf.einsum("hj,ij,kj->ihk", Lambda, iD, Lambda),
                tf.stack([npVec[r], nf, nf]),
            )

            mu0 = tf.scatter_nd(
                Pi[:, r, None],
                tf.matmul(iD * S, Lambda, transpose_b=True),
                tf.stack([npVec[r], nf]),
            )

            if sDim[r] > 0:
                Eta = modelSpatialFull(
                    Eta, Lambda, LamInvSigLam, mu0, Alpha, iWg, npVec[r], nf
                )
            else:
                Eta = modelNonSpatial(Eta, Lambda, LamInvSigLam, mu0, npVec[r], nf)

            EtaListNew[r] = Eta

            LRanLevelList[r] = tf.matmul(tf.gather(EtaListNew[r], Pi[:, r]), Lambda)

        return EtaListNew

    def modelSpatialFull(
        Eta, Lambda, LamInvSigLam, mu0, Alpha, iWg, np, nf, dtype=np.float64
    ):
        iWs = tf.reshape(
            tf.transpose(
                tfla.diag(
                    tf.transpose(tf.gather(iWg, tf.squeeze(Alpha, -1)), [1, 2, 0])
                ),
                [2, 0, 3, 1],
            ),
            [nf * np, nf * np],
        )
        iUEta = iWs + tf.reshape(
            tf.transpose(
                tfla.diag(tf.transpose(LamInvSigLam, [1, 2, 0])), [0, 2, 1, 3]
            ),
            [nf * np, nf * np],
        )
        LiUEta = tfla.cholesky(iUEta)
        mu1 = tfla.triangular_solve(LiUEta, tf.reshape(tf.transpose(mu0), [nf * np, 1]))
        eta = tfla.triangular_solve(
            LiUEta, mu1 + tfr.normal([nf * np, 1], dtype=dtype), adjoint=True
        )
        Eta = tf.transpose(tf.reshape(eta, [nf, np]))
        return Eta

    def modelNonSpatial(Eta, Lambda, LamInvSigLam, mu0, np, nf, dtype=np.float64):
        iV = tf.eye(nf, dtype=dtype) + LamInvSigLam
        LiV = tfla.cholesky(iV + tf.eye(nf, dtype=dtype))
        mu1 = tfla.triangular_solve(LiV, tf.expand_dims(mu0, -1))
        Eta = tf.squeeze(
            tfla.triangular_solve(
                LiV, mu1 + tfr.normal([np, nf, 1], dtype=dtype), adjoint=True
            ),
            -1,
        )
        return Eta

    def modelSpatialNNGP(self, Eta, Lambda):
        raise NotImplementedError

    def modelSpatialGPP(self, Eta, Lambda):
        raise NotImplementedError

    def __repr__(self):
        raise NotImplementedError

    def __str__(self):
        raise NotImplementedError

    def __getitem__(self, position):
        raise NotImplementedError


# obj = UpdateEta(sigma, sDim, Pi)
# res = obj(Z, Beta, EtaList, LambdaList, AlphaList, modelData, rLDataParList)

# print([res[r].shape for r in range(len(res))])

# print([item for item in res])