# Topic Modelling via Clustering Embeddings

This is a very code heavy notebook about topic modelling and visualization thereof. Mostly visualization thereof. We will be making use of the vectorizer library.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import sklearn.datasets
import sklearn.preprocessing
import scipy.sparse
import vectorizers
import vectorizers.transformers
import umap
import umap.plot
import pynndescent
import seaborn as sns
import matplotlib.colors
import hdbscan

sns.set()

First step, our test data -- the standard 20-nbewsgroups dataset.

In [2]:
news = sklearn.datasets.fetch_20newsgroups(
    subset="all", remove=("headers", "footers", "quotes")
)
long_enough = [len(t) > 200 for t in news["data"]]
targets = np.array(news.target)[long_enough]
news_data = [t for t in news["data"] if len(t) > 200]

We'll do very simple tokenization; good enough to get the job done. In a more advanced version of this we might use something like sentence-piece instead to learn a tokenization.

In [3]:
%%time
cv = sklearn.feature_extraction.text.CountVectorizer(lowercase=True)
sk_word_tokenize = cv.build_tokenizer()
sk_preprocesser = cv.build_preprocessor()
tokenize = lambda doc: sk_word_tokenize(sk_preprocesser(doc))
tokenized_news = [tokenize(doc) for doc in news_data]

CPU times: user 1 s, sys: 50.1 ms, total: 1.05 s
Wall time: 1.05 s


And now we'll use the ``TokenCooccurrenceVectorizer`` to generate word vectors learned directly from the corpus. This has the benefit that we learn idiomatic word usage. It has the downside that we have an issue when we don't have enough text to learn good word vectors. In an ideal world we could use some pretrained material to manage to make this more tractable -- and indeed we can use a Bayesian prior (pre-trained on a larger corpus) on the co-occurrence matrix, but that isn't implemented yet. Fortunaetly, despite being quite a small datasets, 20 newsgroups is "big enough" to learn reasonable word vectors.

In [4]:
%%time
word_vectorizer = vectorizers.TokenCooccurrenceVectorizer(
    min_document_occurrences=5,
    window_radii=10,          
    window_functions='variable',
    kernel_functions='geometric',            
    n_iter = 0,
    normalize_windows=True,
).fit(tokenized_news)
word_vectors = word_vectorizer.reduce_dimension(dimension=160, algorithm="randomized")

CPU times: user 1min 58s, sys: 3.69 s, total: 2min 1s
Wall time: 1min 46s


Next we need document embeddings, and, to power the topic modelling, word vectors that live in the same space as the topic vectors. Fortunately this is actually surprisingly easy to arrange -- we create a document matrix of word vectors (i.e. the identity matrix) and just push that through the same pipeline.

In [5]:
%%time
doc_matrix = vectorizers.NgramVectorizer(
    token_dictionary=word_vectorizer.token_label_dictionary_
).fit_transform(tokenized_news)
info_transformer = vectorizers.transformers.InformationWeightTransformer(
    prior_strength=1e-1,
    approx_prior=False,
)
info_doc_matrix = info_transformer.fit_transform(doc_matrix)
lat_vectorizer = vectorizers.ApproximateWassersteinVectorizer(
    normalization_power=0.66,
    random_state=42,
)
lat_doc_vectors = lat_vectorizer.fit_transform(info_doc_matrix, vectors=word_vectors)
lat_word_vectors = lat_vectorizer.transform(info_transformer.transform(scipy.sparse.eye(word_vectors.shape[0])))

CPU times: user 20.2 s, sys: 387 ms, total: 20.6 s
Wall time: 17.7 s


Now we can do some topic modelling -- our goal is to cluster a low dimensional representation of the document evctors and consider each cluster a "topic". We can then generate "topic words" for each topic by taking the closest words to the cluster centroid in the joint document-word vector space we just created. This is essentially the same as what Top2Vec does, but we aren't using doc2vec, and we'll be directly using the HDBSCAN cluster hierarchy for varied granularity of topics.

In [6]:
def document_cluster_tree(doc_vectors, min_cluster_size=50):
    low_dim_rep = umap.UMAP(
        metric="cosine", n_components=5, min_dist=1e-4, random_state=42, n_epochs=500
    ).fit_transform(doc_vectors)
    clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size).fit(low_dim_rep)
    tree = clusterer.condensed_tree_.to_pandas()
    return tree

def get_points(tree, cluster_id):
    child_rows = tree[tree.parent == cluster_id]
    result_points = []
    result_lambdas = []
    for i, row in child_rows.iterrows():
        if row.child_size == 1:
            result_points.append(int(row.child))
            result_lambdas.append(row.lambda_val)
        else:
            points, lambdas = get_points(tree, row.child)
            result_points.extend(points)
            result_lambdas.extend(lambdas)
    return result_points, result_lambdas

def get_topic_words(tree, cluster_id, vectors, nn_index, index_to_word_fn):
    row_ids, weights = get_points(tree, cluster_id)
    centroid = np.mean(vectors[row_ids], axis=0)
    if pynndescent.distances.cosine(centroid, np.mean(vectors, axis=0)) < 0.2:
        dists, inds = nn_index.kneighbors([centroid])
        return ["☼Generic☼"], [np.mean(dists)], len(row_ids)
    dists, inds = nn_index.kneighbors([centroid])
    keywords = [index_to_word_fn(x) for x in inds[0]]
    return keywords, dists[0], len(row_ids)

But how do we visualize all of this? We can generate topic words for any cluster, and we can also generate a low dimensional representation of the documents, allowing us the place clusters in a 2D plot. The trick now is to represent the cluster with it's topic words is a somewhat space filling way. The trick to doing that is to make use of word clouds -- specifically the wordcloud package which allows the use of shaped masks in word cloud generation. Our goal will be to generate a shape mask based on the cluster and then use the wordcloud package to pack the topic words into that shape. We can then overlay all these word clouds according to the 2D layout and hopefully produce a useful visualization. This is a surprisingly large amount of work (because there are a lot of parts here). We'll start with a bucket load of imports.

In [7]:
import sklearn.cluster
from sklearn.neighbors import NearestNeighbors, KernelDensity
from matplotlib.colors import rgb2hex, Normalize
from skimage.transform import rescale
import wordcloud

Now we need a lot of helper functions. Most of this is simply cobbled together as the simplest coding solution (rather than the most efficient implementation) to extract relevant data. One thing we will need, to make visualizing the hierarchy at least somewhat tractable, is the ability to slice it into layers. Since we'll need both the clusters at a given slice-layer, and the epsilon values (for better estimating kernel bandwidths per cluster later on) we'll need functions for each. We'll also have to be able to convert those cluster selections into actual label vectors as one might expect to get out of a ``fit_predict``. None of this is that hard, just a little tedious -- it helps to be somewhat familiar with some of the inner workings of HDBSCAN to make this easier to parse.

In [8]:
def clusters_at_level(tree_df, level):
    clusters = tree_df[tree_df.child_size > 1]
    cluster_born_before_level = clusters.child[clusters.lambda_val <= level]
    cluster_dies_after_level = clusters.parent[clusters.lambda_val > level]
    cluster_is_leaf = np.setdiff1d(clusters.child, clusters.parent)
    clusters_lives_after_level = np.union1d(cluster_dies_after_level, cluster_is_leaf)
    result = np.intersect1d(cluster_born_before_level, clusters_lives_after_level)
    return result

def clusters_eps_at_level(tree_df, level):
    clusters = tree_df[tree_df.child_size > 1]
    cluster_born_before_level = clusters.child[clusters.lambda_val <= level]
    cluster_dies_after_level = clusters.parent[clusters.lambda_val > level]
    cluster_is_leaf = np.setdiff1d(clusters.child, clusters.parent)
    clusters_lives_after_level = np.union1d(cluster_dies_after_level, cluster_is_leaf)
    chosen_clusters = np.intersect1d(
        cluster_born_before_level, clusters_lives_after_level
    )
    result = [
        (1.0 / clusters.lambda_val[clusters.child == cid].values[0])
        for cid in chosen_clusters
    ]
    return result

def create_labels(tree, cluster_ids, n_points):
    result = np.full(n_points, -1)
    for i, cluster_id in enumerate(cluster_ids):
        point_ids, _ = get_points(tree, cluster_id)
        result[point_ids] = i
    return result

The next problem is colours. Colours are hard to get right. We want colours to be meaningful over different layers. The easist way to do that is to assign a colour scheme to the leaves of the cluster tree in some vaguely consistent way, and then average colours together to get colours for clusters in higher layers. Since we will have a *lot* of leaf clusters we'll need a huge palette. As a hack to ensure that the colour averaging doesn't produce nothing but muddy browns we can use some thematic colour belnds and ensure they are well placed with reagrd to our 2D layout of the leaf clusters. Since I managed 4 different colour blends we can use KMeans to cluster the leaf clusters in 2D and then just assign colours from the blend palettes within clusters. This is perhaps more complicated than it needs to be but it makes the aesthetics of the final plots a lot nicer.

In [9]:
def avg_colors(hex_colors, weights):
    rgb_colors = matplotlib.colors.to_rgba_array(hex_colors)
    result_color = np.sqrt(np.average(rgb_colors ** 2, weights=weights, axis=0))
    return matplotlib.colors.to_hex(result_color[:3])


def create_leaf_color_key(clusterer, doc_vectors):
    n_leaves = np.max(clusterer.labels_) + 1
    embedding_rep = umap.UMAP(metric="cosine", random_state=42).fit_transform(
        doc_vectors
    )
    embedding_leaf_centroids = np.array(
        [np.mean(embedding_rep[clusterer.labels_ == i], axis=0) for i in range(n_leaves)]
    )
    leaf_nbrs = NearestNeighbors(n_neighbors=3, metric="euclidean").fit(
        embedding_leaf_centroids
    )
    kmeans_classes = sklearn.cluster.KMeans(n_clusters=4).fit_predict(
        embedding_leaf_centroids
    )
    km_based_labelling = np.zeros(embedding_leaf_centroids.shape[0], dtype=np.int64)
    km_based_labelling[kmeans_classes == 0] = np.arange(np.sum(kmeans_classes == 0))
    km_based_labelling[kmeans_classes == 1] = (
        np.arange(np.sum(kmeans_classes == 1)) + np.max(km_based_labelling) + 1
    )
    km_based_labelling[kmeans_classes == 2] = (
        np.arange(np.sum(kmeans_classes == 2)) + np.max(km_based_labelling) + 1
    )
    km_based_labelling[kmeans_classes == 3] = (
        np.arange(np.sum(kmeans_classes == 3)) + np.max(km_based_labelling) + 1
    )
    cluster_order = dict(np.vstack([np.arange(n_leaves), km_based_labelling]).T)
    cluster_leaves = np.array(
        [cluster_order[x] if x >= 0 else -1 for x in clusterer.labels_]
    )
    color_key = (
        list(
            sns.blend_palette(
                ["#fbbabd", "#a566cc", "#51228d"], np.sum(kmeans_classes == 0)
            ).as_hex()
        )
        + list(
            sns.blend_palette(
                ["#ffefa0", "#fd7034", "#9d0d14"], np.sum(kmeans_classes == 1)
            ).as_hex()
        )
        + list(
            sns.blend_palette(
                ["#a0f0d0", "#4093bf", "#084d96"], np.sum(kmeans_classes == 2)
            ).as_hex()
        )
        + list(
            sns.blend_palette(
                ["#e0f3a4", "#66cc66", "#006435"], np.sum(kmeans_classes == 3)
            ).as_hex()
        )
    )
    return color_key, cluster_order, leaf_nbrs, embedding_rep

def create_cluster_layer_color_key(tree, layer, embedding_rep, leaf_nbrs, leaf_color_key, leaf_dict):
    cluster_labels = create_labels(tree, layer, embedding_rep.shape[0])
    cluster_centroids = np.array(
        [
            np.mean(embedding_rep[cluster_labels == i], axis=0)
            for i in range(np.max(cluster_labels) + 1)
        ]
    )
    leaf_dists, leaf_inds = leaf_nbrs.kneighbors(cluster_centroids)
    leaf_dists += np.finfo(np.float32).eps
    color_key = [
        avg_colors(
            [leaf_color_key[leaf_dict[x]] for x in leaves],
            np.nan_to_num(1.0 / (leaf_dists[i])),
        )
        for i, leaves in enumerate(leaf_inds)
    ]
    return cluster_labels, color_key

Now for plotting. To make this useful it really needs to be interactive. I experimented with a few options for this, but Bokeh was the easiest for me to get quick results. Ideally PyDeck would do a good job of this, but I struggled to get the wordclouds working well with PyDeck -- likely due to my lack of expertise in deck.gl. So, bokeh it is.

In [10]:
from bokeh.io import curdoc, show, output_notebook, output_file
from bokeh.models import (
    ColumnDataSource,
    Grid,
    LinearAxis,
    Plot,
    Text,
    CustomJS,
    ImageRGBA,
    Range1d,
    Slider,
    DataTable,
    TableColumn,
    HTMLTemplateFormatter,
    Div,
    LassoSelectTool,
    TapTool,
    BoxSelectTool,
    Button,
)
from bokeh.plotting import figure, Figure
from bokeh.layouts import column, row

Now we need some plotting helper functions. First something to generate the word cloud and populate data for a bokeh ColumnDataSource with the relevant information. We'll also need to be able to generate a KDE for each cluster and from that generate: a mask for the word cloud; and a "glow" effect based on the KDE. The latter is handled by a bokeh ``ImageRGBA`` class, but could equally well be handled by a ``contourf`` style effect in matplotlib or a ``Heatmap`` in PyDeck if we were using those instead.

Lastly we have a giant function to handle generating all the data and plot pieces for a single cluster layer -- the word clouds, the ``ImageRGBA`` for a glow effect, and the appropriately coloured scatterplot of the individual documents. We also need a custom javascript callback so that the text size in the wordcloud scales with the zoom so that we can "zoom in" and see the smaller words in the word clouds.

**Note**: the glow effect has been disabled in this version as it significantly bloats the resulting HTML output (it costs a lot of memory to store all that data). It works fine locally, but is poor for putting on a remote sight. The code has been left in (but commented out) so it can be re-enabled easily.

In [20]:
def add_word_cloud(column_data, word_cloud, size, extent, color, font_scaling=0.66):
    raw_height = size[0]
    raw_width = size[1]
    height = extent[3] - extent[2]
    width = extent[1] - extent[0]
    x_scaling = width / raw_width
    y_scaling = height / raw_height
    max_scaling = max(x_scaling, y_scaling)
    for row in word_cloud.layout_:
        column_data["x"].append(row[2][1] * x_scaling + extent[0])
        column_data["y"].append((raw_height - row[2][0]) * y_scaling + extent[2])
        column_data["text"].append(row[0][0])
        column_data["angle"].append(np.pi / 2 if row[3] is not None else 0.0)
        column_data["align"].append("right" if row[3] is not None else "left")
        column_data["baseline"].append("top" if row[3] is not None else "top")
        column_data["color"].append(color)
        column_data["base_size"].append(f"{(row[1] * font_scaling) * max_scaling}px")
        column_data["current_size"].append(f"{(row[1] * font_scaling) * max_scaling}px")
    return column_data


def kde_for_cluster(
    cluster_embedding, approx_patch_size, cluster_epsilon, kernel_bandwidth_multiplier, color
):
    kernel_bandwidth = min(
        kernel_bandwidth_multiplier * np.power(cluster_epsilon, 0.75),
        kernel_bandwidth_multiplier,
    )
    xmin, xmax = (
        np.min(cluster_embedding.T[0]) - 8 * kernel_bandwidth,
        np.max(cluster_embedding.T[0]) + 8 * kernel_bandwidth,
    )
    ymin, ymax = (
        np.min(cluster_embedding.T[1]) - 8 * kernel_bandwidth,
        np.max(cluster_embedding.T[1]) + 8 * kernel_bandwidth,
    )
    width = xmax - xmin
    height = ymax - ymin
    aspect_ratio = width / height
    patch_size = min(
        max(max(width, height) * approx_patch_size / 6.0, approx_patch_size), 256
    )
    patch_width = int(patch_size * aspect_ratio)
    patch_height = int(patch_size)
    xs = np.linspace(xmin, xmax, patch_width)
    ys = np.linspace(ymin, ymax, patch_height)
    xv, yv = np.meshgrid(xs, ys[::-1])
    for_scoring = np.vstack([xv.ravel(), yv.ravel()]).T

    cluster_kde = KernelDensity(bandwidth=kernel_bandwidth, kernel="gaussian").fit(
        cluster_embedding
    )
    base_zv = cluster_kde.score_samples(for_scoring).reshape(xv.shape)
    zv = rescale(base_zv, 4)
    mask = (np.exp(zv) < 2e-2) * 0xFF

    img = np.empty((zv.shape[0], zv.shape[1]), dtype=np.uint32)
    view = img.view(dtype=np.uint8).reshape((zv.shape[0], zv.shape[1], 4))
    view[:, :, :] = (
        255
        * np.tile(
            matplotlib.colors.to_rgba(color), (zv.shape[0], zv.shape[1], 1),
        )
    ).astype(np.uint8)
    view[:, :, 3] = np.round(128 * np.exp(zv - np.max(zv))).astype(np.uint8)

    return mask, img, (xmin, xmax, ymin, ymax)


def topic_word_by_cluster_layer(
    plot,
    layer_index,
    doc_vectors,
    word_vectors,
    cluster_labels,
    cluster_epsilons,
    umap_embedding,
    index_to_word_fn,
    color_key,
    n_neighbors=150,
    kernel_bandwidth_multiplier=0.2,
    approx_patch_size=64,
):
    unique_clusters = np.unique(cluster_labels)
    unique_clusters = unique_clusters[unique_clusters >= 0]
    word_nbrs = NearestNeighbors(n_neighbors=n_neighbors, metric="cosine").fit(
        word_vectors
    )
    cluster_centroids = [
        np.mean(doc_vectors[cluster_labels == label], axis=0)
        for label in unique_clusters
    ]
    topic_word_dists, topic_word_indices = word_nbrs.kneighbors(cluster_centroids)

    word_cloud_source = dict(
        x=[],
        y=[],
        text=[],
        angle=[],
        align=[],
        baseline=[],
        color=[],
        base_size=[],
        current_size=[],
    )

    img_source = dict(image=[], x=[], y=[], dh=[], dw=[])

    for i, label in enumerate(unique_clusters):
        topic_words_and_freqs = {
            index_to_word_fn(idx): (1.0 - topic_word_dists[i, j])
            for j, idx in enumerate(topic_word_indices[i])
        }

        cluster_embedding = umap_embedding[cluster_labels == label]

        mask, img, extent = kde_for_cluster(
            cluster_embedding,
            approx_patch_size,
            cluster_epsilons[i],
            kernel_bandwidth_multiplier,
            color_key[label],
        )

        img_source["image"].append(img[::-1])
        img_source["x"].append(extent[0])
        img_source["y"].append(extent[2])
        img_source["dw"].append(extent[1] - extent[0])
        img_source["dh"].append(extent[3] - extent[2])

        color_func = lambda *args, **kwargs: color_key[label]
        wc = wordcloud.WordCloud(
            font_path="/home/leland/.fonts/consola.ttf",
            mode="RGBA",
            relative_scaling=1,
            min_font_size=1,
            max_font_size=128,
            background_color=None,
            color_func=color_func,
            mask=mask,
        )
        wc.fit_words(topic_words_and_freqs)
        word_cloud_source = add_word_cloud(
            word_cloud_source,
            wc,
            color=color_key[label],
            size=img.shape[:2],
            extent=extent,
        )
    scatter_source = ColumnDataSource(
        dict(
            x=umap_embedding.T[0],
            y=umap_embedding.T[1],
            color=[
                color_key[label] if label >= 0 else "#aaaaaa"
                for label in cluster_labels
            ],
        )
    )
#     img_source = ColumnDataSource(img_source)
    word_cloud_source = ColumnDataSource(word_cloud_source)
    scatter_renderer = plot.circle(
        x="x",
        y="y",
        color="color",
        radius=2e-2,
        alpha=0.25,
        line_alpha=0.0,
        level="glyph",
        source=scatter_source,
        tags=[f"layer{layer_idx}"],
        selection_alpha=1.0,
    )
#     image_renderer = plot.image_rgba(
#         image="image",
#         x="x",
#         y="y",
#         dw="dw",
#         dh="dh",
#         source=img_source,
#         level="underlay",
#         tags=[f"layer{layer_idx}"],
#     )
    glyph = Text(
        x="x",
        y="y",
        text="text",
        angle="angle",
        text_color="color",
        text_font={"value": "Consolas"},
        text_font_size="current_size",
        text_align="align",
        text_baseline="baseline",
        text_line_height=1.0,
        tags=[f"layer{layer_idx}"],
    )
    text_renderer = plot.add_glyph(word_cloud_source, glyph, level="glyph")
    text_callback = CustomJS(
        args=dict(source=word_cloud_source),
        code="""
        var data = source.data;
        var scale = 1.0 / ((cb_obj.end - cb_obj.start) / 800);
        var base_size = data['base_size'];
        var current_size = data['current_size'];
        for (var i = 0; i < base_size.length; i++) {
            current_size[i] = (scale * parseFloat(base_size[i])) + "px";
        }
        source.change.emit();
    """,
    )
    plot.x_range.js_on_change("start", text_callback)
    plot.lod_threshold = None
    plot.background_fill_color = "black"
    plot.axis.ticker = []
    plot.grid.grid_line_color = None

#     return [text_renderer, scatter_renderer, img_renderer], scatter_source
    return [text_renderer, scatter_renderer], scatter_source

Okay, now we are ready. We set an output for the plot.

In [21]:
output_file("bokeh_20newsgroups_topics_map_20210526_compressed.html", title="Topic Map of 20 Newsgroups")

From here it is a matter of just building the final plot layer by layer. We can add a slider to allow interactively moving through layers, a way to visualize selected posts in an HTML div, and, while we are at it, a download button to download the contents of selected posts. This ends up being a lot of code, but much of it is plotting boilerplate and setting up all the various interactions so that they are all handled with javascript callbacks making the final html output entirely self-contained.

In [23]:
doc_vectors = lat_doc_vectors
low_dim_rep = umap.UMAP(
    metric="cosine", n_components=5, min_dist=1e-4, random_state=42, n_epochs=500
).fit_transform(doc_vectors)
clusterer = hdbscan.HDBSCAN(min_cluster_size=25, cluster_selection_method="leaf").fit(
    low_dim_rep
)
tree = clusterer.condensed_tree_.to_pandas()
max_lambda = tree.lambda_val[tree.child_size > 1].max()
min_lambda = tree.lambda_val[tree.child_size > 1].min()
layers = [
    clusters_at_level(tree, level)
    for level in np.linspace(min_lambda, max_lambda, 9, endpoint=True)[1:-1]
]
epsilons = [
    clusters_eps_at_level(tree, level)
    for level in np.linspace(min_lambda, max_lambda, 9, endpoint=True)[1:-1]
]
leaf_color_key, leaf_dict, leaf_nbrs, embedding_rep = create_leaf_color_key(
    clusterer, doc_vectors
)


layer_plot_elements = []
scatterplot_sources = []
plot = Figure(title="20-Newsgroups Topic Map Explorer", plot_width=800, plot_height=800)
lasso_selector = LassoSelectTool()
plot.add_tools(lasso_selector)
plot.add_tools(TapTool())
plot.add_tools(BoxSelectTool())

for layer_idx, layer in enumerate(layers):

    cluster_labels, color_key = create_cluster_layer_color_key(
        tree, layer, embedding_rep, leaf_nbrs, leaf_color_key, leaf_dict
    )
    layer_renderers, scatter_source = topic_word_by_cluster_layer(
        plot,
        layer_idx,
        lat_doc_vectors,
        lat_word_vectors,
        cluster_labels,
        epsilons[layer_idx],
        embedding_rep,
        lambda x: word_vectorizer.token_index_dictionary_[x],
        color_key,
        n_neighbors=int(2000 / (1 + layer_idx)),
    )
    layer_plot_elements.append(layer_renderers)
    scatterplot_sources.append(scatter_source)

for layer_elements in layer_plot_elements[1:]:
    for element in layer_elements:
        element.visible = False

document_source = ColumnDataSource(
    dict(document=news_data, newsgroup=[news.target_names[x] for x in targets])
)
div_of_text = Div(
    text="<h3 style='color:#2F2F2F;text-align:center;padding:150px 0px;'>up to 100 selected posts display here</h3>",
    width=800,
    height=600,
    style={"overflow-y": "scroll", "height": "350px", "width": "780px"},
)

slider_callback = CustomJS(
    args=dict(layers=layer_plot_elements),
    code="""
        var selected_layer = cb_obj.value;
        for (var i = 0; i < layers.length; i++) {
            for (var j = 0; j < layers[i].length; j++) {
                if (selected_layer - 1 == i) {
                    layers[i][j].visible = true;
                } else {
                    layers[i][j].visible = false;
                }
            }
        }
""",
)
selection_callback_div = CustomJS(
    args=dict(document_source=document_source, div=div_of_text),
    code="""
        var inds = cb_obj.indices;
        var d1 = document_source.data;
        div.text = "";
        for (var i = 0; i < inds.length && i < 100; i++) {
            div.text += "<h3 style='text-align:center;color:#2F2F2F;''>" + d1['newsgroup'][inds[i]] + "</h3>";
            div.text += "<pre style='color:#444444;background-color:#dddddd;'>" + d1['document'][inds[i]] + "</pre><p/>";
        }
        div.change.emit();
    """,
)
for scatter_source in scatterplot_sources:
    scatter_source.selected.js_on_change("indices", selection_callback_div)

plot.title.text = "20-Newsgroups Topic Map Explorer"
plot.title.text_font_size = "26px"
plot.title.align = "center"
plot.title.text_color = "#3F3F3F"

download_callback = CustomJS(
    args=dict(document_source=document_source, scatter_sources=scatterplot_sources),
    code="""
function download(filename, content) {
  var element = document.createElement('a');
  element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(content));
  element.setAttribute('download', filename);

  element.style.display = 'none';
  document.body.appendChild(element);

  element.click();

  document.body.removeChild(element);
}

// Find selection; get content
var csv_content = "";
var docs = document_source.data;
for (var i = 0; i < scatter_sources.length; i++) {
    var sel_inds = scatter_sources[i].selected.indices;
    for (var j = 0; j < sel_inds.length; j++) {
        var ind = sel_inds[j];
        var doc_content = docs['document'][ind].replace(/\\n/g, "\\\\n").replace(/"/g, "'")
        csv_content += ind.toString() + "," + docs['newsgroup'][ind] + ',"' + doc_content + '"\\n';
    }
}

// Start file download.
download("selected_posts.csv", csv_content);
""",
)

layer_slider = Slider(
    start=1,
    end=len(layers),
    value=1,
    step=1,
    title="Cluster Layer (deeper layers have finer clustering)",
)
layer_slider.js_on_change("value", slider_callback)
download_button = Button(label="Download selected posts", button_type="success")
download_button.js_on_click(download_callback)
layout = column(plot, layer_slider, div_of_text, download_button)
show(layout)