In [None]:

import os
import pandas as pd
import numpy as np
import torch
from torch.distributions.categorical import Categorical
from tqdm import tqdm
from collections import Counter
import nltk
nltk.download('punkt_tab')
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

# Make sure NLTK components exist
try:
    nltk.data.find("tokenizers/punkt")
except:
    nltk.download("punkt")
try:
    nltk.data.find("corpora/stopwords")
except:
    nltk.download("stopwords")

STOPWORDS = set(stopwords.words("english"))

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# -------------------------------------------------
# Preprocessing
# -------------------------------------------------
def preprocess_docs(docs, min_df=10):
    tokenized = []
    for d in docs:
        if not isinstance(d, str):
            tokenized.append([])
            continue
        d = d.lower()
        tokens = word_tokenize(d)
        tokens = [t for t in tokens if t.isalpha() and t not in STOPWORDS]
        tokenized.append(tokens)

    freq = Counter([w for doc in tokenized for w in doc])
    vocab = {w for w, c in freq.items() if c >= min_df}
    tokenized = [[w for w in doc if w in vocab] for doc in tokenized]
    return tokenized, vocab


# -------------------------------------------------
# Biterm extraction
# -------------------------------------------------
def extract_biterms(tokenized_docs, window=None):
    biterms = []
    doc_bit_indices = []

    for doc in tokenized_docs:
        L = len(doc)
        indices = []
        for i in range(L):
            jmax = L if window is None else min(L, i + window)
            for j in range(i + 1, jmax):
                if doc[i] == doc[j]:
                    continue
                wi, wj = sorted([doc[i], doc[j]])
                indices.append(len(biterms))
                biterms.append((wi, wj))
        doc_bit_indices.append(indices)

    return biterms, doc_bit_indices


# -------------------------------------------------
# Vocabulary mapping
# -------------------------------------------------
def build_vocab_maps(vocab):
    words = sorted(list(vocab))
    w2id = {w: i for i, w in enumerate(words)}
    id2w = {i: w for w, i in w2id.items()}
    return w2id, id2w


def encode_biterms(biterms, w2id):
    return [(w2id[a], w2id[b]) for (a, b) in biterms]


# -------------------------------------------------
# GPU-Accelerated BTM with PyTorch
# -------------------------------------------------
class BTM_GPU:
    def __init__(self, K=20, alpha=None, beta=0.01, iterations=500):
        self.K = K
        self.alpha = alpha if alpha is not None else 50.0 / K
        self.beta = beta
        self.iterations = iterations

    def fit(self, biterms, vocab_size, doc_biterms):
        B = len(biterms)
        M = vocab_size
        K = self.K
        alpha = self.alpha
        beta = self.beta

        # Convert biterms to tensors on GPU
        wi_tensor = torch.tensor([b[0] for b in biterms], device=device, dtype=torch.long)
        wj_tensor = torch.tensor([b[1] for b in biterms], device=device, dtype=torch.long)

        # Initialize counts (move to GPU)
        n_z = torch.zeros(K, device=device)
        n_wz = torch.zeros((M, K), device=device)
        z_b = torch.randint(0, K, (B,), device=device)

        # Initialize count matrices
        for i in range(B):
            z = z_b[i]
            n_z[z] += 1
            n_wz[wi_tensor[i], z] += 1
            n_wz[wj_tensor[i], z] += 1

        # Gibbs Sampling
        for it in tqdm(range(self.iterations), desc="Training BTM (GPU)"):
            denom = n_wz.sum(0) + M * beta

            for i in range(B):
                zi = z_b[i]
                wi = wi_tensor[i]
                wj = wj_tensor[i]

                # Remove current assignment
                n_z[zi] -= 1
                n_wz[wi, zi] -= 1
                n_wz[wj, zi] -= 1

                # Compute probabilities on GPU
                pz = (n_z + alpha) \
                    * (n_wz[wi] + beta) / denom \
                    * (n_wz[wj] + beta) / denom

                # Normalize
                pz = pz / pz.sum()

                # Sample new topic using PyTorch
                new_z = Categorical(pz).sample()

                # Add counts back
                z_b[i] = new_z
                n_z[new_z] += 1
                n_wz[wi, new_z] += 1
                n_wz[wj, new_z] += 1

        # Compute phi & theta
        self.phi = (n_wz + beta) / (n_wz.sum(0)[None, :] + M * beta)
        self.theta = (n_z + alpha) / (B + K * alpha)

        self.biterms = biterms
        self.doc_biterms = doc_biterms
        self.wi_tensor = wi_tensor
        self.wj_tensor = wj_tensor
        self.z_b = z_b
        self.vocab_size = M

    def infer_doc_topics(self, biterm_indices):
        if len(biterm_indices) == 0:
            return torch.ones(self.K, device=device) / self.K

        z_dist = torch.zeros(self.K, device=device)
        for b in biterm_indices:
            wi = self.wi_tensor[b]
            wj = self.wj_tensor[b]

            pz = self.theta * (self.phi[wi] * self.phi[wj])
            pz = pz / pz.sum()

            z_dist += pz

        return (z_dist / len(biterm_indices)).detach().cpu().numpy()

    def top_words(self, id2w, top_n=10):
        topics = []
        phi_cpu = self.phi.cpu().numpy()
        for k in range(self.K):
            top_idx = phi_cpu[:, k].argsort()[::-1][:top_n]
            topics.append([id2w[i] for i in top_idx])
        return topics


# -------------------------------------------------
# MAIN PROGRAM (NO ARGUMENTS)
# -------------------------------------------------
if __name__ == "__main__":

    # Your local dataset path (edit as needed)
    
    
    csv_path = os.path.join(os.environ['USERPROFILE'], 'Desktop', 'ml_lab', 'AML_lab_papers', 'AML_lab_papers', 'BTM', 'archive(1)', 'tweets.csv')

    print("Loading CSV:", csv_path)
    df = pd.read_csv(csv_path)

    text_col = "content"   # <--- change to your actual column
    documents = df[text_col].astype(str).tolist()
    print(f"num of docs: {len(documents)}")

    # Preprocess
    tokenized, vocab = preprocess_docs(documents, min_df=5)
    print(f"vocab: {len(vocab)}")

    # Biterms
    biterms_raw, doc_biterms = extract_biterms(tokenized, window=None)

    # Vocab
    w2id, id2w = build_vocab_maps(vocab)
    encoded_biterms = encode_biterms(biterms_raw, w2id)
    print(f"num of biterms: {len(encoded_biterms)}")

    # Train BTM on GPU
    model = BTM_GPU(K=50, iterations=500)
    model.fit(encoded_biterms, vocab_size=len(w2id), doc_biterms=doc_biterms)

    # Print topics
    topics = model.top_words(id2w, top_n=10)
    for i, t in enumerate(topics):
        print(f"Topic {i}: {t}")

    # Inference example
    print("\nDocument 0 topic distribution:")
    print(model.infer_doc_topics(doc_biterms[0]))

[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\madhu\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


Using device: cpu
Loading CSV: C:\Users\madhu\Desktop\ml_lab\AML_lab_papers\AML_lab_papers\BTM\archive(1)\tweets.csv
num of docs: 52542
vocab: 8298
num of biterms: 1333785


Training BTM (GPU):   1%|          | 4/500 [14:57<31:39:15, 229.75s/it]