In [None]:
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import TruncatedSVD
from wordcloud import WordCloud

# ------------------------
# 1. Read Data and Construct Graph
# ------------------------
filename = 'updated_chatgpt_reddit_comments.csv'
data = pd.read_csv(filename)

# Ensure the IDs are treated as strings
data['comment_id'] = data['comment_id'].astype(str)
data['comment_parent_id'] = data['comment_parent_id'].astype(str)

# Extract columns
comment_ids = data['comment_id']
parent_ids = data['comment_parent_id']

# Collect all unique nodes
all_nodes = pd.unique(pd.concat([comment_ids, parent_ids]))

# Create a directed graph
G = nx.DiGraph()
G.add_nodes_from(all_nodes)
edges = list(zip(comment_ids, parent_ids))
G.add_edges_from(edges)
# Remove self loops
G.remove_edges_from(nx.selfloop_edges(G))

# ------------------------
# 2. Plot the Graph
# ------------------------
# Use a spring layout (force-directed)
pos = nx.spring_layout(G, seed=42)  # seed for reproducibility

fig, ax = plt.subplots(figsize=(12, 8))
nx.draw_networkx_nodes(G, pos, node_size=50, ax=ax)
nx.draw_networkx_edges(G, pos, ax=ax, arrows=False, alpha=0.3)
ax.set_title('Comment Thread Graph')
ax.axis('off')

# Global storage for use in callback
global_data = {
    'G': G,
    'data': data,
    'pos': pos
}

# ------------------------
# 3. Define Mouse Click Callback Function
# ------------------------
def on_click(event):
    # Ignore clicks outside the axes
    if event.xdata is None or event.ydata is None:
        return
    click_point = np.array([event.xdata, event.ydata])
    
    # Find the nearest node by comparing distances from click_point to node positions
    pos_array = np.array(list(global_data['pos'].values()))
    nodes = list(global_data['pos'].keys())
    distances = np.linalg.norm(pos_array - click_point, axis=1)
    min_dist = np.min(distances)
    idx = np.argmin(distances)
    
    # Define a threshold (may need tuning depending on your layout scale)
    threshold = 0.05
    if min_dist > threshold:
        print("Click was not close enough to any node.")
        return
    
    clicked_node = nodes[idx]
    print("Clicked comment_id:", clicked_node)
    
    # Retrieve corresponding serial_number from the data table (if available)
    row = global_data['data'][global_data['data']['comment_id'] == clicked_node]
    if not row.empty:
        serial_number = row.iloc[0]['serial_number']
        print("Serial number for clicked comment:", serial_number)
    else:
        print("Clicked comment_id not found in data table.")
    
    # ------------------------
    # Find Descendant Nodes
    # ------------------------
    # Reverse the graph to follow child-to-parent edges
    G_rev = global_data['G'].reverse(copy=False)
    # Use networkx.descendants to get all nodes reachable from clicked_node
    descendant_nodes = nx.descendants(G_rev, clicked_node)
    print("Descendant nodes:", descendant_nodes)
    
    # ------------------------
    # Filter Data and Display Comment Details
    # ------------------------
    all_ids = set([clicked_node]) | descendant_nodes
    subData = global_data['data'][global_data['data']['comment_id'].isin(all_ids)]
    print(f"Performing topic analysis on {len(subData)} comments...")
    
    # Build a multi-line string with comment details
    details_str = ""
    for idx_row, row in subData.iterrows():
        details_str += f"comment_id: {row['comment_id']} | serial_number: {row['serial_number']}\n"
        details_str += f"Comment: {row['comment_body']}\n\n"
    
    # Display details in a new figure (simple text display)
    fig_details, ax_details = plt.subplots(figsize=(8, 6))
    ax_details.text(0.01, 0.99, details_str, va='top', ha='left', fontsize=8, wrap=True)
    ax_details.axis('off')
    fig_details.canvas.manager.set_window_title("Comment Details")
    plt.show(block=False)
    
    # ------------------------
    # Clean and Tokenize the Comment Text
    # ------------------------
    # Convert comment text to lowercase
    corpus = subData['comment_body'].astype(str).str.lower().tolist()
    vectorizer = TfidfVectorizer(stop_words='english')
    tfidfMat = vectorizer.fit_transform(corpus)
    
    # ------------------------
    # Latent Semantic Analysis (LSA) and Word Cloud
    # ------------------------
    # Use TruncatedSVD to perform SVD on the TF-IDF matrix
    n_components = 5  # Number of topics to extract (can be adjusted)
    svd = TruncatedSVD(n_components=n_components, random_state=42)
    svd.fit(tfidfMat)
    # Extract the first latent topic (first column of V)
    V = svd.components_.T  # shape: (n_terms, n_components)
    topic_weights = V[:, 0]
    words = vectorizer.get_feature_names_out()
    
    # Sort words by absolute weight in descending order
    sort_idx = np.argsort(np.abs(topic_weights))[::-1]
    topN = min(20, len(words))
    topWords = words[sort_idx[:topN]]
    topWeights = topic_weights[sort_idx[:topN]]
    
    # Generate a word cloud (using absolute values of weights)
    freq_dict = {word: float(abs(weight)) for word, weight in zip(topWords, topWeights)}
    wc = WordCloud(width=800, height=400, background_color='white')
    wc.generate_from_frequencies(freq_dict)
    
    plt.figure()
    plt.imshow(wc, interpolation='bilinear')
    plt.axis('off')
    plt.title(f"Latent Semantic Topic for comment_id: {clicked_node}")
    plt.show()
    
    # Print top words and their weights
    print("Top words in the extracted topic:")
    for word, weight in zip(topWords, topWeights):
        print(f"{word} (weight: {weight:.3f})")

# ------------------------
# 4. Attach the Click Callback to the Figure
# ------------------------
cid = fig.canvas.mpl_connect('button_press_event', on_click)
plt.show()
