In [1]:
%config Completer.use_jedi = False

In [2]:
import os
import argparse
import sys
from gensim.models import Word2Vec
import re

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import seed_everything
from torch.utils.data import DataLoader, TensorDataset, random_split, Dataset
from tqdm.auto import tqdm
from sklearn.feature_extraction.text import TfidfVectorizer
from collections import defaultdict

sys.path.append("../")

from utils.data_processing import get_process_data
from utils.data_loader import load_document
from utils.loss import ListNet
from utils.eval import retrieval_normalized_dcg_all, retrieval_precision_all

# fix random seed
def same_seeds(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def get_freer_gpu():
    os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free > tmp')
    memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
    return int(np.argmax(memory_available))

In [3]:
device = f"cuda:{get_freer_gpu()}"
config = {
    "topk":[10,30,50]
}

In [4]:
seed = 123
seed_everything(123)

Global seed set to 123


123

In [5]:
from octis.dataset.dataset import Dataset
raw_dataset = Dataset()
raw_dataset.fetch_dataset("20NewsGroup")

In [6]:
corpus = raw_dataset.get_corpus()
raw_documents = [" ".join(i) for i in corpus]

In [7]:
docvec = np.load("../data/docvec_20NG_SBERT_768d.npy")
dim = 768

In [8]:
# get TF-IDF score
vectorizer = TfidfVectorizer(min_df=10,stop_words="english")
importance_score = vectorizer.fit_transform(raw_documents).todense()

In [9]:
vocab_size = len(vectorizer.vocabulary_)
print(f"Vocab size:{vocab_size}")

Vocab size:1588


In [10]:
class IDEDataset(Dataset):
    def __init__(self, 
                 doc_vectors,
                 weight_ans):
        
        assert len(doc_vectors) == len(weight_ans)

        self.doc_vectors = torch.FloatTensor(doc_vectors)
        self.weight_ans = torch.FloatTensor(weight_ans)        
        self.weight_ans_s = torch.argsort(self.weight_ans, dim=1, descending=True)
        self.topk = torch.sum(self.weight_ans > 0, dim=1)
        
    def __getitem__(self, idx):
        return self.doc_vectors[idx], self.weight_ans[idx], self.weight_ans_s[idx], self.topk[idx]

    def __len__(self):
        return len(self.doc_vectors)

In [11]:
dataset = IDEDataset(docvec, importance_score)
train_length = int(len(dataset)*0.6)
valid_length = int(len(dataset)*0.2)
test_length = len(dataset) - train_length - valid_length

full_loader = DataLoader(dataset, batch_size=128)
train_dataset, valid_dataset, test_dataset = random_split(
    dataset, lengths=[train_length, valid_length,test_length],
    generator=torch.Generator().manual_seed(42)
    )

In [12]:
train_loader = DataLoader(
    train_dataset, batch_size=128, 
    shuffle=True, pin_memory=True,
)
valid_loader = DataLoader(
    valid_dataset, batch_size=128, shuffle=False, pin_memory=True,drop_last=True)
test_loader = DataLoader(
    test_dataset, batch_size=128, shuffle=False)

In [13]:
class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.word_embedding = nn.Parameter(torch.randn(hidden_dim, output_dim))
        self.transform = nn.Sequential(
            nn.Linear(input_dim,hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Sigmoid(),
            nn.Linear(hidden_dim,hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Sigmoid(),
        )
    def forward(self, x):
        docvec = self.transform(x)
        decoded = torch.sigmoid(torch.matmul(docvec,self.word_embedding))
        return decoded
    
    def load_pretrianed(self,word_embedding):
        self.word_embedding = nn.Parameter(torch.FloatTensor(word_embedding))

In [14]:
def evaluate_Decoder(model, data_loader):
    results = defaultdict(list)
    model.eval()
        
    # predict all data
    for data in data_loader:
        doc_embs, target, _, _ = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
                
        pred = model(doc_embs)
    
        # Precision
        precision_scores = retrieval_precision_all(pred, target, k=config["topk"])
        for k, v in precision_scores.items():
            results['precision@{}'.format(k)].append(v)
        
        # NDCG
        ndcg_scores = retrieval_normalized_dcg_all(pred, target, k=config["topk"])
        for k, v in ndcg_scores.items():
            results['ndcg@{}'.format(k)].append(v)
        
    for k in results:
        results[k] = np.mean(results[k])

    return results

In [15]:
decoder = Decoder(input_dim=dim,hidden_dim=1024,output_dim=vocab_size)
optimizer = torch.optim.Adam(decoder.parameters(), lr = 1e-4)
# initialize parameters
for p in decoder.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
# decoder.load_pretrianed(pretrain_wordembedding)
decoder = decoder.to(device)

In [16]:
# early stop settings 
stop_rounds = 3
no_improvement = 0
best_score = None 


for epoch in range(100):
    # Training
    decoder.train()
    train_loss = []
    for batch in tqdm(train_loader, desc="Training"):
        batch = [i.to(device) for i in batch]
        doc_embs, target, _, _ = batch
        target = torch.nn.functional.normalize(target.to(device), dim=1)
        decoded = torch.nn.functional.normalize(decoder(doc_embs), dim=1)
        loss = ListNet(decoded, target)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss.append(loss.item())

    print(f"[Epoch {epoch+1:02d}]")
    res = evaluate_Decoder(decoder, valid_loader)
    for key,val in res.items():
        print(f"{key}:{val:.4f}")
        
     # early stopping 
    current_score = res["precision@10"]
    if best_score == None:
        best_score = current_score
        continue
    if current_score < best_score:
        no_improvement += 1
    if no_improvement >= stop_rounds:
        print("Early stopping...")
        break 
    if current_score > best_score:
        no_improvement = 0
        best_score = current_score

Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 01]
precision@10:0.1843
precision@30:0.1287
precision@50:0.1082
ndcg@10:0.0938
ndcg@30:0.1157
ndcg@50:0.1360
ndcg@all:0.3672


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 02]
precision@10:0.1971
precision@30:0.1657
precision@50:0.1385
ndcg@10:0.0925
ndcg@30:0.1511
ndcg@50:0.1805
ndcg@all:0.3903


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 03]
precision@10:0.3149
precision@30:0.2076
precision@50:0.1642
ndcg@10:0.2159
ndcg@30:0.2541
ndcg@50:0.2800
ndcg@all:0.4662


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 04]
precision@10:0.3778
precision@30:0.2348
precision@50:0.1816
ndcg@10:0.3463
ndcg@30:0.3668
ndcg@50:0.3907
ndcg@all:0.5602


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 05]
precision@10:0.4224
precision@30:0.2534
precision@50:0.1938
ndcg@10:0.4019
ndcg@30:0.4169
ndcg@50:0.4397
ndcg@all:0.5987


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 06]
precision@10:0.4496
precision@30:0.2677
precision@50:0.2030
ndcg@10:0.4422
ndcg@30:0.4546
ndcg@50:0.4765
ndcg@all:0.6277


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 07]
precision@10:0.4733
precision@30:0.2785
precision@50:0.2097
ndcg@10:0.4728
ndcg@30:0.4823
ndcg@50:0.5036
ndcg@all:0.6487


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 08]
precision@10:0.4926
precision@30:0.2869
precision@50:0.2152
ndcg@10:0.4925
ndcg@30:0.5000
ndcg@50:0.5210
ndcg@all:0.6622


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 09]
precision@10:0.5105
precision@30:0.2950
precision@50:0.2202
ndcg@10:0.5191
ndcg@30:0.5240
ndcg@50:0.5446
ndcg@all:0.6804


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 10]
precision@10:0.5200
precision@30:0.2997
precision@50:0.2230
ndcg@10:0.5268
ndcg@30:0.5319
ndcg@50:0.5519
ndcg@all:0.6860


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 11]
precision@10:0.5319
precision@30:0.3057
precision@50:0.2271
ndcg@10:0.5431
ndcg@30:0.5465
ndcg@50:0.5665
ndcg@all:0.6972


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 12]
precision@10:0.5418
precision@30:0.3103
precision@50:0.2301
ndcg@10:0.5540
ndcg@30:0.5567
ndcg@50:0.5768
ndcg@all:0.7046


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 13]
precision@10:0.5483
precision@30:0.3142
precision@50:0.2318
ndcg@10:0.5641
ndcg@30:0.5663
ndcg@50:0.5855
ndcg@all:0.7119


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 14]
precision@10:0.5551
precision@30:0.3174
precision@50:0.2346
ndcg@10:0.5727
ndcg@30:0.5744
ndcg@50:0.5938
ndcg@all:0.7179


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 15]
precision@10:0.5605
precision@30:0.3191
precision@50:0.2358
ndcg@10:0.5798
ndcg@30:0.5808
ndcg@50:0.6000
ndcg@all:0.7229


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 16]
precision@10:0.5668
precision@30:0.3221
precision@50:0.2375
ndcg@10:0.5873
ndcg@30:0.5873
ndcg@50:0.6063
ndcg@all:0.7277


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 17]
precision@10:0.5716
precision@30:0.3245
precision@50:0.2388
ndcg@10:0.5902
ndcg@30:0.5905
ndcg@50:0.6093
ndcg@all:0.7298


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 18]
precision@10:0.5759
precision@30:0.3274
precision@50:0.2409
ndcg@10:0.5946
ndcg@30:0.5954
ndcg@50:0.6140
ndcg@all:0.7332


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 19]
precision@10:0.5810
precision@30:0.3293
precision@50:0.2420
ndcg@10:0.6016
ndcg@30:0.6010
ndcg@50:0.6197
ndcg@all:0.7376


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 20]
precision@10:0.5827
precision@30:0.3305
precision@50:0.2433
ndcg@10:0.6039
ndcg@30:0.6042
ndcg@50:0.6230
ndcg@all:0.7397


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 21]
precision@10:0.5890
precision@30:0.3327
precision@50:0.2441
ndcg@10:0.6111
ndcg@30:0.6091
ndcg@50:0.6276
ndcg@all:0.7436


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 22]
precision@10:0.5908
precision@30:0.3342
precision@50:0.2451
ndcg@10:0.6125
ndcg@30:0.6116
ndcg@50:0.6300
ndcg@all:0.7452


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 23]
precision@10:0.5947
precision@30:0.3355
precision@50:0.2457
ndcg@10:0.6191
ndcg@30:0.6168
ndcg@50:0.6351
ndcg@all:0.7494


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 24]
precision@10:0.5954
precision@30:0.3368
precision@50:0.2466
ndcg@10:0.6187
ndcg@30:0.6177
ndcg@50:0.6360
ndcg@all:0.7499


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 25]
precision@10:0.5997
precision@30:0.3380
precision@50:0.2477
ndcg@10:0.6221
ndcg@30:0.6199
ndcg@50:0.6384
ndcg@all:0.7513


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 26]
precision@10:0.6014
precision@30:0.3391
precision@50:0.2483
ndcg@10:0.6262
ndcg@30:0.6242
ndcg@50:0.6423
ndcg@all:0.7546


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 27]
precision@10:0.6037
precision@30:0.3399
precision@50:0.2498
ndcg@10:0.6275
ndcg@30:0.6251
ndcg@50:0.6441
ndcg@all:0.7553


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 28]
precision@10:0.6058
precision@30:0.3404
precision@50:0.2489
ndcg@10:0.6301
ndcg@30:0.6268
ndcg@50:0.6451
ndcg@all:0.7565


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 29]
precision@10:0.6070
precision@30:0.3419
precision@50:0.2502
ndcg@10:0.6319
ndcg@30:0.6295
ndcg@50:0.6478
ndcg@all:0.7582


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 30]
precision@10:0.6096
precision@30:0.3419
precision@50:0.2509
ndcg@10:0.6344
ndcg@30:0.6312
ndcg@50:0.6498
ndcg@all:0.7598


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 31]
precision@10:0.6123
precision@30:0.3428
precision@50:0.2511
ndcg@10:0.6376
ndcg@30:0.6336
ndcg@50:0.6518
ndcg@all:0.7615


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 32]
precision@10:0.6116
precision@30:0.3441
precision@50:0.2517
ndcg@10:0.6396
ndcg@30:0.6365
ndcg@50:0.6544
ndcg@all:0.7633


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 33]
precision@10:0.6116
precision@30:0.3449
precision@50:0.2517
ndcg@10:0.6373
ndcg@30:0.6350
ndcg@50:0.6528
ndcg@all:0.7619


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 34]
precision@10:0.6147
precision@30:0.3447
precision@50:0.2524
ndcg@10:0.6417
ndcg@30:0.6378
ndcg@50:0.6558
ndcg@all:0.7643


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 35]
precision@10:0.6140
precision@30:0.3453
precision@50:0.2530
ndcg@10:0.6413
ndcg@30:0.6379
ndcg@50:0.6563
ndcg@all:0.7643


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 36]
precision@10:0.6143
precision@30:0.3469
precision@50:0.2533
ndcg@10:0.6416
ndcg@30:0.6394
ndcg@50:0.6572
ndcg@all:0.7649


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 37]
precision@10:0.6155
precision@30:0.3457
precision@50:0.2530
ndcg@10:0.6440
ndcg@30:0.6406
ndcg@50:0.6588
ndcg@all:0.7661


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 38]
precision@10:0.6162
precision@30:0.3464
precision@50:0.2531
ndcg@10:0.6440
ndcg@30:0.6414
ndcg@50:0.6590
ndcg@all:0.7668


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 39]
precision@10:0.6178
precision@30:0.3471
precision@50:0.2538
ndcg@10:0.6453
ndcg@30:0.6418
ndcg@50:0.6597
ndcg@all:0.7672


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 40]
precision@10:0.6171
precision@30:0.3477
precision@50:0.2539
ndcg@10:0.6436
ndcg@30:0.6411
ndcg@50:0.6592
ndcg@all:0.7665


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 41]
precision@10:0.6181
precision@30:0.3480
precision@50:0.2541
ndcg@10:0.6457
ndcg@30:0.6429
ndcg@50:0.6606
ndcg@all:0.7676


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 42]
precision@10:0.6200
precision@30:0.3488
precision@50:0.2542
ndcg@10:0.6462
ndcg@30:0.6437
ndcg@50:0.6612
ndcg@all:0.7678


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 43]
precision@10:0.6170
precision@30:0.3480
precision@50:0.2543
ndcg@10:0.6452
ndcg@30:0.6430
ndcg@50:0.6613
ndcg@all:0.7676


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 44]
precision@10:0.6213
precision@30:0.3500
precision@50:0.2551
ndcg@10:0.6472
ndcg@30:0.6452
ndcg@50:0.6628
ndcg@all:0.7690


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 45]
precision@10:0.6212
precision@30:0.3490
precision@50:0.2546
ndcg@10:0.6479
ndcg@30:0.6452
ndcg@50:0.6629
ndcg@all:0.7691


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 46]
precision@10:0.6181
precision@30:0.3500
precision@50:0.2550
ndcg@10:0.6458
ndcg@30:0.6447
ndcg@50:0.6622
ndcg@all:0.7685


Training:   0%|          | 0/77 [00:00<?, ?it/s]

[Epoch 47]
precision@10:0.6195
precision@30:0.3494
precision@50:0.2551
ndcg@10:0.6494
ndcg@30:0.6472
ndcg@50:0.6648
ndcg@all:0.7703
Early stopping...


In [17]:
print("Testing...")
res = evaluate_Decoder(decoder, test_loader)
for key,val in res.items():
    print(f"{key}:{val:.4f}")

Testing...
precision@10:0.6097
precision@30:0.3402
precision@50:0.2491
ndcg@10:0.6482
ndcg@30:0.6449
ndcg@50:0.6634
ndcg@all:0.7690


## Empirical evaluation

In [18]:
import colored
from colored import stylize
from colored import fore, back, style
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [23]:
labels = raw_dataset.get_labels()
test_indices = test_dataset.indices
test_documents = [raw_documents[i] for i in test_indices]
test_labels = [labels[i] for i in test_indices]
test_embedding = torch.FloatTensor(docvec[test_indices])
test_importance = torch.FloatTensor(np.array(importance_score)[test_indices])
itos = vectorizer.get_feature_names()

### Train topic model for comparision

In [19]:
from contextualized_topic_models.utils.preprocessing import WhiteSpacePreprocessingStopwords

train_documents = [raw_documents[i] for i in train_dataset.indices]
sp = WhiteSpacePreprocessingStopwords(train_documents, "english")
preprocessed_documents, unpreprocessed_documents, vocab = sp.preprocess()

In [20]:
from contextualized_topic_models.models.ctm import CombinedTM, ZeroShotTM
from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation
from contextualized_topic_models.utils.data_preparation import bert_embeddings_from_file

qt = TopicModelDataPreparation("all-mpnet-base-v2")

training_dataset = qt.fit(text_for_contextual=unpreprocessed_documents, text_for_bow=preprocessed_documents)

ctm = CombinedTM(bow_size=len(qt.vocab), contextual_size=768, n_components=50) # 50 topics

ctm.fit(training_dataset) # run the model

Batches:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch: [100/100]	 Seen Samples: [978500/978500]	Train Loss: 341.2394142022228	Time: 0:00:05.163202: : 100it [08:30,  5.10s/it]
Sampling: [20/20]: : 20it [01:26,  4.33s/it]


In [21]:
topic_words = ctm.get_topics(10)

In [24]:
sp = WhiteSpacePreprocessingStopwords(test_documents, "english")
preprocessed_documents, unpreprocessed_documents, vocab = sp.preprocess()
test_dataset = qt.transform(text_for_contextual=unpreprocessed_documents, text_for_bow=preprocessed_documents)

Batches:   0%|          | 0/17 [00:00<?, ?it/s]

In [25]:
test_topics = ctm.get_doc_topic_distribution(test_dataset, n_samples=20) # returns a (n_documents, n_topics) matrix with the topic distribution of each document

Sampling: [20/20]: : 20it [01:20,  4.05s/it]


### Decoder output and analysis

In [26]:
check_id = 2 #1200
K = 10
raw_doc = test_documents[check_id]
category = test_labels[check_id]
embedding = test_embedding[check_id].unsqueeze(0).to(device)
prediction = torch.topk(decoder(embedding)[0],k=K).indices.detach().cpu().numpy()
important_words = torch.topk(test_importance[check_id],k=K).indices
print(category)

# calculate score
precision_scores = retrieval_precision_all(decoder(embedding).detach().cpu(), test_importance[check_id].unsqueeze(0), k=config["topk"])
ndcg_scores = retrieval_normalized_dcg_all(decoder(embedding).detach().cpu(), test_importance[check_id].unsqueeze(0), k=config["topk"])
print(precision_scores)
print(ndcg_scores)

# transform to word
prediction = [itos[i] for i in prediction]
important_words = [itos[i] for i in important_words]

sci.crypt
{10: 0.6000000238418579, 30: 0.36666667461395264, 50: 0.3199999928474426}
{10: 0.3357486426830292, 30: 0.3524261713027954, 50: 0.3837764859199524, 'all': 0.5948744416236877}


In [27]:
print(style.BOLD + back.LIGHT_SKY_BLUE_1 + 'Ground truth'+ style.RESET)
for word in raw_doc.split():
    if word in important_words:
        print (fore.LIGHT_BLUE + style.BOLD + word + style.RESET,end = " ")
    else:
        print(word, end=' ')

[1m[48;5;153mGround truth[0m
suggest common distribution private [38;5;12m[1mkey[0m public [38;5;12m[1mkey[0m system encrypt posting theory work fine long [38;5;12m[1mkey[0m remain secure [38;5;12m[1mpractice[0m [38;5;12m[1mgood[0m [38;5;12m[1midea[0m [38;5;12m[1mcheck[0m [38;5;12m[1mviolation[0m net rule [38;5;12m[1mpractice[0m [38;5;12m[1mgood[0m [38;5;12m[1midea[0m [38;5;12m[1mcheck[0m [38;5;12m[1mkey[0m [38;5;12m[1mgood[0m [38;5;12m[1midea[0m [38;5;12m[1mcheck[0m post forward site make [38;5;12m[1mchain[0m work [38;5;12m[1mproblem[0m discussion group travel [38;5;12m[1mfacility[0m control member [38;5;12m[1mproblem[0m [38;5;12m[1mmailing[0m list approach fun 

In [28]:
print(style.BOLD + back.LIGHT_PINK_1 + 'Prediction'+ style.RESET)
for word in raw_doc.split():
    if word in prediction:
        print (fore.LIGHT_RED + style.BOLD + word + style.RESET,end = " ")
    else:
        print(word, end=' ')

[1m[48;5;217mPrediction[0m
suggest common distribution private key public key system encrypt [38;5;9m[1mposting[0m theory [38;5;9m[1mwork[0m fine long key remain secure practice good [38;5;9m[1midea[0m check violation net rule practice good [38;5;9m[1midea[0m check key good [38;5;9m[1midea[0m check [38;5;9m[1mpost[0m forward site make chain [38;5;9m[1mwork[0m [38;5;9m[1mproblem[0m [38;5;9m[1mdiscussion[0m group travel facility control member [38;5;9m[1mproblem[0m mailing list approach fun 

In [29]:
pred_topic = np.argsort(test_topics[check_id])[-2:][::-1]
for topicID in pred_topic:
    print("Topic id: %s, Probability: %.2f" %(topicID, test_topics[check_id][topicID]))
    print(" ".join(topic_words[topicID]))

Topic id: 42, Probability: 0.14
privacy internet protect technology encryption ensure policy law security network
Topic id: 46, Probability: 0.13
key block bit chip encrypt number attack serial session message
