In [None]:
import numpy as np
import random
from collections import Counter, defaultdict
from tqdm import tqdm
import math

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import normalized_mutual_info_score
import nltk
try:
    nltk.data.find('tokenizers/punkt')
except:
    nltk.download('punkt')
try:
    nltk.data.find('corpora/stopwords')
except:
    nltk.download('stopwords')
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

STOPWORDS = set(stopwords.words('english'))



def preprocess_documents(docs, min_df = 5, lowercase = True, remove_stopwords = True):
    tokenized = []
    for d in docs:
        if lowercase:
            d = d.lower()
        tokens = word_tokenize(d)
        if remove_stopwords:
            tokens = [t for t in tokens if t.isalpha() and t not in STOPWORDS]
        else:
            tokens = [t for t in tokens if t.isalpha()]
        tokenized.append(tokens)


    freq = Counter([w for doc in tokenized for w in doc])
    vocab = {w for w, c in freq.items() if c >= min_df}
    tokenized = [[w for w in doc if w in vocab] for doc in tokenized]
    return tokenized, vocab


def extract_biterms(tokenized_docs, window = None):
    biterms = []
    doc_biterm_idx = []
    for doc in tokenized_docs:
        inds = []
        L = len(doc)
        for i in range(L):
            jmax = L if window is None else min(L, i + window)
            for j in range(i + 1, jmax):
                if doc[i] == doc[j]:
                    continue
                wi, wj = doc[i], doc[j]
                if wi < wj:
                    pair = (wi, wj)
                else:
                    pair = (wj, wi)
                inds.append(len(biterms))
                biterms.append(pair)
        doc_biterm_idx.append(inds)
    return biterms, doc_biterm_idx



def build_id_mappings(vocab):
    words = sorted(list(vocab))
    w2id = {w: i for i, w in enumerate(words)}
    id2w = {i: w for w, i in w2id.items()}
    return w2id, id2w


def encode_biterms(biterms, w2id):
    return [(w2id[a], w2id[b]) for (a, b) in biterms]


class BTM:
    def __init__(self, K = 50, alpha = None, beta = 0.01, n_iter = 500, random_state = 0):
        self.K = K
        self.alpha = alpha
        self.beta = beta
        self.n_iter = n_iter
        np.random.seed(random_state)
    
    def fit(self, biterms, vocab_size, doc2biterms = None, burnin = 100, collect_last = 50, verbose = True):
        B = len(biterms)
        K = self.K
        M = vocab_size
        alpha = self.alpha if self.alpha is not None else 50.0 / K
        beta = self.beta

        n_z = np.zeros(K, dtype = np.int64)
        n_wz = np.zeros((M, K), dtype = np.int64)
        z_b = np.zeros(B, dtype = np.int64)

        for i in range(B):
            z = np.random.randint(0, K)
            z_b[i] = z
            n_z[z] += 1
            wi, wj = biterms[i]
            n_wz[wi, z] += 1
            n_wz[wj, z] += 1
        
        collect = 0
        phi_sum = np.zeros((M, K), dtype = np.float64)
        theta_sum = np.zeros(K, dtype = np.float64)
        it_range = range(self.n_iter)
        if verbose:
            it_range = tqdm(it_range, desc = "BTM Gibbs Sampling")
        for it in it_range:
            for b_idx in range(B):
                wi, wj = biterms[b_idx]
                cur_z = z_b[b_idx]

                n_z[cur_z] -= 1
                n_wz[wi, cur_z] -= 1
                n_wz[wj, cur_z] -= 1

                denom = (n_wz.sum(axis = 0) + M * beta)
                pz = (n_z + alpha) * ((n_wz[wi, :] + beta) / denom) * ((n_wz[wj, :] + beta) / denom)

                pz_sum = pz.sum()
                if pz_sum <= 0:
                    pz = np.ones(K) / K
                else:
                    pz = pz / pz_sum
                

                new_z = np.searchsorted(np.cumsum(pz), np.random.rand())
                if new_z >= K:
                    new_z = K - 1
                
                z_b[b_idx] = new_z, np.random.rand(, np.random.rand())

                n_z[new_z] += 1
                n_wz[wi, new_z] += 1
                n_wz[wj, new_z] += 1

            if it >= burnin:
                phi = (n_wz + beta) / (n_wz.sum(axis = 0)[None, :] + M * beta)
                theta = (n_z + alpha) / (B + K * alpha)

                phi_sum += phi
                theta_sum += theta
                collect += 1

        if collect == 0:
            self.phi = (n_)

