# Clustering polls into topics
> Using titles/descriptions of polls for clustering. Goal is unsupervised grouping into topicsish based on [this](https://radimrehurek.com/gensim/auto_examples/core/run_topics_and_transformations.html#sphx-glr-auto-examples-core-run-topics-and-transformations-py).

> Note: You may need to run `python -m spacy download de_core_news_sm`, if not already done, to process the German language.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pprint

import pandas as pd
import polars as pl
from bundestag.fine_logging import setup_logging
import logging
from bundestag.paths import get_paths
from bundestag.data.transform.abgeordnetenwatch.transform import get_polls_parquet_path
from bundestag.ml.poll_clustering import (
    SpacyTransformer,
    clean_text,
    compare_word_frequencies,
)
from gensim.models.coherencemodel import CoherenceModel
from plotnine import (
    aes,
    geom_histogram,
    ggplot,
    labs,
    scale_fill_manual,
    geom_line,
    geom_area,
    geom_bar,
    geom_point,
    scale_x_continuous,
)
from functools import partial

logger = logging.getLogger(__name__)
setup_logging(logging.INFO)

paths = get_paths("../data")
paths

In [None]:
legislature_ids = [67, 83, 97, 111, 132, 161]
_df_polls = []
for legislature_id in legislature_ids:
    file = get_polls_parquet_path(legislature_id, paths.preprocessed_abgeordnetenwatch)
    tmp = pl.read_parquet(file)
    _df_polls.append(tmp)
df_polls = pl.concat(_df_polls)
df_polls.head(3)

In [None]:
df_polls = df_polls.with_columns(pl.col("poll_date").str.to_date())

## Clustering based on poll title



Sanity checking word counts, longest and shortest titles

### Cleaning using spacy

https://towardsdatascience.com/end-to-end-topic-modeling-in-python-latent-dirichlet-allocation-lda-35ce4ed6b3e0

In [None]:
# !python -m spacy download de_core_news_sm

In [None]:
col = "poll_title"
nlp_col = f"{col}_nlp_processed"

In [None]:
st = SpacyTransformer()
df_polls = df_polls.with_columns(
    **{
        nlp_col: pl.col(col).map_elements(
            partial(clean_text, nlp=st.nlp), return_dtype=pl.List(pl.String)
        )
    }
)

In [None]:
df_polls.head(3)

### Inspecting word frequencies

In [None]:
compare_word_frequencies(df_polls, col, nlp_col)

The word count distribution shifted to lower values, as could be expected, but no documents were left without any words.

### Transforming using LDA

Let's first find a suitable `num_topics` value

In [None]:
num_topics_grid = [5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

_stats = []
for num_topics in num_topics_grid:
    st.fit_lda(df_polls[nlp_col].to_list(), num_topics=num_topics)
    # compute metrics - https://radimrehurek.com/gensim/models/coherencemodel.html
    cm = CoherenceModel(model=st.lda_model, corpus=st.corpus, coherence="u_mass")
    coherence = cm.get_coherence()

    log_perplexity = st.lda_model.log_perplexity(st.corpus)
    _stats.append(
        {
            "num topics": num_topics,
            "coherence": coherence,
            "log perplexity": log_perplexity,
        }
    )

stats = pl.from_dicts(_stats)
stats = stats.unpivot(index="num topics", variable_name="metric", value_name="value")

In [None]:
stats.head()

Visualize (coherence spike and perplexity drop indicates an interesting number of topics)

In [None]:
(
    ggplot(stats, aes("num topics", "value", color="metric"))
    + geom_line()
    + scale_x_continuous(breaks=num_topics_grid)
)

Fit again with the chosen `num_topics` value

In [None]:
num_topics = 8
st.fit_lda(df_polls[nlp_col].to_list(), num_topics=num_topics)

In [None]:
df_polls = st.transform(df_polls, col=nlp_col)
df_polls.head()

In [None]:
tmp = df_polls.select(["poll_date", "poll_id"] + st.nlp_cols).unpivot(
    index=["poll_date", "poll_id"], value_name="weight", variable_name="topic"
)

tmp = (
    tmp.group_by(["poll_date", "topic"])
    .agg(**{"n polls": pl.col("poll_id").n_unique(), "weight": pl.col("weight").sum()})
    .with_columns(**{"normalized weight": pl.col("weight") / pl.col("n polls")})
)
tmp.head()

In [None]:
(
    ggplot(tmp, aes(x="poll_date", y="normalized weight", color="topic", fill="topic"))
    + geom_point()
)

In [None]:
tmp = df_polls.select(["poll_date", "poll_id"] + st.nlp_cols).unpivot(
    index=["poll_date", "poll_id"], value_name="weight", variable_name="topic"
)

tmp = (
    tmp.group_by([pl.col("poll_date").dt.year(), "topic"])
    .agg(**{"n polls": pl.col("poll_id").n_unique(), "weight": pl.col("weight").sum()})
    .with_columns(**{"normalized weight": pl.col("weight") / pl.col("n polls")})
)
tmp.head()

In [None]:
(
    ggplot(tmp, aes(x="poll_date", y="normalized weight", color="topic", fill="topic"))
    + geom_area()
)

In [None]:
print("Discovered topics:")
pprint.pprint(st.lda_topics)