In [None]:
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import accuracy_score

PRED_CSV_PATH = r"C:\dev\vea\analysis\late_fusion_lstm_test_predictions.csv"
SEMANTIC_DIR = r"C:\dev\vea\data\features_semantic"
TEXT_JSON_DIR = r"C:\dev\vea\data\semantic\frame_description"

sns.set_theme(style="whitegrid")
plt.rcParams['figure.figsize'] = (12, 10)
plt.rcParams['font.size'] = 12

def load_analysis_data(csv_path, semantic_dir, text_json_dir):
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"Prediction CSV not found at: {csv_path}")

    print(f"Loading predictions from: {csv_path}")
    df = pd.read_csv(csv_path)
    print(f"Initial samples: {len(df)}")

    semantic_vectors = []
    raw_texts = []
    valid_indices = []

    print("Loading semantic features and raw texts...")
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        vid = row['video_id']
        npy_path = os.path.join(semantic_dir, f"{vid}.npy")
        json_path = os.path.join(text_json_dir, f"{vid}.json")

        if os.path.exists(npy_path):
            try:
                feat = np.load(npy_path)
                if feat.ndim > 1:
                    feat = np.mean(feat, axis=0)
                semantic_vectors.append(feat)

                description_str = ""
                if os.path.exists(json_path):
                    with open(json_path, 'r', encoding='utf-8') as f:
                        data = json.load(f)
                        descs = [item.get('description', '') for item in data]
                        description_str = " ".join(descs)

                raw_texts.append(description_str)
                valid_indices.append(idx)

            except Exception as e:
                print(f"Error processing {vid}: {e}")
                continue

    df_clean = df.iloc[valid_indices].reset_index(drop=True)
    df_clean['raw_text'] = raw_texts

    print(f"Successfully loaded {len(df_clean)} samples.")
    return df_clean

def analyze_keyword_performance(df, min_df=30, max_features=100):
    custom_stop_words = [
        'the', 'is', 'in', 'and', 'on', 'at', 'of', 'with', 'a', 'an', 'to', 'for',
        'image', 'frame', 'video', 'clip', 'shows', 'picture', 'shot', 'scene',
        'man', 'woman', 'person', 'people', 'male', 'female',
        'standing', 'sitting', 'looking', 'wearing'
    ]

    print("Extracting high-frequency keywords...")
    vectorizer = CountVectorizer(
        stop_words=custom_stop_words,
        min_df=min_df,
        max_features=max_features
    )

    X_counts = vectorizer.fit_transform(df['raw_text'])
    feature_names = vectorizer.get_feature_names_out()

    word_presence = (X_counts > 0).astype(int)
    word_df = pd.DataFrame(word_presence.toarray(), columns=feature_names)

    print(f"Identified {len(feature_names)} keywords for analysis.")

    global_acc = accuracy_score(df['valence_true'], df['valence_pred'])
    keyword_metrics = []

    for word in feature_names:
        indices = word_df[word_df[word] == 1].index
        sub_df = df.iloc[indices]

        if len(sub_df) == 0:
            continue

        acc = accuracy_score(sub_df['valence_true'], sub_df['valence_pred'])
        delta = acc - global_acc

        keyword_metrics.append({
            'Keyword': word,
            'Count': len(sub_df),
            'Accuracy': acc,
            'Delta': delta
        })

    return pd.DataFrame(keyword_metrics), global_acc

def plot_semantic_impact(metrics_df, global_acc, top_n=15):
    top_performers = metrics_df.sort_values('Accuracy', ascending=False).head(top_n)
    worst_performers = metrics_df.sort_values('Accuracy', ascending=True).head(top_n)

    plot_df = pd.concat([top_performers, worst_performers]).sort_values('Accuracy', ascending=False)

    plt.figure(figsize=(12, 10))

    colors = ['#2ecc71' if x >= 0 else '#e74c3c' for x in plot_df['Delta']]

    sns.barplot(data=plot_df, x='Accuracy', y='Keyword', palette=colors)

    plt.axvline(global_acc, color='black', linestyle='--', linewidth=1.5,
                label=f'Global Avg ({global_acc:.1%})')

    plt.title('Semantic Concepts Impact on Valence Prediction Accuracy', fontsize=16)
    plt.xlabel('Subset Prediction Accuracy', fontsize=14)
    plt.ylabel('Semantic Keyword', fontsize=14)
    plt.legend()

    for index, (value, count) in enumerate(zip(plot_df['Accuracy'], plot_df['Count'])):
        plt.text(value + 0.005, index, f"{value:.1%} (N={count})", va='center', fontsize=10)

    plt.xlim(0, 1.05)
    plt.tight_layout()
    plt.show()

    return top_performers, worst_performers

if __name__ == "__main__":
    try:
        df_clean = load_analysis_data(PRED_CSV_PATH, SEMANTIC_DIR, TEXT_JSON_DIR)

        metrics_df, global_accuracy = analyze_keyword_performance(df_clean)

        top_k, bottom_k = plot_semantic_impact(metrics_df, global_accuracy)

        print("\n=== Top 10 High-Confidence Semantic Concepts ===")
        print(top_k[['Keyword', 'Accuracy', 'Count', 'Delta']].reset_index(drop=True).head(10))

        print("\n=== Bottom 10 Confusing Semantic Concepts ===")
        print(bottom_k[['Keyword', 'Accuracy', 'Count', 'Delta']].reset_index(drop=True).head(10))

    except Exception as e:
        print(f"An error occurred during execution: {e}")