# 1 Prepare Dataset

### 1.1 Load samples

In [None]:
from sklearn.datasets import fetch_20newsgroups

samples, _ = fetch_20newsgroups(
    remove=('headers', 'footers', 'quotes'),
    shuffle=True,
    random_state=1,
    return_X_y=True
)

### 1.2 Tokenize samples

In [None]:
import re

pattern = re.compile(r'\b[a-z]+\b')
samples = [pattern.findall(s.lower()) for s in samples]

### 1.3 Filter stopwords

In [None]:
with open("stopwords.txt", "r") as f:
    stopword_list = set(f.read().splitlines())
    
samples = [[w for w in s if w not in stopword_list] for s in samples]

### 1.4 Use only a subset

In [None]:
N_SAMPLES = 2000
subset = samples[:N_SAMPLES]

# 2 Write a Python/NumPy Implementation of LDA

### 2.1 Define class for LDA with Collapsed Gibbs Sampling 

In [None]:
from numpy import argsort, ones, zeros, zeros_like, cumsum, random, searchsorted, log
import itertools

class PythonLDA:
    def __init__(self, corpus, T, S, beta, alpha): 
        self._init_corpus(corpus)
        
        self.D = D = len(self.corpus)
        self.W = W = len(self.idx_to_word)
        self.T = T
        self.S = S
        
        self.beta_arr = beta * ones(W)
        self.beta_sum = beta * W

        self.alpha_arr = alpha * ones(T)
        self.alpha_sum = alpha * T

        self.nwt = zeros((W, T), dtype=float)
        self.nt = zeros(T, dtype=float)
        self.ntd = zeros((T, D), dtype=float)

        self.z = [zeros(len(doc), dtype=int) for doc in corpus]
        
    def _init_corpus(self, corpus):
        word_map = {}
        for doc in corpus:
            for word in doc:
                if word not in word_map:
                    word_map[word] = len(word_map)
                word = word_map[word]
                
        self.corpus = [[word_map[w] for w in d] for d in corpus]
        self.idx_to_word = {v: k for k, v in word_map.items()}

    def _log_prob(self):
        nwt = zeros_like(self.nwt)
        nt = zeros_like(self.nt)
        ntd = zeros_like(self.ntd)

        lp = 0.0
        for d, (doc, zd) in enumerate(zip(self.corpus, self.z)):
            for n, (w, t) in enumerate(zip(doc, zd)):
                first_term = (nwt[w, t] + self.beta_arr[w]) / (nt[t] + self.beta_sum)
                second_term = (ntd[t, d] + self.alpha_arr[t]) / (n + self.alpha_sum)
                lp += log(first_term * second_term)

                nwt[w, t] += 1
                nt[t] += 1
                ntd[t, d] += 1
                
        return lp

    def _sample_topics(self, init=False):
        for d, (doc, zd) in enumerate(zip(self.corpus, self.z)):
            for n, (w, t) in enumerate(zip(doc, zd)):
                if not init:
                    self.nwt[w, t] -= 1
                    self.nt[t] -= 1
                    self.ntd[t, d] -= 1

                first_term = (self.nwt[w, :] + self.beta_arr[w]) / (self.nt + self.beta_sum)
                second_term = (self.ntd[:, d] + self.alpha_arr)
                dist = first_term * second_term

                dist_sum = cumsum(dist)
                r = random.random() * dist_sum[-1]
                t = searchsorted(dist_sum, r)

                self.nwt[w, t] += 1
                self.nt[t] += 1
                self.ntd[t, d] += 1

                zd[n] = t

    def fit(self):
        self._sample_topics(init=True)
        lp = self._log_prob()
        print('Iteration %s: %s' % (0, lp))

        for s in range(1, self.S+1):
            self._sample_topics()
            if not(s % (self.S//10)):
                lp = self._log_prob()
                print('Iteration %s: %s' % (s, lp))
                
        print()
                
                
    def print_topics(self, num=20):
        for t in range(self.T):
            highest_prob_words = argsort(self.nwt[:, t] + self.beta_arr)
            sorted_types = [self.idx_to_word[i] for i in highest_prob_words]
            print('Topic %s: %s' % (t+1, ' '.join(sorted_types[-num:][::-1]))) 

### 2.3 Test LDA

In [None]:
import cProfile
import pstats
from pstats import SortKey

py_lda = PythonLDA(corpus=subset, T=20, S=100, beta=0.01, alpha=0.1)
cProfile.runctx('py_lda.fit()', globals(), locals(), filename="py_stats.txt")
py_lda.print_topics()
pstats.Stats('py_stats.txt').strip_dirs().sort_stats(SortKey.TIME).print_stats(10)

# 3 Write a Cython Implementation of LDA

### 3.1 Prepare notebook for Cython

In [None]:
!pip install wurlitzer

In [2]:
%load_ext cython
%load_ext wurlitzer

### 3.2 Build Cython extension (Separate file because of iPython quirks)

In [1]:
!python3 cy_setup.py build_ext --inplace --force

running build_ext
building 'cy_lda' extension
x86_64-linux-gnu-gcc -pthread -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -I/home/joshua/repos/cython-lda/venv/include -I/usr/include/python3.8 -c cy_lda.cpp -o build/temp.linux-x86_64-3.8/cy_lda.o
x86_64-linux-gnu-g++ -pthread -shared -Wl,-O1 -Wl,-Bsymbolic-functions -Wl,-Bsymbolic-functions -Wl,-z,relro -g -fwrapv -O2 -Wl,-Bsymbolic-functions -Wl,-z,relro -g -fwrapv -O2 -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 build/temp.linux-x86_64-3.8/cy_lda.o -o /home/joshua/repos/cython-lda/cy_lda.cpython-38-x86_64-linux-gnu.so


### 3.2 Repeat steps 1.1-1.4, load Cython extension, and test LDA

In [6]:
%%cython

# 1.1 Load samples
from sklearn.datasets import fetch_20newsgroups

samples, _ = fetch_20newsgroups(
    remove=('headers', 'footers', 'quotes'),
    shuffle=True,
    random_state=1,
    return_X_y=True
)

# 1.2 Tokenize samples
import re

pattern = re.compile(r'\b[a-z]+\b')
samples = [pattern.findall(s.lower()) for s in samples]

# 1.3 Filter stopwords
with open("stopwords.txt", "r") as f:
    stopword_list = set(f.read().splitlines())
samples = [[w for w in s if w not in stopword_list] for s in samples]

# 1.4 Use only a subset
N_SAMPLES = 2000
subset = samples[:N_SAMPLES]

# Test LDA
from cy_lda import CythonLDA
import cProfile
import pstats
from pstats import SortKey

cy_lda = CythonLDA(corpus=subset, T=20, S=100, beta=0.01, alpha=0.1)
cProfile.runctx('cy_lda.fit()', globals(), locals(), filename="cy_stats.txt")
pstats.Stats('cy_stats.txt').strip_dirs().sort_stats(SortKey.TIME).print_stats(10)