In [None]:
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, List, Tuple
import numpy as np
import pandas as pd
from scipy import stats
from statsmodels.stats.multitest import multipletests
import seaborn as sns
import matplotlib.pyplot as plt

@dataclass
class AttentionAnalysisResults:
    """Stores results from attention weight analysis"""
    val_type_effects: pd.DataFrame
    val_class_effects: pd.DataFrame
    significance_tests: pd.DataFrame
    top_attention_shifts: pd.DataFrame

@dataclass
class DiagnosisTestResults:
    """Stores results from diagnosis statistical testing"""
    test_results: pd.DataFrame
    significant_diagnoses: List[str]

class ResultsAnalyzer:
    def __init__(self, results_dir: str):
        self.results_dir = Path(results_dir)
        self.diagnosis_dfs = {}
        self.attention_dfs = {}
        self.notes_dfs = {}
        
        self.VALS_TYPES = {
            "pejorative": frozenset(["non_compliant", "uncooperative", "resistant", "difficult"]),
            "laudatory": frozenset(["compliant", "cooperative", "pleasant", "respectful"]),
            "neutralval": frozenset(['neutral'])
        }
        self._load_results()
    
    def _load_results(self):
        """Load all results files from the directory"""
        self._load_csv_files("_diagnosis.csv", self.diagnosis_dfs)
        self._load_csv_files("_attention.csv", self.attention_dfs)
        self._load_csv_files("_clinical_notes.csv", self.notes_dfs)

    def _load_csv_files(self, file_pattern: str, dfs_dict: Dict[str, pd.DataFrame]):
        """Helper function to load CSV files matching the pattern"""
        for file in self.results_dir.glob(f"*{file_pattern}"):
            name = file.stem.split('_')[0]
            df = pd.read_csv(file)
            
            if 'AttentionWeight' in df.columns:
                df['AttentionWeight'] = pd.to_numeric(df['AttentionWeight'], errors='coerce')
            
            categorical_cols = ['Valence', 'Val_class', 'Word']
            for col in categorical_cols:
                if col in df.columns:
                    df[col] = df[col].astype('category')
            
            dfs_dict[name] = df
    
    def analyze_attention_patterns(self) -> AttentionAnalysisResults:
        """Analyze attention weight changes across valences and word types"""
        combined_attention = pd.concat(self.attention_dfs.values(), ignore_index=True)
        
        mask_neutralize = combined_attention["Val_class"] == "neutralize"
        baseline = (combined_attention[mask_neutralize]
                   .groupby("Word", observed=True)["AttentionWeight"]
                   .mean()
                   .reset_index())
        
        baseline_dict = dict(zip(baseline["Word"], baseline["AttentionWeight"]))
        baseline_words = set(baseline["Word"])
        
        val_effects = []
        
        for val_class, vals in self.VALS_TYPES.items():
            for val in vals:
                mask = combined_attention["Valence"].str.contains(val, case=False, na=False)
                filter_table = combined_attention[mask]
                common_words = list(set(filter_table["Word"]) & baseline_words)
                
                if not common_words:
                    continue
                
                val_data = filter_table[filter_table["Word"].isin(common_words)]
                
                if len(val_data) == 0:
                    continue

                try:
                    baseline_values = np.array([baseline_dict[word] for word in val_data["Word"]])
                    effect = self._calculate_effect_size(
                        val_data["AttentionWeight"].values,
                        baseline_values
                    )
                    val_effects.append({
                        "valence_type": val,
                        "val_class": val_class,
                        "effect_size": effect
                    })
                except KeyError:
                    continue

        vals_types_effects = pd.DataFrame(val_effects)
        
        class_effects = (vals_types_effects
                        .groupby("val_class", observed=True)["effect_size"]
                        .agg(["mean", "std", "count"])
                        .reset_index())
        
        significance_tests = self._run_significance_tests(combined_attention)
        top_shifts = self._analyze_attention_shifts(combined_attention, baseline["AttentionWeight"])

        return AttentionAnalysisResults(
            val_type_effects=vals_types_effects,
            val_class_effects=class_effects,
            significance_tests=significance_tests,
            top_attention_shifts=top_shifts
        )

    
    @staticmethod
    def _calculate_effect_size(group1: np.ndarray, group2: np.ndarray) -> float:
        """Calculate Cohen's d effect size between two groups"""
        n1, n2 = len(group1), len(group2)
        if n1 < 2 or n2 < 2:
            return 0.0
        
        var1 = np.var(group1, ddof=1)
        var2 = np.var(group2, ddof=1)
        pooled_se = np.sqrt(((n1 - 1) * var1 + (n2 - 1) * var2) / (n1 + n2 - 2))
        
        if pooled_se == 0:
            return 0.0
        
        return abs((np.mean(group1) - np.mean(group2))) / pooled_se
    
    def _run_significance_tests(self, attention_data: pd.DataFrame) -> pd.DataFrame:
        """Run statistical tests on attention weight differences"""
        neutralize_mask = attention_data.Val_class == "neutralize"
        neutralize_weights = attention_data[neutralize_mask]["AttentionWeight"].values
        
        results = []
        classes = attention_data.Val_class.unique()
        
        for val_class in classes:
            if val_class == "neutralize":
                continue
            
            class_mask = attention_data.Val_class == val_class
            class_weights = attention_data[class_mask]["AttentionWeight"].values
            
            t_stat, p_val = stats.ttest_ind(class_weights, neutralize_weights)
            
            results.append({
                "valence_class": val_class,
                "t_statistic": t_stat,
                "p_value": p_val
            })
        
        results_df = pd.DataFrame(results)
        results_df["adjusted_p"] = multipletests(results_df["p_value"], method="fdr_bh")[1]
        
        for col in ["p_value", "adjusted_p"]:
            results_df[col] = results_df[col].round(4)
        
        return results_df

    def _analyze_attention_shifts(self, attention_data: pd.DataFrame, baseline: pd.Series, top_n: int = 50) -> pd.DataFrame:
        """Identify words with largest attention weight changes"""
        shifts = (attention_data
                 .groupby(['Word', 'Val_class'], observed=True)['AttentionWeight']
                 .mean()
                 .unstack()
                 .fillna(0))

        neutralize_mask = ~shifts['neutralize'].isna()
        valid_words = shifts[neutralize_mask].index
        
        result_dfs = []
        
        for val_class in shifts.columns:
            if val_class == 'neutralize':
                continue
            
            shift_data = pd.DataFrame({
                'Word': valid_words,
                'Neutralize_Weight': shifts.loc[valid_words, 'neutralize'],
                'Shifted_Weight': shifts.loc[valid_words, val_class],
                'Valence_Class': val_class
            })
            
            shift_data['Absolute_Shift'] = abs(shift_data['Shifted_Weight'] - shift_data['Neutralize_Weight'])
            shift_data['Percentage_Change'] = (
                (shift_data['Shifted_Weight'] - shift_data['Neutralize_Weight']) / 
                shift_data['Neutralize_Weight'] * 100
            ).round(2)
            
            top_shifts = (shift_data
                         .nlargest(top_n, 'Absolute_Shift')
                         .reset_index(drop=True))
            
            result_dfs.append(top_shifts)
        
        final_shifts = pd.concat(result_dfs, ignore_index=True)
        
        formatted_shifts = (final_shifts
            .assign(
                Neutral_Weight=lambda x: x['Neutralize_Weight'].round(4),
                Shifted_Weight=lambda x: x['Shifted_Weight'].round(4),
                Absolute_Shift=lambda x: x['Absolute_Shift'].round(4)
            )
            .sort_values(['Valence_Class', 'Absolute_Shift'], ascending=[True, False])
            .reset_index(drop=True))
        
        formatted_shifts['Rank'] = (formatted_shifts
            .groupby('Valence_Class')
            .cumcount() + 1)
        
        return formatted_shifts[[
            'Rank',
            'Valence_Class', 
            'Word',
            'Neutral_Weight',
            'Shifted_Weight', 
            'Absolute_Shift',
            'Percentage_Change'
        ]]
    

    def test_diagnosis_significance(self) -> DiagnosisTestResults:
        """
        Perform statistical tests to compare diagnosis predictions across different valences
        Returns both test results and list of significant diagnoses
        """
        diagnoses = self.get_available_diagnoses()
        test_results = []
        
        for diagnosis in diagnoses:
            # Get neutral predictions as baseline
            neutralize_data = None
            for valence, df in self.diagnosis_dfs.items():
                if 'neutralize' in valence.lower() and diagnosis in df.columns:
                    neutralize_data = pd.to_numeric(df[diagnosis], errors='coerce').dropna()
                    break
            
            if neutralize_data is None or len(neutralize_data) < 2:
                continue
                
            # Compare each valence against baseline
            for valence, df in self.diagnosis_dfs.items():
                if 'neutraize' in valence.lower() or diagnosis not in df.columns:
                    continue
                    
                test_data = pd.to_numeric(df[diagnosis], errors='coerce').dropna()
                if len(test_data) < 2:
                    continue
                    
                # Perform t-test
                t_stat, p_val = stats.ttest_ind(neutralize_data, test_data)
                effect_size = self._calculate_effect_size(neutralize_data.values, test_data.values)
                
                test_results.append({
                    'diagnosis': diagnosis,
                    'valence': valence,
                    't_statistic': t_stat,
                    'p_value': p_val,
                    'effect_size': effect_size
                })
        
        if not test_results:
            return DiagnosisTestResults(pd.DataFrame(), [])
            
        # Create results DataFrame and adjust p-values
        results_df = pd.DataFrame(test_results)
        results_df['adjusted_p'] = multipletests(results_df['p_value'], method='fdr_bh')[1]
        
        # Get list of significant diagnoses
        alpha = 0.05
        significant_diagnoses = (results_df[results_df['adjusted_p'] < alpha]
                               ['diagnosis'].unique().tolist())
        
        # Round numeric columns
        for col in ['t_statistic', 'p_value', 'adjusted_p', 'effect_size']:
            results_df[col] = results_df[col].round(4)
            
        return DiagnosisTestResults(results_df, significant_diagnoses)
    

    def get_distribution_statistics(self, diagnosis: str) -> pd.DataFrame:
        """Calculate statistical summary for prediction distributions"""
        stats_data = []

        for valence, df in self.diagnosis_dfs.items():
            if diagnosis not in df.columns:
                continue

            numeric_data = pd.to_numeric(df[diagnosis], errors='coerce')
            df_cleaned = numeric_data.dropna()
            
            if len(df_cleaned) > 0:
                stats_dict = {
                    'valence': valence,
                    'count': len(df_cleaned),
                    'mean': df_cleaned.mean(),
                    'std': df_cleaned.std(),
                    'min': df_cleaned.min(),
                    'max': df_cleaned.max(),
                    'median': df_cleaned.median()
                }
                stats_data.append(stats_dict)

        return pd.DataFrame(stats_data) if stats_data else pd.DataFrame()


    def plot_prediction_distribution(self, diagnosis: str, save_path: str):
        """
        Plots the distribution of predictions for a specific diagnosis across different valences.
        
        Args:
            diagnosis (str): The diagnosis code to plot distributions for
            save_path (str): Path to save the generated plot
        """
        # Create figure
        plt.figure(figsize=(12, 6))
        sns.set_style("whitegrid")
        
        # Get statistical test results for this diagnosis
        test_results = self.test_diagnosis_significance()
        sig_results = test_results.test_results[
            test_results.test_results['diagnosis'] == diagnosis
        ]
        
        # Plot distribution for each valence
        legend_elements = []
        for valence, df in self.diagnosis_dfs.items():
            if diagnosis not in df.columns:
                continue
                
            # Convert to numeric and drop NaN values
            values = pd.to_numeric(df[diagnosis], errors='coerce').dropna()
            if len(values) < 2:
                continue
                
            # Get p-value for this valence if available
            p_val = 1.0
            if not sig_results.empty:
                val_result = sig_results[sig_results['valence'] == valence]
                if not val_result.empty:
                    p_val = val_result['adjusted_p'].iloc[0]
            
            # Create label with p-value for significant results
            label = valence
            if p_val < 0.05:
                label = f'{valence} (p={p_val:.4f})'
            
            # Plot the distribution
            sns.kdeplot(
                data=values,
                label=label,
                common_norm=False,
                fill=True,
                alpha=0.3
            )
            
        # Customize plot
        plt.title(f'Prediction Distribution for {diagnosis}')
        plt.xlabel('Predicted Probability')
        plt.ylabel('Density')
        plt.legend(title='Valence Categories')
        
        # Add distribution statistics as text
        stats_df = self.get_distribution_statistics(diagnosis)
        if not stats_df.empty:
            stats_text = "Distribution Statistics:\n"
            for _, row in stats_df.iterrows():
                stats_text += f"\n{row['valence']}:\n"
                stats_text += f"mean: {row['mean']:.3f}\n"
                stats_text += f"std: {row['std']:.3f}\n"
            
            plt.figtext(
                1.02, 0.5,
                stats_text,
                fontsize=8,
                va='center'
            )
        
        # Save plot
        plt.tight_layout()
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
        plt.close()
        
        return sig_results

    
    def get_available_diagnoses(self) -> List[str]:
        """Get list of available diagnoses"""
        diagnoses = set()
        for df in self.diagnosis_dfs.values():
            exclude_cols = {'Valence', 'Val_class', 'Word', 'AttentionWeight'}
            diagnoses.update(set(df.columns) - exclude_cols)
        return list(diagnoses)
    
    def plot_attention_heatmap(self, top_n: int = 20, save_path: Optional[str] = None):
        """Plot heatmap of attention weight changes"""
        if 'neutralize' not in self.attention_dfs:
            return

        base_attention = self.attention_dfs['neutralize']
        top_words = base_attention.groupby('Word')['AttentionWeight'].mean().nlargest(top_n).index

        attention_matrix = []
        valences = []

        for valence, df in self.attention_dfs.items():
            word_attention = df[df['Word'].isin(top_words)].groupby('Word')['AttentionWeight'].mean()
            attention_matrix.append(word_attention)
            valences.append(valence)

        attention_df = pd.DataFrame(attention_matrix, index=valences, columns=top_words)

        plt.figure(figsize=(15, 8))
        sns.heatmap(attention_df, cmap='RdBu_r', center=0, annot=True, fmt='.2f')
        plt.title('Attention Weight Changes Across Valences')
        plt.xlabel('Words')
        plt.ylabel('Valence')

        if save_path:
            plt.savefig(save_path)
        else:
            plt.show()

        plt.close()


    def generate_all_visualizations(self, save_dir: str, diagnoses: Optional[List[str]] = None):
        """
        Generate visualizations for specified or significant diagnoses
        
        Args:
            save_dir (str): Directory to save visualizations
            diagnoses (Optional[List[str]]): List of diagnoses to visualize. If None, uses significant diagnoses.
        """
        save_path = Path(save_dir)
        save_path.mkdir(parents=True, exist_ok=True)

        # Get significant diagnoses
        test_results = self.test_diagnosis_significance()
        significant_diagnoses = test_results.significant_diagnoses

        if diagnoses is None:
            # Use all available diagnoses that show significant effects
            diagnoses = significant_diagnoses
        else:
            # Filter requested diagnoses to only include significant ones
            diagnoses = [d for d in diagnoses if d in significant_diagnoses]

        # Generate distribution plots
        for diagnosis in diagnoses:
            diagnosis_path = save_path / f"prediction_dist_{diagnosis}.png"
            try:
                sig_results = self.plot_prediction_distribution(
                    diagnosis,
                    str(diagnosis_path)
                )
                print(f"Generated distribution plot for {diagnosis}")
            except Exception as e:
                print(f"Error generating plot for {diagnosis}: {str(e)}")

        # Generate attention heatmap
        try:
            heatmap_path = save_path / "attention_heatmap.png"
            self.plot_attention_heatmap(save_path=str(heatmap_path))
            print("Generated attention heatmap")
        except Exception as e:
            print(f"Error generating attention heatmap: {str(e)}")

        # Generate summary report
        try:
            self._generate_summary_report(
                save_path / "analysis_report.md",
                test_results
            )
            print("Generated analysis report")
        except Exception as e:
            print(f"Error generating analysis report: {str(e)}")

            
    def _generate_summary_report(self, save_path: Path, diagnosis_tests: Optional[DiagnosisTestResults] = None):
        """Generate comprehensive analysis report including statistical test results"""
        analysis = self.analyze_attention_patterns()
        
        with open(save_path, "w") as f:
            f.write("# Analysis Report\n\n")
            
            if diagnosis_tests and not diagnosis_tests.test_results.empty:
                f.write("## Diagnosis Statistical Tests\n")
                f.write(diagnosis_tests.test_results.to_markdown() + "\n\n")
                
                f.write("### Significant Diagnoses\n")
                for diagnosis in diagnosis_tests.significant_diagnoses:
                    f.write(f"- {diagnosis}\n")
                f.write("\n")
            
            f.write("## Word Type Effects\n")
            if not analysis.val_type_effects.empty:
                f.write(analysis.val_type_effects.to_markdown() + "\n\n")
            else:
                f.write("No valid word type effects found.\n\n")
            
            f.write("## Word Class Effects\n")
            if not analysis.val_class_effects.empty:
                f.write(analysis.val_class_effects.to_markdown() + "\n\n")
            else:
                f.write("No valid word class effects found.\n\n")
            
            f.write("\n## Distribution Statistics by Significant Diagnosis\n")
            significant_diagnoses = (diagnosis_tests.significant_diagnoses 
                                   if diagnosis_tests else self.get_available_diagnoses())
            
            for diagnosis in significant_diagnoses:
                stats_df = self.get_distribution_statistics(diagnosis)
                if not stats_df.empty:
                    f.write(f"\n### {diagnosis}\n")
                    f.write(stats_df.to_markdown() + "\n")

                    
    def analyze_full_dataset(self, output_dir: str):
        """Perform comprehensive analysis of the entire dataset"""
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)
        
        self.generate_all_visualizations(str(output_path / "visualizations"))
        self._generate_summary_report(output_path / "full_analysis_report.md")


def main():
    results_dir = "result"
    output_dir = "analysis_result"

    # Initialize the ResultsAnalyzer
    analyzer = ResultsAnalyzer(results_dir)

    # Perform the analysis
    analyzer.analyze_full_dataset(output_dir)

    print(f"Analysis completed. Results saved to {output_dir}")

if __name__ == "__main__":
    main()

  

Generated distribution plot for 493
Generated distribution plot for 276
Generated distribution plot for 997
Generated distribution plot for V498
Generated distribution plot for 305
Generated distribution plot for 300
Generated distribution plot for 780
Generated distribution plot for 272


  top_words = base_attention.groupby('Word')['AttentionWeight'].mean().nlargest(top_n).index
  word_attention = df[df['Word'].isin(top_words)].groupby('Word')['AttentionWeight'].mean()


Generated attention heatmap
Generated analysis report
Analysis completed. Results saved to analysis_result


In [1]:
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, List
import numpy as np
import pandas as pd
from scipy import stats
from statsmodels.stats.multitest import multipletests
import seaborn as sns
import matplotlib.pyplot as plt

@dataclass
class AttentionAnalysisResults:
    """Stores results from attention weight analysis"""
    val_type_effects: pd.DataFrame
    val_class_effects: pd.DataFrame
    significance_tests: pd.DataFrame
    top_attention_shifts: pd.DataFrame

class ResultsAnalyzer:
    def __init__(self, results_dir: str):
        self.results_dir = Path(results_dir)
        self.diagnosis_dfs = {}
        self.attention_dfs = {}
        self.notes_dfs = {}
        
        self.VALS_TYPES = {
            "pejorative": frozenset(["non_compliant", "uncooperative", "resistant", "difficult"]),
            "laudatory": frozenset(["compliant", "cooperative", "pleasant", "respectful"]),
            "neutralval": frozenset(['neutral'])
        }
        self._load_results()
    
    def _load_results(self):
        """Load all results files from the directory"""
        self._load_csv_files("_diagnosis.csv", self.diagnosis_dfs)
        self._load_csv_files("_attention.csv", self.attention_dfs)
        self._load_csv_files("_clinical_notes.csv", self.notes_dfs)

    def _load_csv_files(self, file_pattern: str, dfs_dict: Dict[str, pd.DataFrame]):
        """Helper function to load CSV files matching the pattern"""
        for file in self.results_dir.glob(f"*{file_pattern}"):
            name = file.stem.split('_')[0]
            df = pd.read_csv(file)
            
            if 'AttentionWeight' in df.columns:
                df['AttentionWeight'] = pd.to_numeric(df['AttentionWeight'], errors='coerce')
            
            categorical_cols = ['Valence', 'Val_class', 'Word']
            for col in categorical_cols:
                if col in df.columns:
                    df[col] = df[col].astype('category')
            
            dfs_dict[name] = df

    def analyze_attention_patterns(self) -> AttentionAnalysisResults:
        """Analyze attention weight changes across valences and word types"""
        combined_attention = pd.concat(self.attention_dfs.values(), ignore_index=True)
        
        mask_neutralize = combined_attention["Val_class"] == "neutralize"
        baseline = (combined_attention[mask_neutralize]
                   .groupby("Word", observed=True)["AttentionWeight"]
                   .mean()
                   .reset_index())
        
        baseline_dict = dict(zip(baseline["Word"], baseline["AttentionWeight"]))
        baseline_words = set(baseline["Word"])
        
        val_effects = []
        
        for val_class, vals in self.VALS_TYPES.items():
            for val in vals:
                mask = combined_attention["Valence"].str.contains(val, case=False, na=False)
                filter_table = combined_attention[mask]
                common_words = list(set(filter_table["Word"]) & baseline_words)
                
                if not common_words:
                    continue
                
                val_data = filter_table[filter_table["Word"].isin(common_words)]
                
                if len(val_data) == 0:
                    continue

                try:
                    baseline_values = np.array([baseline_dict[word] for word in val_data["Word"]])
                    effect = self._calculate_effect_size(
                        val_data["AttentionWeight"].values,
                        baseline_values
                    )
                    val_effects.append({
                        "valence_type": val,
                        "val_class": val_class,
                        "effect_size": effect
                    })
                except KeyError:
                    continue

        vals_types_effects = pd.DataFrame(val_effects)
        
        class_effects = (vals_types_effects
                        .groupby("val_class", observed=True)["effect_size"]
                        .agg(["mean", "std", "count"])
                        .reset_index())
        
        significance_tests = self._run_significance_tests(combined_attention)
        top_shifts = self._analyze_attention_shifts(combined_attention, baseline["AttentionWeight"])

        return AttentionAnalysisResults(
            val_type_effects=vals_types_effects,
            val_class_effects=class_effects,
            significance_tests=significance_tests,
            top_attention_shifts=top_shifts
        )

    @staticmethod
    def _calculate_effect_size(group1: np.ndarray, group2: np.ndarray) -> float:
        """Calculate Cohen's d effect size between two groups"""
        n1, n2 = len(group1), len(group2)
        if n1 < 2 or n2 < 2:
            return 0.0
        
        var1 = np.var(group1, ddof=1)
        var2 = np.var(group2, ddof=1)
        pooled_se = np.sqrt(((n1 - 1) * var1 + (n2 - 1) * var2) / (n1 + n2 - 2))
        
        if pooled_se == 0:
            return 0.0
        
        return abs((np.mean(group1) - np.mean(group2))) / pooled_se

    def _run_significance_tests(self, attention_data: pd.DataFrame) -> pd.DataFrame:
        """Run statistical tests on attention weight differences"""
        neutral_mask = attention_data.Val_class == "neutralize"
        neutral_weights = attention_data[neutral_mask]["AttentionWeight"].values
        
        results = []
        classes = attention_data.Val_class.unique()
        
        for val_class in classes:
            if val_class == "neutralize":
                continue
            
            class_mask = attention_data.Val_class == val_class
            class_weights = attention_data[class_mask]["AttentionWeight"].values
            
            t_stat, p_val = stats.ttest_ind(class_weights, neutral_weights)
            
            results.append({
                "valence_class": val_class,
                "t_statistic": t_stat,
                "p_value": p_val
            })
        
        results_df = pd.DataFrame(results)
        results_df["adjusted_p"] = multipletests(results_df["p_value"], method="fdr_bh")[1]
        
        for col in ["p_value", "adjusted_p"]:
            results_df[col] = results_df[col].round(4)
        
        return results_df

    def _analyze_attention_shifts(self, attention_data: pd.DataFrame, baseline: pd.Series, top_n: int = 20) -> pd.DataFrame:
        """Identify words with largest attention weight changes"""
        shifts = (attention_data
                 .groupby(['Word', 'Val_class'], observed=True)['AttentionWeight']
                 .mean()
                 .unstack()
                 .fillna(0))

        neutral_mask = ~shifts['neutralize'].isna()
        valid_words = shifts[neutral_mask].index
        
        result_dfs = []
        
        for val_class in shifts.columns:
            if val_class == 'neutralize':
                continue
            
            shift_data = pd.DataFrame({
                'Word': valid_words,
                'Neutral_Weight': shifts.loc[valid_words, 'neutralize'],
                'Shifted_Weight': shifts.loc[valid_words, val_class],
                'Valence_Class': val_class
            })
            
            shift_data['Absolute_Shift'] = abs(shift_data['Shifted_Weight'] - shift_data['Neutral_Weight'])
            shift_data['Percentage_Change'] = (
                (shift_data['Shifted_Weight'] - shift_data['Neutral_Weight']) / 
                shift_data['Neutral_Weight'] * 100
            ).round(2)
            
            top_shifts = (shift_data
                         .nlargest(top_n, 'Absolute_Shift')
                         .reset_index(drop=True))
            
            result_dfs.append(top_shifts)
        
        final_shifts = pd.concat(result_dfs, ignore_index=True)
        
        formatted_shifts = (final_shifts
            .assign(
                Neutral_Weight=lambda x: x['Neutral_Weight'].round(4),
                Shifted_Weight=lambda x: x['Shifted_Weight'].round(4),
                Absolute_Shift=lambda x: x['Absolute_Shift'].round(4)
            )
            .sort_values(['Valence_Class', 'Absolute_Shift'], ascending=[True, False])
            .reset_index(drop=True))
        
        formatted_shifts['Rank'] = (formatted_shifts
            .groupby('Valence_Class')
            .cumcount() + 1)
        
        return formatted_shifts[[
            'Rank',
            'Valence_Class', 
            'Word',
            'Neutral_Weight',
            'Shifted_Weight', 
            'Absolute_Shift',
            'Percentage_Change'
        ]]

    def plot_prediction_distribution(self, diagnosis: str, save_path: Optional[str] = None):
        """Plot distribution of predictions for a specific diagnosis"""
        plt.figure(figsize=(12, 6))
        
        for valence, df in self.diagnosis_dfs.items():
            if diagnosis not in df.columns:
                continue
            
            numeric_values = pd.to_numeric(df[diagnosis], errors='coerce')
            if numeric_values.notna().any():
                sns.kdeplot(data=numeric_values.dropna(), label=valence)

        plt.title(f'Prediction Distribution for {diagnosis}')
        plt.xlabel('Prediction Probability')
        plt.ylabel('Density')
        plt.legend()

        if save_path:
            plt.savefig(save_path)
        else:
            plt.show()

        plt.close()

    def get_distribution_statistics(self, diagnosis: str) -> pd.DataFrame:
        """Calculate statistical summary for prediction distributions"""
        stats_data = []

        for valence, df in self.diagnosis_dfs.items():
            if diagnosis not in df.columns:
                continue

            numeric_data = pd.to_numeric(df[diagnosis], errors='coerce')
            df_cleaned = numeric_data.dropna()
            
            if len(df_cleaned) > 0:
                stats_dict = {
                    'valence': valence,
                    'count': len(df_cleaned),
                    'mean': df_cleaned.mean(),
                    'std': df_cleaned.std(),
                    'min': df_cleaned.min(),
                    'max': df_cleaned.max(),
                    'median': df_cleaned.median()
                }
                stats_data.append(stats_dict)

        return pd.DataFrame(stats_data) if stats_data else pd.DataFrame()

    def plot_attention_heatmap(self, top_n: int = 20, save_path: Optional[str] = None):
        """Plot heatmap of attention weight changes"""
        if 'neutralize' not in self.attention_dfs:
            return

        base_attention = self.attention_dfs['neutralize']
        top_words = base_attention.groupby('Word')['AttentionWeight'].mean().nlargest(top_n).index

        attention_matrix = []
        valences = []

        for valence, df in self.attention_dfs.items():
            word_attention = df[df['Word'].isin(top_words)].groupby('Word')['AttentionWeight'].mean()
            attention_matrix.append(word_attention)
            valences.append(valence)

        attention_df = pd.DataFrame(attention_matrix, index=valences, columns=top_words)

        plt.figure(figsize=(15, 8))
        sns.heatmap(attention_df, cmap='RdBu_r', center=0, annot=True, fmt='.2f')
        plt.title('Attention Weight Changes Across Valences')
        plt.xlabel('Words')
        plt.ylabel('Valence')

        if save_path:
            plt.savefig(save_path)
        else:
            plt.show()

        plt.close()

    def get_available_diagnoses(self) -> List[str]:
        """Get list of available diagnoses"""
        diagnoses = set()
        for df in self.diagnosis_dfs.values():
            exclude_cols = {'Valence', 'Val_class', 'Word', 'AttentionWeight'}
            diagnoses.update(set(df.columns) - exclude_cols)
        return list(diagnoses)

    def generate_all_visualizations(self, save_dir: str, diagnoses: Optional[List[str]] = None):
        """Generate all visualizations"""
        save_path = Path(save_dir)
        save_path.mkdir(parents=True, exist_ok=True)

        if diagnoses is None:
            diagnoses = self.get_available_diagnoses()

        for diagnosis in diagnoses:
            diagnosis_path = save_path / f"prediction_dist_{diagnosis}.png"
            self.plot_prediction_distribution(diagnosis, str(diagnosis_path))

        heatmap_path = save_path / "attention_heatmap.png"
        self.plot_attention_heatmap(save_path=str(heatmap_path))

        self._generate_summary_report(save_path / "analysis_report.md")

    def _generate_summary_report(self, save_path: Path):
        """Generate comprehensive analysis report"""
        analysis = self.analyze_attention_patterns()
        
        with open(save_path, "w") as f:
            f.write("# Attention Analysis Report\n\n")
            
            f.write("## Word Type Effects\n")
            if not analysis.val_type_effects.empty:
                f.write(analysis.val_type_effects.to_markdown() + "\n\n")
            else:
                f.write("No valid word type effects found.\n\n")
            
            f.write("## Word Class Effects\n")
            if not analysis.val_class_effects.empty:
                f.write(analysis.val_class_effects.to_markdown() + "\n\n")
            else:
                f.write("No valid word class effects found.\n\n")
            
            f.write("\n## Distribution Statistics by Diagnosis\n")
            for diagnosis in self.get_available_diagnoses():
                stats_df = self.get_distribution_statistics(diagnosis)
                if not stats_df.empty:
                    f.write(f"\n### {diagnosis}\n")
                    f.write(stats_df.to_markdown() + "\n")

    def analyze_full_dataset(self, output_dir: str):
        """Perform comprehensive analysis of the entire dataset"""
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)
        
        self.generate_all_visualizations(str(output_path / "visualizations"))
        self._generate_summary_report(output_path / "full_analysis_report.md")

def main():
    """Main function to demonstrate usage"""
    analyzer = ResultsAnalyzer("result")
    analyzer.analyze_full_dataset("OUTPUT")

if __name__ == "__main__":
    main()

  top_words = base_attention.groupby('Word')['AttentionWeight'].mean().nlargest(top_n).index
  word_attention = df[df['Word'].isin(top_words)].groupby('Word')['AttentionWeight'].mean()
