In [1]:
import pandas as pd
import numpy as np

import math
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceMeanField_ELBO
import torch.nn as nn
import torch.nn.functional as F
from tqdm import trange

import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
news = fetch_20newsgroups(subset='all')
vectorizer = CountVectorizer(max_df=0.5, min_df=20, stop_words='english')
docs = torch.from_numpy(vectorizer.fit_transform(news['data']).toarray())

vocab = pd.DataFrame(columns=['word', 'index'])
vocab['word'] = vectorizer.get_feature_names_out()
vocab['index'] = vocab.index

KeyboardInterrupt: 

In [5]:
print('Dictionary size: %d' % len(vocab))
print('Corpus size: {}'.format(docs.shape))

Dictionary size: 12722
Corpus size: torch.Size([18846, 12722])


In [6]:
class Encoder(nn.Module):
    def __init__(self, num_genes, num_topics, hidden, dropout):
        super().__init__()
        self.drop = nn.Dropout(dropout)  # to avoid component collapse
        self.fc1 = nn.Linear(num_genes, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fcmu = nn.Linear(hidden, num_topics)
        self.fclv = nn.Linear(hidden, num_topics)
        self.bnmu = nn.BatchNorm1d(num_topics, affine=False)  # to avoid component collapse
        self.bnlv = nn.BatchNorm1d(num_topics, affine=False)  # to avoid component collapse

    def forward(self, inputs):
        h = F.softplus(self.fc1(inputs))
        h = F.softplus(self.fc2(h))
        h = self.drop(h)
        log_z_loc = self.bnmu(self.fcmu(h))
        log_z_logvar = self.bnlv(self.fclv(h))
        log_z_scale = (0.5 * log_z_logvar).exp()  # Enforces positivity
        return log_z_loc, log_z_scale


class Decoder(nn.Module):
    def __init__(self, num_genes, num_topics, dropout):
        super().__init__()
        self.beta = nn.Linear(num_topics, num_genes, bias=True)
        self.bn = nn.BatchNorm1d(num_genes, affine=False)
        self.drop = nn.Dropout(dropout)

    def forward(self, inputs):
        inputs = self.drop(inputs)
        return F.softmax(self.bn(self.beta(inputs)), dim=1)


class ProdLDA(nn.Module):
    def __init__(self, num_genes, num_topics, library_size, hidden, dropout, dispersion):
        super().__init__()
        self.num_genes = num_genes
        self.num_topics = num_topics
        self.library_size = library_size
        self.dispersion = dispersion
        self.encoder = Encoder(num_genes, num_topics, hidden, dropout)
        self.decoder = Decoder(num_genes, num_topics, dropout)

    def model(self, docs):
        pyro.module("decoder", self.decoder)
        with pyro.plate("documents", docs.shape[0]):
            # Logistic normal prior 𝑝(z|μ, Σ) 
            log_z_loc = docs.new_zeros((docs.shape[0], self.num_topics))
            log_z_scale = docs.new_ones((docs.shape[0], self.num_topics))
            log_z = pyro.sample(
                "log_z", dist.Normal(log_z_loc, log_z_scale).to_event(1))
            z = F.softmax(log_z, -1)

            latent_exp = self.decoder(z)
            latent_exp_scaled = self.library_size * latent_exp
            pyro.sample(
                'obs',
                dist.GammaPoisson(self.dispersion, self.dispersion / latent_exp_scaled),
                obs=docs
            )

    def guide(self, docs):
        pyro.module("encoder", self.encoder)
        with pyro.plate("documents", docs.shape[0]):
            log_z_loc, log_z_scale = self.encoder(docs)
            log_z = pyro.sample(
                "log_z", dist.Normal(log_z_loc, log_z_scale).to_event(1))

    def beta(self):
        # beta matrix elements are the weights of the FC layer on the decoder
        return self.decoder.beta.weight.cpu().detach().T

In [12]:
# setting global variables
seed = 0
torch.manual_seed(seed)
pyro.set_rng_seed(seed)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

num_topics = 10 
docs = docs.float().to(device)
batch_size = 32
learning_rate = 1e-3
num_epochs = 50 

In [18]:
# training
pyro.clear_param_store()

prodLDA = ProdLDA(
    num_genes=docs.shape[1],
    num_topics=num_topics,
    hidden=100,
    dropout=0.2
)
prodLDA.to(device)

optimizer = pyro.optim.Adam({"lr": learning_rate})
svi = SVI(prodLDA.model, prodLDA.guide, optimizer, loss=TraceMeanField_ELBO())
num_batches = int(math.ceil(docs.shape[0] / batch_size)) 

bar = trange(num_epochs)
for epoch in bar:
    running_loss = 0.0
    for i in range(num_batches):
        batch_docs = docs[i * batch_size:(i + 1) * batch_size, :]
        loss = svi.step(batch_docs)
        running_loss += loss / batch_docs.size(0)

    bar.set_postfix(epoch_loss='{:.2e}'.format(running_loss))


100%|██████████| 50/50 [04:37<00:00,  5.56s/it, epoch_loss=3.72e+05]


Extensions - apply a gene-set informed prior to the weights beta