In [1]:
# refs:
# https://arxiv.org/pdf/1511.06038.pdf
# https://github.com/ysmiao/nvdm/blob/master/nvdm.py
# https://github.com/smutahoang/ntm/blob/main/models/nvdm.py
# data: 
# https://github.com/ysmiao/nvdm/tree/master/data/20news

import os
import argparse
import sys
from datetime import datetime

import random
import math
import numpy as np
import torch
import torch.nn as nn

# from graphviz import Digraph
from torchviz import make_dot

# https://github.com/mblondel/fenchel-young-losses/tree/master
# from fyl_pytorch import SparsemaxLoss # takes label proportions
# from torch.autograd import Function

# !pip install entmax
from entmax.activations import sparsemax
from entmax.root_finding import entmax_bisect, sparsemax_bisect

from losses import entmax_loss
import torch
import torch.nn.functional as F

from entmax.activations import sparsemax, entmax15
from fyl_pytorch import SparsemaxLoss # takes label proportions
from losses import softmax_loss, entmax_loss

In [2]:
#-------------------------------
def data_set(data_url):
    """process data input."""
    data = []
    word_count = []
    fin = open(data_url)
    while True:
        line = fin.readline()
        if not line:
            break
        id_freqs = line.split()
        doc = {}
        count = 0
        for id_freq in id_freqs[1:]:
            items = id_freq.split(':')
            # python starts from 0
            doc[int(items[0])-1] = int(items[1])
            count += int(items[1])
        if count > 0:
            data.append(doc)
            word_count.append(count)
    fin.close()
    return data, word_count


def create_batches(data_size, batch_size, shuffle=True):
    """create index by batches."""
    batches = []
    ids = list(range(data_size))
    if shuffle:
        random.shuffle(ids)
    for i in list(range(data_size // batch_size)):
        start = i * batch_size
        end = (i + 1) * batch_size
        batches.append(ids[start:end])
  # the batch of which the length is less than batch_size
    rest = data_size % batch_size
    if rest > 0:
        batches.append(ids[-rest:] + [-1] * (batch_size - rest))  # -1 as padding
    return batches


def fetch_data(data, count, idx_batch, vocab_size):
    """fetch input data by batch."""
    batch_size = len(idx_batch)
    data_batch = np.zeros((batch_size, vocab_size))
    count_batch = []
    mask = np.zeros(batch_size)
    for i, doc_id in enumerate(idx_batch):
        if doc_id != -1:
            for word_id, freq in data[doc_id].items():
                data_batch[i, word_id] = freq
            count_batch.append(count[doc_id])
            mask[i]=1.0
        else:
            count_batch.append(0)
    return data_batch, count_batch, mask

In [3]:
#-------------------------------
class Encoder(nn.Module):
    def __init__(self, args):
        super(Encoder, self).__init__()
        self.args = args
        self.mlp = nn.Sequential(nn.Linear(args['n_input'], args['n_hidden']),
                                 nn.Tanh(),
                                 nn.Linear(args['n_hidden'], args['n_hidden']),
                                 nn.Tanh())
        # self.mlp_drop = nn.Dropout(0.8)
        self.mean_fc = nn.Linear(args['n_hidden'], args['n_topics'])
        self.logsigm_fc = nn.Linear(args['n_hidden'], args['n_topics'])
        # nn.init.normal_(self.mean_fc.weight)
        # nn.init.normal_(self.mean_fc.bias)
        # nn.init.zeros_(self.logsigm_fc.weight)  # cf. https://github.com/ysmiao/nvdm/blob/master/nvdm.py#L51
        # nn.init.zeros_(self.logsigm_fc.bias)

    def forward(self, doc_freq_vecs, train=True):

        en_vec = self.mlp(doc_freq_vecs)
        # print(en_vec)
        # if train:
        #     en_vec = self.mlp_drop(en_vec)

        mean = self.mean_fc(en_vec)
        logsigm = self.logsigm_fc(en_vec)  # 1/2 log sigma^2
        # print(mean.sum(), logsigm.sum())
        kld = -0.5 * torch.sum(1 - torch.square(mean) + 2 * logsigm - (2 * logsigm).exp(), 1)

        return kld, mean, logsigm
        
class Decoder(nn.Module):
    def __init__(self, args):
        super(Decoder, self).__init__()
        self.args = args
        self.decoder = nn.Linear(args['n_topics'], args['n_input'])  # 

    def forward(self, mean, logsigm, train=True):
        if train:
            eps = torch.randn(self.args['batch_size'], self.args['n_topics'])  # eps ~ N(0,1)
            z = mean + logsigm.exp() * eps  # reparam
            # print(z.shape)
            recon = self.decoder(z)
            return recon, self.decoder 

        else:
            recon = self.decoder(mean) # sample? 
            return recon


#-------------------------------
def make_optimizer(encoder, decoder, args):
    if args['optimizer'] == 'Adam':
        optimizer_enc = torch.optim.Adam(encoder.parameters(), 
        args['learning_rate'], betas=(args['momentum'], 0.999))
        optimizer_dec = torch.optim.Adam(decoder.parameters(),  
        args['learning_rate'], betas=(args['momentum'], 0.999))

    elif args['optimizer'] == 'SGD':
        optimizer_enc = torch.optim.SGD(encoder.parameters(), 
        args['learning_rate'], momentum=args['momentum'])
        optimizer_dec = torch.optim.SGD(decoder.parameters(), 
        args['learning_rate'], momentum=args['momentum'])

    return optimizer_enc, optimizer_dec

In [4]:
#-------------------------------
def train(args, data_dir, test=True, save=False, output_path=None):
    torch.manual_seed(args['seed'])

    timestamp = datetime.now().strftime('%Y_%m_%d_%H.%M') #%S
    # writer = SummaryWriter('runs/'.format(timestamp))
    best_vloss = 1_000_000.

    train_url = os.path.join(data_dir, 'train.feat')
    train_set, train_count = data_set(train_url)

    # test, dev batches
    test_url = os.path.join(data_dir, 'test.feat')
    test_set, test_count = data_set(test_url)
    dev_set = test_set[:50]
    dev_count = test_count[:50]
    test_set = test_set[50:]
    test_count[50:]
    test_batches = create_batches(len(test_set), 64, shuffle=False)
    dev_batches = create_batches(len(dev_set), 50, shuffle=False)
    
    # model
    encoder = Encoder(args)
    encoder.to(args['device'])
    decoder = Decoder(args)
    decoder.to(args['device'])

    optimizer_enc, optimizer_dec = make_optimizer(encoder, decoder, args)
    #-------------------------------
    # train    
    for epoch in range(args['n_epoch']):
        train_batches = create_batches(len(train_set), 64, shuffle=True)               
        for switch in list(range(0, 2)):
            if switch == 0:
                optimizer = optimizer_dec
                decoder.train()
                print_mode = 'updating decoder'
            else:
                optimizer = optimizer_enc
                encoder.train()
                print_mode = 'updating encoder'
            for i in list(range(args['n_alternating_epoch'])):
                loss_sum = 0.0
                ppx_sum = 0.0
                kld_sum = 0.0
                word_count = 0
                doc_count = 0
                error_sum = 0.0              
                mean_sum = 0.0
                logsigm_sum = 0.0

                for idx_batch in train_batches[:-1]:
                    data_batch, count_batch, mask = fetch_data(train_set, train_count, idx_batch, 2000)
                    batch_N = torch.tensor(data_batch.sum(-1), dtype=torch.float)
                    data_batch = torch.tensor((data_batch.T/batch_N), dtype=torch.float).T
                    
                    kld, mean, logsigm = encoder(data_batch)
                    recon, dec = decoder(mean, logsigm)

                    # fyl = SparsemaxLoss()
                    fy_loss = entmax_loss(recon, data_batch, alpha=1.5)*batch_N
                    loss = fy_loss + args['c']*kld #1e-5*
                    #nll + args.c *kld #1e-4*


                    error_sum += torch.abs(data_batch-entmax15(recon, dim=-1)).sum()
                    mean_sum += mean.mean()
                    logsigm_sum += logsigm.mean()
                
                    optimizer.zero_grad()      
                    loss.mean().backward()  
                    optimizer.step()  
                    
                    loss_sum += loss.sum().item() 
                    kld_sum += (kld.sum() / mask.sum()).item()       
                    word_count += np.sum(count_batch)
                    # per document loss
                    count_batch = np.add(count_batch, 1e-12)
                    ppx_sum += (loss.detach()/count_batch).sum().item()
                    doc_count += mask.sum().item()

                print_kld = kld_sum/len(train_batches[:-1])
                print_ppx = np.exp(loss_sum / word_count)
                print_ppx_perdoc = np.exp(ppx_sum / doc_count)

                print('post mean', (mean_sum/len(train_batches[:-1])).item(), 
                'logsigm', (logsigm_sum/len(train_batches[:-1])).item())                
                print('| Epoch train: {:d} |'.format(epoch+1), print_mode, '{:d}'.format(i),
                      '| Corpus ppx: {:.5f}'.format(print_ppx),  # perplexity for all docs
                      '| KLD: {:.5}'.format(print_kld))        
                av_error_train = error_sum.item()/2000
                print(av_error_train)
        #-------------------------------
        # dev
        encoder.eval()
        decoder.eval()
        with torch.no_grad():
            loss_sum = 0.0
            ppx_sum = 0.0
            kld_sum = 0.0
            word_count = 0
            doc_count = 0
            error_sum = 0.0                   
            mean_sum = 0.0
            logsigm_sum = 0.0
    
            for idx_batch in dev_batches:
                data_batch, count_batch, mask = fetch_data(dev_set, dev_count, idx_batch, 2000)
                batch_N = torch.tensor(data_batch.sum(-1), dtype=torch.float)
                # print(count_batch)
                # print(batch_N)

                data_batch = torch.tensor((data_batch.T/batch_N), dtype=torch.float).T
                        
                kld, mean, logsigm = encoder(data_batch)
                recon = decoder(mean, logsigm, False)
                
                # fyl = SparsemaxLoss()
                fy_loss = entmax_loss(recon, data_batch, alpha=1.5)*batch_N
                loss = fy_loss + args['c']*kld #1e-5*

                error_sum += torch.abs(data_batch-entmax15(recon, dim=-1)).sum()                        
                             
                loss_sum += loss.sum().item() 
                kld_sum += (kld.sum() / mask.sum()).item()

                word_count += np.sum(count_batch)       
                count_batch = np.add(count_batch, 1e-12)
                ppx_sum += (loss/count_batch).sum().item()
                doc_count += mask.sum().item()

            print_kld = kld_sum/len(dev_batches)
            print_vppx = np.exp(loss_sum / word_count)
            print_ppx_perdoc = np.exp(ppx_sum / doc_count)

            print('dev post mean', (mean_sum/len(dev_batches)), 
            'logsigm', (logsigm_sum/len(dev_batches)))            
            print('| Epoch eval: {:d}'.format(epoch+1), 
                  '| Corpus ppx: {:.5f}'.format(print_vppx),  # perplexity for all docs
                  '| KLD: {:.5}'.format(print_kld))        
            av_error = error_sum.item()/2000
            print(av_error)
            
            #-------------------------------
            # test
            if test: 
                loss_sum = 0.0
                ppx_sum = 0.0
                kld_sum = 0.0
                word_count = 0
                doc_count = 0
                error_sum = 0.0                       
                mean_sum = 0.0
                logsigm_sum = 0.0
        
                for idx_batch in test_batches[:-1]:
                    data_batch, count_batch, mask = fetch_data(test_set, test_count, idx_batch, 2000)
                    batch_N = torch.tensor(data_batch.sum(-1), dtype=torch.float)

                    data_batch = torch.tensor((data_batch.T/batch_N), dtype=torch.float).T
                            
                    kld, mean, logsigm = encoder(data_batch)
                    recon = decoder(mean, logsigm, False)
                    
                    # fyl = SparsemaxLoss()
                    fy_loss = entmax_loss(recon, data_batch, alpha=1.5)*batch_N
                    loss = fy_loss +  args['c'] *kld #1e-5*
    
                    error_sum += torch.abs(data_batch-entmax15(recon, dim=-1)).sum()               

                    loss_sum += loss.sum().item() 
                    kld_sum += (kld.sum() / mask.sum()).item()
                                
                    word_count += np.sum(count_batch)       
                    # per document loss
                    count_batch = np.add(count_batch, 1e-12)
                    ppx_sum += (loss/count_batch).sum().item()
                    doc_count += mask.sum().item()

                print_kld = kld_sum/len(test_batches[:-1])
                print_ppx = np.exp(loss_sum / word_count)
                print_ppx_perdoc = np.exp(ppx_sum / doc_count) # very different with and w/out the dev data

                print('post mean', (mean_sum/len(test_batches[:-1])), 
                'logsigm', (logsigm_sum/len(test_batches[:-1])))                
                print('| Epoch test: {:d}'.format(epoch+1), 
                        '| Corpus ppx: {:.5f}'.format(print_ppx),  # perplexity for all docs
                        '| KLD: {:.5}'.format(print_kld))                        
                av_error_test = error_sum.item()/2000
                print(av_error_test)

    return dec #, encoder, decoder

In [5]:
def default_params(n_epoch=200, batch_size=64, gpu_index=0):
    params = {'batch_size': batch_size,
              'optimizer': 'Adam',
              'learning_rate': 5e-4, #0.002,
              'momentum': 0.99,
              'n_epoch': n_epoch,
              'n_alternating_epoch': 10,
              'init_mult': 0, # 0.001,
              'device': torch.device('cpu'), #torch.device('cuda:{}'.format(gpu_index))}
              'n_input': 2000, 
              'n_topics': 50, 
              'n_hidden': 500, 
             'c': 0.01, 
             'seed': 42, 
             'save_path': ''} 
    return params

data_dir='./data/20news/'

decoder = train(default_params(), data_dir, test=True)

  data_batch = torch.tensor((data_batch.T/batch_N), dtype=torch.float).T


post mean -0.001577716669999063 logsigm -0.0005005676648579538
| Epoch train: 1 | updating decoder 0 | Corpus ppx: 3.66458 | KLD: 0.064708
11.1053388671875
post mean -0.001577716669999063 logsigm -0.0005005676648579538
| Epoch train: 1 | updating decoder 1 | Corpus ppx: 2.79444 | KLD: 0.064708
11.03869140625
post mean -0.001577716669999063 logsigm -0.0005005676648579538
| Epoch train: 1 | updating decoder 2 | Corpus ppx: 2.22783 | KLD: 0.064708
10.962232421875
post mean -0.001577716669999063 logsigm -0.0005005676648579538
| Epoch train: 1 | updating decoder 3 | Corpus ppx: 1.80358 | KLD: 0.064708
10.872931640625
post mean -0.001577716669999063 logsigm -0.0005005676648579538
| Epoch train: 1 | updating decoder 4 | Corpus ppx: 1.49254 | KLD: 0.064708
10.6717333984375
post mean -0.001577716669999063 logsigm -0.0005005676648579538
| Epoch train: 1 | updating decoder 5 | Corpus ppx: 1.31283 | KLD: 0.064708
10.2959150390625
post mean -0.001577716669999063 logsigm -0.0005005676648579538
| Epo

  data_batch = torch.tensor((data_batch.T/batch_N), dtype=torch.float).T
  data_batch = torch.tensor((data_batch.T/batch_N), dtype=torch.float).T


post mean 0.0 logsigm 0.0
| Epoch test: 1 | Corpus ppx: 1.12924 | KLD: 29.623
6.6737294921875
post mean 0.004440819378942251 logsigm -0.6943723559379578
| Epoch train: 2 | updating decoder 0 | Corpus ppx: 1.11916 | KLD: 29.031
9.9988369140625
post mean 0.004440819378942251 logsigm -0.6943723559379578
| Epoch train: 2 | updating decoder 1 | Corpus ppx: 1.11387 | KLD: 29.031
9.879228515625
post mean 0.004440819378942251 logsigm -0.6943723559379578
| Epoch train: 2 | updating decoder 2 | Corpus ppx: 1.11219 | KLD: 29.031
9.8751455078125
post mean 0.004440819378942251 logsigm -0.6943723559379578
| Epoch train: 2 | updating decoder 3 | Corpus ppx: 1.11160 | KLD: 29.031
9.8743310546875
post mean 0.004440819378942251 logsigm -0.6943723559379578
| Epoch train: 2 | updating decoder 4 | Corpus ppx: 1.11141 | KLD: 29.031
9.8739462890625
post mean 0.004440819378942251 logsigm -0.6943723559379578
| Epoch train: 2 | updating decoder 5 | Corpus ppx: 1.11138 | KLD: 29.031
9.8730341796875
post mean 0.0

In [6]:
decoder.weight.shape

torch.Size([2000, 50])

In [7]:
torch.topk(decoder.weight, 10, 0)[0].shape

torch.Size([10, 50])

In [8]:
def vocab(data_url):
  """process data input."""
  data = []
  i2w = {}
  fin = open(data_url)
  while True:
    line = fin.readline()
    if not line:
      break
    word_freqs = line.split()
    index = len(data)
    data.append(word_freqs[0])
    i2w[index] = word_freqs[0]
      
  fin.close()
  return i2w

voc = vocab('./data/20news/vocab.new')
# len(voc)

In [9]:
words_list = []
for t in range(50):
    words = torch.topk(decoder.weight, 10, 0)[1][:,t]  
    topic_list = []
    for w in words:
        i2w = voc[w.item()]
        topic_list.append(i2w)
    words_list.append(topic_list)
print(words_list)

[['flyers', 'uiuc', 'cso', 'disk', 'udel', 'meg', 'scsi', 'min', 'int', 'format'], ['leafs', 'roger', 'jewish', 'games', 'baseball', 'card', 'his', 'boston', 'kings', 'bc'], ['br', 'draft', 'isc', 'simms', 'rochester', 'insurance', 'azerbaijan', 'sandvik', 'cache', 'sex'], ['jpeg', 'christian', 'cd', 'file', 'cwru', 'pt', 'rom', 'church', 'season', 'christians'], ['gatech', 'audio', 'cwru', 'fi', 'utexas', 'weapon', 'freenet', 'columbia', 'tape', 'cleveland'], ['flyers', 'greek', 'xterm', 'church', 'br', 'lc', 'cpu', 'ac', 'leafs', 'water'], ['azerbaijan', 'greek', 'turkish', 'armenian', 'armenia', 'satellite', 'turks', 'port', 'visual', 'georgia'], ['jpeg', 'motif', 'coverage', 'server', 'crime', 'satellite', 'shuttle', 'health', 'russian', 'teams'], ['armenians', 'azerbaijan', 'civilians', 'arabs', 'armenian', 'pro', 'population', 'mcgill', 'children', 'israeli'], ['sgi', 'stephanopoulos', 'mr', 'morality', 'islam', 'ripem', 'xlib', 'vga', 'isc', 'president'], ['christians', 'atheism

In [10]:
words_list = []
for t in range(50):
    words = torch.topk(sparsemax(decoder.weight, dim=0), 10, 0)[1][:,t]  
    topic_list = []
    for w in words:
        i2w = voc[w.item()]
        topic_list.append(i2w)
    words_list.append(topic_list)
print(words_list)

[['flyers', 'uiuc', 'cso', 'disk', 'udel', 'meg', 'scsi', 'min', 'int', 'format'], ['leafs', 'roger', 'jewish', 'games', 'baseball', 'card', 'his', 'boston', 'kings', 'bc'], ['br', 'draft', 'isc', 'simms', 'rochester', 'insurance', 'azerbaijan', 'sandvik', 'cache', 'sex'], ['jpeg', 'christian', 'cd', 'file', 'cwru', 'pt', 'rom', 'church', 'season', 'christians'], ['gatech', 'audio', 'cwru', 'fi', 'utexas', 'weapon', 'freenet', 'columbia', 'tape', 'cleveland'], ['flyers', 'greek', 'xterm', 'church', 'br', 'lc', 'cpu', 'ac', 'leafs', 'water'], ['azerbaijan', 'greek', 'turkish', 'armenian', 'armenia', 'satellite', 'turks', 'port', 'visual', 'georgia'], ['jpeg', 'motif', 'coverage', 'server', 'crime', 'satellite', 'shuttle', 'health', 'russian', 'teams'], ['armenians', 'azerbaijan', 'civilians', 'arabs', 'armenian', 'pro', 'population', 'mcgill', 'children', 'israeli'], ['sgi', 'stephanopoulos', 'mr', 'morality', 'islam', 'ripem', 'xlib', 'vga', 'isc', 'president'], ['christians', 'atheism

In [11]:
sparsemax(decoder.weight)[0]

tensor([0.0131, 0.0476, 0.0000, 0.0600, 0.0050, 0.0037, 0.0186, 0.0505, 0.0117,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0074, 0.0411,
        0.0000, 0.0524, 0.0549, 0.0000, 0.0377, 0.0000, 0.0745, 0.0065, 0.0000,
        0.0157, 0.0371, 0.0000, 0.0292, 0.0000, 0.0387, 0.0000, 0.0448, 0.0090,
        0.0000, 0.0128, 0.0275, 0.0226, 0.0609, 0.0000, 0.0000, 0.0336, 0.0321,
        0.0150, 0.0504, 0.0266, 0.0000, 0.0593], grad_fn=<SelectBackward0>)

In [12]:
# L * Loss(model prob; label freqs)
# L * KL(label freqs || model prob)
# KL(label freqs || model prob)

#W = W - stepsize * (L * grad Loss(model prob; label freqs))

# don't normalize increase step size
# tune step

In [13]:
# entmax