In [2]:
import numpy as np
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
from collections import defaultdict, Counter
import random, time

K = 20
sample_size = 2500
max_features = 2000
min_df = 20
max_df = 0.5
iterations = 80
alpha = 1.0
beta = 1.0
np.random.seed(42)
random.seed(42)

data = fetch_20newsgroups(subset='train', remove=('headers','footers','quotes'))
docs_all = data.data
labels_all = data.target
target_names = data.target_names
N_docs_total = len(docs_all)
print("Total documents available:", N_docs_total)

if sample_size is None or sample_size >= N_docs_total:
    docs = docs_all
    true_labels = labels_all
else:
    idx = np.random.choice(N_docs_total, sample_size, replace=False)
    docs = [docs_all[i] for i in idx]
    true_labels = [labels_all[i] for i in idx]

D = len(docs)

vectorizer = CountVectorizer(stop_words='english',
                             max_features=max_features,
                             min_df=min_df,
                             max_df=max_df)
X = vectorizer.fit_transform(docs)
vocab = np.array(vectorizer.get_feature_names_out())
V = len(vocab)
print("Vocabulary size after pruning:", V)

Xcsr = X.tocsr()
total_tokens = int(X.sum())
print("Total tokens (after pruning):", total_tokens)

word_ids = np.empty(total_tokens, dtype=np.int32)
doc_ids = np.empty(total_tokens, dtype=np.int32)
pos = 0
for d in range(D):
    row = Xcsr[d]
    for w, cnt in zip(row.indices, row.data):
        c = int(cnt)
        word_ids[pos:pos+c] = w
        doc_ids[pos:pos+c] = d
        pos += c
N = total_tokens

print("Initializing topics (K={})...".format(K))
z = np.random.randint(0, K, size=N)
ndk = np.zeros((D, K), dtype=np.int32)
nkw = np.zeros((K, V), dtype=np.int32)
nk = np.zeros(K, dtype=np.int32)

for i in range(N):
    t = z[i]; w = word_ids[i]; d = doc_ids[i]
    ndk[d, t] += 1
    nkw[t, w] += 1
    nk[t] += 1

beta_sum = V * beta

print("Start Gibbs sampling: N tokens =", N)
start = time.time()
for it in range(1, iterations+1):
    t_iter0 = time.time()
    for i in range(N):
        w = word_ids[i]; d = doc_ids[i]; t = z[i]
        ndk[d, t] -= 1
        nkw[t, w] -= 1
        nk[t] -= 1

        left = ndk[d] + alpha
        right = (nkw[:, w] + beta) / (nk + beta_sum)
        p = left * right
        s = p.sum()
        if s <= 0:
            p = np.ones(K) / K
        else:
            p = p / s

        r = np.random.rand()
        cum = np.cumsum(p)
        new_t = int(np.searchsorted(cum, r))
        if new_t >= K:
            new_t = K-1

        z[i] = new_t
        ndk[d, new_t] += 1
        nkw[new_t, w] += 1
        nk[new_t] += 1

    if it % 10 == 0 or it == 1:
        print(f"Iter {it}/{iterations} — iter time {time.time()-t_iter0:.1f}s — total elapsed {time.time()-start:.1f}s")
print("Gibbs finished; total time {:.1f}s".format(time.time()-start))

phi = (nkw + beta) / (nk[:, None] + beta_sum)
theta = (ndk + alpha).astype(float)
theta = theta / theta.sum(axis=1, keepdims=True)

def top_words(topic, n=10):
    ids = np.argsort(phi[topic])[::-1][:n]
    return vocab[ids], phi[topic, ids]

print("\nTOP-10 words per topic:")
topic_words = []
for k in range(K):
    words, probs = top_words(k, 10)
    topic_words.append(words)
    print(f"Topic {k+1}: {', '.join(words)}")

dominant_topic = np.argmax(theta, axis=1)
topic_to_docs = defaultdict(list)
for d, tk in enumerate(dominant_topic):
    topic_to_docs[tk].append(d)

print("\nTopic -> most frequent true label among docs where topic dominates:")
for k in range(K):
    docs_k = topic_to_docs.get(k, [])
    if not docs_k:
        print(f"Topic {k+1}: (no dominant docs)")
        continue
    labs = [true_labels[d] for d in docs_k]
    most_common_label, cnt = Counter(labs).most_common(1)[0]
    frac = cnt / len(labs)
    print(f"Topic {k+1}: {target_names[most_common_label]} (share {frac:.2f}) | top words: {', '.join(topic_words[k])}")

Total documents available: 11314
Vocabulary size after pruning: 1508
Total tokens (after pruning): 120332
Initializing topics (K=20)...
Start Gibbs sampling: N tokens = 120332
Iter 1/80 — iter time 6.3s — total elapsed 6.3s
Iter 10/80 — iter time 3.5s — total elapsed 45.0s
Iter 20/80 — iter time 2.9s — total elapsed 76.6s
Iter 30/80 — iter time 3.2s — total elapsed 108.1s
Iter 40/80 — iter time 2.8s — total elapsed 140.1s
Iter 50/80 — iter time 3.0s — total elapsed 171.2s
Iter 60/80 — iter time 3.0s — total elapsed 203.4s
Iter 70/80 — iter time 3.3s — total elapsed 235.5s
Iter 80/80 — iter time 2.9s — total elapsed 267.0s
Gibbs finished; total time 267.0s

TOP-10 words per topic:
Topic 1: edu, available, com, file, files, motif, window, version, server, ftp
Topic 2: people, israel, jews, israeli, did, men, killed, country, right, like
Topic 3: people, gun, law, guns, person, like, state, group, self, rights
Topic 4: new, war, time, south, military, secret, power, world, years, plan
Top