In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

In [3]:
import numpy as np
import math
import pickle
import argparse
import os
import math
import matplotlib.pyplot as plt

In [4]:
from pathlib import Path
from collections import OrderedDict
import pickle, gzip, math, torch, matplotlib as mpl
import matplotlib.pyplot as plt

Path.ls = lambda x: list(x.iterdir())

## Constants

In [5]:
bs = 200
en1_units=100
en2_units=100
num_topic=50
num_input=1995
variance=0.995
init_mult=1.0
learning_rate=0.002
batch_size=200
momentum=0.99
num_epoch=100
nogpu=True
drop_rate=0.2

## Data Utility Functions

In [6]:
def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str): return [o]
    if isinstance(o, Iterable): return list(o)
    return [o]
def setify(o): return o if isinstance(o,set) else set(listify(o))
def compose(x, funcs, *args, order_key='_order', **kwargs):
    key = lambda o: getattr(o, order_key, 0)
    for f in sorted(listify(funcs), key=key): x = f(x, **kwargs)
    return x

def collate(b):
    return torch.stack(b)

class Onehotify():
    _order=1
    def __init__(self, vocab_size):
        self.vocab_size = vocab_size
    def __call__(self, item):
        return np.array(np.bincount(item.astype('int'), minlength=self.vocab_size))

class Tensorify():
    _order=2
    def __call__(self, item):
        return torch.from_numpy(item)

class Floatify():
    _order=3
    def __call__(self, item):
        return item.float()
    
class News20Dataset(Dataset):
    def __init__(self, x, tfms): 
        self.x = x
        self.x_tfms = tfms
    def __len__(self): 
        return len(self.x)
    def __getitem__(self, i): 
        return compose(self.x[i], self.x_tfms)
    
class Sampler():
    def __init__(self, ds, bs, shuffle=False):
        self.n,self.bs,self.shuffle = len(ds),bs,shuffle
        
    def __iter__(self):
        self.idxs = torch.randperm(self.n) if self.shuffle else torch.arange(self.n)
        for i in range(0, self.n, self.bs): yield self.idxs[i:i+self.bs]

class DataLoader():
    def __init__(self, ds, sampler, collate_fn=collate):
        self.ds,self.sampler,self.collate_fn = ds,sampler,collate_fn
        
    def __iter__(self):
        for s in self.sampler: yield self.collate_fn([self.ds[i] for i in s])

## Topic Model Utility Functions

In [7]:
def print_perp(model):
    cost=[]
    model.eval()                        # switch to testing mode
    input_ = tensor_te
    recon, loss = model(input_, compute_loss=True, avg_loss=False)
    loss = loss.data
    counts = tensor_te.sum(1)
    avg = (loss / counts).mean()
    print('The approximated perplexity is: ', math.exp(avg))

def visualize():
    global recon
    input_ = tensor_te[:10]
    register_vis_hooks(model)
    recon = model(input_, compute_loss=False)
    remove_vis_hooks()
    save_visualization('pytorch_model', 'png')
    
def print_perp(model):
    cost=[]
    model.eval()                        # switch to testing mode
    input_ = next(iter(test_dl))
    recon, loss = model(input_, compute_loss=True, avg_loss=False)
    loss = loss.data
    counts = input_.sum(1)
    avg = (loss / counts).mean()
    print('The approximated perplexity is: ', math.exp(avg))

def visualize():
    global recon
    input_ = tensor_te[:10]
    register_vis_hooks(model)
    recon = model(input_, compute_loss=False)
    remove_vis_hooks()
    save_visualization('pytorch_model', 'png')

def print_top_words(beta, feature_names, n_top_words=10):
    print ('---------------Printing the Topics------------------')
    for i in range(len(beta)):
        line = " ".join([feature_names[j] 
                         for j in beta[i].argsort()[:-n_top_words - 1:-1]])
        print('{}'.format(line))
    print ('---------------End of Topics------------------')

## Load Data

In [8]:
path = Path('data/20news_clean/')

In [9]:
path.ls()

[PosixPath('data/20news_clean/test.txt.npy'),
 PosixPath('data/20news_clean/train.txt.npy'),
 PosixPath('data/20news_clean/valid.txt.npy'),
 PosixPath('data/20news_clean/vocab.pkl')]

In [10]:
path_train = path/'train.txt.npy'
path_test = path/'test.txt.npy'
path_vocab = path/'vocab.pkl'

In [11]:
data_tr = np.load(path_train, encoding="latin1")
data_te = np.load(path_test, encoding="latin1")
vocab_file = open(path_vocab,'rb')
vocab = pickle.load(vocab_file)
vocab_size=len(vocab)

In [12]:
data_tr = np.array([doc for doc in data_tr if np.sum(doc)!=0])
data_te = np.array([doc for doc in data_te if np.sum(doc)!=0])

In [13]:
tfms = [Onehotify(vocab_size=vocab_size), Tensorify(), Floatify()]

In [14]:
train_ds = News20Dataset(data_tr, tfms=tfms)
test_ds = News20Dataset(data_te, tfms=tfms)

In [15]:
train_samp = Sampler(train_ds, bs, shuffle=True)
test_samp = Sampler(test_ds, bs, shuffle=False)

In [16]:
train_dl = DataLoader(train_ds, sampler=train_samp, collate_fn=collate)
test_dl = DataLoader(test_ds, sampler=test_samp, collate_fn=collate)

## Define Model

In [17]:
class ProdLDA(nn.Module):
    def __init__(self, num_input, en1_units, en2_units, num_topic, drop_rate, init_mult):
        super(ProdLDA, self).__init__()
        self.num_input, self.en1_units, self.en2_units, \
        self.num_topic, self.drop_rate, self.init_mult = num_input, en1_units, en2_units, \
                                                            num_topic, drop_rate, init_mult
        # encoder
        self.en = nn.Sequential(OrderedDict([
            ('linear1', nn.Linear(num_input, en1_units)),
            ('act1', nn.Softplus()),
            ('linear2', nn.Linear(en1_units, en2_units)),
            ('act2', nn.Softplus()),
            ('dropout', nn.Dropout(drop_rate))
        ]))
        self.mean = nn.Sequential(OrderedDict([
            ('linear', nn.Linear(en2_units, num_topic)),
            ('batchnorm', nn.BatchNorm1d(num_topic))
        ]))
        self.logvar = nn.Sequential(OrderedDict([
            ('linear', nn.Linear(en2_units, num_topic)),
            ('batchnorm', nn.BatchNorm1d(num_topic))
        ]))
        # decoder
        self.de = nn.Sequential(OrderedDict([
            ('act1', nn.Softmax(dim=-1)),
            ('dropout', nn.Dropout(drop_rate)),
            ('linear', nn.Linear(num_topic, num_input)),
            ('batchnorm', nn.BatchNorm1d(num_input)),
            ('act2', nn.Softmax(dim=-1))
        ]))
        # prior mean and variance as constant buffers
        self.prior_mean   = torch.Tensor(1, num_topic).fill_(0)
        self.prior_var    = torch.Tensor(1, num_topic).fill_(variance)
        self.prior_mean = nn.Parameter(self.prior_mean, requires_grad=False)
        self.prior_var = nn.Parameter(self.prior_var, requires_grad=False)
        self.prior_logvar = nn.Parameter(self.prior_var.log(), requires_grad=False)
        # initialize decoder weight
        if init_mult != 0:
            #std = 1. / math.sqrt( init_mult * (num_topic + num_input))
            self.de.linear.weight.data.uniform_(0, init_mult)
        # remove BN's scale parameters
        for component in [self.mean, self.logvar, self.de]:
            component.batchnorm.weight.requires_grad = False
            component.batchnorm.weight.fill_(1.0)

    def encode(self, input_):
        encoded = self.en(input_)
        posterior_mean = self.mean(encoded)
        posterior_logvar = self.logvar(encoded)
        return encoded, posterior_mean, posterior_logvar
    
    def decode(self, input_, posterior_mean, posterior_var):
        # take sample
        eps = input_.data.new().resize_as_(posterior_mean.data).normal_() # noise 
        z = posterior_mean + posterior_var.sqrt() * eps                   # reparameterization
        # do reconstruction
        recon = self.de(z)          # reconstructed distribution over vocabulary
        return recon
    
    def forward(self, input_, compute_loss=False, avg_loss=True):
        # compute posterior
        en2, posterior_mean, posterior_logvar = self.encode(input_) 
        posterior_var    = posterior_logvar.exp()
        
        recon = self.decode(input_, posterior_mean, posterior_var)
        if compute_loss:
            return recon, self.loss(input_, recon, posterior_mean, posterior_logvar, posterior_var, avg_loss)
        else:
            return recon

    def loss(self, input_, recon, posterior_mean, posterior_logvar, posterior_var, avg=True):
        # NL
        NL  = -(input_ * (recon + 1e-10).log()).sum(1)
        # KLD, see Section 3.3 of Akash Srivastava and Charles Sutton, 2017, 
        # https://arxiv.org/pdf/1703.01488.pdf
        prior_mean   = self.prior_mean.expand_as(posterior_mean)
        prior_var    = self.prior_var.expand_as(posterior_mean)
        prior_logvar = self.prior_logvar.expand_as(posterior_mean)
        var_division    = posterior_var  / prior_var
        diff            = posterior_mean - prior_mean
        diff_term       = diff * diff / prior_var
        logvar_division = prior_logvar - posterior_logvar
        # put KLD together
        KLD = 0.5 * ( (var_division + diff_term + logvar_division).sum(1) - self.num_topic)
        # loss
        loss = (NL + KLD)
        # in traiming mode, return averaged loss. In testing mode, return individual loss
        if avg:
            return loss.mean()
        else:
            return loss

## Train

In [18]:
model = ProdLDA(num_input, en1_units, en2_units, num_topic, drop_rate, init_mult)

In [19]:
optimizer = torch.optim.Adam(model.parameters(), learning_rate, betas=(momentum, 0.999))

In [20]:
for epoch in range(num_epoch):
    loss_epoch = 0.0
    model.train()                    # switch to training mode
    for input_ in train_dl:
        recon, loss = model(input_, compute_loss=True)
        # optimize
        optimizer.zero_grad()        # clear previous gradients
        loss.backward()              # backprop
        optimizer.step()             # update parameters
        # report
        loss_epoch += loss.item()    # add loss to loss_epoch
    if epoch % 5 == 0:
        print('Epoch {}, loss={}'.format(epoch, loss_epoch / len(input_)))

Epoch 0, loss=746.7075184789197
Epoch 5, loss=665.4300326643319
Epoch 10, loss=656.3798186203529
Epoch 15, loss=640.3864467226226
Epoch 20, loss=635.5135119207974
Epoch 25, loss=627.7520083723397
Epoch 30, loss=619.6737539357152
Epoch 35, loss=615.8232500799771
Epoch 40, loss=613.9448368467133
Epoch 45, loss=610.6923949142982
Epoch 50, loss=608.6022507240033
Epoch 55, loss=606.7663190118197
Epoch 60, loss=604.5161595837824
Epoch 65, loss=604.0063718598465
Epoch 70, loss=602.491992292733
Epoch 75, loss=605.2028750715584
Epoch 80, loss=599.282444394868
Epoch 85, loss=599.0273027091191
Epoch 90, loss=598.9137646905307
Epoch 95, loss=598.7541203992121


## Test

In [21]:
emb = model.de.linear.weight.data.cpu().numpy().T
print_top_words(emb, list(zip(*sorted(vocab.items(), key=lambda x:x[1])))[0])
print_perp(model)
# visualize()

---------------Printing the Topics------------------
offense hitter pitcher team player defensive season roger career braves
launch rocket spacecraft satellite nuclear fund moon lunar orbit administration
thanks advance hus honda appreciate surrender _eos_ boot jeff wonder
microsoft printer modem hello thanks advance appreciate postscript fax greatly
bio mb jumper connector controller rom hd drive floppy interface
greece jews turkish turks jew armenian constitution palestinian professor handgun
anonymous abuse privacy threat security electronic social rights responsibility militia
armenian neighbor apartment beat floor woman armenians doctor stephanopoulos azerbaijan
annual bmw rider club ride shipping cd green motorcycle organize
st pp calgary philadelphia rangers detroit louis winnipeg pittsburgh jose
surrender gordon bank keith ignorance kent thanks uucp associate georgia
existence doctrine faith absolute truth belief revelation biblical interpretation conclusion
flyers puck neighbo