In [1]:
import torch

<h1>3-1. Latent Dirichlet Allocation<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Data:-Reuters" data-toc-modified-id="Data:-Reuters-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Data: Reuters</a></span></li><li><span><a href="#Model:-Basic-LDA" data-toc-modified-id="Model:-Basic-LDA-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Model: Basic LDA</a></span><ul class="toc-item"><li><span><a href="#Training" data-toc-modified-id="Training-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>Training</a></span></li></ul></li><li><span><a href="#Model:-Smoothed-LDA" data-toc-modified-id="Model:-Smoothed-LDA-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Model: Smoothed LDA</a></span><ul class="toc-item"><li><span><a href="#Collaped-Gibbs-Sampling" data-toc-modified-id="Collaped-Gibbs-Sampling-3.1"><span class="toc-item-num">3.1&nbsp;&nbsp;</span>Collaped Gibbs Sampling</a></span></li><li><span><a href="#Variational-EM" data-toc-modified-id="Variational-EM-3.2"><span class="toc-item-num">3.2&nbsp;&nbsp;</span>Variational EM</a></span><ul class="toc-item"><li><span><a href="#E-step" data-toc-modified-id="E-step-3.2.1"><span class="toc-item-num">3.2.1&nbsp;&nbsp;</span>E-step</a></span></li><li><span><a href="#M-step" data-toc-modified-id="M-step-3.2.2"><span class="toc-item-num">3.2.2&nbsp;&nbsp;</span>M-step</a></span></li></ul></li></ul></li></ul></div>

## Data: Reuters

Reuters is a multi-class, multi-label dataset.

* 90 classes
* 10788 documents
    * 7769 training documents
    * 3019 testing documents

In [7]:
from nltk.corpus import reuters
from nltk.corpus import stopwords

* train-test split
: The data is already splitted. Just sort it out.

In [13]:
stops = stopwords.words("english")
stops += [
    "a", "about", "above", "across", "after", "afterwards", "again", "against",
    "all", "almost", "alone", "along", "already", "also", "although", "always",
    "am", "among", "amongst", "amoungst", "amount", "an", "and", "another",
    "any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are",
    "around", "as", "at", "back", "be", "became", "because", "become",
    "becomes", "becoming", "been", "before", "beforehand", "behind", "being",
    "below", "beside", "besides", "between", "beyond", "bill", "both",
    "bottom", "but", "by", "call", "can", "cannot", "cant", "co", "con",
    "could", "couldnt", "cry", "de", "describe", "detail", "do", "done",
    "down", "due", "during", "each", "eg", "eight", "either", "eleven", "else",
    "elsewhere", "empty", "enough", "etc", "even", "ever", "every", "everyone",
    "everything", "everywhere", "except", "few", "fifteen", "fifty", "fill",
    "find", "fire", "first", "five", "for", "former", "formerly", "forty",
    "found", "four", "from", "front", "full", "further", "get", "give", "go",
    "had", "has", "hasnt", "have", "he", "hence", "her", "here", "hereafter",
    "hereby", "herein", "hereupon", "hers", "herself", "him", "himself", "his",
    "how", "however", "hundred", "i", "ie", "if", "in", "inc", "indeed",
    "interest", "into", "is", "it", "its", "itself", "keep", "last", "latter",
    "latterly", "least", "less", "ltd", "made", "many", "may", "me",
    "meanwhile", "might", "mill", "mine", "more", "moreover", "most", "mostly",
    "move", "much", "must", "my", "myself", "name", "namely", "neither",
    "never", "nevertheless", "next", "nine", "no", "nobody", "none", "noone",
    "nor", "not", "nothing", "now", "nowhere", "of", "off", "often", "on",
    "once", "one", "only", "onto", "or", "other", "others", "otherwise", "our",
    "ours", "ourselves", "out", "over", "own", "part", "per", "perhaps",
    "please", "put", "rather", "re", "same", "see", "seem", "seemed",
    "seeming", "seems", "serious", "several", "she", "should", "show", "side",
    "since", "sincere", "six", "sixty", "so", "some", "somehow", "someone",
    "something", "sometime", "sometimes", "somewhere", "still", "such",
    "system", "take", "ten", "than", "that", "the", "their", "them",
    "themselves", "then", "thence", "there", "thereafter", "thereby",
    "therefore", "therein", "thereupon", "these", "they", "thick", "thin",
    "third", "this", "those", "though", "three", "through", "throughout",
    "thru", "thus", "to", "together", "too", "top", "toward", "towards",
    "twelve", "twenty", "two", "un", "under", "until", "up", "upon", "us",
    "very", "via", "was", "we", "well", "were", "what", "whatever", "when",
    "whence", "whenever", "where", "whereafter", "whereas", "whereby",
    "wherein", "whereupon", "wherever", "whether", "which", "while", "whither",
    "who", "whoever", "whole", "whom", "whose", "why", "will", "with",
    "within", "without", "would", "yet", "you", "your", "yours", "yourself",
    "yourselves"]

In [15]:
reuters.words()[:10]

['ASIAN', 'EXPORTERS', 'FEAR', 'DAMAGE', 'FROM', 'U', '.', 'S', '.-', 'JAPAN']

In [17]:
trainset, testset = [], []
vocab = []
for file_id in reuters.fileids():
    if file_id.startswith("train"):
        trainset.append([w.lower() for w in reuters.words(file_id) if w.lower() not in stops])
        vocab += trainset[-1]
    else:
        testset.append([w.lower() for w in reuters.words(file_id) if w.lower() not in stops])

In [18]:
vocab = list(set(vocab))
word_to_ix = {w: i for i, w in enumerate(vocab)}

In [19]:
def seq_to_ix(seq, vocab=vocab):
    # len(vocab), which is the last index, is for the <unk> (unknown) token
    unk_idx = len(vocab)
    return torch.tensor(list(map(lambda w: word_to_ix.get(w, unk_idx), seq)))

data = {
    "train": list(map(seq_to_ix, trainset)),
    "test": list(map(seq_to_ix, testset))
}

In [20]:
data["train"][0][:5]  # word indices

tensor([20903, 16228, 13181,  3299,  3881])

## Model: Basic LDA

For each document $\mathbf{w}$ in a corpus $D$, generate

$$
N \sim \mathcal{P}(\xi) \\
\theta \sim \text{Dir}(\alpha)
$$

and for $n = 1, \cdots, N$, generate

$$
z_n \sim \text{Multi}(\theta) \\
w_n \sim P(w_n | z_n, \beta)
$$

where $\beta \in \mathbb{R}^{k \times V}$, $\beta_{ij} = P(w^j = 1| z^i = 1)$.

* $\alpha, \beta$: hyperparameters (Dirichlet, Multinomial).
* $N$: The number of words in the document. (ancillary variable)
* $\theta$: A topic mixture.
* (For $i$ in $1\cdots N$)
  * $z_n$: A topic variable.
  * $w_n$: A generated word.

In [38]:
def init_lda(docs, n_topic, random_state=0):
    global V, k, N, M, alpha, beta, gamma, phi
    torch.manual_seed(random_state)
    
    V = len(vocab)
    k = n_topic  # number of topics
    N = torch.tensor([doc.shape[0] for doc in docs])
    M = len(docs)

    print(f"V: {V}\nk: {k}\nN: {N[:10]}...\nM: {M}")

    # initialize α, β
    alpha = torch.rand(k)
    beta = torch.ones((k, V)) / V

    print(f"α: dim {alpha.shape}\nβ: dim {beta.shape}")

    # initialize ϕ, γ
    ## ϕ: (M x max(N) x k) arrays with zero paddings on the right
    gamma = alpha + torch.ones((M, k)) * N.reshape(-1, 1) / k

    phi = torch.ones((M, max(N), k)) / k
    for m, N_d in enumerate(N):
        phi[m, N_d:, :] = 0  # zero padding for vectorized operations

    print(f"γ: dim {gamma.shape}\nϕ: dim ({len(phi)}, N_d, {phi[0].shape[1]})")

In [39]:
from tqdm.auto import tqdm

def E_step(docs, phi, gamma, alpha, beta):
    """
    Minorize the joint likelihood function via variational inference.
    This is the E-step of variational EM algorithm for (smoothed) LDA.
    """
    # optimize phi
    for m in range(M):
        phi[m, :N[m], :] = (beta[:, docs[m]] * (
            torch.exp(torch.digamma(gamma[m, :]) - torch.digamma(gamma[m, :].sum())).reshape(-1, 1))
        ).T

        # Normalize phi
        phi[m, :N[m]] /= phi[m, :N[m]].sum(axis=1).reshape(-1, 1)
        if torch.any(torch.isnan(phi)):
            raise ValueError("phi nan")

    # optimize gamma
    gamma = alpha + phi.sum(dim=1)

    return phi, gamma

In [41]:
import warnings

def _update(var, vi_var, const, max_iter=10000, tol=1e-4):
    """
    From appendix A.2 of Blei et al., 2003.
    For hessian with shape `H = diag(h) + 1z1'`
    
    To update alpha, input var=alpha and vi_var=gamma, const=M.
    To update eta, input var=eta and vi_var=lambda, const=k.
    """
    for _ in range(max_iter):
        # store old value
        var0 = var.detach().clone()
        
        # g: gradient 
        psi_sum = torch.digamma(vi_var.sum(dim=1)).reshape(-1, 1)
        g = const * (torch.digamma(var.sum()) - torch.digamma(var)) \
            + (torch.digamma(vi_var) - psi_sum).sum(dim=0)

        # H = diag(h) + 1z1'
        z = const * torch.polygamma(1, var.sum())  # z: Hessian constant component
        h = -const * torch.polygamma(1, var)       # h: Hessian diagonal component
        c = (g / h).sum() / (1./z + (1./h).sum())

        # update var
        var -= (g - c) / h
        
        # check convergence
        err = torch.mean((var - var0) ** 2)
        crit = err < tol ** 2
        if crit:
            break
    else:
        warnings.warn(f"max_iter={max_iter} reached: values might not be optimal.")
    
    #print(err)
    return var

In [76]:
def _inner_sum(docs, phi, m, j):
    # doc = np.zeros(docs[m].shape[0] * V, dtype=int)
    # doc[np.arange(0, docs[m].shape[0] * V, V) + docs[m]] = 1
    # doc = doc.reshape(-1, V)
    # lam += phi[m, :N[m], :].T @ doc
    return (docs[m] == j).float() @ phi[m, :N[m], :]

def M_step(docs, phi, gamma, alpha, beta, M):
    """
    maximize the lower bound of the likelihood.
    This is the M-step of variational EM algorithm for (smoothed) LDA.
    
    update of alpha follows from appendix A.2 of Blei et al., 2003.
    """
    # update alpha
    alpha = _update(alpha, gamma, M)
    
    # update beta
    for j in range(V):
        beta[:, j] = torch.stack([_inner_sum(docs, phi, m, j) for m in range(M)]).sum(dim=0)
    beta /= beta.sum(dim=1).reshape(-1, 1)

    return alpha, beta

In [87]:
def dg(gamma, d, i):
    """
    E[log θ_t] where θ_t ~ Dir(gamma)
    """
    return torch.digamma(gamma[d, i]) - torch.digamma(torch.sum(gamma[d, :]))


def dl(lam, i, w_n):
    """
    E[log β_t] where β_t ~ Dir(lam)
    """
    return torch.digamma(lam[i, w_n]) - torch.digamma(torch.sum(lam[i, :]))


def vlb(docs, phi, gamma, alpha, beta, M, N, k):
    a, b, c, _d = 0, 0, 0, 0
    for d in range(M):
        a += (
            torch.lgamma(alpha.sum())
            - torch.lgamma(alpha).sum()
            + torch.tensor([(alpha[i] - 1) * dg(gamma, d, i) for i in range(k)]).sum()
        )

        _d += (
            torch.lgamma(gamma[d, :].sum())
            - torch.lgamma(gamma[d, :]).sum()
            + torch.tensor([(gamma[d, i] - 1) * dg(gamma, d, i) for i in range(k)]).sum()
        )

        for n in range(N[d]):
            w_n = int(docs[d][n])

            b += torch.tensor([phi[d][n, i] * dg(gamma, d, i) for i in range(k)]).sum()
            c += torch.tensor([phi[d][n, i] * torch.log(beta[i, w_n]) for i in range(k)]).sum()
            _d += torch.tensor([phi[d][n, i] * torch.log(phi[d][n, i]) for i in range(k)]).sum()

    return a + b + c - _d

### Training

Only on 100 documents

In [97]:
docs = data["train"][:100]

In [98]:
init_lda(docs, n_topic=3)

V: 26118
k: 3
N: tensor([442, 173,  96, 100,  74,  24, 141,  80, 155, 100])...
M: 100
α: dim torch.Size([3])
β: dim torch.Size([3, 26118])
γ: dim torch.Size([100, 3])
ϕ: dim (100, N_d, 3)


In [99]:
%%time

N_EPOCH = 1000
TOL = 5

verbose = True
lb = -float("inf")

with torch.
for epoch in range(N_EPOCH): 
    # store old value
    lb_old = lb
    
    # Variational EM
    phi, gamma = E_step(docs, phi, gamma, alpha, beta)
    alpha, beta = M_step(docs, phi, gamma, alpha, beta, M)
    
    # check anomaly
    if torch.any(torch.isnan(alpha)):
        print("NaN detected: terminating")
        break
    
    # check convergence
    lb = vlb(docs, phi, gamma, alpha, beta, M, N, k)
    
    if verbose:
        print(f"{epoch: 04}:  variational lower bound: {lb: .3f}")
    
    if abs(lb - lb_old) < TOL:
        break
else:
    warnings.warn(f"max_iter reached: values might not be optimal.")

 000:  variational lower bound: -60571.770


KeyboardInterrupt: 

* Training result

## Model: Smoothed LDA

For each document $\mathbf{w}$ in a corpus $D$, generate

$$
N \sim \mathcal{P}(\xi) \\
\beta \sim \text{Dir}(\lambda) \\
\theta \sim \text{Dir}(\alpha)
$$

and for $n = 1, \cdots, N$, generate

$$
z_n \sim \text{Multi}(\theta) \\
w_n \sim P(w_n | z_n, \beta)
$$

where $\beta \in \mathbb{R}^{k \times V}$, $\beta_{ij} = P(w^j = 1| z^i = 1)$.

* $\alpha, \eta$: Dirichlet hyperparameters.
* $\beta$: Unsmoothed multinomial hyperparameter.
* $N$: The number of words in the document. (ancillary variable)
* $\theta$: A topic mixture.
* (For $i$ in $1\cdots N$)
  * $z_n$: A topic variable.
  * $w_n$: A generated word.

### Collaped Gibbs Sampling

### Variational EM

#### E-step

Let $\phi_d \in \mathbb{R}^{N \times k}, \gamma_d \in \mathbb{R}^k, \lambda \in \mathbb{R}^{k \times V}$ be variational parameters for $\alpha, \beta, \eta$.  
Suppose further that for $\beta \in \mathbb{R}^{k \times V}$, $\beta_i^0 \sim \text{Dir}(\lambda^0)$ where $\lambda_i^0 = \eta$ for all $i$.

For a document $\mathbf{w}_d$, $d = 1,\cdots,M$,

1. initialize $\phi_{dni}^0 := 1/k$ for all $i,n$.
2. initialize $\gamma_{di} := \alpha_i + N/k$ for all $i$.
3. **repeat until** convergence
    1. for $n=1$ to $N$
        1. for $i=1$ to $k$
            1. $\phi_{dni}^{t+1} := \exp\left(\Psi(\lambda_{iw_{dn}}^t) - \Psi(\sum_{j=1}^V \lambda_{ij}^t) + \Psi(\gamma_{di}^t) - \Psi(\sum_{j=1}^k \gamma_{dj}^t)\right)$
            1. for $j=1$ to $V$
                1. $\lambda_{ij} = \eta + \sum_{d=1}^M \sum_{n=1}^{N_d} \phi_{dni} w_{dn}^j$
        2. normalize $\phi_{dn}^{t+1}$ to sum to 1
    2. $\gamma_d^{t+1} := \alpha + \sum_{n=1}^N \phi_{dn}^{t+1}$
    
where $\Psi$ is the first derivative of the $\log\Gamma$ function.

#### M-step

$$
\beta_{ij} \propto \sum_{d=1}^M \sum_{n=1}^N \phi_{dni} \mathbf{w}_{dn}^j
$$

$\alpha$ is updated via Newton-Raphson method:

$$
\frac{\partial L}{\partial \alpha_i} 
  = M\left( \Psi\left(\sum_{j=1}^k \alpha_j\right) - \Psi(\alpha_i) \right)
    - \sum_{d=1}^M \left( \Psi(\gamma_{di}) - \Psi\left(\sum_{j=1}^k \gamma_{dj}\right) \right) \\
\frac{\partial^2 L}{\partial \alpha_i \alpha_j} = M \left( \Psi'\left(\sum_{j=1}^k \alpha_j\right) - \delta(i,j) \Psi'(\alpha_i) \right)
$$

where $\delta(i,j) = 1$ if $i=j$, $0$ otherwise.

In [185]:
class SmoothedLDA:
    
    def __init__(self, docs, vocab, k):
        self.docs = docs
        
        self.V = len(vocab)
        self.k = k  # number of topics
        self.N = np.array([doc.shape[0] for doc in docs])
        self.M = len(docs)
        
        V = self.V
        N = self.N
        M = self.M
        
        # initialize model parameters
        ##self.beta = np.ones((k, V)) / V
        self.alpha = np.random.gamma(100, 0.01, k)
        self.eta = np.ones(V)

        # initialize variational parameters
        # ϕ: (M x max(N) x k) arrays with zero paddings on the right
        self.phi = [np.ones((N[d], k)) / k for d in range(M)]
        self.gamma = alpha + (N / k).reshape(-1, 1)
        self.lam = np.random.gamma(shape=100, scale=0.01, size=(k, V))
        
        
    def _update_phi(self):
        """
        Update variational parameter phi
        ϕ_{n, j} ∝ e^[ (Ψ(λ_j) - Ψ(Σλ_j)) + ( Ψ(γ_j) - Ψ(Σγ_j) ) ]
        """
        M = self.M
        N = self.N
        k = self.k

        phi = self.phi
        gamma = self.gamma
        docs = self.docs

        for d in range(M):
            for n in range(N[d]):
                for i in range(k):
                    w_n = int(docs[d][n])
                    phi[d][n, i] = np.exp(dl(lam, i, w_n) + dg(gamma, d, i))

                # Normalize over topics
                phi[d][n, :] = phi[d][n, :] / np.sum(phi[d][n, :])
                
        return phi
    
    def _update_gamma(self):
        """
        Update variational parameter gamma
        γ_t = α_t + Σ_{n=1}^{N_d} ϕ_{t, n}
        """
        M = self.M
        phi = self.phi
        alpha = self.alpha

        gamma = alpha + np.array(
            list(map(lambda x: x.sum(axis=0), phi))
        )
        
        return gamma
    
    
    def _update_lam(self):
        V = self.V
        N = self.N
        M = self.M
        phi = self.phi
        lam = self.lam
        eta = self.eta
        docs = self.docs
        
        lam[:] = eta
        for d in range(M):  #, desc="MINORIZE lam"):
            doc = np.zeros(N[d] * V, dtype=int)
            doc[np.arange(0, N[d] * V, V) + docs[d]] = 1
            doc = doc.reshape(-1, V)
            
            lam += phi[d].T @ doc
        
        return lam
        
    
    def _update_alpha(self, max_iter=1000, tol=0.1):
        """
        Update alpha with linear time Newton-Raphson.
        """
        M = self.M
        k = self.k

        alpha = self.alpha
        gamma = self.gamma

        for _ in range(max_iter):
            alpha_old = alpha

            #  Calculate gradient
            g = M * (psi(np.sum(alpha)) - psi(alpha)) +\
                (psi(gamma) - psi(np.sum(gamma, axis=1)).reshape(-1, 1)).sum(axis=0)

            #  Calculate Hessian diagonal component
            h = -M * polygamma(1, alpha)

            #  Calculate Hessian constant component
            z = M * polygamma(1, np.sum(alpha))

            #  Calculate constant
            c = np.sum(g / h) / (z ** (-1.0) + np.sum(h ** (-1.0)))

            #  Update alpha
            alpha = alpha - (g - c) / h
            
            #  Check convergence
            if np.sqrt(np.mean(np.square(alpha - alpha_old))) < tol:
                break
        else:
            warnings.warn("_update_alpha(): max_iter reached.")

        return alpha
    
    
    def _update_eta(self, max_iter=1000, tol=0.1):
        """
        Update eta with linear time Newton-Raphson.
        """
        M = self.M
        k = self.k

        eta = self.eta
        lam = self.lam

        for _ in range(max_iter):
            eta_old = eta

            #  Calculate gradient
            g = k * (psi(np.sum(eta)) - psi(eta)) +\
                (psi(lam) - psi(np.sum(lam, axis=1)).reshape(-1, 1)).sum(axis=0)

            #  Calculate Hessian diagonal component
            h = -k * polygamma(1, eta)

            #  Calculate Hessian constant component
            z = k * polygamma(1, np.sum(eta))

            #  Calculate constant
            c = np.sum(g / h) / (z ** (-1.0) + np.sum(h ** (-1.0)))

            #  Update alpha
            eta = eta - (g - c) / h

            #  Check convergence
            if np.sqrt(np.mean(np.square(eta - eta_old))) < tol:
                break
        else:
            warnings.warn("_update_eta(): max_iter reached.")

        return eta
    
    
    def _E_step(self):
        """
        E-step of the variational EM algorithm.
        Update ϕ, γ, λ.
        """
        self.phi = self._update_phi()
        self.gamma = self._update_gamma()
        self.lam = self._update_lam()
        
        
    def _M_step(self):
        """
        M-step of the variational EM algorithm.
        Update α, η.
        """
        self.alpha = self._update_alpha()
        self.eta = self._update_eta()
    
    
    def vlb(self):
        """
        lower bound from variational inference
        """
        phi = self.phi
        gamma = self.gamma
        lam = self.lam
        alpha = self.alpha
        eta = self.eta
        docs = self.docs
        
        M = self.M
        k = self.k
        N = self.N
        
        a0, a1, a2, a3_1, a3_2, a4, a5 = 0., 0., 0., 0., 0., 0., 0.
        for d in range(M):
            a0 += (
                k * (
                    gammaln(np.sum(eta)) 
                    - np.sum(gammaln(eta))
                )
                + np.sum([(eta[j] - 1) * dl(lam, i, j) for j in range(V) for i in range(k)])
            )
            a1 += (
                gammaln(np.sum(alpha))
                - np.sum(gammaln(alpha))
                + np.sum([(alpha[i] - 1) * dg(gamma, d, i) for i in range(k)])
            )

            a4 += (
                gammaln(np.sum(gamma[d, :]))
                - np.sum(gammaln(gamma[d, :]))
                + np.sum([(gamma[d, i] - 1) * dg(gamma, d, i) for i in range(k)])
            )
            
            for i in range(k):
                for j in range(V):
                    a3_2 += (
                        gammaln(np.sum(lam[i, j]))
                        - np.sum(gammaln(lam[i, :]))
                        + np.sum((lam[i, j] - 1) * dl(lam, i, j))
                    )

            for n in range(N[d]):
                w_n = int(docs[d][n])
                a2 += np.sum([phi[d][n, i] * dg(gamma, d, i) for i in range(k)])
                a3_1 += np.sum([phi[d][n, i] * dl(lam, i, w_n) for i in range(k)])
                a5 += np.sum([phi[d][n, i] * np.log(phi[d][n, i]) for i in range(k)])

        return a0 + a1 + a2 + a3_1 - a3_2 - a4 - a5
    
    
    def train(self, max_iter=1000, tol=5, verbose=True):
        vlb = -np.inf
        
        for it in range(max_iter):
            old_vlb = vlb
            self._E_step()
            self._M_step()
            
            vlb = self.vlb()
            err = vlb - old_vlb
            
            if verbose:
                print(f"Iteration {it+1}: {vlb: .3f} (delta: {err: .2f})") 
            
            if err < tol:
                break
        else:
            warnings.warn("max_iter reached.")
            

    
def dg(gamma, d, i):
    """
    E[log θ_t] where θ_t ~ Dir(gamma)
    """
    return psi(gamma[d, i]) - psi(np.sum(gamma[d, :]))


def dl(lam, i, w_n):
    """
    E[log β_t] where β_t ~ Dir(lam)
    """
    return psi(lam[i, w_n]) - psi(np.sum(lam[i, :]))

```python
lda = SmoothedLDA(docs, vocab, k)
lda.train()
```