In [1]:
import sys
project_root = "d:/MachineLearning/federated_vae"
sys.path.append(project_root)

In [2]:
import torch
from main.data.basic_dataset import BasicDataset
import torch.nn as nn
from collections import defaultdict
import numpy as np
from main.utils import _utils

In [44]:
class BasicTrainer:
    def __init__(self, 
                 model : nn.Module,
                 dataset : BasicDataset,
                 num_top_words = 15,
                 epochs = 200,
                 learning_rate = 0.002,
                 batch_size = 200,
                 verbose = False):
        self.model = model
        self.dataset = dataset
        self.num_top_words = num_top_words
        self.epochs = epochs
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.verbose = verbose
        self.log_interval = 1
        self.data_size = len(self.dataset.train_data)

    def make_optimizer(self):
        return torch.optim.Adam(self.model.parameters(), lr = self.learning_rate)
    
    def train(self):
        optimizer = self.make_optimizer()

        for epoch in range(self.epochs):
            self.model.train()
            total_loss = 0.0
            for batch_data in self.dataset.train_dataloader:
                
                output = self.model(batch_data)

                batch_loss = output['loss']

                optimizer.zero_grad()
                batch_loss.backward()
                optimizer.step()

                total_loss += batch_loss * len(batch_data)

            if (epoch % self.log_interval == 0):    
                print(f"Epoch: {epoch:03d} | Loss: {total_loss / self.data_size}")

        top_words = self.get_top_words()
        train_theta = self.test(self.dataset.train_data)

        return top_words, train_theta

    def test(self, bow):
        data_size = bow.shape[0]
        theta = list()
        all_idx = torch.split(torch.arange(data_size), self.batch_size)

        with torch.no_grad():
            self.model.eval()
            for idx in all_idx:
                batch_input = bow[idx]
                batch_theta = self.model.get_theta(batch_input)
                theta.extend(batch_theta.cpu().tolist())

        theta = np.asarray(theta)
        return theta

    def get_beta(self):
        beta = self.model.get_beta().detach().cpu().numpy()
        return beta

    def get_top_words(self, num_top_words=None):
        if num_top_words is None:
            num_top_words = self.num_top_words
        beta = self.get_beta()
        top_words = _utils.get_top_words(beta, self.dataset.vocab, num_top_words, self.verbose)
        return top_words

    def export_theta(self):
        train_theta = self.test(self.dataset.train_data)
        test_theta = self.test(self.dataset.test_data)
        return train_theta, test_theta


In [47]:
### test
from main.model.ETM import ETM
test_basic_dataset = BasicDataset(
    dataset_dir = "../../data/20NG"
)
test_model = ETM(test_basic_dataset.vocab_size)
test_basic_trainer = BasicTrainer(
    model=test_model,
    dataset=test_basic_dataset,
    verbose=True,
    epochs = 1
)

train_size:  11314
test_size:  7532
vocab_size:  5000
average length: 110.543


In [48]:
test_basic_trainer.train()

Epoch: 000 | Loss: 2018.4725341796875
Topic 0: people replies surface sea lists eliminate simultaneously floppy sensitive determine ian mormons purdue stayed saved
Topic 1: shown stealth stolen gotten clients transmitted kim taylor erik focus seats speaks glory workstations symptoms
Topic 2: software class prejudice attempts manufacturer culture irvine bosnia utterly interest extremely giant laughed market hence
Topic 3: printer clinton quebec biggest convention tapped remains sacred operation inherently management burns enforcement quick volume
Topic 4: absolutely authorized promote village initial dump nasty vincent involving richardson investment won store nothing water
Topic 5: dan hunting rich audience explanation dogs modems imply view reserve university hayes saying statistical gives
Topic 6: equivalent staff apologies clark creator satisfy function document quick member schools skin shit aix workstations
Topic 7: afraid conner jpeg developers baker pack pitching giant reach sav

(['people replies surface sea lists eliminate simultaneously floppy sensitive determine ian mormons purdue stayed saved',
  'shown stealth stolen gotten clients transmitted kim taylor erik focus seats speaks glory workstations symptoms',
  'software class prejudice attempts manufacturer culture irvine bosnia utterly interest extremely giant laughed market hence',
  'printer clinton quebec biggest convention tapped remains sacred operation inherently management burns enforcement quick volume',
  'absolutely authorized promote village initial dump nasty vincent involving richardson investment won store nothing water',
  'dan hunting rich audience explanation dogs modems imply view reserve university hayes saying statistical gives',
  'equivalent staff apologies clark creator satisfy function document quick member schools skin shit aix workstations',
  'afraid conner jpeg developers baker pack pitching giant reach saving beautiful hole mode deliberately passes',
  'suck police abiding use