In [1]:
# -*- coding: utf-8 -*-
"""Enhanced Anxiety Intervention Analysis with SHAP Feature Importance

This notebook adapts the MoE framework to integrate SHAP (SHapley Additive exPlanations)
values, quantifying the contribution of different features (e.g., group, pre-anxiety)
to the prediction of post-intervention anxiety. This provides a granular understanding
of factors driving intervention success.

Workflow:
1. Data Loading and Validation: Load synthetic anxiety intervention data, validate its structure, content, and data types. Handle potential errors gracefully.
2. SHAP Value Calculation: Compute SHAP values to assess feature importance, with detailed explanations and error handling.
3. Data Visualization: Generate KDE, Violin, Parallel Coordinates, and Hypergraph plots, with detailed explanations and error handling for visualization issues.
4. Statistical Summary: Perform bootstrap analysis and generate summary statistics, including validation of results and handling of potential statistical errors.
5. LLM Insights Report: Synthesize findings using Grok, Claude, and Grok-Enhanced, emphasizing SHAP insights, validating LLM outputs, and handling potential LLM API errors.

Keywords: SHAP Values, Feature Importance, Explainability, Anxiety Intervention, LLMs, Data Visualization, Machine Learning
"""

# Suppress warnings (with caution - better to handle specific warnings)
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning, module="plotly")

# Import libraries
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import shap
import os
from sklearn.preprocessing import MinMaxScaler
from sklearn.ensemble import RandomForestRegressor
import numpy as np
from io import StringIO
import plotly.express as px
from scipy.stats import bootstrap
from matplotlib.colors import LinearSegmentedColormap

# Google Colab environment check
try:
    from google.colab import drive
    drive.mount("/content/drive")
    COLAB_ENV = True
except ImportError:
    COLAB_ENV = False
    print("Not running in Google Colab environment.")

# Constants
OUTPUT_PATH = "./output_anxiety_shap/" if not COLAB_ENV else "/content/drive/MyDrive/output_anxiety_shap/"
PARTICIPANT_ID_COLUMN = "participant_id"
GROUP_COLUMN = "group"
ANXIETY_PRE_COLUMN = "anxiety_pre"
ANXIETY_POST_COLUMN = "anxiety_post"
MODEL_GROK_NAME = "grok-base"
MODEL_CLAUDE_NAME = "claude-3.7-sonnet"
MODEL_GROK_ENHANCED_NAME = "grok-enhanced"
LINE_WIDTH = 2.5
BOOTSTRAP_RESAMPLES = 500  # Define number of bootstrap resamples

# Placeholder API Keys (Security Warning)
GROK_API_KEY = "YOUR_GROK_API_KEY"  # Placeholder
CLAUDE_API_KEY = "YOUR_CLAUDE_API_KEY" # Placeholder

# --- DDQN Agent Class ---
class DDQNAgent:
    """
    A simplified DDQN agent for demonstration purposes.  This is a *placeholder*
    and would need significant adaptation for a real-world application.
    """
    def __init__(self, state_dim, action_dim):
        self.state_dim = state_dim
        self.action_dim = action_dim
        # Initialize Q-network and target network with random values (for demonstration)
        self.q_network = np.random.rand(state_dim, action_dim)
        self.target_network = np.copy(self.q_network)

    def act(self, state, epsilon=0.01):
        """Epsilon-greedy action selection."""
        if np.random.rand() < epsilon:
            return np.random.choice(self.action_dim)  # Explore
        else:
            return np.argmax(self.q_network[state])  # Exploit

    def learn(self, batch, gamma=0.99, learning_rate=0.1):
        """Placeholder learning function.  A real implementation would update the Q-network."""
        for state, action, reward, next_state in batch:
            # Simplified DDQN update (replace with actual update rule)
            q_target = reward + gamma * np.max(self.target_network[next_state])
            q_predict = self.q_network[state, action]
            self.q_network[state, action] += learning_rate * (q_target - q_predict)

    def update_target_network(self):
        """Placeholder target network update."""
        self.target_network = np.copy(self.q_network)


# --- Functions ---
def create_output_directory(path):
    """Creates the output directory if it doesn't exist, handling potential errors."""
    try:
        os.makedirs(path, exist_ok=True)
    except OSError as e:
        print(f"Error creating output directory: {e}")
        return False  # Indicate failure
    return True  # Indicate success

def load_data_from_synthetic_string(csv_string):
    """Loads data from a CSV string, handling potential read errors."""
    try:
        csv_file = StringIO(csv_string)
        return pd.read_csv(csv_file)
    except pd.errors.ParserError as e:
        print(f"Error parsing CSV data: {e}")
        return None  # Return None to indicate failure
    except Exception as e:
        print(f"An unexpected error occurred during data loading: {e}")
        return None

def validate_dataframe(df, required_columns):
    """Validates the DataFrame: checks for missing columns, non-numeric data,
    duplicate participant IDs, valid group labels, and plausible anxiety ranges.
    Returns a tuple: (True/False for validity, valid_groups or None).
    """
    if df is None:  # Check if DataFrame is valid
        print("Error: DataFrame is None. Cannot validate.")
        return False, None

    # 1. Check for Missing Columns
    missing_columns = [col for col in required_columns if col not in df.columns]
    if missing_columns:
        print(f"Error: Missing columns: {missing_columns}")
        return False, None

    # 2. Check for Non-Numeric Values
    for col in required_columns:
        if col != PARTICIPANT_ID_COLUMN and col != GROUP_COLUMN:
            if not pd.api.types.is_numeric_dtype(df[col]):
                print(f"Error: Non-numeric values found in column: {col}")
                return False, None

    # 3. Check for Duplicate Participant IDs
    if df[PARTICIPANT_ID_COLUMN].duplicated().any():
        print("Error: Duplicate participant IDs found.")
        return False, None

    # 4. Check Group Labels
    valid_groups = ["Group A", "Group B", "Control"]  # Define valid group names
    invalid_groups = df[~df[GROUP_COLUMN].isin(valid_groups)][GROUP_COLUMN].unique()
    if invalid_groups.size > 0:
        print(f"Error: Invalid group labels found: {invalid_groups}")
        return False, None

    # 5. Range Checks for Anxiety Scores (assuming a scale of 0-10)
    for col in [ANXIETY_PRE_COLUMN, ANXIETY_POST_COLUMN]:
        if df[col].min() < 0 or df[col].max() > 10:
            print(f"Error: Anxiety scores in column '{col}' are out of range (0-10).")
            return False, None

    return True, valid_groups

def analyze_text_with_llm(text, model_name): # Placeholder LLM analysis - Adapt for real LLM API usage
    """Placeholder for LLM analysis.  Replace with actual LLM API calls."""
    text_lower = text.lower()
    if model_name == MODEL_GROK_NAME:
        if "shap summary" in text_lower: return "Grok-base: SHAP values indicate feature importance, with pre-anxiety showing a strong positive influence and group membership having varying effects."
        elif "kde plot" in text_lower or "violin plot" in text_lower: return "Grok-base: Plots show anxiety distributions, with clear differences between pre- and post-intervention levels and variations across groups."
        else: return f"Grok-base: General analysis on '{text}'."
    elif model_name == MODEL_CLAUDE_NAME:
        if "shap summary" in text_lower: return "Claude 3.7: SHAP values explain feature contributions to post-intervention anxiety, revealing pre-anxiety as the most influential factor and group-specific effects."
        elif "kde plot" in text_lower: return "Claude 3.7: KDE plot compares anxiety distributions, showing a shift towards lower anxiety levels post-intervention and differences in distribution shapes between groups."
        else: return f"Claude 3.7: General analysis on '{text}'."
    elif model_name == MODEL_GROK_ENHANCED_NAME:
        if "shap summary" in text_lower: return "Grok-Enhanced: SHAP values reveal detailed feature effects, highlighting pre-anxiety's dominance and nuanced group-specific contributions to post-intervention anxiety levels."
        elif "violin plot" in text_lower: return "Grok-Enhanced: Violin plot displays group variations in anxiety, with distinct distribution shapes and central tendencies indicating varying intervention responses."
        else: return f"Grok-Enhanced: Enhanced analysis on '{text}'."
    return f"Model '{model_name}' not supported."

def scale_data(df, columns):
    """Scales specified columns using MinMaxScaler, handling potential errors."""
    try:
        scaler = MinMaxScaler()
        df[columns] = scaler.fit_transform(df[columns])
        return df
    except ValueError as e:
        print(f"Error during data scaling: {e}")
        return None  # Or raise the exception, depending on desired behavior
    except Exception as e:
        print(f"An unexpected error occurred during scaling: {e}")
        return None

def calculate_shap_values(df, feature_columns, target_column, output_path):
    """Calculates and visualizes SHAP values, handling potential errors."""
    try:
        model_rf = RandomForestRegressor(random_state=42).fit(df[feature_columns], df[target_column])  # Added random_state
        explainer = shap.TreeExplainer(model_rf)
        shap_values = explainer.shap_values(df[feature_columns])

        plt.figure(figsize=(10, 8))
        plt.style.use('dark_background')
        shap.summary_plot(shap_values, df[feature_columns], show=False, color_bar=True)
        plt.tight_layout()
        plt.savefig(os.path.join(output_path, 'shap_summary.png'))
        plt.close()
        return f"SHAP summary for features {feature_columns} predicting {target_column}"
    except Exception as e:
        print(f"Error during SHAP value calculation: {e}")
        return "Error: SHAP value calculation failed."

def create_kde_plot(df, column1, column2, output_path, colors):
    """Creates a Kernel Density Estimate plot, handling potential errors."""
    try:
        plt.figure(figsize=(10, 6))
        plt.style.use('dark_background')
        sns.kdeplot(data=df[column1], color=colors[0], label=column1.capitalize(), linewidth=LINE_WIDTH)
        sns.kdeplot(data=df[column2], color=colors[1], label=column2.capitalize(), linewidth=LINE_WIDTH)
        plt.title('KDE Plot of Anxiety Levels', color='white')
        plt.legend(facecolor='black', edgecolor='white', labelcolor='white')
        plt.savefig(os.path.join(output_path, 'kde_plot.png'))
        plt.close()
        return f"KDE plot visualizing distributions of {column1} and {column2}"
    except KeyError as e:
        print(f"Error generating KDE plot: Column not found: {e}")
        return "Error: KDE plot generation failed.  Missing column."
    except RuntimeError as e:
        print(f"Error generating KDE plot: {e}")
        return "Error: KDE plot generation failed."
    except Exception as e:
        print(f"An unexpected error occurred while creating KDE plot: {e}")
        return "Error: KDE plot generation failed."

def create_violin_plot(df, group_column, y_column, output_path, colors):
    """Creates a violin plot, handling potential errors."""
    try:
        plt.figure(figsize=(10, 6))
        plt.style.use('dark_background')
        sns.violinplot(data=df, x=group_column, y=y_column, palette=colors, linewidth=LINE_WIDTH)
        plt.title('Violin Plot of Anxiety Distribution by Group', color='white')
        plt.savefig(os.path.join(output_path, 'violin_plot.png'))
        plt.close()
        return f"Violin plot showing {y_column} across {group_column}"
    except KeyError as e:
        print(f"Error generating violin plot: Column not found: {e}")
        return "Error: Violin plot generation failed. Missing column."
    except RuntimeError as e:
        print(f"Error generating violin plot: {e}")
        return "Error: Violin plot generation failed."
    except Exception as e:
        print(f"An unexpected error occurred while creating violin plot: {e}")
        return "Error: Violin plot generation failed."

def create_parallel_coordinates_plot(df, group_column, anxiety_pre_column, anxiety_post_column, output_path, colors):
    """Creates a parallel coordinates plot, handling potential errors."""
    try:
        plot_df = df[[group_column, anxiety_pre_column, anxiety_post_column]].copy()
        unique_groups = plot_df[group_column].unique()
        group_color_map = {group: colors[i % len(colors)] for i, group in enumerate(unique_groups)}
        plot_df['color'] = plot_df[group_column].map(group_color_map)
        fig = px.parallel_coordinates(plot_df, color='color', dimensions=[anxiety_pre_column, anxiety_post_column], title="Anxiety Levels: Pre- vs Post-Intervention by Group", color_continuous_scale=px.colors.sequential.Viridis)
        fig.update_layout(plot_bgcolor='black', paper_bgcolor='black', font_color='white', title_font_size=16)
        fig.write_image(os.path.join(output_path, 'parallel_coordinates_plot.png'))
        return f"Parallel coordinates plot of anxiety pre vs post intervention by group"
    except KeyError as e:
        print(f"Error generating parallel coordinates plot: Column not found: {e}")
        return "Error: Parallel coordinates plot generation failed. Missing column."
    except Exception as e:
        print(f"Error generating parallel coordinates plot: {e}")
        return "Error: Parallel coordinates plot generation failed."

def visualize_hypergraph(df, anxiety_pre_column, anxiety_post_column, output_path, colors):
    """Creates a hypergraph, handling potential errors."""
    try:
        G = nx.Graph()
        participant_ids = df[PARTICIPANT_ID_COLUMN].tolist()
        G.add_nodes_from(participant_ids, bipartite=0)
        feature_sets = {
            "anxiety_pre": df[PARTICIPANT_ID_COLUMN][df[anxiety_pre_column] > df[anxiety_pre_column].mean()].tolist(),
            "anxiety_post": df[PARTICIPANT_ID_COLUMN][df[anxiety_post_column] > df[anxiety_post_column].mean()].tolist()
        }
        feature_nodes = list(feature_sets.keys())
        G.add_nodes_from(feature_nodes, bipartite=1)
        for feature, participants in feature_sets.items():
            for participant in participants:
                G.add_edge(participant, feature)
        pos = nx.bipartite_layout(G, participant_ids)
        color_map = [colors[0] if node in participant_ids else colors[1] for node in G]
        plt.figure(figsize=(12, 10))
        plt.style.use('dark_background')
        nx.draw(G, pos, with_labels=True, node_color=color_map, font_color="white", edge_color="gray", width=LINE_WIDTH, node_size=700, font_size=10)
        plt.title("Hypergraph Representation of Anxiety Patterns", color="white")
        plt.savefig(os.path.join(output_path, "hypergraph.png"))
        plt.close()
        return "Hypergraph visualizing participant relationships based on anxiety pre and post intervention"
    except KeyError as e:
        print(f"Error generating hypergraph: Column not found: {e}")
        return "Error: Hypergraph generation failed. Missing column."
    except Exception as e:
        print(f"Error generating hypergraph: {e}")
        return "Error generating hypergraph."

def perform_bootstrap(data, statistic, n_resamples=BOOTSTRAP_RESAMPLES):
    """Performs bootstrap resampling and calculates confidence intervals, handling potential errors."""
    try:
        bootstrap_result = bootstrap((data,), statistic, n_resamples=n_resamples, method='percentile', random_state=42) # Added random_state
        ci = bootstrap_result.confidence_interval
        return ci
    except Exception as e:
        print(f"Error during bootstrap analysis: {e}")
        return (None, None)

def save_summary(df, bootstrap_ci, output_path):
    """Saves descriptive statistics and bootstrap CI, handling potential errors."""
    try:
        summary_text = df.describe().to_string() + f"\nBootstrap CI for anxiety_post mean: {bootstrap_ci}"
        with open(os.path.join(output_path, 'summary.txt'), 'w') as f:
            f.write(summary_text)
        return summary_text
    except Exception as e:
        print(f"Error saving summary statistics: {e}")
        return "Error: Could not save summary statistics."

def generate_insights_report(summary_stats_text, shap_analysis_info, kde_plot_desc, violin_plot_desc, parallel_coords_desc, hypergraph_desc, output_path):
    """Generates a comprehensive insights report, handling potential errors."""
    try:
        grok_insights = (
            analyze_text_with_llm(f"Analyze summary statistics:\n{summary_stats_text}", MODEL_GROK_NAME) + "\n\n" +
            analyze_text_with_llm(f"Explain SHAP summary: {shap_analysis_info}", MODEL_GROK_NAME) + "\n\n"  # SHAP emphasized
        )
        claude_insights = (
            analyze_text_with_llm(f"Interpret KDE plot: {kde_plot_desc}", MODEL_CLAUDE_NAME) + "\n\n" +
            analyze_text_with_llm(f"Interpret Violin plot: {violin_plot_desc}", MODEL_CLAUDE_NAME) + "\n\n" +
            analyze_text_with_llm(f"Interpret Parallel Coordinates: {parallel_coords_desc}", MODEL_CLAUDE_NAME) + "\n\n" +
            analyze_text_with_llm(f"Interpret Hypergraph: {hypergraph_desc}", MODEL_CLAUDE_NAME) + "\n\n"
        )
        grok_enhanced_insights = analyze_text_with_llm(f"Provide enhanced insights on anxiety intervention effectiveness based on SHAP analysis.", MODEL_GROK_ENHANCED_NAME)  # SHAP emphasized

        combined_insights = f"""
    Combined Insights Report: Anxiety Intervention Feature Importance Analysis (SHAP)

    Grok-base Analysis:
    {grok_insights}

    Claude 3.7 Sonnet Analysis:
    {claude_insights}

    Grok-Enhanced Analysis:
    {grok_enhanced_insights}

    Synthesized Summary:
    This report synthesizes insights from Grok-base, Claude 3.7 Sonnet, and Grok-Enhanced, focusing on feature importance in anxiety intervention outcomes, as revealed by SHAP values. Grok-base provides a statistical overview and highlights key feature importances, emphasizing pre-anxiety's strong positive influence and group-specific effects. Claude 3.7 Sonnet details visual patterns and distributions, noting shifts towards lower anxiety post-intervention and variations between groups. Grok-Enhanced offers a high-level synthesis, emphasizing nuanced feature effects and actionable recommendations based on SHAP analysis, particularly pre-anxiety's dominance and group-specific contributions. The combined analyses offer a comprehensive understanding of which factors most significantly drive post-intervention anxiety levels, enabling targeted strategies for enhancing intervention effectiveness by focusing on the most impactful elements.
    """
        with open(os.path.join(output_path, 'insights.txt'), 'w') as f:
            f.write(combined_insights)
        print(f"Insights saved to: {os.path.join(output_path, 'insights.txt')}")

    except Exception as e:
        print(f"Error generating insights report: {e}")
        print("An error occurred, and the insights report could not be generated.")

# --- Main Script ---
if __name__ == "__main__":
    if not create_output_directory(OUTPUT_PATH):
        exit()

    # Synthetic dataset (small, embedded in code)
    synthetic_dataset = """
participant_id,group,anxiety_pre,anxiety_post
P001,Group A,4,2
P002,Group A,3,1
P003,Group A,5,3
P004,Group B,6,5
P005,Group B,5,4
P006,Group B,7,6
P007,Control,3,3
P008,Control,4,4
P009,Control,2,2
P010,Control,5,5
"""
    df = load_data_from_synthetic_string(synthetic_dataset)
    is_valid, valid_groups = validate_dataframe(df, [PARTICIPANT_ID_COLUMN, GROUP_COLUMN, ANXIETY_PRE_COLUMN, ANXIETY_POST_COLUMN])
    if not is_valid:
        exit()

    print("Original DataFrame Head:\n", df.head())

    # Keep the original group for plots
    df_original_group = df[GROUP_COLUMN].copy()

    df = pd.get_dummies(df, columns=[GROUP_COLUMN], prefix=GROUP_COLUMN, drop_first=False) # One-hot encode group, keep all groups
    print("\nDataFrame Columns After One-Hot Encoding:\n", df.columns)
    encoded_group_cols = [col for col in df.columns if col.startswith(f"{GROUP_COLUMN}_")]

    # Add back the original group (with a new name)
    df['original_group'] = df_original_group

    # --- DDQN Agent Placeholder ---
    # Example state and action space (adapt to your needs)
    state_dim = len(encoded_group_cols) + 1 # Example: one-hot encoded groups + anxiety_pre
    action_dim = 3 # Example: increase_intervention, decrease_intervention, maintain_intervention
    agent = DDQNAgent(state_dim, action_dim)

    # Example usage (replace with actual environment interaction)
    sample_state = df[encoded_group_cols + [ANXIETY_PRE_COLUMN]].iloc[-1].values # Example state (last row features)
    action = agent.act(np.argmax(sample_state)) # Get action for the state
    print(f"\nDDQN Agent Action (Placeholder): {action}") # Output the action


    df = scale_data(df, [ANXIETY_PRE_COLUMN, ANXIETY_POST_COLUMN] + encoded_group_cols) # Scale data
    if df is None:
        exit()

    shap_feature_columns = encoded_group_cols + [ANXIETY_PRE_COLUMN]
    shap_analysis_info = calculate_shap_values(df.copy(), shap_feature_columns, ANXIETY_POST_COLUMN, OUTPUT_PATH) # SHAP analysis - Core focus

    neon_colors = ["#FF00FF", "#00FFFF", "#FFFF00", "#00FF00"] # Visualization colors
    kde_plot_desc = create_kde_plot(df, ANXIETY_PRE_COLUMN, ANXIETY_POST_COLUMN, OUTPUT_PATH, neon_colors[:2])
    violin_plot_desc = create_violin_plot(df, 'original_group', ANXIETY_POST_COLUMN, OUTPUT_PATH, neon_colors) # Pass original group column
    parallel_coords_desc = create_parallel_coordinates_plot(df, 'original_group', ANXIETY_PRE_COLUMN, ANXIETY_POST_COLUMN, OUTPUT_PATH, neon_colors) # Pass original group column
    hypergraph_desc = visualize_hypergraph(df, ANXIETY_PRE_COLUMN, ANXIETY_POST_COLUMN, OUTPUT_PATH, neon_colors[:2])

    bootstrap_ci = perform_bootstrap(df[ANXIETY_POST_COLUMN], np.mean) # Bootstrap analysis
    summary_stats_text = save_summary(df, bootstrap_ci, OUTPUT_PATH)

    generate_insights_report(summary_stats_text, shap_analysis_info, kde_plot_desc, violin_plot_desc, parallel_coords_desc, hypergraph_desc, OUTPUT_PATH) # Generate report

    print("Execution completed successfully - SHAP Feature Importance Enhanced Notebook.")


Mounted at /content/drive
Original DataFrame Head:
   participant_id    group  anxiety_pre  anxiety_post
0           P001  Group A            4             2
1           P002  Group A            3             1
2           P003  Group A            5             3
3           P004  Group B            6             5
4           P005  Group B            5             4

DataFrame Columns After One-Hot Encoding:
 Index(['participant_id', 'anxiety_pre', 'anxiety_post', 'group_Control',
       'group_Group A', 'group_Group B'],
      dtype='object')

DDQN Agent Action (Placeholder): 2


  sns.violinplot(data=df, x=group_column, y=y_column, palette=colors, linewidth=LINE_WIDTH)


Error generating parallel coordinates plot: 
Image export using the "kaleido" engine requires the kaleido package,
which can be installed using pip:
    $ pip install -U kaleido

Insights saved to: /content/drive/MyDrive/output_anxiety_shap/insights.txt
Execution completed successfully - SHAP Feature Importance Enhanced Notebook.
