In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
import scipy as sp
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
import seaborn as sns

from collections import defaultdict
from itertools import groupby
from sklearn import datasets
from numpy import random
from scipy.stats import dirichlet, norm, poisson

In [3]:
# from keras.datasets import reuters, imdb

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

In [5]:
import numpy as np
import os

In [6]:
from pathlib import Path
from collections import OrderedDict
import pickle, gzip, math, torch, matplotlib as mpl
import matplotlib.pyplot as plt

Path.ls = lambda x: list(x.iterdir())

## URSA Datasets 1K 

In [7]:
folder_ds_path = Path('./data/User Review Structure Analysis (URSA)/')
xml_path = (folder_ds_path/'Classified_Corpus.xml')
ds_path = (folder_ds_path/'10k')
sentence_npy_path = (folder_ds_path/'sentence.npy')
vocab_pkl_path = (ds_path/'vocab.pkl')
seed_words_path = (ds_path/'seed_words.txt')
train_filename = (ds_path/'train.txt.npy')

# log words not pass
aspect_tags = ['Food', 'Staff', 'Ambience']
polatiry_tags = ['Positive', 'Negative', 'Neutral']
xml_review_tag = './/Review'
log_np = [[], [], []]

# length allowed sentences
# length_allowed = [11, 7, 4]
# min_freq_allowed = -1

In [8]:
vocab2id = pickle.load(open(vocab_pkl_path, 'rb'))
vocab_size=len(vocab2id)

In [9]:
train_data = np.load((train_filename), allow_pickle=True)

In [10]:
p_sentence_list, label_list = train_data[:, 0], train_data[:, 1]

In [11]:
vocab = dict(map(reversed, vocab2id.items()))

In [12]:
vocab_size = len(vocab)

## Dataset

In [13]:
from sklearn.model_selection import train_test_split

In [14]:
x_, y_ = [], []
for p_sentence, label_ in zip(p_sentence_list, label_list): 
    x_.append(p_sentence)
    y_.append(label_)

In [15]:
len(x_) == len(y_)

True

In [16]:
train_x, test_x, train_y, test_y =  train_test_split(
    x_, y_, test_size=0.1, random_state=0)

In [17]:
print ('Data Loaded')
print ('Dim Training Data',len(train_x), vocab_size)
print ('Dim Test Data', len(test_x), vocab_size)

Data Loaded
Dim Training Data 29125 18073
Dim Test Data 3237 18073


## Constants

In [18]:
bs = 200
en1_units=100
en2_units=100
num_topic=3
num_input=vocab_size
variance=0.995
init_mult=1.0
learning_rate=0.0005
batch_size=200
momentum=0.99
num_epoch=200
nogpu=True
drop_rate=0.6

## Topic Model Utility Functions

In [19]:
def read_file_seed_words(fn):
    with open(fn, "r") as fr:
        def p_string_sw(l):
            return l.replace('\n','').split(',')
        rl = [p_string_sw(l) for l in fr]
    return rl

In [20]:
seed_words = read_file_seed_words(seed_words_path)

## Toy Gamma

In [21]:
num_topic_toy = 3
batch_size_toy = 2
seed_words_toy = [['a'], ['c'], ['f']]
print (seed_words_toy)

[['a'], ['c'], ['f']]


In [22]:
vocab2id_toy = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5}

In [23]:
vocab_size_toy = len(vocab2id_toy)

In [24]:
gamma_toy = np.zeros((vocab_size_toy, num_topic_toy))
print (gamma_toy)

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


In [25]:
gamma_bin_toy = np.zeros((batch_size_toy, vocab_size_toy, num_topic_toy))
print (gamma_bin_toy)

[[[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]

 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]]


In [26]:
for k in range(len(seed_words_toy)): # number of topics
    print ("k:", k)
    for idx in range(len(seed_words_toy[k])): # number of words
        print ("idx:", idx)
        ivocab = vocab2id_toy[seed_words_toy[k][idx]]
        gamma_toy[ivocab, k] = 1.0
        gamma_bin_toy[:, ivocab, :] = 1.0

k: 0
idx: 0
k: 1
idx: 0
k: 2
idx: 0


In [27]:
gamma_toy.shape

(6, 3)

In [28]:
gamma_toy

array([[1., 0., 0.],
       [0., 0., 0.],
       [0., 1., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 1.]])

In [29]:
gamma_bin_toy.shape

(2, 6, 3)

In [30]:
gamma_bin_toy

array([[[1., 1., 1.],
        [0., 0., 0.],
        [1., 1., 1.],
        [0., 0., 0.],
        [0., 0., 0.],
        [1., 1., 1.]],

       [[1., 1., 1.],
        [0., 0., 0.],
        [1., 1., 1.],
        [0., 0., 0.],
        [0., 0., 0.],
        [1., 1., 1.]]])

In [31]:
(gamma_bin_toy == 1.) & (gamma_bin_toy == 1.) 

array([[[ True,  True,  True],
        [False, False, False],
        [ True,  True,  True],
        [False, False, False],
        [False, False, False],
        [ True,  True,  True]],

       [[ True,  True,  True],
        [False, False, False],
        [ True,  True,  True],
        [False, False, False],
        [False, False, False],
        [ True,  True,  True]]])

In [32]:
at = torch.tensor((gamma_bin_toy == 1.) & (gamma_bin_toy == 1.) )

In [33]:
at.int().float()

tensor([[[1., 1., 1.],
         [0., 0., 0.],
         [1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [0., 0., 0.],
         [1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.]]])

In [34]:
at.int().shape

torch.Size([2, 6, 3])

In [35]:
at.int().sum((1, 2))

tensor([9, 9])

## Seed words

In [36]:
print (seed_words)

[['food', 'sauc', 'chicken', 'shrimp', 'chees', 'potato', 'fri', 'tomato', 'roast', 'onion', 'pork', 'goat', 'grill', 'tuna', 'salad', 'beef', 'tapa'], ['staff', 'servic', 'friendli', 'rude', 'hostess', 'waiter', 'bartend', 'waitress', 'help', 'polit', 'bar', 'courteou', 'member', 'waitstaff', 'attitud', 'reserv', 'tip'], ['atmospher', 'scene', 'place', 'tabl', 'outsid', 'area', 'ambianc', 'outdoor', 'romant', 'cozi', 'decor', 'sit', 'wall', 'light', 'window', 'area', 'ceil', 'floor']]


In [37]:
def setup_prior(fn, n_k=3):
    gamma = torch.zeros((len(vocab),n_k))
    gamma_bin = torch.zeros((1, len(vocab),n_k))

    full_vocab = read_file_seed_words(fn)
    for k in range(len(full_vocab)):
        for idx in range(len(full_vocab[k])):
            ivocab = vocab2id[full_vocab[k][idx]]
            gamma[ivocab, k] = 1.0
            gamma_bin[:, ivocab, :] = 1.0

    return (gamma, gamma_bin)

In [38]:
def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str): return [o]
    if isinstance(o, Iterable): return list(o)
    return [o]
def setify(o): return o if isinstance(o,set) else set(listify(o))
def compose(x, funcs, *args, order_key='_order', **kwargs):
    key = lambda o: getattr(o, order_key, 0)
    for f in sorted(listify(funcs), key=key): x = f(x, **kwargs)
    return x

In [39]:
def print_perp(model):
    cost = []
    model.eval()                        # switch to testing mode
    for x_test, y_test in test_dl:
        recon, loss = model(x_test, compute_loss=True, avg_loss=False)
        loss = loss.data
        counts = x_test.sum(1)
        cost.extend((loss / counts).data.cpu().tolist())
    print('The approximated perplexity is: ', (np.exp(np.mean(np.array(cost)))))

def print_top_words(beta, feature_names, n_top_words=10):
    print ('---------------Printing the Topics------------------')
    for i in range(len(beta)):
        line = " ".join([feature_names[j] 
                         for j in beta[i].argsort()[:-n_top_words - 1:-1]])
        print('{}'.format(line))
    print ('---------------End of Topics------------------')
    
def print_gamma(gamma, seed_words, vocab, vocab2id):
    sws = []        
    for k in range(len(seed_words)):
        for idx in range(len(seed_words[k])):
            w = seed_words[k][idx]
            sws.append((k, w))

    for idx in range(len(sws)):
        k, w = sws[idx]
        ivocab = vocab2id[w]
        mk = gamma[ivocab].argmax(-1)
        print (ivocab, w, k, mk, gamma[ivocab])

## Data Utility Functions

In [40]:
def collate(b):
    x, y = zip(*b)
    return torch.stack(x), torch.stack(y)

class IdifyAndLimitedVocab():
    _order=-1
    def __init__(self, vocab2id, limited_vocab):
        self.vocab2id = vocab2id
        self.limited_vocab = limited_vocab
    def __call__(self, item):
        idlist = [self.vocab2id[w] for w in item if self.vocab2id[w] < self.limited_vocab]
        return np.array(idlist)
    

class Numpyify():
    _order=0
    def __call__(self, item):
        return np.array(item)

class Onehotify():
    _order=1
    def __init__(self, vocab_size):
        self.vocab_size = vocab_size
    def __call__(self, item):
        return np.array(np.bincount(item.astype('int'), minlength=self.vocab_size))
    
class YToOnehot():
    _order=1
    def __init__(self, num_classes):
        self.num_classes = num_classes
    def __call__(self, item):
        categorical = np.zeros((1, self.num_classes))
        categorical[0, item] = 1
        return categorical

class Tensorify():
    _order=2
    def __call__(self, item):
        return torch.from_numpy(item)

class Floatify():
    _order=3
    def __call__(self, item):
        return item.float()
    
class CheckAndCudify():
    _order=100
    def __init__(self):
        self.ic = torch.cuda.is_available()
    def __call__(self, item):
        return item.cuda() if self.ic else item
    
class URSADataset(Dataset):
    def __init__(self, x, y, tfms_x, tfms_y): 
        self.x, self.y = x, y
        self.x_tfms = tfms_x
        self.y_tfms = tfms_y
    def __len__(self): 
        return len(self.x)
    def __getitem__(self, i): 
        return compose(self.x[i], self.x_tfms), compose(self.y[i], self.y_tfms)
    
class Sampler():
    def __init__(self, ds, bs, shuffle=False):
        self.n,self.bs,self.shuffle = len(ds),bs,shuffle
        
    def __iter__(self):
        self.idxs = torch.randperm(self.n) if self.shuffle else torch.arange(self.n)
        for i in range(0, self.n, self.bs): yield self.idxs[i:i+self.bs]

class DataLoader():
    def __init__(self, ds, sampler, collate_fn=collate):
        self.ds,self.sampler,self.collate_fn = ds,sampler,collate_fn
        
    def __iter__(self):
        for s in self.sampler: yield self.collate_fn([self.ds[i] for i in s])

## Load Data

In [41]:
num_classes = np.max(train_y) + 1

In [42]:
num_classes

3

In [43]:
tfms_x = [Numpyify(), Onehotify(vocab_size=vocab_size), Tensorify(), Floatify(), CheckAndCudify()]
tfms_y = [YToOnehot(num_classes=num_classes), Tensorify(), Floatify(), CheckAndCudify()]

In [44]:
train_ds = URSADataset(train_x, train_y, tfms_x=tfms_x, tfms_y=tfms_y)
test_ds = URSADataset(test_x, test_y, tfms_x=tfms_x, tfms_y=tfms_y)

In [45]:
train_samp = Sampler(train_ds, bs, shuffle=False)
test_samp = Sampler(test_ds, bs, shuffle=False)

In [46]:
train_dl = DataLoader(train_ds, sampler=train_samp, collate_fn=collate)
test_dl = DataLoader(test_ds, sampler=test_samp, collate_fn=collate)

In [47]:
gamma_prior = setup_prior(seed_words_path, 3)

In [48]:
gamma, gamma_bin = gamma_prior

## Define Model

In [49]:
class ProdLDA(nn.Module):
    def __init__(self, num_input, en1_units, en2_units, num_topic, drop_rate, init_mult, gamma_prior):
        super(ProdLDA, self).__init__()
        self.num_input, self.en1_units, self.en2_units, \
        self.num_topic, self.drop_rate, self.init_mult = num_input, en1_units, en2_units, \
                                                            num_topic, drop_rate, init_mult
        # gamma prior
        self.gamma_prior = gamma_prior
        
        # encoder
        self.en1_fc = nn.Linear(num_input, en1_units)
        self.en1_ac = nn.Softplus()
        self.en2_fc     = nn.Linear(en1_units, en2_units)
        self.en2_ac = nn.Softplus()
        self.en2_dr   = nn.Dropout(drop_rate)
        
        # mean, logvar
        self.mean_fc = nn.Linear(en2_units, num_topic)
        self.mean_bn = nn.BatchNorm1d(num_topic)
        self.logvar_fc = nn.Linear(en2_units, num_topic)
        self.logvar_bn = nn.BatchNorm1d(num_topic)

        # decoder
        self.de_ac1 = nn.Softmax(dim=-1)
        self.de_dr = nn.Dropout(drop_rate)
        self.de_fc = nn.Linear(num_topic, num_input)
        self.de_bn = nn.BatchNorm1d(num_input)
        self.de_ac2 = nn.Softmax(dim=-1)
        
        # prior mean and variance as constant buffers
        self.prior_mean   = torch.Tensor(1, num_topic).fill_(0)
        self.prior_var    = torch.Tensor(1, num_topic).fill_(variance)
        self.prior_mean   = nn.Parameter(self.prior_mean, requires_grad=False)
        self.prior_var    = nn.Parameter(self.prior_var, requires_grad=False)
        self.prior_logvar = nn.Parameter(self.prior_var.log(), requires_grad=False)
        # initialize decoder weight
        if init_mult != 0:
            #std = 1. / math.sqrt( init_mult * (num_topic + num_input))
            self.de_fc.weight.data.uniform_(0, init_mult)
        # remove BN's scale parameters
        for component in [self.mean_bn, self.logvar_bn, self.de_bn]:
            component.weight.requires_grad = False
            component.weight.fill_(1.0)
        
    def gamma(self):
        # this function have to run after self.encode
        encoder_w1 = self.en1_fc.weight
        encoder_b1 = self.en1_fc.bias
        encoder_w2 = self.en2_fc.weight
        encoder_b2 = self.en2_fc.bias
        mean_w = self.mean_fc.weight
        mean_b = self.mean_fc.bias
        mean_running_mean = self.mean_bn.running_mean
        mean_running_var = self.mean_bn.running_var
        logvar_w = self.logvar_fc.weight
        logvar_b = self.logvar_fc.bias
        logvar_running_mean = self.logvar_bn.running_mean
        logvar_running_var = self.logvar_bn.running_var
        
        w1 = F.softplus(encoder_w1.t() + encoder_b1)
        w2 = F.softplus(F.linear(w1, encoder_w2, encoder_b2))
        wdr = F.dropout(w2, self.drop_rate)
        wo_mean = F.softmax(F.linear(wdr, mean_w, mean_b), dim=-1)
        wo_logvar = F.softmax(F.batch_norm(F.linear(wdr, logvar_w, logvar_b), logvar_running_mean, logvar_running_var), dim=-1)
        
        return wo_mean, wo_logvar
            
    def encode(self, input_):
        # encoder
        encoded1 = self.en1_fc(input_)
        encoded1_ac = self.en1_ac(encoded1)
        encoded2 = self.en2_fc(encoded1_ac)
        encoded2_ac = self.en2_ac(encoded2)
        encoded2_dr = self.en2_dr(encoded2_ac)
        
        encoded = encoded2_dr
        
        # hidden => mean, logvar
        mean_theta = self.mean_fc(encoded)
        mean_theta_bn = self.mean_bn(mean_theta)
        logvar_theta = self.logvar_fc(encoded)
        logvar_theta_bn = self.logvar_bn(logvar_theta)
        
        posterior_mean = mean_theta_bn
        posterior_logvar = logvar_theta_bn
        return encoded, posterior_mean, posterior_logvar
    
    def decode(self, input_, posterior_mean, posterior_var):
        # take sample
        eps = input_.data.new().resize_as_(posterior_mean.data).normal_() # noise 
        z = posterior_mean + posterior_var.sqrt() * eps                   # reparameterization
        # do reconstruction
        # decoder
        decoded1_ac = self.de_ac1(z)
        decoded1_dr = self.de_dr(decoded1_ac)
        decoded2 = self.de_fc(decoded1_dr)
        decoded2_bn = self.de_bn(decoded2)
        decoded2_ac = self.de_ac2(decoded2_bn)
        recon = decoded2_ac          # reconstructed distribution over vocabulary
        return recon
    
    def forward(self, input_, compute_loss=False, avg_loss=True):
        # compute posterior
        en2, posterior_mean, posterior_logvar = self.encode(input_) 
        posterior_var    = posterior_logvar.exp()
        
        recon = self.decode(input_, posterior_mean, posterior_var)
        if compute_loss:
            return recon, self.loss(input_, recon, posterior_mean, posterior_logvar, posterior_var, avg_loss)
        else:
            return recon

    def loss(self, input_, recon, posterior_mean, posterior_logvar, posterior_var, avg=True):
        # NL
        NL  = -(input_ * (recon + 1e-10).log()).sum(1)
        # KLD, see Section 3.3 of Akash Srivastava and Charles Sutton, 2017, 
        # https://arxiv.org/pdf/1703.01488.pdf
        prior_mean   = self.prior_mean.expand_as(posterior_mean)
        prior_var    = self.prior_var.expand_as(posterior_mean)
        prior_logvar = self.prior_logvar.expand_as(posterior_mean)
        var_division    = posterior_var  / prior_var
        diff            = posterior_mean - prior_mean
        diff_term       = diff * diff / prior_var
        logvar_division = prior_logvar - posterior_logvar
        # put KLD together
        KLD = 0.5 * ( (var_division + diff_term + logvar_division).sum(1) - self.num_topic)
        
        # gamma
        n, _ = input_.size()
        gamma_mean, gamma_logvar = self.gamma()
        gamma_prior, gammar_prior_bin = self.gamma_prior
        input_t = (input_ > 0).unsqueeze(dim=-1)
        input_bin = ((gammar_prior_bin.expand(n, -1, -1) == 1) & input_t)
        lambda_c = 20.0
        
        gamma_prior = gamma_prior.expand(n, -1, -1)      
        
        GL = lambda_c * ((gamma_prior - (input_bin.int()*gamma_mean))**2).sum((1, 2))
        
        # loss
        loss = (NL + KLD + GL)
        
        # in traiming mode, return averaged loss. In testing mode, return individual loss
        if avg:
            return loss.mean()
        else:
            return loss

## Train

In [50]:
from sklearn import metrics

In [51]:
def compute_accuracy(y_pred, y_true):
    accuracy = metrics.accuracy_score(y_true, y_pred)
    precision, recall, f1_score, _ = metrics.precision_recall_fscore_support(y_true=y_true, \
                                                     y_pred=y_pred, \
                                                     average=None)

    return (accuracy, precision, recall, f1_score)

In [52]:
model = ProdLDA(num_input, en1_units, en2_units, num_topic, drop_rate, init_mult, gamma_prior)
optimizer = torch.optim.Adam(model.parameters(), learning_rate, betas=(momentum, 0.999))

In [53]:
if torch.cuda.is_available():
    model = model.cuda()

In [54]:
for epoch in range(num_epoch):
    loss_epoch = 0.0
    model.train()                    # switch to training mode
    for input_, label_ in train_dl:
        recon, loss = model(input_, compute_loss=True)
        # optimize
        optimizer.zero_grad()        # clear previous gradients
        loss.backward()              # backprop
        optimizer.step()             # update parameters
        # report
        loss_epoch += loss.item()    # add loss to loss_epoch
    if epoch % 5 == 0:
        model.eval()
        # Test Model
        pred_train = []
        label_train = []
        pred_test = []
        label_test = []
        
        for x_train, y_train in train_dl:
            encoded, theta_mean, theta_logvar = model.encode(x_train)
            temp_theta_mean = theta_mean.argmax(-1).int().data.cpu().tolist()
            temp_y_train = y_train.argmax(-1).flatten().data.cpu().tolist()
            
            pred_train.extend(temp_theta_mean)
            label_train.extend(temp_y_train)
        
        accuracy_train, precision_train, recall_train, f1_score_train = compute_accuracy(pred_train, label_train)
        
        for x_test, y_test in test_dl:
            encoded, theta_mean, theta_logvar = model.encode(x_test)
            temp_theta_mean = theta_mean.argmax(-1).int().data.cpu().tolist()
            temp_y_test = y_test.argmax(-1).flatten().data.cpu().tolist()
            
            pred_test.extend(temp_theta_mean)
            label_test.extend(temp_y_test)
        
        accuracy_test, precision_test, recall_test, f1_score_test = compute_accuracy(pred_test, label_test)
        print ("##################################################")
        print('Epoch {}, loss={}, accuracy_train={}, accuracy_test={}'.format(epoch, loss_epoch / len(input_), accuracy_train, accuracy_test))
        for k in range(num_topic):
            print ("precision_train{}".format(k), "=" , "{:.9f}".format(precision_train[k]), \
                 "recall_train{}".format(k), "=" , "{:.9f}".format(recall_train[k]), \
                 "f1_score_train{}".format(k), "=" , "{:.9f}".format(f1_score_train[k]))
            print ("precision_te{}".format(k), "=" , "{:.9f}".format(precision_test[k]), \
                 "recall_te{}".format(k), "=" , "{:.9f}".format(recall_test[k]), \
                 "f1_score_te{}".format(k), "=" , "{:.9f}".format(f1_score_test[k]))
        emb = model.de_fc.weight.data.detach().cpu().numpy().T
        print_top_words(emb, vocab, 50)
#         print_perp(model)
        print ("##################################################")        

##################################################
Epoch 0, loss=1333.859552734375, accuracy_train=0.6132875536480686, accuracy_test=0.6184738955823293
precision_train0 = 0.725344481 recall_train0 = 0.451218338 f1_score_train0 = 0.556347227
precision_te0 = 0.761904762 recall_te0 = 0.453608247 f1_score_te0 = 0.568659128
precision_train1 = 0.496794452 recall_train1 = 0.828948805 f1_score_train1 = 0.621262323
precision_te1 = 0.494206257 recall_te1 = 0.841222880 f1_score_te1 = 0.622627737
precision_train2 = 0.756504065 recall_train2 = 0.582715792 f1_score_train2 = 0.658333825
precision_te2 = 0.759168704 recall_te2 = 0.586402266 f1_score_te2 = 0.661694193
---------------Printing the Topics------------------
balsam fratti spr avacado tendenc street wth tequila antithesi basqu thumb condit heat banana away candlelit white concept trifl unn boat outsid warmth glove color group meze beg snack bond sip class strip lob vile eaten know everi hop achil closest swell seafood obv jet section someon i

##################################################
Epoch 25, loss=1284.889060546875, accuracy_train=0.8563776824034335, accuracy_test=0.8683966635773864
precision_train0 = 0.911062454 recall_train0 = 0.841567948 f1_score_train0 = 0.874937419
precision_te0 = 0.929038282 recall_te0 = 0.854810997 f1_score_te0 = 0.890380313
precision_train1 = 0.806023589 recall_train1 = 0.835498308 f1_score_train1 = 0.820496328
precision_te1 = 0.822857143 recall_te1 = 0.852071006 f1_score_te1 = 0.837209302
precision_train2 = 0.851763299 recall_train2 = 0.892391191 f1_score_train2 = 0.871604057
precision_te2 = 0.853046595 recall_te2 = 0.898961284 f1_score_te2 = 0.875402299
---------------Printing the Topics------------------
seafood goat wrap pepper crab tasti calamari mozzarella dumpl mushroom curri sandwich perfectli mango cook starter creami vegetarian steam fri pad grill bland garlic good banana chorizo bass tempura salt sushi snapper oliv cream cake lobster milk cold salsa delici toast highlight delic 

##################################################
Epoch 50, loss=1277.0822890625, accuracy_train=0.8851158798283262, accuracy_test=0.8897126969416126
precision_train0 = 0.946783749 recall_train0 = 0.861889627 f1_score_train0 = 0.902344341
precision_te0 = 0.950897073 recall_te0 = 0.865120275 f1_score_te0 = 0.905982906
precision_train1 = 0.826668017 recall_train1 = 0.891278245 f1_score_train1 = 0.857758168
precision_te1 = 0.829690346 recall_te1 = 0.898422091 f1_score_te1 = 0.862689394
precision_train2 = 0.884544712 recall_train2 = 0.904394113 f1_score_train2 = 0.894359292
precision_te2 = 0.890740741 recall_te2 = 0.908404155 f1_score_te2 = 0.899485741
---------------Printing the Topics------------------
goat seafood mozzarella pepper calamari wrap mushroom crab curri mango dumpl creami sandwich perfectli tasti starter pad steam banana chorizo cream garlic milk vegetarian snapper bass tempura cook oliv sushi fieri lobster highlight juic vinegar toast delic pineappl butternut gra chunk plu

##################################################
Epoch 75, loss=1275.4836005859374, accuracy_train=0.8846008583690987, accuracy_test=0.8928019771393265
precision_train0 = 0.960719409 recall_train0 = 0.838582298 f1_score_train0 = 0.895505502
precision_te0 = 0.971400394 recall_te0 = 0.846219931 f1_score_te0 = 0.904499541
precision_train1 = 0.829369079 recall_train1 = 0.898264382 f1_score_train1 = 0.862443012
precision_te1 = 0.849673203 recall_te1 = 0.897435897 f1_score_te1 = 0.872901679
precision_train2 = 0.870611440 recall_train2 = 0.921406951 f1_score_train2 = 0.895289286
precision_te2 = 0.863715278 recall_te2 = 0.939565628 f1_score_te2 = 0.900045228
---------------Printing the Topics------------------
mozzarella pepper goat calamari seafood wrap crab mango creami starter mushroom curri steam banana pad dumpl chorizo perfectli milk snapper cream bass tasti tempura oliv sandwich fieri plum pineappl highlight vinegar delic gra toast lobster chunk butternut fri aromat garlic tartar sizz

##################################################
Epoch 100, loss=1275.21734765625, accuracy_train=0.879793991416309, accuracy_test=0.8894037689218413
precision_train0 = 0.958221024 recall_train0 = 0.821727824 f1_score_train0 = 0.884741017
precision_te0 = 0.971000000 recall_te0 = 0.834192440 f1_score_te0 = 0.897412200
precision_train1 = 0.826587858 recall_train1 = 0.900665866 f1_score_train1 = 0.862038343
precision_te1 = 0.850746269 recall_te1 = 0.899408284 f1_score_te1 = 0.874400767
precision_train2 = 0.863463229 recall_train2 = 0.922763803 f1_score_train2 = 0.892129162
precision_te2 = 0.854935622 recall_te2 = 0.940509915 f1_score_te2 = 0.895683453
---------------Printing the Topics------------------
mozzarella calamari starter pepper mango crab creami curri wrap seafood chorizo steam banana pad goat dumpl milk snapper mushroom tempura fieri cream perfectli pineappl bass butternut aromat oliv urchin highlight plum vinegar sesam delic chunk sizzl gra tartar boil flaki garlicki carpacc

##################################################
Epoch 125, loss=1275.0457958984375, accuracy_train=0.8741974248927039, accuracy_test=0.8856966326845845
precision_train0 = 0.956310130 recall_train0 = 0.813733988 f1_score_train0 = 0.879279842
precision_te0 = 0.972672065 recall_te0 = 0.825601375 f1_score_te0 = 0.893122677
precision_train1 = 0.825704759 recall_train1 = 0.892042353 f1_score_train1 = 0.857592612
precision_te1 = 0.847441860 recall_te1 = 0.898422091 f1_score_te1 = 0.872187650
precision_train2 = 0.850572501 recall_train2 = 0.922659430 f1_score_train2 = 0.885150696
precision_te2 = 0.847529813 recall_te2 = 0.939565628 f1_score_te2 = 0.891177788
---------------Printing the Topics------------------
mozzarella mango starter pepper calamari creami chorizo curri crab snapper steam wrap banana pad seafood milk dumpl fieri butternut aromat goat tempura pineappl plum urchin cream mushroom bass flaki sesam vinegar sizzl garlicki boil highlight gra avacado oliv pistachio tartar rosemari

##################################################
Epoch 150, loss=1274.9696484375, accuracy_train=0.8697682403433477, accuracy_test=0.8792091442693852
precision_train0 = 0.958764071 recall_train0 = 0.803910238 f1_score_train0 = 0.874535073
precision_te0 = 0.973251029 recall_te0 = 0.812714777 f1_score_te0 = 0.885767790
precision_train1 = 0.816334661 recall_train1 = 0.894662155 f1_score_train1 = 0.853705536
precision_te1 = 0.839037928 recall_te1 = 0.894477318 f1_score_te1 = 0.865871122
precision_train2 = 0.846806051 recall_train2 = 0.917336395 f1_score_train2 = 0.880661323
precision_te2 = 0.838682432 recall_te2 = 0.937677054 f1_score_te2 = 0.885421311
---------------Printing the Topics------------------
mozzarella mango starter chorizo calamari pepper snapper steam curri fieri creami butternut pad crab milk aromat banana dumpl wrap urchin pineappl boil flaki plum sizzl seafood garlicki cream pistachio vinegar gra tempura bass avacado rosemari goat carpaccio ham mushroom highlight chunk 

##################################################
Epoch 175, loss=1274.97136328125, accuracy_train=0.8679828326180258, accuracy_test=0.8804448563484708
precision_train0 = 0.956591640 recall_train0 = 0.802272946 f1_score_train0 = 0.872662511
precision_te0 = 0.974226804 recall_te0 = 0.811855670 f1_score_te0 = 0.885660731
precision_train1 = 0.816493610 recall_train1 = 0.892697304 f1_score_train1 = 0.852896699
precision_te1 = 0.842592593 recall_te1 = 0.897435897 f1_score_te1 = 0.869149952
precision_train2 = 0.843380444 recall_train2 = 0.915562050 f1_score_train2 = 0.877990191
precision_te2 = 0.838247683 recall_te2 = 0.939565628 f1_score_te2 = 0.886019590
---------------Printing the Topics------------------
mozzarella mango starter fieri chorizo curri butternut snapper calamari pepper aromat steam milk crab banana urchin flaki wrap creami pad plum sizzl pineappl dumpl garlicki gra pistachio avacado vinegar boil goat seafood rosemari bass carpaccio cream mushroom gratin ham tempura gazpacho

## Test

In [55]:
model.eval()
gamma_mean, gamma_logvar = model.gamma()
gm, gl = gamma_mean.data.cpu().numpy(), gamma_logvar.data.cpu().numpy()
print_gamma(gm, seed_words, vocab, vocab2id)

5998 food 0 0 [9.9996567e-01 3.3058648e-05 1.2778987e-06]
13729 sauc 0 0 [1.0000000e+00 1.4041303e-08 6.2098975e-09]
2883 chicken 0 0 [1.0000000e+00 7.3674580e-14 4.3292796e-14]
14277 shrimp 0 0 [1.0000000e+00 1.1627485e-09 9.4916874e-10]
2826 chees 0 0 [1.0000000e+00 7.8425894e-10 4.3580079e-12]
12158 potato 0 0 [1.0000000e+00 6.3853653e-13 1.0989002e-13]
6169 fri 0 0 [9.9999905e-01 2.7589283e-07 6.9605949e-07]
16268 tomato 0 0 [1.000000e+00 9.556484e-10 2.038501e-09]
13325 roast 0 0 [1.0000000e+00 1.2562194e-09 3.1371628e-10]
10925 onion 0 0 [1.0000000e+00 3.1480988e-12 2.1347347e-12]
12107 pork 0 0 [9.9999440e-01 5.0415756e-06 6.0911532e-07]
6656 goat 0 0 [1.0000000e+00 5.3420615e-11 2.0737238e-10]
6882 grill 0 0 [1.0000000e+00 1.8845795e-12 3.9701306e-12]
16543 tuna 0 0 [1.0000000e+00 2.0551364e-12 7.7016041e-13]
13592 salad 0 0 [1.000000e+00 4.192013e-09 8.473717e-10]
1411 beef 0 0 [1.000000e+00 6.034032e-10 1.889462e-10]
15764 tapa 0 0 [9.9999976e-01 1.0917542e-07 1.3662320e-07]


In [56]:
emb = model.de_fc.weight.data.cpu().numpy().T
print_top_words(emb, vocab, 50)
print_perp(model)

---------------Printing the Topics------------------
mozzarella mango chorizo fieri aromat calamari butternut curri starter snapper flaki steam milk pepper urchin banana wrap avacado plum crab pad pistachio creami pineappl gra sizzl vinegar gratin dumpl goat gazpacho carpaccio spaghetti gai massaman boil cream seafood chunk mushroom mahi garlicki flav monkfish peppercorn rosemari gorgonzola scrambl highlight tempura
adult frustrat apologet friendlier disinterest credit storm appolog clearer flag blame smirk unapologet bill repli ct meanwhil automat rectifi patient unforget inqu consult unattent children comp fume displeasur charg respond preoccupi question misogyni unheard waitor bordeaux shoddi umpteen steward palaetswer lousi goodby substitu rudest didnt taff unprofession deni cash obnoxi
glow column brightli softli antiqu nobl furnitur lit rockwel ceil fixtur wall rail courtyard hip sofa lair arti darker lobbi centuri cosi cer dark linen flicker booth vault spotless artifact country