In [None]:
import pandas as pd
import numpy as np
import re, os, ast, glob
from collections import defaultdict, Counter
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from multipec.simulation_utils import set_plotting_style

def find_repo_root(marker="setup.py"):
    path = Path.cwd()
    while not (path / marker).exists() and path != path.parent:
        path = path.parent
    return path

project_root = find_repo_root()

In [None]:
# Original labels including A1, A2 (which were at index 31 and 32)
channels_labels = [
    "Fp1", "Fp2", "F7", "F3", "Fz", "F4", "F8", "FC5", "FC1", "FC2", "FC6", "T7",
    "C3", "C4", "T8", "TP9", "CP5", "CP1", "CP2", "CP6", "TP10", "P7", "P3", "Pz",
    "P4", "P8", "O1", "Oz", "O2", "Iz", "A1", "A2", "AF7", "AF3", "AFz", "AF4",
    "AF8", "F5", "F1", "F2", "F6", "FT7", "FC3", "FCz", "FC4", "FT8", "C5", "C1",
    "C2", "C6", "TP7", "CP3", "CPz", "CP4", "TP8", "P5", "P1", "P2", "P6", "PO7",
    "PO3", "POz", "PO4", "PO8"
]

eeg_info = {
    "Fp1": ("Frontal pole", "attention, emotion, executive function"),
    "Fp2": ("Frontal pole", "attention, emotion, executive function"),
    "F7": ("Frontal", "language processing, auditory processing"),
    "F3": ("Frontal", "working memory, decision making"),
    "Fz": ("Frontal midline", "executive control, motor planning"),
    "F4": ("Frontal", "working memory, decision making"),
    "F8": ("Frontal", "language, auditory response"),
    "FC5": ("Fronto-central", "speech perception, auditory-motor integration"),
    "FC1": ("Fronto-central", "motor planning, decision making"),
    "FC2": ("Fronto-central", "motor planning, decision making"),
    "FC6": ("Fronto-central", "auditory-motor integration"),
    "T7": ("Temporal", "primary auditory processing"),
    "C3": ("Central", "motor cortex - movement of right side"),
    "C4": ("Central", "motor cortex - movement of left side"),
    "T8": ("Temporal", "auditory association"),
    "TP9": ("Temporal-parietal", "sound localization, multisensory integration"),
    "CP5": ("Centro-parietal", "sensorimotor integration"),
    "CP1": ("Centro-parietal", "tactile processing, spatial attention"),
    "CP2": ("Centro-parietal", "tactile processing, spatial attention"),
    "CP6": ("Centro-parietal", "sensorimotor integration"),
    "TP10": ("Temporal-parietal", "sound localization, multisensory integration"),
    "P7": ("Parietal", "visual attention, spatial processing"),
    "P3": ("Parietal", "spatial awareness, somatosensory integration"),
    "Pz": ("Parietal midline", "visuospatial attention, awareness"),
    "P4": ("Parietal", "spatial awareness, somatosensory integration"),
    "P8": ("Parietal", "visual attention, spatial processing"),
    "O1": ("Occipital", "primary visual cortex (left visual field)"),
    "Oz": ("Occipital midline", "central visual processing"),
    "O2": ("Occipital", "primary visual cortex (right visual field)"),
    "Iz": ("Occipital", "visual association area"),
    "A1": ("Reference electrode", "reference"),
    "A2": ("Reference electrode", "reference"),
    "AF7": ("Anterior frontal", "language and auditory attention"),
    "AF3": ("Anterior frontal", "emotional control, working memory"),
    "AFz": ("Anterior frontal midline", "conflict monitoring, executive control"),
    "AF4": ("Anterior frontal", "emotional control, working memory"),
    "AF8": ("Anterior frontal", "language and auditory attention"),
    "F5": ("Frontal", "motor planning, cognitive control"),
    "F1": ("Frontal", "executive function, motor planning"),
    "F2": ("Frontal", "executive function, motor planning"),
    "F6": ("Frontal", "motor planning, cognitive control"),
    "FT7": ("Fronto-temporal", "auditory processing, speech perception"),
    "FC3": ("Fronto-central", "motor preparation, sensorimotor integration"),
    "FCz": ("Fronto-central midline", "motor control, attention"),
    "FC4": ("Fronto-central", "motor preparation, sensorimotor integration"),
    "FT8": ("Fronto-temporal", "auditory processing, speech perception"),
    "C5": ("Central", "motor cortex - movement of right limbs"),
    "C1": ("Central", "motor cortex - fine motor control (right)"),
    "C2": ("Central", "motor cortex - fine motor control (left)"),
    "C6": ("Central", "motor cortex - movement of left limbs"),
    "TP7": ("Temporal-parietal", "speech and language integration"),
    "CP3": ("Centro-parietal", "sensorimotor function, spatial processing"),
    "CPz": ("Centro-parietal midline", "sensorimotor integration"),
    "CP4": ("Centro-parietal", "sensorimotor function, spatial processing"),
    "TP8": ("Temporal-parietal", "speech and language integration"),
    "P5": ("Parietal", "visual attention, object recognition"),
    "P1": ("Parietal", "visual spatial processing"),
    "P2": ("Parietal", "visual spatial processing"),
    "P6": ("Parietal", "visual attention, object recognition"),
    "PO7": ("Parieto-occipital", "high-level visual processing"),
    "PO3": ("Parieto-occipital", "visual integration, object recognition"),
    "POz": ("Parieto-occipital midline", "visual attention and processing"),
    "PO4": ("Parieto-occipital", "visual integration, object recognition"),
    "PO8": ("Parieto-occipital", "high-level visual processing")
}

function_map = {
    "visual": ["visual", "object recognition", "spatial", "association", "high-level"],
    "auditory": ["auditory", "sound"],
    "language": ["language", "speech"],
    "motor": ["motor", "movement", "fine motor"],
    "executive": ["executive", "decision", "working memory", "control", "conflict"],
    "attention": ["attention", "focus"],
    "sensorimotor": ["sensorimotor", "tactile", "somatosensory"],
    "emotion": ["emotion", "emotional"],
}

def get_original_index(modified_index):
    return modified_index + 2 if modified_index >= 31 else modified_index

def describe_eeg_channels(indices):
    results = []
    for idx in indices:
        original_idx = get_original_index(idx)
        label = channels_labels[original_idx]
        region, modality = eeg_info.get(label, ("Unknown region", "unknown"))
        results.append({
            "index": idx,
            "original_label": label,
            "brain_region": region,
            "processing_type": modality
        })
    return results

def map_processing_type(processing_str):
    processing_str = processing_str.lower()
    tags = set()
    for group, keywords in function_map.items():
        for kw in keywords:
            if kw in processing_str:
                tags.add(group)
                break
    return list(tags)

# Parse a single Excel file
def parse_subject_file(filepath):
    df = pd.read_excel(filepath)
    df['Net Nodes'] = df['Net Nodes'].apply(ast.literal_eval)
    df['Processing Tags'] = df['Processing Type'].apply(map_processing_type)
    return df

# Process all Excel files in folder
def process_all_subjects(data_folder='data'):
    files = glob.glob(f"{data_folder}/nets_*_S*_*.xlsx")
    all_networks = []

    for file in files:
        match = re.search(r'nets_(3down|23|12)_(S\d+)_(\d+)\.xlsx', file)
        if not match:
            continue

        sigma, task, subject_id = match.groups()
        subject_id = f"sub{subject_id}"
        df = parse_subject_file(file)

        # Temp dict to collect unique networks
        net_dict = {}

        for _, row in df.iterrows():
            key = (task, subject_id, row['Net Nodes'])
            if key not in net_dict:
                net_dict[key] = {
                    'task': task,
                    'subject': subject_id,
                    'net_nodes': row['Net Nodes'],
                    'pec': row['Last PEC Value'],
                    'tags': set(row['Processing Tags']),
                }
            else:
                net_dict[key]['tags'].update(row['Processing Tags'])

        # Add unique entries to the result
        for net in net_dict.values():
            net['tags'] = sorted(net['tags'])  # optional
            all_networks.append(net)

    return all_networks

# Summarize per subject and group
def summarize_group(networks, top_n=3):
    per_subject_summary = defaultdict(lambda: defaultdict(list))
    group_tag_summary = defaultdict(list)

    for net in networks:
        subj = net['subject']
        task = net['task']
        key = f"{subj}_{task}"
        per_subject_summary[key]['networks'].append(net)

    per_subject_lowpec = {}
    for key, data in per_subject_summary.items():
        top_networks = sorted(data['networks'], key=lambda x: x['pec'])[:top_n]
        tag_counter = defaultdict(int)

        for net in top_networks:
            for tag in net['tags']:
                tag_counter[tag] += 1
                group_tag_summary[tag].append(net['pec'])

        per_subject_lowpec[key] = tag_counter

    return per_subject_lowpec, group_tag_summary

Check channel (node) functions:

In [None]:
# Generate node_functions from eeg_info
node_functions = {}

for idx, label in enumerate(channels_labels):
    region, functions = eeg_info.get(label, ("Unknown", ""))
    # Split the function string into a list of stripped keywords
    function_keywords = [f.strip() for f in functions.split(",") if f.strip()]
    node_functions[idx] = function_keywords

print("Node functions:", node_functions)

Define path to the network files.

In [None]:
output_folder = project_root/"data/output/eeg/"
figures_folder = project_root/"data/figures/eeg/"
os.makedirs(figures_folder, exist_ok=True)

all_networks = process_all_subjects(data_folder=output_folder)
print(f"Found {len(all_networks)} total networks across all files.")

Check which subjects are missing nets per task.

In [None]:
def find_task_gaps_by_subject(directory, all_subjects=[str(i+1) for i in range(24)], tasks=['S1', 'S2', 'S3', 'S4', 'S5', 'S6', 'S7']):
    pattern = re.compile(r'nets_(3down|23|12)_(S\d+)_(\d+)\.p$')
    subject_tasks = defaultdict(set)
    all_subjects = set(all_subjects)

    for filename in os.listdir(directory):
        match = re.search(r'nets_(3down|23|12)_(S\d+)_(\d+)\.p', filename)
        if match:
            sigma, task, subject_id = match.groups()
            subject_tasks[subject_id].add(task)

    # Initialize a dict to collect missing subjects for each task
    task_missing_subjects = defaultdict(list)

    for subject_id in all_subjects:
        for task in tasks:
            if task not in subject_tasks[subject_id]:
                task_missing_subjects[task].append(subject_id)

    return task_missing_subjects

missing_by_task = find_task_gaps_by_subject(output_folder)

if missing_by_task:
    for task, subjects in sorted(missing_by_task.items()):
        if subjects:
          print(f"Task {task} is missing in subjects: {', '.join(map(str, sorted(subjects)))}")

else:
    print("All tasks are available for all subjects.")


Summarize the networks per task (per subject summary, and then group summary).

In [None]:
def summarize_by_stimulus(all_networks, summarize_group_func, top_n=3):
    """
    Summarize networks grouped by stimulus/task (S1, S2, ..., S7).

    Parameters:
    - all_networks: list of network dictionaries
    - summarize_group_func: function to summarize a group (expects signature like summarize_group(networks, top_n))
    - top_n: number of top functions/features to consider when summarizing

    Prints:
    - Per stimulus: per-subject summary and group-level summary
    """
    # Group networks by stimulus type (e.g., S1, S2, ..., S7)
    stimulus_groups = defaultdict(list)
    for net in all_networks:
        stimulus = net['task']  # assuming 'task' field like 'S1', 'S2', ...
        stimulus_groups[stimulus].append(net)

    # Summarize each stimulus group
    for stim_type in sorted(stimulus_groups.keys()):
        print(f"\nStimulus: {stim_type}")
        subject_summary, group_summary = summarize_group_func(stimulus_groups[stim_type], top_n=top_n)

        print("Per Subject Summary:")
        for key, tag_counts in subject_summary.items():
            print(f"    {key}: {dict(tag_counts)}")

        print("Group Summary (Avg PEC per function):")
        for tag, pecs in group_summary.items():
            mean_pec = sum(pecs)/len(pecs) if pecs else 0
            print(f"    {tag}: Mean PEC = {mean_pec:.4f}, Count = {len(pecs)}")

summarize_by_stimulus(all_networks, summarize_group, top_n=3)

In [None]:
set_plotting_style()

def get_best_nets_by_subject(networks):
    subject_stimulus_map = defaultdict(list)

    # Group networks by (subject, task)
    for net in networks:
        key = (net['subject'], net['task'])
        subject_stimulus_map[key].append(net)

    best_nets = []

    for (subject, stimulus), nets in subject_stimulus_map.items():
        min_pec = min(net['pec'] for net in nets)
        lowest_nets = [net for net in nets if net['pec'] == min_pec]

        best_nets.extend(lowest_nets)

    return best_nets

# Get only best PEC networks per subject per stimulus
best_networks = get_best_nets_by_subject(all_networks)

# Group tags per stimulus type
stimulus_tag_summary = defaultdict(list)
for net in best_networks:
    for tag in net['tags']:
        stimulus_tag_summary[net['task']].append(tag)

# Count and display tags per stimulus
print("\nBest Networks - Functional Tag Summary by Stimulus")
for stim_type in sorted(stimulus_tag_summary.keys()):
    tag_counts = defaultdict(int)
    for tag in stimulus_tag_summary[stim_type]:
        tag_counts[tag] += 1
    print(f"\nStimulus {stim_type}:")
    for tag, count in sorted(tag_counts.items(), key=lambda x: -x[1]):
        print(f"  {tag}: {count}")

# Plot bar plots for each stimulus
def plot_tag_distributions(tag_summary):
    num_stimuli = len(tag_summary)
    fig, axs = plt.subplots(num_stimuli, 1, figsize=(10, 4 * num_stimuli))

    if num_stimuli == 1:
        axs = [axs] # Ensure axs is iterable

    for ax, (stimulus, tags) in zip(axs, sorted(tag_summary.items())):
        tag_counts = Counter(tags)
        sorted_tags = sorted(tag_counts.items(), key=lambda x: -x[1])
        labels, counts = zip(*sorted_tags)

        ax.bar(labels, counts, color='skyblue', edgecolor='black')
        ax.set_title(f"Stimulus {stimulus} - Functional Tag Distribution")
        ax.set_ylabel("Count")
        ax.set_xticklabels(labels, rotation=45, ha='right')
        ax.grid(axis='y', linestyle='--', alpha=0.6)

    plt.tight_layout()
    plt.show()

# Call the plotting function
plot_tag_distributions(stimulus_tag_summary)


See PEC variation across tasks, per subject.

In [None]:
# Dictionary: { subject: {stimulus: [best_net(s)] } }
best_nets_by_subject = defaultdict(lambda: defaultdict(list))

for net in all_networks:
    subj = net["subject"]
    stim = net["task"]
    pec = net["pec"]

    current_best = best_nets_by_subject[subj][stim]

    if not current_best:
        best_nets_by_subject[subj][stim] = [net]
    else:
        best_pec = current_best[0]['pec']
        if pec < best_pec:
            best_nets_by_subject[subj][stim] = [net]
        elif pec == best_pec:
            best_nets_by_subject[subj][stim].append(net)

# Define stimulus groups
group_a = ['S1', 'S2', 'S3', 'S7']
group_b = ['S4', 'S5', 'S6', 'S7']

def extract_min_pecs(best_nets_by_subject, group):
    subj_pecs = defaultdict(list)

    for subj, stim_dict in best_nets_by_subject.items():
        for stim in group:
            if stim in stim_dict:
                min_pecs = [n['pec'] for n in stim_dict[stim]]
                mean_min = np.mean(min_pecs)  # Handle ties
                subj_pecs[subj].append((stim, mean_min))
        subj_pecs[subj] = sorted(subj_pecs[subj], key=lambda x: group.index(x[0]))

    return subj_pecs

def plot_group_pec_trends(ax, subj_pecs, group, title):
    x_ticks = list(range(len(group)))

    for subj, data in subj_pecs.items():
        if len(data) < 2:
            continue
        x = [group.index(stim) for stim, _ in data]
        y = [pec for _, pec in data]
        ax.plot(x, y, marker='o', alpha=0.4, label=subj, linewidth=1)

    ax.set_xticks(x_ticks)
    ax.set_xticklabels(group)
    ax.set_xlabel("Stimulus")
    ax.set_ylabel("Min PEC")
    ax.set_title(title)
    ax.grid(True)

subj_pecs_a = extract_min_pecs(best_nets_by_subject, group_a)
subj_pecs_b = extract_min_pecs(best_nets_by_subject, group_b)

fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)

plot_group_pec_trends(axes[0], subj_pecs_a, group_a, "Group A (S1, S2, S3, S7)")
plot_group_pec_trends(axes[1], subj_pecs_b, group_b, "Group B (S4, S5, S6, S7)")

plt.tight_layout()
plt.show()


Count how many subjects have all data available

In [None]:
group_a_count = sum(len(data) == 4 for data in subj_pecs_a.values())
group_b_count = sum(len(data) == 4 for data in subj_pecs_b.values())

print(f"Subjects in Group A (S1, S2, S3, S7): {group_a_count}")
print(f"Subjects in Group B (S4, S5, S6, S7): {group_b_count}")


Calculate PEC changes between S1 and S3; and S4 and S6.

In [None]:
# Calculate PEC changes between task pairs
changes_data = []

for subj, pec_values in best_nets_by_subject.items():
    # Get the PEC values for each stimulus
    pec_S1 = pec_values.get('S1')[0]['pec'] if pec_values.get('S1') else None
    pec_S3 = pec_values.get('S3')[0]['pec'] if pec_values.get('S3') else None
    pec_S4 = pec_values.get('S4')[0]['pec'] if pec_values.get('S4') else None
    pec_S6 = pec_values.get('S6')[0]['pec'] if pec_values.get('S6') else None

    if pec_S1 is not None and pec_S3 is not None:
        # Calculate the PEC change for S2 vs S1
        change_S3_S1 = pec_S3 - pec_S1
        changes_data.append({'subject': subj, 'stimulus_group': 'S3-S1', 'change_value': change_S3_S1})

    if pec_S4 is not None and pec_S6 is not None:
        # Calculate the PEC change for S4 vs S5
        change_S6_S4 = pec_S6 - pec_S4
        changes_data.append({'subject': subj, 'stimulus_group': 'S6-S4', 'change_value': change_S6_S4})

changes_df = pd.DataFrame(changes_data)

# Calculate the 'direction' based on PEC change (positive or negative) between tasks
changes_df['direction'] = changes_df['change_value'].apply(lambda x: 'positive' if x > 0 else 'negative')

# Plot the count of positive and negative changes for each comparison
plt.figure(figsize=(4, 3))
sns.countplot(x='direction', hue='stimulus_group', data=changes_df)
plt.title('Change in PEC across Tasks')
plt.xlabel('Change Direction')
plt.ylabel('Number of Subjects')
plt.show()


In [None]:
from sklearn.model_selection import cross_val_score, cross_val_predict
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import LabelEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.neighbors import KNeighborsClassifier

def test_classifiers_on_tasks(best_networks, selected_tasks, all_nodes=[i for i in range(62)]):
    # 1. Filter networks to selected tasks only
    filtered_nets = [net for net in best_networks if net['task'] in selected_tasks]

    # 3. Create feature matrix
    X = []
    y = []

    for net in filtered_nets:
        node_set = set(net['net_nodes'])
        feature_vector = [1 if node in node_set else 0 for node in all_nodes]
        X.append(feature_vector)
        y.append(net['task'])

    # 4. Encode class labels
    label_encoder = LabelEncoder()
    y_encoded = label_encoder.fit_transform(y)

    # 5. Define classifiers
    classifiers = {
        "Random Forest": RandomForestClassifier(n_estimators=200, random_state=42),
        "Logistic Regression": LogisticRegression(max_iter=1000, random_state=42),
        "SVM (Linear Kernel)": SVC(kernel='linear', probability=True, random_state=42),
        "Gradient Boosting": GradientBoostingClassifier(n_estimators=100, random_state=42),
        "KNN": KNeighborsClassifier(n_neighbors=5)
    }

    # 6. Run and evaluate each classifier
    print(f"\nClassification on tasks: {selected_tasks}\n")
    for name, clf in classifiers.items():
        scores = cross_val_score(clf, X, y_encoded, cv=5)
        print(f"{name}: Accuracy = {np.mean(scores)*100:.2f}%")

        if len(set(y_encoded)) == 2:  # Binary classification
            probs = cross_val_predict(clf, X, y_encoded, cv=5, method='predict_proba')[:, 1]
            auc = roc_auc_score(y_encoded, probs)
            print(f"         AUC = {auc:.2f}")

set_plotting_style()

def classify_tasks(best_networks, selected_tasks=None, classifier_name="Random Forest", channels_labels=None, all_nodes=[i for i in range(62)], top_n=20):

    best_networks = [net for net in best_networks if net['task'] in selected_tasks]

    # Build feature matrix
    X = []
    y = []

    for net in best_networks:
        node_set = set(net['net_nodes'])
        feature_vector = [1 if node in node_set else 0 for node in all_nodes]
        X.append(feature_vector)
        y.append(net['task'])

    X = pd.DataFrame(X, columns=all_nodes)
    label_encoder = LabelEncoder()
    y_encoded = label_encoder.fit_transform(y)

    # Classifier options
    classifiers = {
        "Random Forest": RandomForestClassifier(n_estimators=200, random_state=42),
        "Logistic Regression": LogisticRegression(max_iter=1000, random_state=42),
        "SVM": SVC(kernel='linear', probability=True, random_state=42),
        "Gradient Boosting": GradientBoostingClassifier(n_estimators=100, random_state=42),
        "KNN": KNeighborsClassifier(n_neighbors=5)
    }

    clf = classifiers.get(classifier_name, RandomForestClassifier(n_estimators=200, random_state=42))

    # Cross-validation
    scores = cross_val_score(clf, X, y_encoded, cv=5)
    print(f"Using {classifier_name}")
    print("Tasks:", sorted(set(y)))
    print("Mean Classification Accuracy: {:.2f}%".format(np.mean(scores) * 100))

    # Train on full data
    clf.fit(X, y_encoded)

    # Feature importance or coefficients
    if classifier_name in ["Random Forest", "Gradient Boosting"]:
        importances = clf.feature_importances_
    elif classifier_name in ["Logistic Regression", "SVM"]:
        importances = np.abs(clf.coef_).mean(axis=0)
    else:
        importances = None  # KNN has no coefficients/feature importances

    if importances is not None:
        sorted_indices = np.argsort(importances)[::-1]
        top_indices = sorted_indices[:top_n]
        top_nodes = [all_nodes[i] for i in top_indices]
        top_labels = [channels_labels[i] if channels_labels else str(i) for i in top_nodes]
        top_importances = [importances[i] for i in top_indices]

        # Plot
        plt.figure(figsize=(10, 6))
        plt.bar(top_labels, top_importances, color='mediumseagreen')
        plt.title(f"Top EEG Nodes by Importance ({classifier_name})")
        plt.ylabel("Importance / Coefficient Magnitude")
        plt.xticks(rotation=45, ha='right')
        plt.grid(axis='y', linestyle='--', alpha=0.6)
        plt.tight_layout()
        plt.show()

    return clf, label_encoder, all_nodes


In [None]:
test_classifiers_on_tasks(best_networks, selected_tasks=["S1", "S2"])

clf, enc, nodes = classify_tasks(
    best_networks,
    selected_tasks=["S1", "S2"],
    classifier_name="Logistic Regression",
    channels_labels=channels_labels,
    top_n=15
)


In [None]:
# Define task groups
group_a = ['S1']
group_b = ['S4']

# Helper to collect node counts
def collect_node_counts(best_nets_by_subject, group):
    node_counter = Counter()
    for subj, stim_dict in best_nets_by_subject.items():
        for stim in group:
            if stim in stim_dict:
                for net in stim_dict[stim]:
                    node_counter.update(net['net_nodes'])
    return node_counter

# Count nodes in each group
node_counts_a = collect_node_counts(best_nets_by_subject, group_a)
node_counts_b = collect_node_counts(best_nets_by_subject, group_b)

# All unique nodes from both groups
all_nodes = set(node_counts_a) | set(node_counts_b)

# Sort by count
sorted_nodes = sorted(all_nodes, key=lambda n: node_counts_a[n] + node_counts_b[n], reverse=True)[:20]

# Plot
plt.figure(figsize=(14, 6))
x = range(len(sorted_nodes))
a_vals = [node_counts_a[n] for n in sorted_nodes]
b_vals = [node_counts_b[n] for n in sorted_nodes]

bar_width = 0.35
plt.bar(x, a_vals, width=bar_width, label=group_a[0], color='steelblue')
plt.bar([i + bar_width for i in x], b_vals, width=bar_width, label=group_b[0], color='indianred')

plt.xticks(
    [i + bar_width/2 for i in x],
    [channels_labels[n] for n in sorted_nodes],
    rotation=45, ha='right'
)
plt.ylabel("Node Frequency")
plt.title("EEG Node Frequency in Best Networks")
plt.legend()
plt.tight_layout()
plt.grid(axis='y', linestyle='--', alpha=0.6)
plt.show()


In [None]:
# Binary node presence (0/1 for each of the 62 channels)
def classify_by_node_presence(networks, node_labels, classifier, bar=False):

    all_nodes = sorted({node for net in networks for node in net['net_nodes']})
    X = []
    y = []

    for net in networks:
        node_set = set(net['net_nodes'])
        X.append([1 if node in node_set else 0 for node in all_nodes])
        y.append(net['task'])

    df_X = pd.DataFrame(X, columns=[node_labels[i] for i in all_nodes])
    le = LabelEncoder()
    y_enc = le.fit_transform(y)

    classifier.fit(df_X, y_enc)
    scores = cross_val_score(classifier, df_X, y_enc, cv=5)
    print(f"Node Presence Accuracy: {scores.mean() * 100:.2f}%")

    # Plot feature importances
    importances = classifier.feature_importances_
    feature_names = df_X.columns

    if bar==True:
      top_n = 20
      indices = np.argsort(importances)[::-1][:top_n]
      plt.figure(figsize=(10, 6))
      plt.bar([feature_names[i] for i in indices], [importances[i] for i in indices], color='teal')
      plt.xticks(rotation=45, ha='right')
      plt.title(f"Top {top_n} Feature Importances")
      plt.ylabel("Importance")
      plt.tight_layout()
      plt.show()

    plt.figure(figsize=(12, 1.5))
    sns.heatmap([importances], cmap='magma', xticklabels=feature_names)
    plt.title("Feature Importances (Heatmap)")
    plt.yticks([])
    plt.xticks(rotation=90)
    plt.show()


# Cognitive function counts (frequency of function keywords in the net nodes)
def classify_by_function_keywords(networks, node_functions, classifier):

    all_functions = sorted({func for funcs in node_functions.values() for func in funcs})
    X = []
    y = []

    for net in networks:
        func_counter = Counter()
        for node in net['net_nodes']:
            functions = node_functions.get(node, [])
            func_counter.update(functions)
        X.append([func_counter.get(f, 0) for f in all_functions])
        y.append(net['task'])

    df_X = pd.DataFrame(X, columns=all_functions)
    le = LabelEncoder()
    y_enc = le.fit_transform(y)

    classifier.fit(df_X, y_enc)
    scores = cross_val_score(classifier, df_X, y_enc, cv=5)
    print(f"Function Keyword Accuracy: {scores.mean() * 100:.2f}%")

    importances = classifier.feature_importances_
    feature_names = df_X.columns

    plt.figure(figsize=(12, 1.5))
    sns.heatmap([importances], cmap='magma', xticklabels=feature_names)
    plt.title("Feature Importances (Heatmap)")
    plt.yticks([])
    plt.xticks(rotation=90)
    plt.show()


In [None]:
def load_subject(id=1, data_folder=output_folder):
    """
    Load all networks for a single subject, ensuring data exists for all tasks S1–S7.

    Parameters:
    - id (int): Subject ID to load.
    - data_folder (str): Path to folder containing the .xlsx network files.

    Returns:
    - all_networks (list of dicts): Each dict has keys: subject, task, net_nodes, pec, tags.

    Raises:
    - ValueError if any of the tasks (S1–S7) are missing.
    """
    expected_tasks = {f"S{i}" for i in range(1, 7)}
    all_networks = []
    found_tasks = set()

    pattern = re.compile(r'nets_(3down|23|12)_(S\d+)_(\d+)\.xlsx$')
    files = glob.glob(os.path.join(data_folder, f"nets_*_S*_{id}.xlsx"))

    for file in files:
        match = pattern.search(file)
        if not match:
            continue

        sigma, task, subject_id = match.groups()
        found_tasks.add(task)

        subject_id = f"subj{subject_id}"
        df = parse_subject_file(file)

        # Temp dict to collect unique networks
        net_dict = {}

        for _, row in df.iterrows():
            key = (task, subject_id, row['Net Nodes'])
            if key not in net_dict:
                net_dict[key] = {
                    'task': task,
                    'subject': subject_id,
                    'net_nodes': row['Net Nodes'],
                    'pec': row['Last PEC Value'],
                    'tags': set(row['Processing Tags']),
                }
            else:
                net_dict[key]['tags'].update(row['Processing Tags'])

        # Add unique entries to the result
        for net in net_dict.values():
            net['tags'] = sorted(net['tags'])  # optional
            all_networks.append(net)

    # Check if all required tasks are present
    missing_tasks = expected_tasks - found_tasks
    if missing_tasks:
        raise ValueError(f"Subject {id} is missing data for task(s): {', '.join(sorted(missing_tasks))}")

    return all_networks



In [None]:
# Load subject network data
subject_data = None

try:
    subject_data = load_subject(id=4)
except ValueError as e:
    print(e)

classification_data = subject_data

all_functions = sorted({func for funcs in node_functions.values() for func in funcs})

In [None]:
classifier = RandomForestClassifier(n_estimators=200, random_state=42)

classify_by_node_presence(classification_data, channels_labels, classifier)
classify_by_function_keywords(classification_data, node_functions, classifier)


In [None]:
function_weights = {
    # General / always-on (low base score)
    'attention': 0.2,
    'executive function': 0.2,
    'working memory': 0.2,

    'emotional control': 0.2,
    'spatial awareness': 0.4,
    'sensorimotor integration': 0.45,
    'visual attention': 0.45,

    # Task-specific (medium to high)
    'primary visual cortex (left visual field)': 0.5,
    'primary visual cortex (right visual field)': 0.5,
    'visual integration': 0.55,
    'auditory-motor integration': 0.65,
    'visual spatial processing': 0.7,
    'object recognition': 0.75,
    'high-level visual processing': 0.8,
    'auditory processing': 0.9,
    'primary auditory processing': 0.9,
    'speech perception': 0.9,
    'speech and language integration': 0.9,
    'language': 0.95,
    'language processing': 0.95,


    # Other (moderate)
    'decision making': 0.5,
    'motor planning': 0.75,
    'motor cortex - movement of left limbs': 0.5,
    'motor cortex - movement of right limbs': 0.5,
    'conflict monitoring': 0.6,
    'spatial processing': 0.65,
    'emotional regulation': 0.3,
    'tactile processing': 0.3,
    'sound localization': 0.8,
    'auditory association': 0.8,
    'visual processing': 0.6,
}



In [None]:
def generate_function_vectors(networks, node_functions, function_weights, plot_type='heatmap'):
    # Get all cognitive functions mapped by the 64-channel EEG
    all_functions = sorted({func for funcs in node_functions.values() for func in funcs})

    # Normalize PECs across all networks
    all_pecs = [net['pec'] for net in networks]
    max_pec = max(all_pecs)
    min_pec = min(all_pecs)

    def normalize_pec(pec):
        return 1 - (pec - min_pec) / (max_pec - min_pec + 1e-6)

    data = []
    ids = []

    for net in networks:
        weighted_vector = defaultdict(float)
        pec_weight = normalize_pec(net['pec'])

        for node in net['net_nodes']:
            funcs = node_functions.get(node, [])
            for f in funcs:
                base_weight = function_weights.get(f, 0.0)
                weighted_vector[f] += pec_weight * base_weight

        vec = [weighted_vector.get(f, 0.0) for f in all_functions]
        data.append(vec)
        ids.append(f"{net['subject']}_{net['task']}")

    df = pd.DataFrame(data, columns=all_functions, index=ids)

    if plot_type == 'bar':
        avg_weights = df.mean().sort_values(ascending=False)
        plt.figure(figsize=(12, 6))
        sns.barplot(x=avg_weights.values, y=avg_weights.index, palette="viridis")
        plt.title("Average PEC-weighted Function Contributions")
        plt.xlabel("Weighted Importance")
        plt.tight_layout()
        plt.show()

    elif plot_type == 'heatmap':
        plt.figure(figsize=(14, 10))
        sns.heatmap(df.T, cmap="mako", cbar_kws={'label': 'PEC-weighted score'})
        plt.title("Function Contributions Across Networks")
        plt.xlabel("Networks")
        plt.ylabel("Functions")
        plt.tight_layout()
        plt.show()

    return df


In [None]:
df_vectors = generate_function_vectors(classification_data, node_functions, function_weights)


In [None]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler

def plot_dimensionality_reduction(df_vectors, method='pca'):
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(df_vectors.values)
    labels = [idx.split('_')[1] for idx in df_vectors.index]

    if method == 'pca':
        reducer = PCA(n_components=2)
        X_red = reducer.fit_transform(X_scaled)
        title = 'PCA of Function Vectors'
    elif method == 'tsne':
        reducer = TSNE(n_components=2, perplexity=30, random_state=42)
        X_red = reducer.fit_transform(X_scaled)
        title = 't-SNE of Function Vectors'
    else:
        raise ValueError("method must be 'pca' or 'tsne'")

    df_plot = pd.DataFrame(X_red, columns=['Dim1', 'Dim2'])
    df_plot['Task'] = labels

    plt.figure(figsize=(8, 6))
    sns.scatterplot(data=df_plot, x='Dim1', y='Dim2', hue='Task', palette='Set2', s=80)
    plt.title(title)
    plt.tight_layout()
    plt.show()

plot_dimensionality_reduction(df_vectors, method='pca')
plot_dimensionality_reduction(df_vectors, method='tsne')


In [None]:
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score, roc_auc_score, confusion_matrix, ConfusionMatrixDisplay, classification_report, accuracy_score
from sklearn.preprocessing import label_binarize, LabelEncoder
from sklearn.svm import SVC

def classify_tasks_v2(df_vectors, classifier=None, plot=True):
    # Prepare features and labels
    X = df_vectors.values
    y_raw = [idx.split('_')[1] for idx in df_vectors.index]
    le = LabelEncoder()
    y = le.fit_transform(y_raw)
    class_labels = le.classes_

    # Initialize classifier
    if classifier is None:
        classifier = SVC(kernel='linear', probability=True, random_state=42, class_weight='balanced')

    # K-Fold Cross Validation
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    f1_scores = []
    auc_scores = []
    confusion_matrix_all = np.zeros((len(np.unique(y)), len(np.unique(y))))
    skipped_auc_folds = 0

    all_y_true = []
    all_y_pred = []

    for train_index, test_index in skf.split(X, y):
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]

        classifier.fit(X_train, y_train)
        y_pred = classifier.predict(X_test)
        y_pred_prob = classifier.predict_proba(X_test)

        f1 = f1_score(y_test, y_pred, average='weighted')
        f1_scores.append(f1)

        all_y_true.extend(y_test)
        all_y_pred.extend(y_pred)

        if set(np.unique(y_test)) == set(np.unique(y)):
            y_test_bin = label_binarize(y_test, classes=np.unique(y))
            auc_score = roc_auc_score(y_test_bin, y_pred_prob, average='weighted', multi_class='ovr')
            auc_scores.append(auc_score)
        else:
            skipped_auc_folds += 1

        cm = confusion_matrix(y_test, y_pred, labels=np.unique(y))
        confusion_matrix_all += cm

    print(f"\nAverage F1 Score (weighted): {np.mean(f1_scores):.2f}")
    if auc_scores:
        print(f"Average AUC (weighted): {np.mean(auc_scores):.2f}")
    else:
        print(f"AUC not calculated for {skipped_auc_folds} fold(s) due to missing classes.")

    # Confusion matrix
    disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix_all.astype(int),
                                  display_labels=le.inverse_transform(np.unique(y)))
    disp.plot(cmap='Blues', values_format='d')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.show()

    # Class-wise accuracy and sample counts
    total_per_class = np.bincount(y)
    correct_per_class = np.diag(confusion_matrix_all)
    accuracy_per_class = correct_per_class / total_per_class

    print("\nClass-wise accuracy and sample count:")
    for label, acc, count in zip(class_labels, accuracy_per_class, total_per_class):
        print(f"  - {label}: {acc:.2f} accuracy over {int(count)} samples")

    # Classification report (precision, recall, f1 for each class)
    print("\n📄 Classification report:")
    print(classification_report(all_y_true, all_y_pred, target_names=class_labels))

    # Feature importance heatmap
    if plot and hasattr(classifier, 'coef_'):
        importances = np.mean(np.abs(classifier.coef_), axis=0)
        if importances.shape[0] == df_vectors.shape[1]:
            top_idx = np.argsort(importances)[::-1][:30]
            top_features = np.array(df_vectors.columns)[top_idx]
            top_importances = importances[top_idx]

            heat_df = pd.DataFrame({'Feature': top_features, 'Importance': top_importances})
            heat_df = heat_df.set_index('Feature').T

            plt.figure(figsize=(14, 2))
            sns.heatmap(heat_df, cmap='YlGnBu', annot=False, cbar=True)
            plt.title("Top 30 Feature Importances")
            plt.yticks(rotation=0)
            plt.tight_layout()
            plt.show()
        else:
            print(f"Skipped feature importance plot: {importances.shape[0]} model coefficients vs {df_vectors.shape[1]} feature columns.")


In [None]:
classify_tasks_v2(df_vectors)


In [None]:
results = []

for subject_id in range(1, 25):
    print(f"\nSubject {subject_id}")
    try:
        subject_data = load_subject(id=subject_id)
        df_vectors = generate_function_vectors(subject_data, node_functions, function_weights, plot_type=None)

        metrics = {
            'subject': subject_id,
            'f1_scores': [],
            'auc_scores': [],
            'class_accuracies': {},
            'class_counts': {},
            'feature_importances': {}
        }

        # Local capture of classification results
        class LocalClassifierLogger:
            def __init__(self):
                self.y_true = []
                self.y_pred = []
                self.cm_all = None
                self.f1_scores = []
                self.auc_scores = []
                self.class_labels = []
                self.class_counts = None
                self.accuracy_per_class = None
                self.feature_importances = None

        logger = LocalClassifierLogger()

        def classify_and_capture(df_vectors, logger, classifier=None, plot=False):

          X = df_vectors.values
          y_raw = [idx.split('_')[1] for idx in df_vectors.index]

          # Remove under-represented classes
          class_counts_raw = pd.Series(y_raw).value_counts()
          min_samples = 5
          valid_classes = class_counts_raw[class_counts_raw >= min_samples].index.tolist()

          # Filter raw labels and data
          mask = [label in valid_classes for label in y_raw]
          X = X[mask]
          y_raw = np.array(y_raw)[mask]
          df_vectors = df_vectors.iloc[mask]

          # Label encode
          le = LabelEncoder()
          y = le.fit_transform(y_raw)
          logger.class_labels = le.classes_


          if classifier is None:
              classifier = SVC(kernel='linear', probability=True, random_state=42, class_weight='balanced')

          skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
          logger.cm_all = np.zeros((len(np.unique(y)), len(np.unique(y))))
          skipped_auc_folds = 0

          for train_idx, test_idx in skf.split(X, y):
              X_train, X_test = X[train_idx], X[test_idx]
              y_train, y_test = y[train_idx], y[test_idx]

              # Fit model
              classifier.fit(X_train, y_train)

              # Predictions
              y_pred = classifier.predict(X_test)
              logger.y_true.extend(y_test)
              logger.y_pred.extend(y_pred)

              # F1 Score
              f1 = f1_score(y_test, y_pred, average='weighted')
              logger.f1_scores.append(f1)

              # AUC Score — only if proba matches class count
              if hasattr(classifier, "predict_proba"):
                  y_pred_prob = classifier.predict_proba(X_test)

                  # Proceed only if class count matches and test includes all classes
                  if (y_pred_prob.shape[1] == len(le.classes_)) and (set(np.unique(y_test)) == set(np.unique(y))):
                      y_test_bin = label_binarize(y_test, classes=np.unique(y))
                      auc = roc_auc_score(y_test_bin, y_pred_prob, average='weighted', multi_class='ovr')
                      logger.auc_scores.append(auc)
                  else:
                      skipped_auc_folds += 1

              # Update confusion matrix
              cm = confusion_matrix(y_test, y_pred, labels=np.unique(y))
              logger.cm_all += cm

          total_per_class = np.bincount(y)
          correct_per_class = np.diag(logger.cm_all)
          logger.accuracy_per_class = correct_per_class / total_per_class
          logger.class_counts = total_per_class

          # Feature importances (if linear classifier)
          if hasattr(classifier, 'coef_'):
              importances = np.mean(np.abs(classifier.coef_), axis=0)
              if importances.shape[0] == df_vectors.shape[1]:
                  feature_names = df_vectors.columns
                  top_idx = np.argsort(importances)[::-1]
                  logger.feature_importances = dict(zip(feature_names[top_idx], importances[top_idx]))


        # Run classification
        classify_and_capture(df_vectors, logger)

        # Store results
        metrics['f1_scores'] = logger.f1_scores
        metrics['auc_scores'] = logger.auc_scores
        metrics['f1'] = float(np.nan_to_num(np.mean(logger.f1_scores), nan=0.0))
        metrics['auc'] = float(np.nan_to_num(np.mean(logger.auc_scores), nan=0.0))
        metrics['class_accuracies'] = dict(zip(logger.class_labels, logger.accuracy_per_class))
        metrics['class_counts'] = dict(zip(logger.class_labels, logger.class_counts))
        if logger.feature_importances:
            metrics['feature_importances'] = logger.feature_importances

        results.append(metrics)

    except ValueError as e:
        print(f"Skipping Subject {subject_id}: {e}")
        continue

# Example: convert to DataFrame
df_summary = pd.DataFrame(results)

# Optional: save to CSV
df_summary.to_csv(project_root/'data/results/eeg/subject_classification_summary.csv', index=False)


More careful classification:

In [None]:
from sklearn.svm import SVC
from sklearn.metrics import (
    f1_score, roc_auc_score, average_precision_score,
    precision_score, recall_score, confusion_matrix
)
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder, label_binarize
import numpy as np
import pandas as pd

results = []

def compute_classification_metrics(X, y_raw, classifier, le, n_classes):
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    y = le.transform(y_raw)

    metrics = {
        'f1_macro': [],
        'f1_weighted': [],
        'roc_auc': [],
        'pr_auc': [],
        'precision_weighted': [],
        'recall_weighted': [],
        'confusion_matrices': []
    }

    for train_idx, test_idx in skf.split(X, y):
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        classifier.fit(X_train, y_train)
        y_pred = classifier.predict(X_test)

        # AUC and PR-AUC
        if hasattr(classifier, "predict_proba"):
            y_prob = classifier.predict_proba(X_test)
            if y_prob.shape[1] == n_classes:
                y_test_bin = label_binarize(y_test, classes=np.arange(n_classes))
                metrics['roc_auc'].append(
                    roc_auc_score(y_test_bin, y_prob, average='weighted', multi_class='ovr')
                )
                metrics['pr_auc'].append(
                    average_precision_score(y_test_bin, y_prob, average='weighted')
                )

        # Core classification metrics
        metrics['f1_macro'].append(f1_score(y_test, y_pred, average='macro'))
        metrics['f1_weighted'].append(f1_score(y_test, y_pred, average='weighted'))
        metrics['precision_weighted'].append(precision_score(y_test, y_pred, average='weighted'))
        metrics['recall_weighted'].append(recall_score(y_test, y_pred, average='weighted'))
        metrics['confusion_matrices'].append(confusion_matrix(y_test, y_pred, labels=np.arange(n_classes)))

    return metrics


# Subject loop (simplified)
for subject_id in range(1, 25):
    try:
        subject_data = load_subject(id=subject_id)
        df_vectors = generate_function_vectors(subject_data, node_functions, function_weights, plot_type=None)

        y_raw = [idx.split('_')[1] for idx in df_vectors.index]
        class_counts_raw = pd.Series(y_raw).value_counts()
        min_samples = 5
        valid_classes = class_counts_raw[class_counts_raw >= min_samples].index.tolist()
        mask = [label in valid_classes for label in y_raw]
        df_vectors = df_vectors.iloc[mask]
        y_raw = np.array(y_raw)[mask]
        X = df_vectors.values

        if len(np.unique(y_raw)) < 2:
            print(f"Skipping Subject {subject_id}: Not enough valid classes")
            continue

        le = LabelEncoder()
        le.fit(y_raw)
        n_classes = len(le.classes_)
        classifier = SVC(kernel='linear', probability=True, random_state=42, class_weight='balanced')

        detailed_metrics = compute_classification_metrics(X, y_raw, classifier, le, n_classes)

        result = {
            'subject': subject_id,
            'f1_macro': float(np.nanmean(detailed_metrics['f1_macro'])),
            'f1_weighted': float(np.nanmean(detailed_metrics['f1_weighted'])),
            'precision_weighted': float(np.nanmean(detailed_metrics['precision_weighted'])),
            'recall_weighted': float(np.nanmean(detailed_metrics['recall_weighted'])),
            'roc_auc': float(np.nanmean(detailed_metrics['roc_auc'])) if detailed_metrics['roc_auc'] else np.nan,
            'pr_auc': float(np.nanmean(detailed_metrics['pr_auc'])) if detailed_metrics['pr_auc'] else np.nan,
        }

        results.append(result)

    except ValueError as e:
        print(f"Skipping Subject {subject_id}: {e}")
        continue

# Final results DataFrame
df_summary = pd.DataFrame(results)


In [None]:
df_summary = pd.read_csv(project_root/'data/results/eeg/subject_classification_summary.csv')

Check F1 and AUC:

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

set_plotting_style()

df_long = df_summary[['subject', 'f1', 'auc']].melt(id_vars='subject', var_name='metric', value_name='score')

fig, axes = plt.subplots(2, 1, figsize=(8, 8), sharex=True)

# Use gray to reduce unnecessary emphasis
bar_color = 'gray'

# F1 subplot
sns.barplot(data=df_summary, x='subject', y='f1', color=bar_color, ax=axes[0])
axes[0].set_title('Task Prediction Scores', pad=10)
axes[0].set_ylabel('F1 Score')
axes[0].set_ylim(0, 1.05)
axes[0].grid(axis='y', linestyle='--', alpha=0.4)
axes[0].set_xlabel('')  # Remove x-axis label

# AUC subplot
sns.barplot(data=df_summary, x='subject', y='auc', color=bar_color, ax=axes[1])
axes[1].set_ylabel('ROC-AUC')
axes[1].set_ylim(0, 1.05)
axes[1].grid(axis='y', linestyle='--', alpha=0.4)
axes[1].set_xlabel('')  # Remove default x-axis label

# Add single, centered x-axis label below both subplots
fig.text(0.5, 0.01, 'Subject ID', ha='center', va='center')

# Rotate ticks and finalize layout
plt.setp(axes[1].xaxis.get_majorticklabels(), rotation=90)
plt.tight_layout(rect=[0, 0.03, 1, 1])  # Leave space at bottom for xlabel

# Save and show
plt.savefig(figures_folder / "F1_AUC_CleanSubplots.png", dpi=300)
plt.show()




In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

set_plotting_style()

df_long = df_summary[['subject', 'f1', 'auc']].melt(id_vars='subject', var_name='metric', value_name='score')

df_long['metric'] = df_long['metric'].map({'f1': 'F1', 'auc': 'ROC-AUC'})

# Plot
plt.figure(figsize=(12, 6))
sns.barplot(data=df_long, x='subject', y='score', hue='metric', palette='gray')

plt.xticks(rotation=90)
plt.xlabel('Subject ID', labelpad=15)
plt.ylabel('')
plt.ylim(0, 1.05)
plt.legend(loc='upper right')
plt.title('Task Prediction Scores', pad=10)
plt.grid(axis='y', linestyle='--', alpha=0.5)
plt.tight_layout()

# Save
plt.savefig(figures_folder / "F1_AUC_GroupedBar.png", dpi=300)
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

set_plotting_style()

# Prepare data in long format
df_long = df_summary[['subject', 'f1', 'auc']].melt(id_vars='subject', var_name='metric', value_name='score')

plt.figure(figsize=(4,5))

# Boxplot with datapoints overlay
sns.boxplot(data=df_long, x='metric', y='score', color='lightgray', fliersize=0, width=0.5)
sns.stripplot(data=df_long, x='metric', y='score', color='black', size=6, jitter=True, alpha=0.7)

plt.title('Task Prediction Scores', pad=10)
plt.ylim(0, 1.05)
plt.ylabel('')
plt.xlabel('')
plt.xticks([0, 1], ['F1', 'ROC-AUC'], fontsize=14)
plt.tight_layout()

plt.savefig(figures_folder / "AUC_F1_Boxplot.png", dpi=300)
plt.show()


Check class accuracies:

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Patch
from scipy.stats import ttest_rel
from itertools import combinations

# 1. Prepare Accuracy Data
acc_records = []

for _, row in df_summary.iterrows():
    subject = row['subject']
    acc_dict = eval(row['class_accuracies'], {"np": np})
    for task, acc in acc_dict.items():
        acc_records.append({'subject': subject, 'task': task, 'accuracy': acc})

acc_df = pd.DataFrame(acc_records)
acc_df['accuracy'] = acc_df['accuracy'].replace([np.inf, -np.inf], np.nan)

task_order = ['S1', 'S2', 'S3', 'S4', 'S5', 'S6', 'S7']
acc_df['task'] = pd.Categorical(acc_df['task'], categories=task_order, ordered=True)

task_labels = {
    'S1': 'Original',
    'S2': 'Scrambled',
    'S3': 'Randomized',
    'S4': 'Original',
    'S5': 'Scrambled',
    'S6': 'Randomized',
    'S7': 'Rest'
}

# 2. Define Color Palette
custom_palette = {
    'S1': '#a6dba0',  # light teal
    'S2': '#5aae61',  # medium teal
    'S3': '#1b7837',  # dark teal
    'S4': '#80cdc1',  # light blue
    'S5': '#4393c3',  # medium blue
    'S6': '#2166ac',  # dark blue
    'S7': '#bdbdbd'   # gray
}

# 3. Compute Statistical Tests
task_pairs = list(combinations(task_order, 2))
p_values = []

from scipy.stats import ttest_rel, wilcoxon, mannwhitneyu, shapiro

p_values = []
test_types = []

for task1, task2 in task_pairs:
    d1 = acc_df[acc_df['task'] == task1].sort_values('subject')
    d2 = acc_df[acc_df['task'] == task2].sort_values('subject')

    common_subjects = set(d1['subject']) & set(d2['subject'])
    d1_common = d1[d1['subject'].isin(common_subjects)]
    d2_common = d2[d2['subject'].isin(common_subjects)]

    d1_vals = d1_common['accuracy'].values
    d2_vals = d2_common['accuracy'].values

    # CASE 1: Paired design (same subjects in both)
    if len(d1_vals) > 0 and len(d1_vals) == len(d2_vals):
        diff = d1_vals - d2_vals

        # Check for normality of differences
        if len(diff) >= 3:  # Shapiro needs at least 3 samples
            _, p_norm = shapiro(diff)
        else:
            p_norm = 1.0  # Assume normal if too few samples

        if p_norm > 0.05:
            stat, p = ttest_rel(d1_vals, d2_vals)
            test = 'paired t-test'
        else:
            try:
                stat, p = wilcoxon(d1_vals, d2_vals)
                test = 'wilcoxon'
            except ValueError:
                p = 1.0
                test = 'wilcoxon (fail)'
        
        p_values.append(((task1, task2), p))
        test_types.append(((task1, task2), test))

    # CASE 2: Unpaired design
    else:
        d1_vals = d1['accuracy'].values
        d2_vals = d2['accuracy'].values
        try:
            stat, p = mannwhitneyu(d1_vals, d2_vals, alternative='two-sided')
            test = 'mannwhitney'
        except ValueError:
            p = 1.0
            test = 'mannwhitney (fail)'

        p_values.append(((task1, task2), p))
        test_types.append(((task1, task2), test))


def pval_to_stars(p):
    if p < 0.001: return '***'
    elif p < 0.01: return '**'
    elif p < 0.05: return '*'
    else: return 'n.s.'

def add_sig_bar(ax, x1, x2, y, h, p):
    bar_x = [x1, x1, x2, x2]
    bar_y = [y, y+h, y+h, y]
    ax.plot(bar_x, bar_y, c='k', lw=1)
    ax.text((x1 + x2)/2, y + h + 0.01, pval_to_stars(p), ha='center', va='bottom', fontsize=10)

# 4. Plotting
plt.figure(figsize=(14, 6))
ax = sns.boxplot(data=acc_df, x='task', y='accuracy', palette=custom_palette)
sns.stripplot(data=acc_df, x='task', y='accuracy', color='black', alpha=0.3, jitter=True)

# Custom X-axis labels
ax.set_xticklabels([task_labels[t] for t in task_order])

# Group dividers
plt.axvline(x=2.5, color='gray', linestyle='--', linewidth=1)
plt.axvline(x=5.5, color='gray', linestyle='--', linewidth=1)

# Group titles
plt.text(1, 1.03, 'Audiobook', ha='center', fontsize=14, color='#1b7837', weight='bold')
plt.text(4, 1.03, 'Video', ha='center', fontsize=14, color='#2166ac', weight='bold')
plt.text(6, 1.03, 'Rest', ha='center', fontsize=14, color='#555555', weight='bold')

# (Optional) Subject-level trends
# for subject in acc_df['subject'].unique():
#     subject_data = acc_df[acc_df['subject'] == subject].sort_values('task')
#     plt.plot(subject_data['task'].cat.codes, subject_data['accuracy'], color='gray', alpha=0.2, linewidth=0.8)

# Add significance bars
start_y = 1.01
bar_height = 0.02
visible_pvals = [((t1, t2), p) for ((t1, t2), p) in p_values if p < 0.05]

for idx, ((t1, t2), p) in enumerate(visible_pvals):
    x1 = task_order.index(t1)
    x2 = task_order.index(t2)
    y = start_y + idx * (bar_height + 0.005)
    # add_sig_bar(ax, x1, x2, y, bar_height, p)

# Final layout
plt.title('Class-wise Accuracy Across Subjects', pad=20)
plt.ylim(0, 1.05 + len(visible_pvals)*0.03)
plt.ylabel('Accuracy', labelpad=15)
plt.xlabel('')
plt.tight_layout()

# Save
plt.savefig(figures_folder / "Class_Accuracy_Final_with_Stats.png", dpi=300)
plt.show()


Check feature importances:

In [None]:
def load_dict_simple(dict_str):
    """
    Parse a simple dict-like string into a dictionary.
    Assumes format like: "key1: val1, key2: val2, key3: val3"
    Keys and values can be stripped of whitespace.
    Values will be kept as strings.
    """
    result = {}
    # Remove possible surrounding braces and whitespace
    dict_str = dict_str.strip().strip('{}').strip()
    if not dict_str:
        return result
    
    items = dict_str.split(',')
    for item in items:
        if ':' not in item:
            continue  # skip if no colon
        
        key, value = item.strip().split(':', 1)  # split only on first colon
        key = key.strip().strip("'\"")  # Strip quotes from keys
        try:
            value = eval(value.strip(), {"np": np})
        except Exception as e:
            print(f"Warning: failed to parse value '{value.strip()}': {e}")
            continue
        result[key] = value
    return result

from collections import Counter

all_features = Counter()

for fi in df_summary['feature_importances']:
    feat_dict = load_dict_simple(fi)
    for k, v in list(feat_dict.items())[:10]:
        try:
            v = float(v)  # force conversion to numeric
            all_features[k] += v
        except ValueError:
            print(f"Skipping non-numeric feature: {k}={v}")

# Plot top 20 features by cumulative importance
top_features = all_features.most_common(20)
feat_names, feat_vals = zip(*top_features)

set_plotting_style()
plt.figure(figsize=(10, 6))

# Use a clean, elegant color palette
sns.barplot(
    x=list(feat_vals),
    y=list(feat_names),
    palette=sns.color_palette("viridis", n_colors=len(feat_vals)),
    orient='h'
)

# Improve labels and aesthetics
plt.xlabel('Cumulative Feature Importance', fontsize=12)
plt.ylabel('Feature', fontsize=12)
plt.title('Top 20 Features Across Subjects', fontsize=14)

# Style tweaks
sns.despine(left=True, bottom=True)
plt.grid(axis='x', linestyle='--', alpha=0.4)
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.tight_layout()
# Save high-res
plt.savefig(figures_folder/f"Feature_Importance.png", dpi=300)

plt.show()
