In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
import scipy as sp
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
import seaborn as sns

from collections import defaultdict
from itertools import groupby
from sklearn import datasets
from numpy import random
from scipy.stats import dirichlet, norm, poisson

In [3]:
# from keras.datasets import reuters, imdb

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

In [5]:
import numpy as np
import os

In [6]:
from pathlib import Path
from collections import OrderedDict
import pickle, gzip, math, torch, matplotlib as mpl
import matplotlib.pyplot as plt

Path.ls = lambda x: list(x.iterdir())

## URSA Datasets 1K 

In [7]:
folder_ds_path = Path('../data/User Review Structure Analysis (URSA)/')
xml_path = (folder_ds_path/'Classified_Corpus.xml')
ds_path = (folder_ds_path/'1k')
sentence_npy_path = (folder_ds_path/'sentence.npy')
vocab_pkl_path = (ds_path/'vocab.pkl')
seed_words_path = (ds_path/'seed_words.txt')
train_filename = (ds_path/'train.txt.npy')

# log words not pass
aspect_tags = ['Food', 'Staff', 'Ambience']
polatiry_tags = ['Positive', 'Negative', 'Neutral']
xml_review_tag = './/Review'
log_np = [[], [], []]

# length allowed sentences
# length_allowed = [11, 7, 4]
# min_freq_allowed = -1

In [8]:
vocab2id = pickle.load(open(vocab_pkl_path, 'rb'))
vocab_size=len(vocab2id)

In [9]:
train_data = np.load((train_filename), allow_pickle=True)

In [10]:
p_sentence_list, label_list = train_data[:, 0], train_data[:, 1]

In [11]:
vocab = dict(map(reversed, vocab2id.items()))

In [12]:
vocab_size = len(vocab)

## Dataset

In [13]:
from sklearn.model_selection import train_test_split

In [14]:
x_, y_ = [], []
for p_sentence, label_ in zip(p_sentence_list, label_list): 
    x_.append(p_sentence)
    y_.append(label_)

In [15]:
len(x_) == len(y_)

True

In [16]:
train_x, test_x, train_y, test_y =  train_test_split(
    x_, y_, test_size=0.1, random_state=0)

In [17]:
print ('Data Loaded')
print ('Dim Training Data',len(train_x), vocab_size)
print ('Dim Test Data', len(test_x), vocab_size)

Data Loaded
Dim Training Data 3095 2772
Dim Test Data 344 2772


## Constants

In [18]:
bs = 200
en1_units=100
en2_units=100
num_topic=3
num_input=vocab_size
variance=0.995
init_mult=1.0
learning_rate=0.0005
batch_size=200
momentum=0.99
num_epoch=200
nogpu=True
drop_rate=0.6

## Topic Model Utility Functions

In [19]:
def read_file_seed_words(fn):
    with open(fn, "r") as fr:
        def p_string_sw(l):
            return l.replace('\n','').split(',')
        rl = [p_string_sw(l) for l in fr]
    return rl

## Seed words

In [20]:
seed_words = read_file_seed_words(seed_words_path)

In [21]:
print (seed_words)

[['food', 'sauc', 'chicken', 'shrimp', 'chees', 'potato', 'fri', 'tomato', 'roast', 'onion', 'pork', 'goat', 'grill', 'tuna', 'salad', 'beef', 'tapa'], ['staff', 'servic', 'friendli', 'rude', 'hostess', 'waiter', 'bartend', 'waitress', 'help', 'polit', 'bar', 'courteou', 'member', 'waitstaff', 'attitud', 'reserv', 'tip'], ['atmospher', 'scene', 'place', 'tabl', 'outsid', 'area', 'ambianc', 'outdoor', 'romant', 'cozi', 'decor', 'sit', 'wall', 'light', 'window', 'area', 'ceil', 'floor']]


In [22]:
def setup_prior(fn, n_k=3):
    gamma = torch.zeros((len(vocab),n_k))
    gamma_bin = torch.zeros((1, len(vocab),n_k))

    full_vocab = read_file_seed_words(fn)
    for k in range(len(full_vocab)):
        for idx in range(len(full_vocab[k])):
            ivocab = vocab2id[full_vocab[k][idx]]
            gamma[ivocab, k] = 1.0
            gamma_bin[:, ivocab, :] = 1.0

    return (gamma, gamma_bin)

In [23]:
def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str): return [o]
    if isinstance(o, Iterable): return list(o)
    return [o]
def setify(o): return o if isinstance(o,set) else set(listify(o))
def compose(x, funcs, *args, order_key='_order', **kwargs):
    key = lambda o: getattr(o, order_key, 0)
    for f in sorted(listify(funcs), key=key): x = f(x, **kwargs)
    return x

In [24]:
def print_perp(model):
    cost = []
    model.eval()                        # switch to testing mode
    for x_test, y_test in test_dl:
        recon, loss = model(x_test, compute_loss=True, avg_loss=False)
        loss = loss.data
        counts = x_test.sum(1)
        cost.extend((loss / counts).data.cpu().tolist())
    print('The approximated perplexity is: ', (np.exp(np.mean(np.array(cost)))))

def print_top_words(beta, feature_names, n_top_words=10):
    print ('---------------Printing the Topics------------------')
    for i in range(len(beta)):
        line = " ".join([feature_names[j] 
                         for j in beta[i].argsort()[:-n_top_words - 1:-1]])
        print('{}'.format(line))
    print ('---------------End of Topics------------------')
    
def print_gamma(gamma, seed_words, vocab, vocab2id):
    sws = []        
    for k in range(len(seed_words)):
        for idx in range(len(seed_words[k])):
            w = seed_words[k][idx]
            sws.append((k, w))

    for idx in range(len(sws)):
        k, w = sws[idx]
        ivocab = vocab2id[w]
        mk = gamma[ivocab].argmax(-1)
        print (ivocab, w, k, mk, gamma[ivocab])

## Data Utility Functions

In [25]:
def collate(b):
    x, y = zip(*b)
    return torch.stack(x), torch.stack(y)

class IdifyAndLimitedVocab():
    _order=-1
    def __init__(self, vocab2id, limited_vocab):
        self.vocab2id = vocab2id
        self.limited_vocab = limited_vocab
    def __call__(self, item):
        idlist = [self.vocab2id[w] for w in item if self.vocab2id[w] < self.limited_vocab]
        return np.array(idlist)
    

class Numpyify():
    _order=0
    def __call__(self, item):
        return np.array(item)

class Onehotify():
    _order=1
    def __init__(self, vocab_size):
        self.vocab_size = vocab_size
    def __call__(self, item):
        return np.array(np.bincount(item.astype('int'), minlength=self.vocab_size))
    
class YToOnehot():
    _order=1
    def __init__(self, num_classes):
        self.num_classes = num_classes
    def __call__(self, item):
        categorical = np.zeros((1, self.num_classes))
        categorical[0, item] = 1
        return categorical

class Tensorify():
    _order=2
    def __call__(self, item):
        return torch.from_numpy(item)

class Floatify():
    _order=3
    def __call__(self, item):
        return item.float()
    
class CheckAndCudify():
    _order=100
    def __init__(self):
        self.ic = torch.cuda.is_available()
    def __call__(self, item):
        return item.cuda() if self.ic else item
    
class URSADataset(Dataset):
    def __init__(self, x, y, tfms_x, tfms_y): 
        self.x, self.y = x, y
        self.x_tfms = tfms_x
        self.y_tfms = tfms_y
    def __len__(self): 
        return len(self.x)
    def __getitem__(self, i): 
        return compose(self.x[i], self.x_tfms), compose(self.y[i], self.y_tfms)
    
class Sampler():
    def __init__(self, ds, bs, shuffle=False):
        self.n,self.bs,self.shuffle = len(ds),bs,shuffle
        
    def __iter__(self):
        self.idxs = torch.randperm(self.n) if self.shuffle else torch.arange(self.n)
        for i in range(0, self.n, self.bs): yield self.idxs[i:i+self.bs]

class DataLoader():
    def __init__(self, ds, sampler, collate_fn=collate):
        self.ds,self.sampler,self.collate_fn = ds,sampler,collate_fn
        
    def __iter__(self):
        for s in self.sampler: yield self.collate_fn([self.ds[i] for i in s])

## Load Data

In [26]:
num_classes = np.max(train_y) + 1

In [27]:
num_classes

3

In [28]:
tfms_x = [Numpyify(), Onehotify(vocab_size=vocab_size), Tensorify(), Floatify(), CheckAndCudify()]
tfms_y = [YToOnehot(num_classes=num_classes), Tensorify(), Floatify(), CheckAndCudify()]

In [29]:
train_ds = URSADataset(train_x, train_y, tfms_x=tfms_x, tfms_y=tfms_y)
test_ds = URSADataset(test_x, test_y, tfms_x=tfms_x, tfms_y=tfms_y)

In [30]:
train_samp = Sampler(train_ds, bs, shuffle=False)
test_samp = Sampler(test_ds, bs, shuffle=False)

In [31]:
train_dl = DataLoader(train_ds, sampler=train_samp, collate_fn=collate)
test_dl = DataLoader(test_ds, sampler=test_samp, collate_fn=collate)

In [32]:
gamma_prior = setup_prior(seed_words_path, 3)

In [33]:
gamma, gamma_bin = gamma_prior

## Define Model

In [34]:
class ProdLDA(nn.Module):
    def __init__(self, num_input, en1_units, en2_units, num_topic, drop_rate, init_mult, gamma_prior):
        super(ProdLDA, self).__init__()
        self.num_input, self.en1_units, self.en2_units, \
        self.num_topic, self.drop_rate, self.init_mult = num_input, en1_units, en2_units, \
                                                            num_topic, drop_rate, init_mult
        # gamma prior
        self.gamma_prior = gamma_prior
        
        # encoder
        self.en1_fc = nn.Linear(num_input, en1_units)
        self.en1_ac = nn.Softplus()
        self.en2_fc     = nn.Linear(en1_units, en2_units)
        self.en2_ac = nn.Softplus()
        self.en2_dr   = nn.Dropout(drop_rate)
        
        # mean, logvar
        self.mean_fc = nn.Linear(en2_units, num_topic)
        self.mean_bn = nn.BatchNorm1d(num_topic)
        self.logvar_fc = nn.Linear(en2_units, num_topic)
        self.logvar_bn = nn.BatchNorm1d(num_topic)

        # decoder
        self.de_ac1 = nn.Softmax(dim=-1)
        self.de_dr = nn.Dropout(drop_rate)
        self.de_fc = nn.Linear(num_topic, num_input)
        self.de_bn = nn.BatchNorm1d(num_input)
        self.de_ac2 = nn.Softmax(dim=-1)
        
        # prior mean and variance as constant buffers
        self.prior_mean   = torch.Tensor(1, num_topic).fill_(0)
        self.prior_var    = torch.Tensor(1, num_topic).fill_(variance)
        self.prior_mean   = nn.Parameter(self.prior_mean, requires_grad=False)
        self.prior_var    = nn.Parameter(self.prior_var, requires_grad=False)
        self.prior_logvar = nn.Parameter(self.prior_var.log(), requires_grad=False)
        # initialize decoder weight
        if init_mult != 0:
            #std = 1. / math.sqrt( init_mult * (num_topic + num_input))
            self.de_fc.weight.data.uniform_(0, init_mult)
        # remove BN's scale parameters
        for component in [self.mean_bn, self.logvar_bn, self.de_bn]:
            component.weight.requires_grad = False
            component.weight.fill_(1.0)
        
    def gamma(self):
        # this function have to run after self.encode
        encoder_w1 = self.en1_fc.weight
        encoder_b1 = self.en1_fc.bias
        encoder_w2 = self.en2_fc.weight
        encoder_b2 = self.en2_fc.bias
        mean_w = self.mean_fc.weight
        mean_b = self.mean_fc.bias
        mean_running_mean = self.mean_bn.running_mean
        mean_running_var = self.mean_bn.running_var
        logvar_w = self.logvar_fc.weight
        logvar_b = self.logvar_fc.bias
        logvar_running_mean = self.logvar_bn.running_mean
        logvar_running_var = self.logvar_bn.running_var
        
        w1 = F.softplus(encoder_w1.t() + encoder_b1)
        w2 = F.softplus(F.linear(w1, encoder_w2, encoder_b2))
        wdr = F.dropout(w2, self.drop_rate)
        wo_mean = F.softmax(F.batch_norm(F.linear(wdr, mean_w, mean_b), mean_running_mean, mean_running_var), dim=-1)
        wo_logvar = F.softmax(F.batch_norm(F.linear(wdr, logvar_w, logvar_b), logvar_running_mean, logvar_running_var), dim=-1)
        
        return wo_mean, wo_logvar
            
    def encode(self, input_):
        # encoder
        encoded1 = self.en1_fc(input_)
        encoded1_ac = self.en1_ac(encoded1)
        encoded2 = self.en2_fc(encoded1_ac)
        encoded2_ac = self.en2_ac(encoded2)
        encoded2_dr = self.en2_dr(encoded2_ac)
        
        encoded = encoded2_dr
        
        # hidden => mean, logvar
        mean_theta = self.mean_fc(encoded)
        mean_theta_bn = self.mean_bn(mean_theta)
        logvar_theta = self.logvar_fc(encoded)
        logvar_theta_bn = self.logvar_bn(logvar_theta)
        
        posterior_mean = mean_theta_bn
        posterior_logvar = logvar_theta_bn
        return encoded, posterior_mean, posterior_logvar
    
    def decode(self, input_, posterior_mean, posterior_var):
        # take sample
        eps = input_.data.new().resize_as_(posterior_mean.data).normal_() # noise 
        z = posterior_mean + posterior_var.sqrt() * eps                   # reparameterization
        # do reconstruction
        # decoder
        decoded1_ac = self.de_ac1(z)
        decoded1_dr = self.de_dr(decoded1_ac)
        decoded2 = self.de_fc(decoded1_dr)
        decoded2_bn = self.de_bn(decoded2)
        decoded2_ac = self.de_ac2(decoded2_bn)
        recon = decoded2_ac          # reconstructed distribution over vocabulary
        return recon
    
    def forward(self, input_, compute_loss=False, avg_loss=True):
        # compute posterior
        en2, posterior_mean, posterior_logvar = self.encode(input_) 
        posterior_var    = posterior_logvar.exp()
        
        recon = self.decode(input_, posterior_mean, posterior_var)
        if compute_loss:
            return recon, self.loss(input_, recon, posterior_mean, posterior_logvar, posterior_var, avg_loss)
        else:
            return recon

    def loss(self, input_, recon, posterior_mean, posterior_logvar, posterior_var, avg=True):
        # NL
        NL  = -(input_ * (recon + 1e-10).log()).sum(1)
        # KLD, see Section 3.3 of Akash Srivastava and Charles Sutton, 2017, 
        # https://arxiv.org/pdf/1703.01488.pdf
        prior_mean   = self.prior_mean.expand_as(posterior_mean)
        prior_var    = self.prior_var.expand_as(posterior_mean)
        prior_logvar = self.prior_logvar.expand_as(posterior_mean)
        var_division    = posterior_var  / prior_var
        diff            = posterior_mean - prior_mean
        diff_term       = diff * diff / prior_var
        logvar_division = prior_logvar - posterior_logvar
        # put KLD together
        KLD = 0.5 * ( (var_division + diff_term + logvar_division).sum(1) - self.num_topic)
        
        # gamma
        n, _ = input_.size()
        gamma_mean, gamma_logvar = self.gamma()
        gamma_prior, gammar_prior_bin = self.gamma_prior
        input_t = (input_ > 0).unsqueeze(dim=-1)
        input_bin = ((gammar_prior_bin.expand(n, -1, -1) == 1) & input_t)
        lambda_c = 20.0
        
        gamma_prior = gamma_prior.expand(n, -1, -1)      
        
        GL = lambda_c * ((gamma_prior - (input_bin.int()*gamma_mean))**2).sum((1, 2))
        
        # loss
        loss = (NL + KLD + GL)
        
        # in traiming mode, return averaged loss. In testing mode, return individual loss
        if avg:
            return loss.mean()
        else:
            return loss

## Train

In [36]:
from sklearn import metrics

In [37]:
def compute_accuracy(y_pred, y_true):
    accuracy = metrics.accuracy_score(y_true, y_pred)
    precision, recall, f1_score, _ = metrics.precision_recall_fscore_support(y_true=y_true, \
                                                     y_pred=y_pred, \
                                                     average=None)

    return (accuracy, precision, recall, f1_score)

In [38]:
model = ProdLDA(num_input, en1_units, en2_units, num_topic, drop_rate, init_mult, gamma_prior)
optimizer = torch.optim.Adam(model.parameters(), learning_rate, betas=(momentum, 0.999))

In [39]:
if torch.cuda.is_available():
    model = model.cuda()

In [54]:
for epoch in range(num_epoch):
    loss_epoch = 0.0
    model.train()                    # switch to training mode
    for input_, label_ in train_dl:
        recon, loss = model(input_, compute_loss=True)
        # optimize
        optimizer.zero_grad()        # clear previous gradients
        loss.backward()              # backprop
        optimizer.step()             # update parameters
        # report
        loss_epoch += loss.item()    # add loss to loss_epoch
    if (epoch + 1) % 10 == 0:
        model.eval()
        # Test Model
        pred_train = []
        label_train = []
        pred_test = []
        label_test = []
        
        for x_train, y_train in train_dl:
            encoded, theta_mean, theta_logvar = model.encode(x_train)
            temp_theta_mean = theta_mean.argmax(-1).int().data.cpu().tolist()
            temp_y_train = y_train.argmax(-1).flatten().data.cpu().tolist()
            
            pred_train.extend(temp_theta_mean)
            label_train.extend(temp_y_train)
        
        accuracy_train, precision_train, recall_train, f1_score_train = compute_accuracy(pred_train, label_train)
        
        for x_test, y_test in test_dl:
            encoded, theta_mean, theta_logvar = model.encode(x_test)
            temp_theta_mean = theta_mean.argmax(-1).int().data.cpu().tolist()
            temp_y_test = y_test.argmax(-1).flatten().data.cpu().tolist()
            
            pred_test.extend(temp_theta_mean)
            label_test.extend(temp_y_test)
        
        accuracy_test, precision_test, recall_test, f1_score_test = compute_accuracy(pred_test, label_test)
        print ("##################################################")
        print('Epoch {}, loss={}, accuracy_train={}, accuracy_test={}'.format(epoch, loss_epoch / len(input_), accuracy_train, accuracy_test))
        for k in range(num_topic):
            print ("precision_train{}".format(k), "=" , "{:.9f}".format(precision_train[k]), \
                 "recall_train{}".format(k), "=" , "{:.9f}".format(recall_train[k]), \
                 "f1_score_train{}".format(k), "=" , "{:.9f}".format(f1_score_train[k]))
            print ("precision_te{}".format(k), "=" , "{:.9f}".format(precision_test[k]), \
                 "recall_te{}".format(k), "=" , "{:.9f}".format(recall_test[k]), \
                 "f1_score_te{}".format(k), "=" , "{:.9f}".format(f1_score_test[k]))
        emb = model.de_fc.weight.data.detach().cpu().numpy().T
        print_top_words(emb, vocab, 50)
        print_perp(model)
        print ("##################################################")        

  _warn_prf(average, modifier, msg_start, len(result))


##################################################
Epoch 0, loss=199.263525390625, accuracy_train=0.3431340872374798, accuracy_test=0.3546511627906977
precision_train0 = 0.342820181 recall_train0 = 1.000000000 f1_score_train0 = 0.510597303
precision_te0 = 0.354651163 recall_te0 = 1.000000000 f1_score_te0 = 0.523605150
precision_train1 = 0.000000000 recall_train1 = 0.000000000 f1_score_train1 = 0.000000000
precision_te1 = 0.000000000 recall_te1 = 0.000000000 f1_score_te1 = 0.000000000
precision_train2 = 0.666666667 recall_train2 = 0.002096436 f1_score_train2 = 0.004179728
precision_te2 = 0.000000000 recall_te2 = 0.000000000 f1_score_te2 = 0.000000000
---------------Printing the Topics------------------
scallop itali ginger gem becom whole drop ala importantli villag trulli worth dig chain dip connoisseur profession popular inde vinyl contribut ice someon moist mall gossip dynamit carb chandeli radish stand spend peanut volum calm kitsch cucumb taco broadway member wont greasi clueless c

##################################################
Epoch 25, loss=197.63344341077303, accuracy_train=0.6995153473344103, accuracy_test=0.6976744186046512
precision_train0 = 0.723253758 recall_train0 = 0.771698113 f1_score_train0 = 0.746691009
precision_te0 = 0.744000000 recall_te0 = 0.762295082 f1_score_te0 = 0.753036437
precision_train1 = 0.736842105 recall_train1 = 0.647548566 f1_score_train1 = 0.689315608
precision_te1 = 0.733944954 recall_te1 = 0.645161290 f1_score_te1 = 0.686695279
precision_train2 = 0.638067061 recall_train2 = 0.678197065 f1_score_train2 = 0.657520325
precision_te2 = 0.609090909 recall_te2 = 0.683673469 f1_score_te2 = 0.644230769
---------------Printing the Topics------------------
itali scallop eel sausag either great trulli hearti rich style chines around hour pineappl becom chair nut villag worth peke tender abl whole care drop beauti sometim egg broadway ultim chandeli taco ice drape greasi gossip soundtrack radish stand boyfriend banana waitress forev filet 

The approximated perplexity is:  1.9576643174566857e+25
##################################################
##################################################
Epoch 50, loss=195.81094520970396, accuracy_train=0.7508885298869143, accuracy_test=0.7645348837209303
precision_train0 = 0.812680115 recall_train0 = 0.798113208 f1_score_train0 = 0.805330795
precision_te0 = 0.833333333 recall_te0 = 0.778688525 f1_score_te0 = 0.805084746
precision_train1 = 0.732283465 recall_train1 = 0.774283071 f1_score_train1 = 0.752697842
precision_te1 = 0.746268657 recall_te1 = 0.806451613 f1_score_te1 = 0.775193798
precision_train2 = 0.703622393 recall_train2 = 0.671907757 f1_score_train2 = 0.687399464
precision_te2 = 0.708333333 recall_te2 = 0.693877551 f1_score_te2 = 0.701030928
---------------Printing the Topics------------------
great sausag chicken scallop tender eel either roll pineappl style egg hearti itali peke sear nut rich chines loaf worth trulli lettuc filet grill taco salt known greasi boyfriend

##################################################
Epoch 75, loss=195.2543521278783, accuracy_train=0.7809369951534734, accuracy_test=0.7965116279069767
precision_train0 = 0.908888889 recall_train0 = 0.771698113 f1_score_train0 = 0.834693878
precision_te0 = 0.957894737 recall_te0 = 0.745901639 f1_score_te0 = 0.838709677
precision_train1 = 0.706140351 recall_train1 = 0.893617021 f1_score_train1 = 0.788893426
precision_te1 = 0.712500000 recall_te1 = 0.919354839 f1_score_te1 = 0.802816901
precision_train2 = 0.765417170 recall_train2 = 0.663522013 f1_score_train2 = 0.710836609
precision_te2 = 0.775280899 recall_te2 = 0.704081633 f1_score_te2 = 0.737967914
---------------Printing the Topics------------------
chicken great sausag roll tender scallop grill eel menu pineappl egg style sear either salt lettuc hearti filet kobe loaf top known citru nut dumpl rich peke octopu die chines taco brais greasi beef itali ice roast duck oil sichuan famou red list love root perfectli free trulli boyfrien

##################################################
Epoch 100, loss=194.62125436883224, accuracy_train=0.8048465266558966, accuracy_test=0.8052325581395349
precision_train0 = 0.941043084 recall_train0 = 0.783018868 f1_score_train0 = 0.854788877
precision_te0 = 0.978494624 recall_te0 = 0.745901639 f1_score_te0 = 0.846511628
precision_train1 = 0.708798883 recall_train1 = 0.938945421 f1_score_train1 = 0.807799443
precision_te1 = 0.698224852 recall_te1 = 0.951612903 f1_score_te1 = 0.805460751
precision_train2 = 0.827144686 recall_train2 = 0.677148847 f1_score_train2 = 0.744668588
precision_te2 = 0.829268293 recall_te2 = 0.693877551 f1_score_te2 = 0.755555556
---------------Printing the Topics------------------
chicken sausag roll great scallop grill tender menu eel sear roast pineappl beef style egg salt lettuc kobe either filet citru hearti dumpl loaf rich known brais octopu die nut greasi taco duck oil creami top homemad sichuan dessert shrimp love famou peke veget chines red gener sweet 

##################################################
Epoch 125, loss=194.31519968133225, accuracy_train=0.8177705977382875, accuracy_test=0.8313953488372093
precision_train0 = 0.956621005 recall_train0 = 0.790566038 f1_score_train0 = 0.865702479
precision_te0 = 0.989583333 recall_te0 = 0.778688525 f1_score_te0 = 0.871559633
precision_train1 = 0.709744299 recall_train1 = 0.950046253 f1_score_train1 = 0.812500000
precision_te1 = 0.703488372 recall_te1 = 0.975806452 f1_score_te1 = 0.817567568
precision_train2 = 0.862694301 recall_train2 = 0.698113208 f1_score_train2 = 0.771726535
precision_te2 = 0.921052632 recall_te2 = 0.714285714 f1_score_te2 = 0.804597701
---------------Printing the Topics------------------
chicken sausag roll scallop grill tender menu great roast eel sear beef pineappl kobe lettuc filet salt citru either hearti dumpl egg shrimp loaf style brais creami die rich homemad known octopu nut green greasi dessert oil duck gener sichuan veget taco famou spici root peke tofu perf

##################################################
Epoch 150, loss=194.06015496504935, accuracy_train=0.821324717285945, accuracy_test=0.8343023255813954
precision_train0 = 0.961009174 recall_train0 = 0.790566038 f1_score_train0 = 0.867494824
precision_te0 = 0.989583333 recall_te0 = 0.778688525 f1_score_te0 = 0.871559633
precision_train1 = 0.705119454 recall_train1 = 0.955596670 f1_score_train1 = 0.811468971
precision_te1 = 0.705202312 recall_te1 = 0.983870968 f1_score_te1 = 0.821548822
precision_train2 = 0.885224274 recall_train2 = 0.703354298 f1_score_train2 = 0.783878505
precision_te2 = 0.933333333 recall_te2 = 0.714285714 f1_score_te2 = 0.809248555
---------------Printing the Topics------------------
chicken sausag roll grill scallop tender menu roast eel sear beef pineappl great kobe filet lettuc citru hearti salt shrimp dumpl loaf brais creami either egg homemad duck dessert nut die greasi gener octopu oil green root tofu sichuan rich style veget perfectli known kielbasa miso fam

##################################################
Epoch 175, loss=193.91462659333882, accuracy_train=0.827140549273021, accuracy_test=0.8343023255813954
precision_train0 = 0.967853042 recall_train0 = 0.795283019 f1_score_train0 = 0.873122734
precision_te0 = 0.989583333 recall_te0 = 0.778688525 f1_score_te0 = 0.871559633
precision_train1 = 0.707934337 recall_train1 = 0.957446809 f1_score_train1 = 0.813999214
precision_te1 = 0.705202312 recall_te1 = 0.983870968 f1_score_te1 = 0.821548822
precision_train2 = 0.895013123 recall_train2 = 0.714884696 f1_score_train2 = 0.794871795
precision_te2 = 0.933333333 recall_te2 = 0.714285714 f1_score_te2 = 0.809248555
---------------Printing the Topics------------------
chicken sausag scallop tender roll grill menu roast sear eel kobe pineappl beef filet lettuc hearti citru creami shrimp salt brais loaf dumpl great duck root octopu homemad tofu either nut dessert greasi oil die egg gener foie sichuan rich miso green kielbasa perfectli veget salmon pie

## Test

In [55]:
model.eval()
gamma_mean, gamma_logvar = model.gamma()
gm, gl = gamma_mean.data.cpu().numpy(), gamma_logvar.data.cpu().numpy()
print_gamma(gm, seed_words, vocab, vocab2id)

935 food 0 1 [0.24964267 0.5249344  0.225423  ]
88 sauc 0 0 [0.48584357 0.2489847  0.26517174]
2681 chicken 0 0 [0.43324354 0.26412746 0.30262896]
2414 shrimp 0 1 [0.35139   0.4001237 0.2484864]
1381 chees 0 1 [0.31703988 0.37503672 0.30792344]
1496 potato 0 1 [0.30937544 0.46532995 0.22529462]
105 fri 0 2 [0.3354963  0.30596307 0.3585406 ]
546 tomato 0 2 [0.35161608 0.29045445 0.35792953]
1347 roast 0 0 [0.43419012 0.33292595 0.23288386]
642 onion 0 1 [0.33218953 0.43471566 0.23309481]
2272 pork 0 0 [0.35807046 0.35356387 0.2883656 ]
872 goat 0 1 [0.32960615 0.3879351  0.28245872]
1005 grill 0 1 [0.21808502 0.507333   0.27458197]
124 tuna 0 1 [0.35623774 0.42428207 0.21948016]
1159 salad 0 0 [0.4088579  0.31646934 0.27467275]
2188 beef 0 1 [0.34266505 0.4064603  0.2508747 ]
601 tapa 0 1 [0.29188886 0.39962614 0.30848497]
1991 staff 1 1 [0.20411542 0.6123133  0.18357132]
1425 servic 1 1 [0.16824763 0.595536   0.2362164 ]
1137 friendli 1 1 [0.15206479 0.622277   0.22565818]
1009 rude 1 

In [56]:
emb = model.de_fc.weight.data.cpu().numpy().T
print_top_words(emb, vocab, 50)
print_perp(model)

---------------Printing the Topics------------------
chicken sausag scallop tender roll grill roast menu sear eel kobe pineappl filet beef hearti lettuc citru loaf creami brais shrimp salt dumpl duck tofu root nut octopu either homemad greasi oil foie die gener miso kielbasa great dessert egg veget salmon chipotl perfectli rich pepper sichuan peke highlight curri
n waitress java take even reserv terribl away min without need staff charg attent tell smooth us constantli host bad effici drink order hour inattent sincer problem came mediocr waiter bold patient got spill go bother diet poor embarrass though confirm ruin anoth wo last wrong bill e forgotten kept
feel chair dark wooden lit paint color tone franchis soar cozi wood place deco atmospher mirror shini midtown tier accent ba mod lamp dim outdoor strike look stylish area silver ceil mismatch weather overhead enclos inspir spaciou carpet band neon chines rear villag bakeri cater rail poster drape vintag environ
---------------End of