In [1]:
"""
This example demonstrates how to marginalize out discrete assignment variables
in a Pyro model.

Our example model is Latent Dirichlet Allocation. While the model in this
example does work, it is not the recommended way of coding up LDA in Pyro.
Whereas the model in this example treats documents as vectors of categorical
variables (vectors of word ids), it is usually more efficient to treat
documents as bags of words (histograms of word counts).
"""
from __future__ import absolute_import, division, print_function

import argparse
import functools
import logging

import torch
from torch import nn
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO
from pyro.optim import Adam

logging.basicConfig(format='%(relativeCreated) 9d %(message)s', level=logging.INFO)

In [2]:
# This is a fully generative model of a batch of documents.
# data is a [num_words_per_doc, num_documents] shaped array of word ids
# (specifically it is not a histogram). We assume in this simple example
# that all documents have the same number of words.
def model(data=None, args=None, batch_size=None):
    # Globals.
    with pyro.plate("topics", args.num_topics):  # 8
        # sample a topic weight (alpha?)
        # (8)
        topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.))
        # sample prob of words for each topic (beta?)
        # symmetric Dirichlet distribution - alpha is the same for each word
        # (8 x 1024) (num_topics x num_words), sum in each row = 1
        topic_words = pyro.sample("topic_words",
                                  dist.Dirichlet(torch.ones(args.num_words) / args.num_words))

    # Locals.
    with pyro.plate("documents", args.num_docs) as ind:  # 1000
        if data is not None:
            # PyTorch jit compiler in Pyro models - for speeding up models
            # ignore warnings in safe code blocks
            with pyro.util.ignore_jit_warnings():
                assert data.shape == (args.num_words_per_doc, args.num_docs)
            # indeksy???
            data = data[:, ind]
        # documents vs topics - dirichlet dist, alpha as input
        # (1000 x 8) (num-docs x num_topics), sum in each row = 1
        doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
        # for every word in doc
        with pyro.plate("words", args.num_words_per_doc):  # 64
            # The word_topics variable is marginalized out during inference,
            # achieved by specifying infer={"enumerate": "parallel"} and using
            # TraceEnum_ELBO for inference. Thus we can ignore this variable in
            # the guide.
            
            # samples one topic from 8 available according to doc_topics probs
            # word_topics ~ (64 x 1000) (num_words_per_doc x num_docs)
            word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics),
                                      infer={"enumerate": "parallel"})
            # samples 64 words per doc
            # topic_words[word_topics] ~ (8 x 64 x 1000 x 1024)
            # ~ (num_topics x num_word_per_doc x num_docs x num_words)
            data = pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]),
                               obs=data)
            # data ~ (64 x 1000) (num_words_per_doc x num_docs)

    return topic_weights, topic_words, data


# We will use amortized inference of the local topic variables, achieved by a
# multi-layer perceptron. We'll wrap the guide in an nn.Module.
def make_predictor(args):
    layer_sizes = ([args.num_words] +
                   [int(s) for s in args.layer_sizes.split('-')] +
                   [args.num_topics])
    logging.info('Creating MLP with sizes {}'.format(layer_sizes))
    layers = []
    for in_size, out_size in zip(layer_sizes, layer_sizes[1:]):
        layer = nn.Linear(in_size, out_size)
        layer.weight.data.normal_(0, 0.001)
        layer.bias.data.normal_(0, 0.001)
        layers.append(layer)
        layers.append(nn.Sigmoid())
    layers.append(nn.Softmax(dim=-1))
    return nn.Sequential(*layers)


def parametrized_guide(predictor, data, args, batch_size=None):
    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
            "topic_weights_posterior",
            lambda: torch.ones(args.num_topics),
            constraint=constraints.positive)
    topic_words_posterior = pyro.param(
            "topic_words_posterior",
            lambda: torch.ones(args.num_topics, args.num_words),
            constraint=constraints.greater_than(0.5)) # ??
    with pyro.plate("topics", args.num_topics):
        pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.))
        pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))

    # Use an amortized guide for local variables.
    pyro.module("predictor", predictor)
    # 
    with pyro.plate("documents", args.num_docs, batch_size) as ind:
        # The neural network will operate on histograms rather than word
        # index vectors, so we'll convert the raw data to a histogram.
        if torch._C._get_tracing_state():
            # ones on the diagonal and zeros elsewhere
            counts = torch.eye(1024)[data[:, ind]].sum(0).t()
        else:
            counts = torch.zeros(args.num_words, ind.size(0))
            # https://pytorch.org/docs/stable/tensors.html
            # scatter_add_(dim, index, other)
            # Adds all values from the tensor other into self at the indices specified in the index tensor 
            # in a similar fashion as scatter_(). For each value in other, it is added to an index in self 
            # which is specified by its index in other for dimension != dim and by the corresponding value 
            # in index for dimension = dim.
            # self[index[i][j][k]][j][k] += other[i][j][k]  # if dim == 0
            counts.scatter_add_(0, data[:, ind], torch.tensor(1.).expand(counts.shape))
            
        doc_topics = predictor(counts.transpose(0, 1))
        
        # a single point (dirac delta), event_dim - event dimension
        pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1))

In [3]:
def main(args):
    logging.info('Generating data')
    pyro.set_rng_seed(0)
    pyro.clear_param_store()
    pyro.enable_validation(True)

    # We can generate synthetic data directly by calling the model.
    true_topic_weights, true_topic_words, data = model(args=args)

    # We'll train using SVI.
    logging.info('-' * 40)
    logging.info('Training on {} documents'.format(args.num_docs))
    predictor = make_predictor(args)
    guide = functools.partial(parametrized_guide, predictor)
    Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
    elbo = Elbo(max_plate_nesting=2)
    optim = Adam({'lr': args.learning_rate})
    svi = SVI(model, guide, optim, elbo)
    logging.info('Step\tLoss')
    for step in range(args.num_steps):
        loss = svi.step(data, args=args, batch_size=args.batch_size)
        if step % 10 == 0:
            logging.info('{: >5d}\t{}'.format(step, loss))
    loss = elbo.loss(model, guide, data, args=args)
    logging.info('final loss = {}'.format(loss))

In [6]:
# num_topics = 8
# num_words = 1024
# num_docs = 1000
# num_words_per_doc = 64
# num_steps = 1000
# layer_sizes = "100-100"
# learning_rate = 0.001
# batch_size = 32
# jit = 'store_true'
# %%time

parser = argparse.ArgumentParser(description="Amortized Latent Dirichlet Allocation")
parser.add_argument("-t", "--num-topics", default=8, type=int)
parser.add_argument("-w", "--num-words", default=1024, type=int)
parser.add_argument("-d", "--num-docs", default=1000, type=int)
parser.add_argument("-wd", "--num-words-per-doc", default=64, type=int)
parser.add_argument("-n", "--num-steps", default=1000, type=int)
parser.add_argument("-l", "--layer-sizes", default="100-100")
parser.add_argument("-lr", "--learning-rate", default=0.001, type=float)
parser.add_argument("-b", "--batch-size", default=32, type=int)
parser.add_argument('--jit', action='store_true')
args = parser.parse_args("-t 8".split())
main(args)

   131822 Generating data
   138031 ----------------------------------------
   138031 Training on 1000 documents
   138031 Creating MLP with sizes [1024, 100, 100, 8]
   138074 Step	Loss
   139223     0	483160.40625
   143525    10	470484.21875
   147795    20	483758.96875
   152067    30	485707.21875
   156332    40	486927.15625
   160520    50	477152.90625
   164781    60	483504.375
   168987    70	495452.59375
   173187    80	491391.25
   177369    90	474378.1875
   181647   100	484284.96875
   185867   110	494794.75
   190025   120	480864.96875
   194189   130	480555.5625
   198492   140	472492.78125
   202747   150	472694.6875
   206918   160	465221.96875
   211180   170	494591.71875
   215430   180	486451.53125
   219728   190	475809.46875
   223996   200	468522.40625
   228214   210	472827.09375
   232435   220	488274.09375
   236652   230	478799.8125
   240975   240	477782.0
   245221   250	487383.4375
   249465   260	466331.6875
   253729   270	457770.96875
   257984   280	46