In [1]:
import copy
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.autograd as autograd
from torch.utils.data import Dataset
import gensim
import numpy as np
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import string
import re
from sklearn.datasets import fetch_20newsgroups
import pandas as pd
from nltk.tokenize import word_tokenize
import swifter
from gensim import corpora, models

stop_words = list(set(stopwords.words('english')))

if torch.cuda.device_count() == 0:
    device = 'cpu'
else:
    device = 'cuda'

In [2]:
def preprocess_text(text):
    text = text.lower()
    text = re.sub(r'[{}0-9]'.format(string.punctuation), ' ', text)
    text=re.sub(r'[^A-Za-z0-9 ]+', ' ', text)
    text = word_tokenize(text)
    text = [word for word in text if word not in stop_words]
    text = [WordNetLemmatizer().lemmatize(word) for word in text]
    text = ' '.join(text)
    return text


newsgroups = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))
df=pd.DataFrame({"content":newsgroups["data"]})

#df=df.sample(1000)
df["content"]=df["content"].swifter.apply(lambda x: preprocess_text(x))
df['content_length'] = df['content'].str.len()

df = df[df['content_length'] > 100]
df = df[df['content_length'] < 2000]

df=df[["content"]].reset_index(drop=True).reset_index().rename(columns={"index":"id"})
df["tokenized_content"]=df["content"].apply(lambda content: word_tokenize(content))

dictionary = gensim.corpora.Dictionary(df['tokenized_content'])
bow_corpus = [dictionary.doc2bow(doc) for doc in df['tokenized_content']]

Pandas Apply:   0%|          | 0/18846 [00:00<?, ?it/s]

In [3]:
Tensor = torch.FloatTensor

class TopicModel(Dataset):
    def __init__(self, corpus = bow_corpus, len_dict = len(dictionary)):
        self.corpus = corpus
        self.len_dict = len_dict
    def __len__(self):
        return len(self.corpus)
    def __getitem__(self, idx):
        sample = np.zeros(self.len_dict)
        index, value = zip(*self.corpus[idx])
        sample[list(index)] = list(value)
        return sample

class Generator(nn.Module):
    def __init__(self, opt):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(opt['num_topics'], opt['enc_mid_layer']),
            nn.LeakyReLU(),
            # nn.LayerNorm(opt['enc_mid_layer']),
            nn.Linear(opt['enc_mid_layer'], opt['vocab_size']),
            nn.Softmax(-1)
        )

    def forward(self, z):
        z = self.model(z)
        return z


class Discriminator(nn.Module):
    def __init__(self, opt):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(opt['vocab_size'], opt['dec_mid_layer']),
            nn.LeakyReLU(),
            nn.Linear(opt['dec_mid_layer'], 1)
        )

    def forward(self, z):
        z = self.model(z)
        return z


def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = Tensor(np.random.random((real_samples.size(0), 1))).to(device)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True).to(device)
    # print(interpolates.shape)
    d_interpolates = D(interpolates)
    # print(d_interpolates.shape)
    fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False).to(device)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.kaiming_normal_(m.weight)

# ----------
#  Training
# ----------
def trainer(dataloader, opt, generator, discriminator, optimizer_G, optimizer_D):
    for epoch in range(opt['n_epochs']):
        for i, real_data in enumerate(dataloader):

            # Configure input
            real_data = real_data.float().to(device)
            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()
            dirichlet = torch.distributions.dirichlet.Dirichlet(torch.tensor([1/opt['num_topics'] for _ in range(opt['num_topics'])]))
            sample = dirichlet.sample()
            # Sample noise as generator input
            z = Variable(sample.repeat(real_data.shape[0],1)).to(device)
            # Generate a batch of images
            fake_data = generator(z)
            # Real images
            real_validity = discriminator(real_data)
            # Fake fake_data
            fake_validity = discriminator(fake_data)
            # Gradient penalty
            gradient_penalty = compute_gradient_penalty(discriminator, real_data.data, fake_data.data)
            # Adversarial loss
            d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + opt['lambda_gp'] * gradient_penalty
            wasserstein_d = -torch.mean(real_validity) + torch.mean(fake_validity)
            d_loss.backward()
            optimizer_D.step()
            optimizer_G.zero_grad()
            # Train the generator every n_critic steps
            if i % opt['n_critic'] == 0:
                # -----------------
                #  Train Generator
                # -----------------
                # Generate a batch of images
                fake_data = generator(z)
                # Loss measures generator's ability to fool the discriminator
                # Train on fake images
                fake_validity = discriminator(fake_data)
                g_loss = -torch.mean(fake_validity)
                g_loss.backward()
                optimizer_G.step()
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [Wasserstein Distance: %f]"
                    % (epoch, opt['n_epochs'], i, len(dataloader), d_loss.item(), g_loss.item(), wasserstein_d.item())
                )


def ATM(dictionary,args):
    # Initialize generator and discriminator
    generator = Generator(args).to(device)
    discriminator = Discriminator(args).to(device)
    generator.apply(init_weights)
    discriminator.apply(init_weights)
    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr = 0.0001, betas=(0, 0.9))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr = 0.0001, betas=(0, 0.9))
    trainer(loader, args, generator, discriminator, optimizer_G, optimizer_D)
    emb = [generator(Tensor([[0 if i != j else 1 for j in range(args['num_topics'])]]).to(device)).squeeze() for i in range(args['num_topics'])]
    feature_names = list(zip(*sorted(dictionary.items(), key=lambda x:x[0])))[1]
    topics = [[feature_names[j] for j in emb[i].argsort(descending=True)[:args['n_top_words'] - 1]] for i in range(len(emb))]
    disc_model = copy.deepcopy(discriminator.state_dict())
    enc_model = copy.deepcopy(generator.state_dict())

    return disc_model, enc_model,topics

In [4]:
args = {'num_topics': 5,"n_top_words": 10,'enc_mid_layer': 100, 'dec_mid_layer': 100, 'lambda_gp':10, 'vocab_size': len(dictionary), 'batch_size': 128, 'n_epochs': 50, 'n_critic': 5}
dataset = TopicModel()
loader = torch.utils.data.DataLoader(dataset = dataset,
                                      batch_size = args['batch_size'],
                                      shuffle = True)
_,_,topics=ATM(dictionary=dictionary,args=args)

[Epoch 0/50] [Batch 0/115] [D loss: 2.028338] [G loss: -0.044502] [Wasserstein Distance: 0.020775]
[Epoch 0/50] [Batch 5/115] [D loss: 0.271844] [G loss: -0.041338] [Wasserstein Distance: -0.074458]
[Epoch 0/50] [Batch 10/115] [D loss: -0.084459] [G loss: -0.038368] [Wasserstein Distance: -0.144568]
[Epoch 0/50] [Batch 15/115] [D loss: -0.213754] [G loss: -0.035595] [Wasserstein Distance: -0.236599]
[Epoch 0/50] [Batch 20/115] [D loss: -0.376561] [G loss: -0.033056] [Wasserstein Distance: -0.394404]
[Epoch 0/50] [Batch 25/115] [D loss: -0.455374] [G loss: -0.030523] [Wasserstein Distance: -0.463785]
[Epoch 0/50] [Batch 30/115] [D loss: -0.615328] [G loss: -0.028243] [Wasserstein Distance: -0.629372]
[Epoch 0/50] [Batch 35/115] [D loss: -0.651015] [G loss: -0.025972] [Wasserstein Distance: -0.661522]
[Epoch 0/50] [Batch 40/115] [D loss: -0.803578] [G loss: -0.023616] [Wasserstein Distance: -0.813273]
[Epoch 0/50] [Batch 45/115] [D loss: -0.843715] [G loss: -0.021251] [Wasserstein Distan

In [5]:
topics

[['note',
  'everyone',
  'though',
  'great',
  'bike',
  'still',
  'well',
  'please',
  'show'],
 ['actually',
  'rule',
  'oh',
  'area',
  'tell',
  'need',
  'read',
  'name',
  'although'],
 ['lot', 'oh', 'common', 'mail', 'ok', 'position', 'bike', 'b', 'book'],
 ['actually',
  'least',
  'ok',
  'everyone',
  'although',
  'general',
  'interested',
  'oh',
  'second'],
 ['got', 'old', 'oh', 'men', 'today', 'b', 'everyone', 'please', 'gun']]