# Implementation of BERTopic for Cluster Renaming
#### This notebook implements BERTopic to analyze, name, and visualize clusters.

# Cell 1: Setup and Import Libraries
#### - This cell imports the necessary libraries, including pandas for data manipulation,
####   datetime for handling dates, and BERTopic for topic modeling.

In [None]:
import os
import pandas as pd
from bertopic import BERTopic
from datetime import datetime
import plotly.express as px

# Set the GPU index
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# Cell 2: Load Environment Variables and File Paths
#### - This cell sets the file paths for the input data, including the cluster data generated
####   in previous steps, and reads the cluster data into a pandas DataFrame for further analysis.
#### - The first few rows of the DataFrame are displayed to ensure the data is loaded correctly.

In [None]:
# Define environment variables
FILESPATH = "/home/tulipan16372/storage_NAS/Misc/Dani_Amaya/sentence-transformers/"
CLUSTERS_DATAFRAME_NAME = "updated_df_cluster.csv"

# Define current date for file naming
current_date = datetime.now().strftime("%Y%m%d")

# Construct the file path with the date prefix
df_cluster_path = os.path.join(FILESPATH, f"{current_date}_{CLUSTERS_DATAFRAME_NAME}")

# Check if the file exists
if not os.path.isfile(df_cluster_path):
    raise FileNotFoundError(f"Cluster data file {df_cluster_path} does not exist.")

# Load the cluster data
df_cluster = pd.read_csv(df_cluster_path)

# Display the first few rows of the dataframe to verify the data
df_cluster.head()


# Cell 3: Initialize BERTopic
#### - Initializes the BERTopic model that will be used to identify topics within each cluster.

In [None]:
topic_model = BERTopic()

# Cell 4: Define Function to Analyze Topics in Clusters
#### - This cell defines a function `analyze_cluster_topics()` that loops through each unique cluster,
####   retrieves the documents associated with that cluster, and applies BERTopic to identify the topics.
#### - The function stores the most representative topic for each cluster and assigns it as the cluster name.

In [None]:
def analyze_cluster_topics(df_cluster, topic_model):
    # Initialize a dictionary to store cluster names
    cluster_names = {}
    
    # Loop through each unique cluster
    for cluster_number in df_cluster['cluster'].unique():
        print(f"\nAnalyzing Cluster {cluster_number}")
        
        # Get the documents for this cluster
        documents_in_cluster = df_cluster[df_cluster['cluster'] == cluster_number]['documents'].dropna().tolist()
        
        if len(documents_in_cluster) > 0:
            # Apply BERTopic to identify topics
            topics, probs = topic_model.fit_transform(documents_in_cluster)
            
            # Get the top topic for this cluster
            topic_info = topic_model.get_topic_info()
            top_topic = topic_info.iloc[1]['Name']  # Assuming top topic is at index 1 (index 0 may be for outliers)
            
            # Rename the cluster based on the top topic
            cluster_names[cluster_number] = top_topic
            print(f"Top topic for Cluster {cluster_number}: {top_topic}")
        else:
            cluster_names[cluster_number] = "No Data"
    
    return cluster_names

# Cell 5: Apply BERTopic and Rename Clusters
#### - This cell calls the function defined in Cell 4 to analyze the clusters.
#### - The identified topics are mapped to their respective clusters, and the clusters are renamed in the DataFrame.
#### - The updated DataFrame with cluster names is displayed for verification.

In [None]:
# Analyze the clusters and assign names
def analyze_cluster_topics(df_cluster, topic_model):
    cluster_names = {}

    # Get unique clusters
    unique_clusters = df_cluster['cluster'].unique()
    
    for cluster_number in unique_clusters:
        # Get the documents belonging to this cluster
        documents_in_cluster = df_cluster[df_cluster['cluster'] == cluster_number]['documents'].dropna().tolist()

        # Check if the cluster contains any documents
        if len(documents_in_cluster) == 0:
            print(f"Cluster {cluster_number} is empty, skipping.")
            cluster_names[cluster_number] = f"Cluster {cluster_number}"  # Default name for empty clusters
            continue

        # Apply BERTopic to identify topics
        try:
            topics, probs = topic_model.fit_transform(documents_in_cluster)
            # Get the top topic for this cluster
            topic_info = topic_model.get_topic_info()

            # Extract the name of the most frequent topic
            if len(topic_info) > 0:
                top_topic = topic_info.iloc[0]['Name']
                cluster_names[cluster_number] = top_topic
            else:
                cluster_names[cluster_number] = f"Unnamed Cluster {cluster_number}"  # Fallback name

        except ValueError as e:
            print(f"Error processing cluster {cluster_number}: {e}")
            cluster_names[cluster_number] = f"Error Cluster {cluster_number}"  # Error fallback name

    return cluster_names

# Analyze the clusters and assign names
cluster_names = analyze_cluster_topics(df_cluster, topic_model)

# Add a new column to the dataframe for the cluster names
df_cluster['cluster_name'] = df_cluster['cluster'].map(cluster_names)

# Display the updated dataframe with cluster names
df_cluster.head()



# Cell 6: Save the Updated DataFrame with Cluster Names
#### - This cell saves the updated DataFrame, which includes the new cluster names, to a CSV file for future use.
#### - The file path of the saved DataFrame is printed for confirmation.

In [None]:
# Save the updated dataframe with cluster names
updated_cluster_name_path = os.path.join(FILESPATH, f"{current_date}_Matt_updated_cluster_names.csv")
df_cluster.to_csv(updated_cluster_name_path, index=False)

# Confirm the file has been saved
print(f"Updated cluster names saved to {updated_cluster_name_path}")


# Cell 7: Visualize the Renamed Clusters with Plotly
#### - This cell uses Plotly to create a dynamic 2D scatter plot of the UMAP-reduced embeddings.
#### - Each point in the plot is colored according to its cluster name, and you can hover over the points
####   to view the associated document and cluster name.

In [None]:
# Create a scatter plot of the UMAP reduced embeddings with cluster names
fig = px.scatter(
    df_cluster, 
    x='umap_x', 
    y='umap_y', 
    color='cluster_name', 
    hover_data=['documents', 'cluster_name'],  # This allows you to hover over points to see their document and cluster name
    title="Cluster Visualization with Renamed Clusters"
)

# Update layout for better visualization
fig.update_layout(
    autosize=True,
    height=800,
    showlegend=True,
    legend_title="Cluster Names"
)

# Show the dynamic plot
fig.show()


# Cell 8: 3D Visualization (Optional)
#### - If 3D UMAP projections are available, this cell creates a dynamic 3D scatter plot of the clusters.
#### - Similar to the 2D plot, the points are colored by cluster name, and you can hover over them for more information.
#### - If 3D projections are not available, a message is printed instead.

In [None]:
# If you have 3D UMAP projections, you can create a 3D scatter plot
if 'umap_z' in df_cluster.columns:
    fig_3d = px.scatter_3d(
        df_cluster, 
        x='umap_x', 
        y='umap_y', 
        z='umap_z',
        color='cluster_name', 
        hover_data=['documents', 'cluster_name'], 
        title="3D Cluster Visualization with Renamed Clusters"
    )

    # Update layout for better visualization
    fig_3d.update_layout(
        autosize=True,
        height=800,
        showlegend=True,
        legend_title="Cluster Names"
    )

    # Show the 3D dynamic plot
    fig_3d.show()
else:
    print("3D UMAP projections not available.")
