# Visualization

This notebook creates comprehensive visualizations, generates publication-ready plots, and creates interactive dashboards.


In [None]:
import sys
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sys.path.append(str(Path().resolve().parent))
from src import visualization, data_preprocessing, network_builder, feature_extractor, models
import config

plt.style.use('seaborn-v0_8')
sns.set_palette('Set2')


## Load Data and Build Network


In [None]:
# Load data
df = data_preprocessing.create_sample_dataset(n_samples=1000)

# Build network
G = network_builder.build_interaction_graph(df, user_column="user_id")
communities = network_builder.detect_communities(G)

print(f"Data loaded: {len(df)} samples")
print(f"Network: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


## Network Visualizations


In [None]:
# Color nodes by label (misinformation vs legitimate)
if 'label' in df.columns:
    user_labels = df.groupby('user_id')['label'].first().to_dict()
    node_colors = {
        node: 'red' if user_labels.get(node, 0) == 1 else 'blue' 
        for node in G.nodes()
    }
    title = "Social Network - Misinformation Spreaders (Red) vs Legitimate (Blue)"
else:
    # Color by community
    node_colors = {node: f"C{comm_id % 10}" for node, comm_id in communities.items()}
    title = "Social Network with Communities"

# Plot network (sample if too large)
if G.number_of_nodes() > 100:
    nodes_sample = list(G.nodes())[:100]
    G_viz = G.subgraph(nodes_sample)
    node_colors_viz = {node: node_colors.get(node, 'gray') for node in nodes_sample}
    visualization.plot_network_graph(
        G_viz,
        node_colors=node_colors_viz,
        title=f"{title} (Sample of 100 nodes)",
        layout="spring"
    )
else:
    visualization.plot_network_graph(
        G,
        node_colors=node_colors,
        title=title,
        layout="spring"
    )


## Temporal Propagation Patterns


In [None]:
# Plot temporal propagation patterns
if 'timestamp' in df.columns and 'label' in df.columns:
    visualization.plot_temporal_propagation(df, "timestamp", "label")
else:
    print("Timestamp or label column not found. Skipping temporal visualization.")


## Feature Visualizations


In [None]:
# Extract features and visualize
extractor = feature_extractor.FeatureExtractor(use_bert=False)
features = extractor.extract_all_features(df, text_column="text", user_column="user_id", timestamp_column="timestamp")

# Plot feature correlations
if len(features.columns) > 1:
    key_features = features.select_dtypes(include=[np.number]).iloc[:, :10]
    visualization.plot_feature_correlations(key_features)


## Model Performance Visualizations


In [None]:
# Train a quick model for visualization
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
train_features = extractor.extract_all_features(train_df, text_column="text", user_column="user_id", timestamp_column="timestamp")
test_features = extractor.extract_all_features(test_df, text_column="text", user_column="user_id", timestamp_column="timestamp")

X_train = train_features.values
y_train = train_df['label'].values
X_test = test_features.values
y_test = test_df['label'].values

# Train model
rf_model = models.TraditionalMLModel("random_forest", n_estimators=50)
rf_model.train(X_train, y_train)

# Visualize
y_pred = rf_model.predict(X_test)
y_proba = rf_model.predict_proba(X_test)[:, 1]

visualization.plot_confusion_matrix(y_test, y_pred, class_names=["Real", "Fake"])
visualization.plot_roc_curve(y_test, y_proba, "Random Forest")

# Feature importance
feature_names = train_features.columns.tolist()
visualization.plot_feature_importance(rf_model.model, feature_names, top_n=10)


## Interactive Network Visualization (Optional)


In [None]:
# Create interactive network visualization (if Plotly is available)
try:
    visualization.create_interactive_network(
        G_viz if G.number_of_nodes() > 100 else G,
        node_colors=node_colors_viz if G.number_of_nodes() > 100 else node_colors,
        output_path=Path("../data/networks/interactive_network.html")
    )
    print("Interactive network saved to ../data/networks/interactive_network.html")
except Exception as e:
    print(f"Could not create interactive visualization: {e}")
    print("This requires Plotly to be installed.")
