<a href="https://colab.research.google.com/github/casllmproject/bending_effect/blob/main/A2_1_SBERT_Topic_Modeling_Simul_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This solution uses the BERTopic library, which is a powerful framework that leverages SBERT embeddings, UMAP for dimensionality reduction, and HDBSCAN for clustering to discover topics. This approach directly integrates SBERT for clustering (step 7) to achieve topic modeling (step 5).

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Block 1: Install Libraries and Mount Drive
This block installs the necessary packages (bertopic, sentence-transformers, and their dependencies) and mounts Google Drive to access the dataset.

In [None]:
# Import necessary libraries
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import silhouette_score
from hdbscan import HDBSCAN
from bertopic import BERTopic
from google.colab import drive
import plotly.offline as pyo

# Enable Plotly for Colab
pyo.init_notebook_mode(connected=True)

Block 2: Load and Prepare Data
Here, we load your CSV file, handle any missing text data, and confirm the data structure.

In [None]:
# Define the path to dataset
file_path = "/content/drive/MyDrive/CYON_Analysis_Materials/simulated_responses_re.csv"

# Load the dataset
try:
    df = pd.read_csv(file_path)

    # --- Data Preparation ---
    print(f"Original dataset shape: {df.shape}")

    # Drop rows where the text data is missing
    df = df.dropna(subset=['ed_generatedBody'])

    # Ensure the text column is treated as a string
    df['ed_generatedBody'] = df['ed_generatedBody'].astype(str)

    print(f"Cleaned dataset shape: {df.shape}")

    # Display the first few rows and column info
    print("\nDataFrame Head:")
    display(df.head())

    print("\nDataFrame Info:")
    df.info()

    print("\nUnique values in grouping columns:")
    print(f"gr_per: {df['gr_per'].nunique()} categories")
    print(f"Group: {df['Group'].nunique()} categories")
    print(f"DEM8: {df['DEM8'].nunique()} categories")

except FileNotFoundError:
    print(f"Error: File not found at {file_path}")
    print("Please check the file path and Google Drive permissions.")
except Exception as e:
    print(f"An error occurred while loading the data: {e}")

Block 3: Initialize SBERT Model
Load the SBERT model once here. I'll use 'all-MiniLM-L6-v2', which is a high-performance, fast model suitable for this task.

In [None]:
# Load a pre-trained SBERT model
# This model will be used to convert all text into numerical embeddings
sbert_model = SentenceTransformer('all-MiniLM-L6-v2')

print("SBERT model 'all-MiniLM-L6-v2' loaded successfully.")

Block 4: Define Topic Modeling & Clustering Function
To avoid repeating code, I'll create a function that performs the complete analysis for any given grouping variable. This function will:

Loop through each unique value in the group column (e.g., 'Democrat', 'Republican').

Generate SBERT embeddings for that group's texts.

Use BERTopic to cluster the embeddings and identify topics.

Store all the results (model, topics, embeddings, etc.) in a dictionary.

In [None]:
def perform_topic_analysis(dataframe, text_column, group_column, embedding_model):
    """
    Performs BERTopic analysis for each subgroup within a grouping column.

    This updated version:
    1. Removes English stop words ("and", "in", "of", etc.).
    2. Uses a custom HDBSCAN model with 'leaf' selection
       to find many granular, detailed topics.
    3. Calculates and reports the Silhouette Score for clustering quality.

    Args:
        dataframe (pd.DataFrame): The main DataFrame.
        text_column (str): The name of the column with text data.
        group_column (str): The name of the grouping column (e.g., 'DEM8').
        embedding_model (SentenceTransformer): The pre-loaded SBERT model.

    Returns:
        dict: A dictionary where keys are group names and values are
              dictionaries containing the analysis results.
    """

    print(f"\n--- Starting Analysis for Grouping Variable: {group_column} ---")

    # 1. Define a CountVectorizer to remove English stop words
    vectorizer = CountVectorizer(stop_words='english')

    # 2. --- NEW: Define a granular HDBSCAN model ---
    # set 'cluster_selection_method' to 'leaf' to get the
    # most detailed, granular clusters.
    # set 'min_cluster_size' to 5 (our new min_topic_size).
    hdbscan_model = HDBSCAN(
        min_cluster_size=5,
        min_samples=2,
        cluster_selection_method='leaf',
        prediction_data=True
    )
    # -----------------------------------------------

    results_store = {}
    unique_groups = dataframe[group_column].unique()

    for group in unique_groups:
        print(f"\nProcessing group: '{group}' (from {group_column})")

        # 3. Filter DataFrame for the current group
        group_df = dataframe[dataframe[group_column] == group].copy()
        texts = group_df[text_column].tolist()

        # 4. Check if there is enough data
        min_docs_required = 15 # Still need a reasonable number of docs
        if len(texts) < min_docs_required:
            print(f"Skipping group '{group}': only {len(texts)} documents.")
            print(f"Need at least {min_docs_required} for robust clustering.")
            continue

        # 5. Generate SBERT Embeddings
        print(f"Generating embeddings for {len(texts)} documents...")
        embeddings = embedding_model.encode(texts, show_progress_bar=True)

        # 6. Initialize and Run BERTopic
        topic_model = BERTopic(
            verbose=False,
            calculate_probabilities=True,
            vectorizer_model=vectorizer, # Remove stop words
            hdbscan_model=hdbscan_model
        )

        # 7. Fit the model (this performs clustering)
        print("Fitting BERTopic model (clustering documents)...")
        topics, probs = topic_model.fit_transform(texts, embeddings=embeddings)

        # 8. Calculate Clustering Score (Silhouette Score)
        silhouette_avg = None
        try:
            # Filter out outlier embeddings and labels (topic == -1)
            non_outlier_mask = (topics != -1)
            if np.sum(non_outlier_mask) > 0:
                filtered_embeddings = embeddings[non_outlier_mask]
                filtered_labels = np.array(topics)[non_outlier_mask]

                # Check if we have more than 1 cluster
                if len(np.unique(filtered_labels)) > 1:
                    silhouette_avg = silhouette_score(filtered_embeddings, filtered_labels)
                    print(f"Silhouette Score (ignoring outliers): {silhouette_avg:.4f}")
                else:
                    print("Only one cluster found (or only outliers). Cannot calculate Silhouette Score.")
            else:
                print("No clusters found (only outliers). Cannot calculate Silhouette Score.")

        except Exception as e:
            print(f"Could not calculate Silhouette Score: {e}")

        # 9. Store all results for this group
        group_df['topic'] = topics
        results_store[group] = {
            'model': topic_model,
            'topics': topics,
            'probabilities': probs,
            'embeddings': embeddings,
            'dataframe_with_topics': group_df,
            'texts': texts,
            'silhouette_score': silhouette_avg
        }

        num_topics = len(topic_model.get_topic_info()) - 1
        print(f"Finished group '{group}'. Found {num_topics} topics.")

    print(f"\n--- Completed Analysis for {group_column} ---")
    return results_store

Block 5: Run Analysis for All 3 Grouping Variables.

In [None]:
# The results will be stored in separate dictionaries.

print("Starting analysis for 'gr_per'...")
gr_per_results = perform_topic_analysis(df, 'ed_generatedBody', 'gr_per', sbert_model)

print("\nStarting analysis for 'Group'...")
group_results = perform_topic_analysis(df, 'ed_generatedBody', 'Group', sbert_model)

print("\nStarting analysis for 'DEM8'...")
dem8_results = perform_topic_analysis(df, 'ed_generatedBody', 'DEM8', sbert_model)

print("\n\n--- ALL ANALYSES COMPLETE ---")

In [None]:
# --- Let's inspect the results for the 'Democrat' group ---
target_group_name = 'Democrat'
target_results = dem8_results # Use dem8_results, group_results, or gr_per_results

if target_group_name in target_results:
    print(f"--- Interpreting Topic Modeling Results for: {target_group_name} ---")

    # Get the fitted model and data for this group
    model = target_results[target_group_name]['model']

    # 1. Get the main Topic Info DataFrame
    # Topic -1 consists of outliers (texts that didn't fit any cluster).
    topic_info_df = model.get_topic_info()
    print(f"Found {len(topic_info_df)-1} topics for '{target_group_name}'.")
    print("Top 10 Topics (by size):")
    display(topic_info_df.head(11))

    # 2. Get the words for a specific topic (e.g., Topic 0)
    print("\nWords and scores for Topic 0 (the largest topic):")
    print(model.get_topic(0))

    # 3. Visualize Topic Word Scores (Interactive Bar Chart)
    # This shows the most important words for each topic.
    print("\nGenerating interactive Topic Word Bar Chart...")
    # This plot is interactive: you can hover to see scores.
    fig_bar = model.visualize_barchart(top_n_topics=10) # Show top 10 topics
    fig_bar.show()

    # 4. Visualize Topic Hierarchy (Interactive Dendrogram)
    # This shows how topics (clusters) relate to each other.
    print("\nGenerating interactive Topic Hierarchy...")
    fig_hierarchy = model.visualize_hierarchy(top_n_topics=20) # Show 20 topics
    fig_hierarchy.show()

else:
    print(f"Group '{target_group_name}' was not processed.")
    print("This might be due to insufficient data (less than 15 documents).")
    print(f"Available processed groups: {list(target_results.keys())}")

In [None]:
# --- Let's inspect the results for the 'Republican' group ---
target_group_name = 'Republican'
target_results = dem8_results # Use dem8_results, group_results, or gr_per_results

if target_group_name in target_results:
    print(f"--- Interpreting Topic Modeling Results for: {target_group_name} ---")

    # Get the fitted model and data for this group
    model = target_results[target_group_name]['model']

    # 1. Get the main Topic Info DataFrame
    # This is the easiest way to see all topics.
    # Topic -1 consists of outliers (texts that didn't fit any cluster).
    topic_info_df = model.get_topic_info()
    print(f"Found {len(topic_info_df)-1} topics for '{target_group_name}'.")
    print("Top 10 Topics (by size):")
    display(topic_info_df.head(11))

    # 2. Get the words for a specific topic (e.g., Topic 0)
    print("\nWords and scores for Topic 0 (the largest topic):")
    print(model.get_topic(0))

    # 3. Visualize Topic Word Scores (Interactive Bar Chart)
    # This shows the most important words for each topic.
    print("\nGenerating interactive Topic Word Bar Chart...")
    # This plot is interactive: you can hover to see scores.
    fig_bar = model.visualize_barchart(top_n_topics=10) # Show top 10 topics
    fig_bar.show()

    # 4. Visualize Topic Hierarchy (Interactive Dendrogram)
    # This shows how topics (clusters) relate to each other.
    print("\nGenerating interactive Topic Hierarchy...")
    fig_hierarchy = model.visualize_hierarchy(top_n_topics=20) # Show 20 topics
    fig_hierarchy.show()

else:
    print(f"Group '{target_group_name}' was not processed.")
    print("This might be due to insufficient data (less than 15 documents).")
    print(f"Available processed groups: {list(target_results.keys())}")

In [None]:
from matplotlib import pyplot as plt
_df_7['Count'].plot(kind='line', figsize=(8, 4), title='Count')
plt.gca().spines[['top', 'right']].set_visible(False)

In [None]:
from matplotlib import pyplot as plt
import seaborn as sns
figsize = (12, 1.2 * len(_df_9['Name'].unique()))
plt.figure(figsize=figsize)
sns.violinplot(_df_9, x='Count', y='Name', inner='stick', palette='Dark2')
sns.despine(top=True, right=True, bottom=True, left=True)

In [None]:
from matplotlib import pyplot as plt
import seaborn as sns
figsize = (12, 1.2 * len(_df_8['Name'].unique()))
plt.figure(figsize=figsize)
sns.violinplot(_df_8, x='Topic', y='Name', inner='stick', palette='Dark2')
sns.despine(top=True, right=True, bottom=True, left=True)

In [None]:
# --- Define the target results dictionary ---
target_results = group_results # Use the results from the 'Group' analysis
target_categories = [0, 1, 2, 3] # The categories you want to loop through

print(f"--- Starting Batch Analysis for 'Group' variable ---")

for group_category in target_categories:

    print(f"\n========================================================")
    print(f"--- Starting Analysis for Group Category: {group_category} ---")
    print(f"========================================================")

    if group_category in target_results:

        # --- Block 6: Interpret Topic Modeling Results ---
        print(f"\n--- Interpreting Topics for: {group_category} ---")

        model = target_results[group_category]['model']

        # 1. Get the main Topic Info DataFrame
        topic_info_df = model.get_topic_info()
        print(f"Found {len(topic_info_df)-1} topics for '{group_category}'.")
        print("Top 10 Topics (by size):")
        display(topic_info_df.head(11))

        # 2. Visualize Topic Word Scores (Interactive Bar Chart)
        print("\nGenerating interactive Topic Word Bar Chart...")
        fig_bar = model.visualize_barchart(top_n_topics=10, title=f"Top Topics for Group {group_category}")
        fig_bar.show()

        # 3. Visualize Topic Hierarchy (Interactive Dendrogram)
        print("\nGenerating interactive Topic Hierarchy...")
        fig_hierarchy = model.visualize_hierarchy(top_n_topics=20, title=f"Topic Hierarchy for Group {group_category}")
        fig_hierarchy.show()

        # --- Block 7: Visualizing Text Clustering Results ---
        print(f"\n--- Visualizing Text Clusters for: {group_category} ---")

        # Get the necessary data we stored
        texts_to_plot = target_results[group_category]['texts']
        embeddings_to_plot = target_results[group_category]['embeddings']

        # 4. Visualize Document Clusters (Interactive 2D Scatter Plot)
        print("Generating interactive 2D cluster visualization...")
        fig_clusters = model.visualize_documents(
            texts_to_plot,
            embeddings=embeddings_to_plot,
            width=900,
            height=700,
            title=f"Text Clusters for Group {group_category}"
        )
        fig_clusters.show()

    else:
        print(f"\nGroup '{group_category}' was not processed.")
        print("This might be due to insufficient data.")

    print(f"--- Finished Analysis for Group Category: {group_category} ---")

print("\n--- All 'Group' variable analyses complete ---")