In [1]:
import sys
project_root = "d:/MachineLearning/federated_vae"
sys.path.append(project_root)

In [2]:
import torch
from main.data.basic_dataset import BasicDataset
import torch.nn as nn
from collections import defaultdict
import numpy as np
from main.utils import _utils

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class BasicTrainer:
    def __init__(self, 
                 model : nn.Module,
                 dataset : BasicDataset,
                 num_top_words = 15,
                 epochs = 200,
                 learning_rate = 0.002,
                 batch_size = 200,
                 verbose = False,
                 device = "cuda"):
        self.model = model
        self.dataset = dataset
        self.num_top_words = num_top_words
        self.epochs = epochs
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.verbose = verbose
        self.log_interval = 1
        self.data_size = len(self.dataset.train_data)
        self.device = device

    def make_optimizer(self):
        return torch.optim.Adam(self.model.parameters(), lr = self.learning_rate)
    
    def train(self, model_name = None, global_round = None):
        optimizer = self.make_optimizer()
        if model_name is not None:
            print(f"Client's model: {model_name} \t | Global round: {global_round}")
        for epoch in range(self.epochs):
            self.model.train()
            total_loss = 0.0
            for batch_data in self.dataset.train_dataloader:
                batch_data = batch_data.to(self.device)
                output = self.model(batch_data)

                batch_loss = output['loss']

                optimizer.zero_grad()
                batch_loss.backward()
                optimizer.step()

                total_loss += batch_loss * len(batch_data)

            if (epoch % self.log_interval == 0):    
                print(f"Epoch: {epoch:03d} | Loss: {total_loss / self.data_size}")

        top_words = self.get_top_words()
        train_theta = self.test(self.dataset.train_data)

        return top_words, train_theta

    def test(self, bow):
        data_size = bow.shape[0]
        theta = list()
        all_idx = torch.split(torch.arange(data_size), self.batch_size)
        with torch.no_grad():
            self.model.eval()
            for idx in all_idx:
                batch_input = bow[idx]
                batch_input = batch_input.to(self.device)
                # print(batch_input.device)
                batch_theta = self.model.get_theta(batch_input)
                theta.extend(batch_theta.cpu().tolist())

        theta = np.asarray(theta)
        return theta

    def get_beta(self):
        beta = self.model.get_beta().detach().cpu().numpy()
        return beta

    def get_top_words(self, num_top_words=None):
        if num_top_words is None:
            num_top_words = self.num_top_words
        beta = self.get_beta()
        top_words = _utils.get_top_words(beta, self.dataset.vocab, num_top_words, self.verbose)
        return top_words

    def export_theta(self):
        train_theta = self.test(self.dataset.train_data)
        test_theta = self.test(self.dataset.test_data)
        return train_theta, test_theta


In [4]:
### test
from main.model.ETM import ETM
test_basic_dataset = BasicDataset(
    dataset_dir = "../../data/20NG"
)
test_model = ETM(test_basic_dataset.vocab_size).to("cuda")
test_basic_trainer = BasicTrainer(
    model=test_model,
    dataset=test_basic_dataset,
    verbose=True,
    epochs = 100
)

train_size:  11314
test_size:  7532
vocab_size:  5000
average length: 110.543


In [5]:
rst = test_basic_trainer.train()

Epoch: 000 | Loss: 2013.8682861328125
Epoch: 001 | Loss: 1493.218017578125
Epoch: 002 | Loss: 1260.7904052734375
Epoch: 003 | Loss: 1136.22802734375
Epoch: 004 | Loss: 1062.514404296875
Epoch: 005 | Loss: 1013.9085083007812
Epoch: 006 | Loss: 980.9697875976562
Epoch: 007 | Loss: 957.2120971679688
Epoch: 008 | Loss: 940.1891479492188
Epoch: 009 | Loss: 926.9656372070312
Epoch: 010 | Loss: 916.5875244140625
Epoch: 011 | Loss: 908.2130126953125
Epoch: 012 | Loss: 901.1885375976562
Epoch: 013 | Loss: 895.0530395507812
Epoch: 014 | Loss: 890.1278076171875
Epoch: 015 | Loss: 885.719482421875
Epoch: 016 | Loss: 882.3067016601562
Epoch: 017 | Loss: 879.1509399414062
Epoch: 018 | Loss: 876.6295776367188
Epoch: 019 | Loss: 874.5859375
Epoch: 020 | Loss: 872.5487670898438
Epoch: 021 | Loss: 870.796875
Epoch: 022 | Loss: 869.3773803710938
Epoch: 023 | Loss: 867.9207763671875
Epoch: 024 | Loss: 866.8463134765625
Epoch: 025 | Loss: 865.6793212890625
Epoch: 026 | Loss: 864.95458984375
Epoch: 027 | Lo

In [6]:
########################### test new documents ####################################
from main.data.preprocess import Preprocess

preprocess = Preprocess()

new_docs = [
    "This is a new document about space, including words like space, satellite, launch, orbit.",
    "This is a new document about Microsoft Windows, including words like windows, files, dos."
]

parsed_new_docs, new_bow = preprocess.parse(new_docs, vocab=test_basic_dataset.vocab)
print(new_bow.shape)

print(new_bow.toarray())
input = torch.as_tensor(new_bow.toarray(), device="cuda").float()
new_theta = test_basic_trainer.test(input)

print(new_theta.argmax(1))

parsing texts: 100%|██████████| 2/2 [00:00<00:00, 1947.22it/s]

(2, 5000)
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
[36 39]





In [7]:
top_words, train_theta = rst

In [9]:
import main.evaluation.topic_coherence as tc
import main.evaluation.topic_diversity as td

In [10]:
coherence_score = tc._coherence(test_basic_dataset.train_texts, test_basic_dataset.vocab, top_words)
diversity_score = td._diversity(top_words)

print(f"Topic coherence: {coherence_score}")
print(f"Topic diversity: {diversity_score}")

Topic coherence: 0.22844494565450063
Topic diversity: 0.956


In [8]:
for x in new_theta.argmax(1):
    print(top_words[x])

great space nntp american whether full video likely kill existence directory tom die language engineering
like post dos washington almost start talking together else friends shot title total chris tools
