In [9]:
import torch
from data.basic_dataset import BasicDataset
import torch.nn as nn
from collections import defaultdict
import numpy as np
from utils import _utils

In [10]:
class BasicTrainer:
    def __init__(self, 
                 model : nn.Module,
                 dataset : BasicDataset,
                 num_top_words = 15,
                 epochs = 200,
                 learning_rate = 0.002,
                 batch_size = 200,
                 verbose = False,
                 device = "cuda"):
        self.model = model
        self.dataset = dataset
        self.num_top_words = num_top_words
        self.epochs = epochs
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.verbose = verbose
        self.log_interval = 1
        self.data_size = len(self.dataset.train_data)
        self.device = device

    def make_optimizer(self):
        return torch.optim.Adam(self.model.parameters(), lr = self.learning_rate)
    
    def train(self, model_name = None, global_round = None):
        optimizer = self.make_optimizer()
        if model_name is not None:
            print(f"Client's model: {model_name} \t | Global round: {global_round}")
        for epoch in range(self.epochs):
            self.model.train()
            total_loss = 0.0
            for batch_data in self.dataset.train_dataloader:
                batch_data = batch_data.to(self.device)
                output = self.model(batch_data)

                batch_loss = output['loss']

                optimizer.zero_grad()
                batch_loss.backward()
                optimizer.step()

                total_loss += batch_loss * len(batch_data)

            if (epoch % self.log_interval == 0):    
                print(f"Epoch: {epoch:03d} | Loss: {total_loss / self.data_size}")

        top_words = self.get_top_words()
        train_theta = self.test(self.dataset.train_data)

        return top_words, train_theta

    def test(self, bow):
        data_size = bow.shape[0]
        theta = list()
        all_idx = torch.split(torch.arange(data_size), self.batch_size)
        with torch.no_grad():
            self.model.eval()
            for idx in all_idx:
                batch_input = bow[idx]
                batch_input = batch_input.to(self.device)
                # print(batch_input.device)
                batch_theta = self.model.get_theta(batch_input)
                theta.extend(batch_theta.cpu().tolist())

        theta = np.asarray(theta)
        return theta

    def get_beta(self):
        beta = self.model.get_beta().detach().cpu().numpy()
        return beta

    def get_top_words(self, num_top_words=None):
        if num_top_words is None:
            num_top_words = self.num_top_words
        beta = self.get_beta()
        top_words = _utils.get_top_words(beta, self.dataset.vocab, num_top_words, self.verbose)
        return top_words

    def export_theta(self):
        train_theta = self.test(self.dataset.train_data)
        test_theta = self.test(self.dataset.test_data)
        return train_theta, test_theta


In [11]:
### test
from model.ETM import ETM
test_basic_dataset = BasicDataset(
    dataset_dir = "../data/20NG"
)
test_model = ETM(test_basic_dataset.vocab_size).to("cuda")
test_basic_trainer = BasicTrainer(
    model=test_model,
    dataset=test_basic_dataset,
    verbose=True,
    epochs = 30
)

train_size:  11314
test_size:  7532
vocab_size:  5000
average length: 110.543


In [12]:
rst = test_basic_trainer.train()

Epoch: 000 | Loss: 2062.722900390625
Epoch: 001 | Loss: 1542.28955078125
Epoch: 002 | Loss: 1282.6055908203125
Epoch: 003 | Loss: 1137.4420166015625
Epoch: 004 | Loss: 1055.7479248046875
Epoch: 005 | Loss: 1007.5901489257812
Epoch: 006 | Loss: 976.6014404296875
Epoch: 007 | Loss: 955.3668212890625
Epoch: 008 | Loss: 939.692138671875
Epoch: 009 | Loss: 927.1049194335938
Epoch: 010 | Loss: 916.8370971679688
Epoch: 011 | Loss: 908.0200805664062
Epoch: 012 | Loss: 901.1065673828125
Epoch: 013 | Loss: 895.1832885742188
Epoch: 014 | Loss: 890.0108642578125
Epoch: 015 | Loss: 885.7984619140625
Epoch: 016 | Loss: 882.3370971679688
Epoch: 017 | Loss: 879.3961181640625
Epoch: 018 | Loss: 876.7074584960938
Epoch: 019 | Loss: 874.4373168945312
Epoch: 020 | Loss: 872.562744140625
Epoch: 021 | Loss: 870.8250122070312
Epoch: 022 | Loss: 869.222900390625
Epoch: 023 | Loss: 867.946533203125
Epoch: 024 | Loss: 866.9002075195312
Epoch: 025 | Loss: 865.8560180664062
Epoch: 026 | Loss: 865.0390625
Epoch: 0

In [13]:
########################### test new documents ####################################
from data.preprocess import Preprocess

preprocess = Preprocess()

new_docs = [
    "This is a new document about space, including words like space, satellite, launch, orbit.",
    "This is a new document about Microsoft Windows, including words like windows, files, dos."
]

parsed_new_docs, new_bow = preprocess.parse(new_docs, vocab=test_basic_dataset.vocab)
print(new_bow.shape)

print(new_bow.toarray())
input = torch.as_tensor(new_bow.toarray(), device="cuda").float()
new_theta = test_basic_trainer.test(input)

print(new_theta.argmax(1))

parsing texts: 100%|██████████| 2/2 [00:00<00:00, 329.99it/s]


(2, 5000)
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
[26 27]


In [14]:
top_words, train_theta = rst

In [15]:
import evaluation.topic_coherence as tc
import evaluation.topic_diversity as td

In [None]:
coherence_score = tc._coherence(test_basic_dataset.train_texts, test_basic_dataset.vocab, top_words)
diversity_score = td._diversity(top_words)

print(f"Topic coherence: {coherence_score}")
print(f"Topic diversity: {diversity_score}")

In [None]:
for x in new_theta.argmax(1):
    print(top_words[x])

great space nntp american whether full video likely kill existence directory tom die language engineering
like post dos washington almost start talking together else friends shot title total chris tools
