# DCTM for NeurIPS dataset

In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

tfb = tfp.bijectors
tfd = tfp.distributions
tfk = tfp.math.psd_kernels

from matplotlib import pyplot as plt
from tqdm import tqdm
from sklearn import metrics
from imp import reload
from scipy import sparse as sp

from dctm import correlated_topic_model as ctmd
from dctm import dynamic_correlated_topic_model as dctm

In [None]:
# download data with:
!curl https://archive.ics.uci.edu/ml/machine-learning-databases/00371/NIPS_1987-2015.csv -o NIPS_1987-2015.csv
# or 
# !wget https://archive.ics.uci.edu/ml/machine-learning-databases/00371/NIPS_1987-2015.csv

In [None]:
# you may need the following
# import nltk
# nltk.download('words')
# nltk.download('punkt')
# nltk.download('wordnet')
from dctm import datasets

df, years, vocabulary = datasets.get_neurips('NIPS_1987-2015.csv')

vocabulary_subset = vocabulary[vocabulary > 1700].index

X_small = df.loc[vocabulary_subset].T.dropna()
X_small = X_small.loc[X_small.sum(axis=1) > 0]

year = np.array([x.split('_')[0] for x in X_small.index])
X = np.expand_dims(X_small.values.astype(np.float64), -2)

import sklearn, pandas as pd
scaler = sklearn.preprocessing.MinMaxScaler([-1, 1])
index_points = scaler.fit_transform(year.astype(int)[:, None])
# index_points = year.astype(np.float64)[:, None]

np.random.seed(42)
(X_tr, X_ts, index_tr, index_ts, X_tr_sorted, X_ts_sorted,
 index_tr_sorted, index_ts_sorted
) = datasets.train_test_split(X, index_points)

inverse_transform_fn = lambda x: pd.to_datetime(scaler.inverse_transform(x)[:, 0], format='%Y')
df_train = pd.DataFrame(X_tr_sorted[:, 0, :])
df_train['years'] = inverse_transform_fn(index_tr_sorted)

df_test = pd.DataFrame(X_ts_sorted[:, 0, :])
df_test['years'] = inverse_transform_fn(index_ts_sorted)

print("Dataset shape: \n tr: {} \n ts: {}".format(X_tr.shape, X_ts.shape))

In [None]:
batch_size = 100
n_train_samples = X_tr.shape[0]

dataset = tf.data.Dataset.zip(
    tuple(map(tf.data.Dataset.from_tensor_slices,
        (X_tr, index_tr))))
dataset = dataset.shuffle(n_train_samples, reshuffle_each_iteration=True)
data_tr = dataset.batch(batch_size)

In [None]:
inducing_index_points_beta = np.linspace(-1, 1, 15)[:, None]
inducing_index_points_mu = np.linspace(-1, 1, 20)[:, None]
inducing_index_points_ell = np.linspace(-1, 1, 15)[:, None]

dtype = np.float64
amplitude_beta = tfp.util.TransformedVariable(
    1., bijector=tfb.Softplus(), dtype=dtype, name='amplitude_beta')
length_scale_beta = tfp.util.TransformedVariable(
    0.5, bijector=tfb.Softplus(), dtype=dtype,
    name='length_scale_beta')
kernel_beta = tfk.MaternOneHalf(amplitude=amplitude_beta, length_scale=length_scale_beta)

amplitude_mu = tfp.util.TransformedVariable(
    1., bijector=tfb.Softplus(), dtype=dtype, name="amplitude_mu")
length_scale_mu = tfp.util.TransformedVariable(
    0.5, bijector=tfb.Softplus(), dtype=dtype,
    name="length_scale_mu")
kernel_mu = tfk.ExponentiatedQuadratic(amplitude=amplitude_mu, length_scale=length_scale_mu)

amplitude_ell = tfp.util.TransformedVariable(
    1., bijector=tfb.Softplus(), dtype=dtype, name='amplitude_ell')
length_scale_ell = tfp.util.TransformedVariable(
    0.5, bijector=tfb.Softplus(), dtype=dtype,
    name='length_scale_ell')
kernel_ell = tfk.ExponentiatedQuadratic(amplitude=amplitude_ell, length_scale=length_scale_ell)

reload(ctmd)
reload(dctm);

losses = []
perplexities = []
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

mdl = dctm.DCTM(
    n_topics=30, n_words=vocabulary_subset.size,
    kernel_beta=kernel_beta,
    index_points_beta=np.unique(index_tr)[:, None],
    inducing_index_points_beta=inducing_index_points_beta,
    kernel_ell=kernel_ell,
    kernel_mu=kernel_mu,
    index_points_mu=np.unique(index_tr)[:, None],
    index_points_ell=np.unique(index_tr)[:, None],
    inducing_index_points_mu=inducing_index_points_mu,
    inducing_index_points_ell=inducing_index_points_ell,
    layer_sizes=(500, 300, 200),
    jitter_beta=1e-6,
    jitter_mu=1e-5, 
    jitter_ell=1e-6,
    encoder_jitter=1e-8,dtype=dtype)

In [None]:
n_iter = 2
pbar = tqdm(range(n_iter), disable=False)

with tf.device('gpu'): 
    for epoch in pbar:
        loss_value = 0
        perplexity_value = 0

        for x_batch, index_points_batch in data_tr:
            loss, perpl = mdl.batch_optimize(
                x_batch,
                optimizer=optimizer,
                observation_index_points=index_points_batch,
                trainable_variables=None,
                kl_weight=float(x_batch.shape[0]) / float(n_train_samples))
            loss = tf.reduce_mean(loss, 0)
            loss_value += loss
            perplexity_value += perpl
        pbar.set_description(
            'loss {:.3e}, perpl {:.3e}'.format(loss_value, perplexity_value))

        losses.append(loss_value)
        perplexities.append(perplexity_value)

In [None]:
plt.plot(losses)
plt.semilogy();

In [None]:
plt.plot(perplexities)
plt.semilogy();

In [None]:
with tf.device('gpu'):
    elbo = mdl.elbo(X_ts, index_ts, kl_weight=0.)
    perpl = mdl.perplexity(X_ts, elbo)
    print(perpl)

In [None]:
mdl.n_topics = mdl.surrogate_posterior_beta.batch_shape[1]

In [None]:
inverse_transform_fn = lambda x: pd.to_datetime(scaler.inverse_transform(x)[:, 0], format='%Y').strftime('%Y')

reload(dctm)
tops = dctm.print_topics(
    mdl, index_points=index_tr, vocabulary=vocabulary_subset,
    inverse_transform_fn=inverse_transform_fn, top_n_topic=30, top_n_time=5)
topics = np.array(tops)

In [None]:
n_topics = mdl.surrogate_posterior_beta.batch_shape[-1]
colors = plt.cm.jet(np.linspace(0, 1, n_topics))

In [None]:
test_points = np.linspace(-1,1, 100)[:,None]

In [None]:
corr_sample, Sigma_sample = dctm.get_correlation(mdl.surrogate_posterior_ell.sample(1200, index_points=test_points))
corr_10p = tfp.stats.percentile(corr_sample, 5, axis=0)
corr = tfp.stats.percentile(corr_sample, 50, axis=0)
corr_90p = tfp.stats.percentile(corr_sample, 95, axis=0)
Sigma_10p = tfp.stats.percentile(Sigma_sample, 5, axis=0)
Sigma = tfp.stats.percentile(Sigma_sample, 50, axis=0)
Sigma_90p = tfp.stats.percentile(Sigma_sample, 95, axis=0)

In [None]:
from dctm import plotting

reload(plotting)
plotting.plot_sigma(corr_sample, test_points, 11,
    topics,
    inverse_transform_fn,
    restrict_to=None,
    color_fn=plt.cm.tab20c,
    legend='right', plot_if_higher_of=0.1);

In [None]:
topic = mdl.predict(X).numpy()
tmp_df = pd.DataFrame(topic[:,0,:], index=index_points[:, 0])
topics_per_time = tmp_df.groupby(tmp_df.index).mean().values.T

In [None]:
prev = 0
cm = plt.get_cmap('tab20c')
colors = cm(np.linspace(0,1,9))

topic_num = 11
plt.title("Topic {}: {}".format(topic_num, topics[topic_num][:35]))
c = 0
for t in range(n_topics):
    if t == topic_num:# or t not in [13,19]:
        continue
    if tf.reduce_mean(np.abs(corr[:, topic_num, t])) < 0.15: continue
    curr = prev+corr[:, topic_num, t]
    plt.fill_between(test_points[:, 0],
                     prev, curr, 
                     color=colors[c], label='{}:{}'.format(t, topics[t][:20]))
    prev = curr
    c += 1

plt.xticks(test_points[::10], inverse_transform_fn(test_points)[::10], rotation=30);
plt.gca().legend(loc='center left', bbox_to_anchor=(1, 0.5));
f2 = plt.gcf()
plt.show()

In [None]:
prev = 0
cm = plt.get_cmap('tab20c')
colors = cm(np.linspace(0,1,n_topics))

topic_num = 19
plt.title("Topic {}: {}".format(topic_num, topics[topic_num][:35]))
c = 0
for t in range(n_topics):
    if t == topic_num:# or t not in [13,19]:
        continue
#     if tf.reduce_mean(np.abs(corr[:, topic_num, t])) < 0.15:
#         continue
    curr = prev + corr[:, topic_num, t]
    plt.fill_between(test_points[:, 0], prev, curr, 
                     color=colors[c], label='{}:{}'.format(t, topics[t][:20]))
    prev = curr
    c += 1

plt.xticks(test_points[::10], inverse_transform_fn(test_points)[::10], rotation=30);
#     plt.ylim([None,0.5])
plt.gca().legend(loc='center left', bbox_to_anchor=(1, 0.5));
f2 = plt.gcf()
plt.show()

In [None]:
plotting.plot_sigma(corr_sample, test_points, 19, topics, inverse_transform_fn, restrict_to=[2,5,11,12,14,15,19],legend='bottom');

In [None]:
# f2.savefig('neurips_correlation_neuroscience_vertical.pdf', dpi=600, transparent=True, bbox_inches='tight')

In [None]:
plotting.plot_sigma(Sigma_sample, test_points, 15, topics, inverse_transform_fn, restrict_to=[13,19],legend='bottom');

In [None]:
# f.savefig('class_correlation1.pdf', dpi=600, transparent=True, bbox_inches='tight')

For a topic, let's show the correlation with the others. $\Sigma$ with error bars

In [None]:
for topic_num in range(n_topics):
    plt.title("Topic {}: {}".format(topic_num, topics[topic_num][:30]))
    for t in range(n_topics):
        if t == topic_num:# or t not in [0,1,2,15,3]:
            continue
        plt.plot(corr[:, topic_num, t], label='{}:{}'.format(t, topics[t][:20]), color=colors[t])

    plt.xticks(range(test_points.size)[::10], inverse_transform_fn(test_points)[::10], rotation=45);
    plt.xlim([20,None])
    plt.gca().legend(loc='center left', bbox_to_anchor=(1, 0.5));
    f = plt.gcf()
    plt.show()

In [None]:
# f.savefig('sample_correlation.pdf', dpi=600, transparent=True, bbox_inches='tight')

In [None]:
topic = mdl.predict(X)[:,0,:].numpy()
tmp_df = pd.DataFrame(topic, index=index_points[:, 0])
topics_per_time = tmp_df.groupby(tmp_df.index).mean().values.T

In [None]:
reload(plotting)
f = plotting.plot_predictions(
    mdl, topics_per_time, index_points, topics, inverse_transform_fn,
    restrict_to=None#[2,5,11,12,14,15,19]
)

In [None]:
reload(plotting)
f = plotting.plot_predictions(
    mdl, topics_per_time, index_points, topics, inverse_transform_fn,
    restrict_to=[2,5,11,12,14,15,19],
    legend='bottom'
)

In [None]:
# f.savefig('neurips_topics_eta_vertical.pdf', dpi=600, transparent=True, bbox_inches='tight')

In [None]:
colors = plt.cm.jet(np.linspace(0,1,n_topics))
mu = mdl.surrogate_posterior_mu.get_marginal_distribution(test_points)
mu_sm = tf.nn.softmax(mu.mean(), axis=0)
mu_sample = tf.nn.softmax(mu.sample(110), axis=1)
mu_90p = tfp.stats.percentile(mu_sample, 95, axis=0)
mu_10p = tfp.stats.percentile(mu_sample, 5, axis=0)

for i in range(n_topics):
    if tf.reduce_mean(tf.abs(mu_sm[i])) > 0.001:
        line, = plt.plot(test_points, mu_sm[i], label=topics[i], color=colors[i]);
        plt.fill_between(
                test_points[:, 0],
                mu_10p[i],
                mu_90p[i],
                color=line.get_color(),
                alpha=0.3,
                lw=1.5,
            )

        plt.plot(np.unique(index_points), topics_per_time[i], label='{}'.format(topics[i]), color=colors[i])

        plt.xticks(test_points[::8], inverse_transform_fn(test_points)[::8], rotation=45);
        plt.gca().legend(loc='center left', bbox_to_anchor=(1, 0.5));
        plt.ylim(0,.3);
        plt.show()

Probability of topics over time.

$\mu$ with error bars

In [None]:
f = plotting.plot_mu(
    mdl, test_points, topics, inverse_transform_fn,
    restrict_to=None, color_fn=lambda x:[None]*len(x), figsize=(9,5), plot_if_higher_of=0
)

In [None]:
reload(plotting)
# legends = [
#     '2:layer unit hidder ar',
#     '5:posterior bayesian g',
#     '11:dirichlet topic expe',
#     '12:theorem proof bound',
#     '14:estim densiti sampl',
#     '15:voltag channel signa',
#     '19:neuron synapt fire c'
# ]
sample_size = 1
f = plotting.plot_mu_stacked(
    mean=tf.reduce_mean(
        tf.nn.softmax(
            tf.transpose(
                tfd.MultivariateNormalTriL(
                    loc=tfd.TransformedDistribution(
                        tfd.Independent(mdl.surrogate_posterior_mu.get_marginal_distribution(test_points), 1),
                        bijector=tfb.Transpose(rightmost_transposed_ndims=2),
                    ).sample(sample_size),
                    scale_tril=mdl.surrogate_posterior_ell.sample(sample_size, index_points=test_points),
                ).sample()
            ), axis=1),
        -1),
    test_points=test_points,
    topics=topics,
    inverse_transform_fn=inverse_transform_fn,
    restrict_to=None, color_fn=plt.cm.tab20c, figsize=(9,5), plot_if_higher_of=0
)

In [None]:
# f.savefig('neurips_posterior_mu_vertical_new_2.pdf', dpi=600, transparent=True, bbox_inches='tight')

Probability of word-topic over time. $\beta$ with error bars

In [None]:
reload(plotting)
with tf.device('CPU'):
    f = plotting.plot_beta_and_stacked(
        mdl, test_points, topic_num=1, vocabulary=vocabulary_subset, inverse_transform_fn=inverse_transform_fn,
        topics=topics,
        restrict_words_to=["lda", "topic", "document", "dirichlet", "hmm", "expert", "mixtur", "word", "latent"],
        figsize=(7,7))

In [None]:
# f.savefig('neurips_posterior_beta_lda_vertical_2.pdf', dpi=600, transparent=True, bbox_inches='tight')