In [127]:
import pyxdf
import numpy as np
import mne
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from psychopy_experiments.brand_associations.brands_wordlist import WORDLIST
from pathlib import Path

In [128]:
instagram_wordlist = WORDLIST['instagram']
linkedin_wordlist = WORDLIST['linkedin']
unrelated_wordlist = WORDLIST['unrelated']
CATEGORY_MAP = {item: k for k, v in WORDLIST.items() for item in v}

In [129]:
import pandas as pd
import pyxdf

def load_multiple_subjects(csv_paths, xdf_paths):
    if len(csv_paths) != len(xdf_paths):
        raise ValueError("Number of CSV and XDF files must match!")

    subjects = []  # list of dicts: subject, csv, streams
    for i, (csv_path, xdf_path) in enumerate(zip(csv_paths, xdf_paths), start=1):
        subject_id = i

        csv_df = pd.read_csv(csv_path)
        csv_df["subject"] = subject_id

        streams, _ = pyxdf.load_xdf(xdf_path)

        subjects.append({
            "subject": subject_id,
            "csv": csv_df,
            "streams": streams,
            "csv_path": csv_path,
            "xdf_path": xdf_path
        })

    return subjects


In [130]:
csv_paths = [
    'psychopy_experiments/brand_associations/levan_brand.csv',
    'psychopy_experiments/brand_associations/raindi_brand.csv',
    'psychopy_experiments/brand_associations/dato_brand.csv'
]

xdf_paths = [
    'psychopy_experiments/brand_associations/levan_brand.xdf',
    'psychopy_experiments/brand_associations/raindi_brand.xdf',
    'psychopy_experiments/brand_associations/dato_brand.xdf'
]

subjects = load_multiple_subjects(csv_paths, xdf_paths)

In [131]:
all_subjects = []

for subj_data in subjects:
    subject_id = subj_data["subject"]
    xdf_data = subj_data["streams"]
    csv_df = subj_data["csv"]

    markers = xdf_data[0]
    eeg = xdf_data[1]

    marker_time = markers["time_stamps"].copy()
    marker_data = np.array([x[0] for x in markers["time_series"]], dtype=int)


    TARGET_MARKER = 1
    keep_m = marker_data == TARGET_MARKER
    marker_time = marker_time[keep_m]
    marker_data = marker_data[keep_m]

    # time alignment
    time_offset = marker_time[0]
    marker_time = marker_time - time_offset

    eeg_time = eeg["time_stamps"].copy()
    CONST_OFFSET = 0.073
    eeg_time = eeg_time - time_offset - CONST_OFFSET

    eeg_data = eeg["time_series"][:, :8]


    info = mne.create_info(ch_names=['Fz', 'C3', 'Cz', 'C4', 'Pz', 'PO7', 'Oz', 'PO8'], ch_types=['eeg'] * 8,
                       sfreq=250)
    raw = mne.io.RawArray([1e-6 * eeg_data[:, i] for i in range(8)], info)
    raw.notch_filter(freqs=[50])
    raw.filter(0.5, 30, method="iir", iir_params=dict(order=4, ftype="butter"))


    samples = np.searchsorted(eeg_time, marker_time, side="left")
    samples = np.clip(samples, 0, len(eeg_time) - 1)

    # MNE events: [sample, 0, event_id]
    events = np.column_stack([samples, np.zeros(len(samples), dtype=int), marker_data]).astype(int)
    events = events[:-1]

    # epoching
    reject_criteria = dict(eeg=70e-6)  # 70 µV
    event_dict = dict(target=TARGET_MARKER)

    epochs = mne.Epochs(
        raw, events, event_id=event_dict,
        tmin=-0.1, tmax=1.0,
        baseline=(-0.1, 0.0),
        reject=reject_criteria,
        preload=True
    )

    all_subjects.append({
        "subject": subject_id,
        "raw": raw,
        "epochs": epochs,
        "events": events,
        "csv": csv_df
        # "const_offset": CONST_OFFSET
    })


Creating RawArray with float64 data, n_channels=8, n_times=113720
    Range : 0 ... 113719 =      0.000 ...   454.876 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 49 - 51 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 49.38
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 49.12 Hz)
- Upper passband edge: 50.62 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 50.88 Hz)
- Filter length: 1651 samples (6.604 s)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 30 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoffs at 0.50, 30.00 Hz: -6.02, -

In [132]:
for subj in all_subjects:
    epochs = subj["epochs"]
    events = subj["events"]
    csv_data = subj["csv"]

    # 1) Align to what survived rejection (epochs.selection refers to ORIGINAL events/csv)
    good_idx = epochs.selection
    csv_kept = csv_data.iloc[good_idx].reset_index(drop=True)
    csv_kept['category'] = csv_kept['target'].map(CATEGORY_MAP)
    # csv_kept["congruency"] = csv_kept['target'].isin(unrelated_wordlist).to_numpy()
    events_kept = events[good_idx]


    # 2) Now drop "no answer" trials using the KEPT csv (same index space as epochs)
    has_response = csv_kept["resp_key"].notna().to_numpy()

    epochs_resp = epochs[has_response]
    events_resp = events_kept[has_response]
    csv_resp = csv_kept.loc[has_response].reset_index(drop=True)

    epochs_bf = epochs_resp.copy().resample(80, npad="auto")

    # 3) Save back
    subj["epochs"] = epochs_resp
    subj["epochs_bf"] = epochs_bf
    subj["events"] = events_resp
    subj["csv"] = csv_resp


In [133]:
def make_temporal_window(times, tmin=0.30, tmax=0.45, kind="hann"):

    mask = (times >= tmin) & (times <= tmax)
    if not mask.any():
        raise ValueError("No time points inside the requested N400 window.")

    w = np.zeros_like(times, dtype=float)

    if kind == "hann":
        w[mask] = np.hanning(mask.sum())
    elif kind == "rect":
        w[mask] = 1.0
    else:
        raise ValueError("kind must be 'hann' or 'rect'")

    return w, mask


In [134]:
def build_n400_template_from_unrelated(
    epochs,
    unrelated_idx,
    tmin=0.30,
    tmax=0.45,
    window_kind="hann"
):
    """
    Build a GLOBAL spatio-temporal N400 template from unrelated trials.

    Parameters
    ----------
    epochs : mne.Epochs
        Epochs containing ONLY unrelated trials (or pooled unrelated trials)
    unrelated_idx : array-like
        Indices of unrelated trials inside epochs
    tmin, tmax : float
        N400 time window in seconds
    window_kind : {"hann", "rect"}
        Temporal weighting

    Returns
    -------
    s : ndarray, shape (n_channels * n_times,)
        Flattened spatio-temporal N400 template
    """

    # Epoch data: (n_trials, n_channels, n_times)
    X = epochs.get_data()
    times = epochs.times

    n_ch = X.shape[1]
    n_t = X.shape[2]

    # Temporal window
    w_t, _ = make_temporal_window(times, tmin=tmin, tmax=tmax, kind=window_kind)

    # Average ERP over unrelated trials
    erp_unrel = X[unrelated_idx].mean(axis=0)   # (n_ch, n_times)

    # Apply temporal weighting
    erp_weighted = erp_unrel * w_t[np.newaxis, :]

    # Flatten to spatio-temporal vector
    s = erp_weighted.reshape(n_ch * n_t)

    # Normalize (important for numerical stability)
    norm = np.linalg.norm(s)
    if norm > 0:
        s = s / norm

    return s


In [None]:
unrel_epochs_list = []
for subj in all_subjects:   # <-- use the list you updated
    epochs_bf = subj["epochs_bf"]
    csv = subj["csv"]
    unrel_mask = csv["target"].isin(unrelated_wordlist).to_numpy()
    unrel_epochs_list.append(epochs_bf[unrel_mask])

epochs_unrel_all = mne.concatenate_epochs(unrel_epochs_list)


# build global template s (spatio-temporal)
template_s = build_n400_template_from_unrelated(
    epochs_unrel_all,
    unrelated_idx=np.arange(len(epochs_unrel_all)),
    tmin=0.30, tmax=0.45,           # paper-aligned N400 emphasis
    window_kind="hann"
)


Not setting metadata
71 matching events found
Applying baseline correction (mode: mean)


In [None]:
# print(len(epochs_unrel_all))


71


In [None]:
# missing = [s["subject"] for s in all_subjects if "epochs_bf" not in s]
# missing


[]

In [139]:
from sklearn.covariance import LedoitWolf

def st_lcmv_scores(epochs, template_s, cov_shrinkage=True, eps=1e-8):
    """
    Spatio-temporal LCMV / matched-filter score:
        w = R^{-1} s / (s^T R^{-1} s)
        score_i = x_i^T w

    epochs: mne.Epochs
    template_s: (n_features,) where n_features = n_channels * n_times

    Returns:
        scores: (n_trials,) single scalar per epoch
    """
    X = epochs.get_data().reshape(len(epochs), -1)  # (n_trials, n_features)

    # Covariance of features across trials (robust shrinkage recommended)
    if cov_shrinkage:
        R = LedoitWolf().fit(X).covariance_
    else:
        R = np.cov(X, rowvar=False)

    # Regularize for stability
    R = R + eps * np.eye(R.shape[0])

    s = template_s.reshape(-1, 1)             # (n_features, 1)
    Rinv_s = np.linalg.solve(R, s)            # (n_features, 1)
    denom = float(s.T @ Rinv_s)               # scalar

    if np.isclose(denom, 0.0):
        raise ValueError("Template leads to near-zero denominator; check template/covariance.")

    w = (Rinv_s / denom).reshape(-1)          # (n_features,)
    scores = X @ w                             # (n_trials,)
    return scores


In [142]:
MIN_USEFUL = 30  # paper threshold (subjects with <30 useful trials excluded)

kept_subjects = []
dropped_subjects = []

for subj in all_subjects:   # or `subjects` if that is your list
    sid = subj["subject"]
    epochs_bf = subj["epochs_bf"]
    csv = subj["csv"].copy()

    # masks in current (already aligned) index space
    is_unrel = csv["target"].isin(unrelated_wordlist).to_numpy()
    keep_mask = ~is_unrel

    n_useful = int(keep_mask.sum())
    subj["n_useful"] = n_useful

    if n_useful < MIN_USEFUL:
        dropped_subjects.append((sid, n_useful))
        continue

    # Remove Unrelated trials from further N400 analysis (paper step)
    epochs_rel = epochs_bf[keep_mask]
    csv_rel = csv.loc[keep_mask].reset_index(drop=True)

    # stLCMV single-trial N400 strength
    n400_raw = st_lcmv_scores(epochs_rel, template_s, cov_shrinkage=True)

    # Optional sign fix: ensure incongruent is more negative than congruent (N400 effect direction)
    # cong = (csv_rel["congruency"] == "congruent").to_numpy()
    # incong = (csv_rel["congruency"] == "incongruent").to_numpy()
    # if cong.any() and incong.any():
    #     if n400_raw[incong].mean() > n400_raw[cong].mean():
    #         n400_raw *= -1.0

    # Per-subject z-score (Van Petten)
    n400_z = (n400_raw - n400_raw.mean()) / (n400_raw.std(ddof=0) + 1e-12)

    # Attach to CSV
    csv_rel["n400_raw"] = n400_raw
    csv_rel["n400_z"] = n400_z
    csv_rel["subject"] = sid

    # Save back
    subj["epochs_rel_bf"] = epochs_rel
    subj["csv_n400"] = csv_rel

    kept_subjects.append(subj)

print("Dropped subjects (<30 useful trials):", dropped_subjects)
print("Kept subjects:", [s["subject"] for s in kept_subjects])


  denom = float(s.T @ Rinv_s)               # scalar
  denom = float(s.T @ Rinv_s)               # scalar


Dropped subjects (<30 useful trials): []
Kept subjects: [1, 2, 3]


  denom = float(s.T @ Rinv_s)               # scalar


In [143]:
all_df = pd.concat([s["csv_n400"] for s in kept_subjects], ignore_index=True)
all_df.head()


Unnamed: 0,timestamp_iso,trial_index,brand,target,prime_time_s,target_time_s,resp_window_s,resp_key,rt_ms_from_target,subject,category,n400_raw,n400_z
0,2025-12-04T06:21:55.721,11,linkedin,რილსები,0.16,0.16,1.5,right,678.8,1,instagram,4.2e-05,0.729212
1,2025-12-04T06:22:07.444,15,instagram,კარიერა,0.16,0.16,1.5,right,1513.15,1,linkedin,-4.5e-05,-1.487989
2,2025-12-04T06:22:16.236,18,instagram,ეიჩარი,0.16,0.16,1.5,right,1578.52,1,linkedin,-1.2e-05,-0.653259
3,2025-12-04T06:22:33.738,24,linkedin,ფოტოები,0.16,0.16,1.5,right,1378.9,1,instagram,6.5e-05,1.325467
4,2025-12-04T06:22:54.137,31,instagram,ვიდეოები,0.16,0.16,1.5,right,748.56,1,instagram,4.4e-05,0.791671


In [146]:
word_subj = (
    all_df.groupby(["subject", "brand", "target"], as_index=False)
          .agg(
              n400_word=("n400_z", "mean"),
              n_rep=("n400_z", "size"),
              rt_mean=("rt_ms_from_target", "mean") if "rt_ms_from_target" in all_df.columns else ("n400_z","size")
          )
)
word_subj.head()



Unnamed: 0,subject,brand,target,n400_word,n_rep,rt_mean
0,1,instagram,ეიჩარი,-0.653259,1,1578.52
1,1,instagram,ვაკანსია,-0.412356,2,289.17
2,1,instagram,ვიდეოები,0.791671,1,748.56
3,1,instagram,კარიერა,-0.826926,2,1129.86
4,1,instagram,კონტაქტი,-1.002737,1,762.79


In [151]:
word_group = (
    word_subj.groupby(["brand", "target"], as_index=False)
             .agg(
                 n400_expected=("n400_word", "mean"),
                 n_sub=("n400_word", "count"),
                 n_rep_total=("n_rep", "sum")
             )
)
word_group


Unnamed: 0,brand,target,n400_expected,n_sub,n_rep_total
0,instagram,ეიჩარი,-0.511148,3,5
1,instagram,ვაკანსია,-0.05921,3,5
2,instagram,ვიდეოები,0.090635,2,3
3,instagram,ინფლუენსერი,-0.46565,2,3
4,instagram,კარიერა,-0.429936,3,5
5,instagram,კონტაქტი,0.073124,3,4
6,instagram,ლაიქი,0.266875,1,2
7,instagram,ლურჯი,-1.06364,2,3
8,instagram,მეგობრები,-0.494285,3,5
9,instagram,ნეთვორქინგი,0.90679,2,2


In [153]:
word_group.sort_values(["brand", "n400_expected"], ascending=[True, False])


Unnamed: 0,brand,target,n400_expected,n_sub,n_rep_total
17,instagram,სივი,1.246918,3,4
9,instagram,ნეთვორქინგი,0.90679,2,2
14,instagram,სამსახური,0.782469,3,4
12,instagram,რეკრუტერი,0.70111,3,5
6,instagram,ლაიქი,0.266875,1,2
15,instagram,სელფი,0.207767,3,6
10,instagram,პოსტები,0.185109,2,3
18,instagram,ფოტოები,0.130254,2,3
2,instagram,ვიდეოები,0.090635,2,3
5,instagram,კონტაქტი,0.073124,3,4


In [None]:
good_idx = epochs.selection 

csv_kept = csv_data.iloc[good_idx].reset_index(drop=True)
events_kept = events[good_idx]

In [None]:
# Read marker data
marker_time = markers['time_stamps']
time_offset = marker_time[0]
marker_time = marker_time - time_offset
marker_data = [x[0] for x in markers['time_series']]
CONST_OFFSET = 0.073 
eeg_time = eeg['time_stamps']
eeg_time = eeg_time - time_offset - CONST_OFFSET
eeg_data = eeg['time_series'][:, :8]

In [None]:
CONST_OFFSET = 0.073  # Delay measured with photodiode that day

In [None]:
eeg_time = eeg['time_stamps']
eeg_time = eeg_time - time_offset - CONST_OFFSET
eeg_data = eeg['time_series'][:, :8]

1) The EEG signal was re-referenced offline to a mastoid reference, then filtered using a 4th order
Butterworth filter with range 0.5–30 Hz and the initial 2048 Hz sampling rate downed to 256
Hz (including anti-aliasing). 

2) The EOG signal was used to remove eye artifacts following the
AAA method proposed in [50]. 

3) Trials with EEG amplitudes exceeding 70μV on any of the
channels were considered to be affected by muscle artifacts and were discarded

4) as well as trials lacking button press responses (i.e., “no answer”). EEG epochs were extracted from -100 to
1000 ms to the onset of the stimulus, with the 100ms pre-stimulus period used for baseline correction.

5) Before applying the multivariate analysis methods, the signal was further downsampled to 80Hz to reduce dimensionality, as suggested in [43].

In [None]:
info = mne.create_info(ch_names=['Fz', 'C3', 'Cz', 'C4', 'Pz', 'PO7', 'Oz', 'PO8'], ch_types=['eeg'] * 8,
                       sfreq=250)
raw = mne.io.RawArray([1e-6 * eeg_data[:, i] for i in range(8)], info)
raw.notch_filter(freqs=[50])

iir_params = dict(order=4, ftype="butter")

raw.notch_filter(freqs=[50])  

raw.filter(
    l_freq=0.5,
    h_freq=30,
    method="iir",
    iir_params=iir_params)


Creating RawArray with float64 data, n_channels=8, n_times=113720
    Range : 0 ... 113719 =      0.000 ...   454.876 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 49 - 51 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 49.38
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 49.12 Hz)
- Upper passband edge: 50.62 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 50.88 Hz)
- Filter length: 1651 samples (6.604 s)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband a

Unnamed: 0,General,General.1
,MNE object type,RawArray
,Measurement date,Unknown
,Participant,Unknown
,Experimenter,Unknown
,Acquisition,Acquisition
,Duration,00:07:35 (HH:MM:SS)
,Sampling frequency,250.00 Hz
,Time points,113720
,Channels,Channels
,EEG,8


In [None]:
TARGET_MARKER = 1
RESPONSE_MARKER = 2

In [None]:
TARGET_MARKER = 1
RESPONSE_MARKER = 2
events = []
for i, marker in enumerate(marker_data):
    eeg_start_index = np.searchsorted(eeg_time, marker_time[i], side="left")
    eeg_start_index = np.clip(eeg_start_index, 0, len(eeg_time)-1)
    events.append([eeg_start_index, 0, marker])

events = np.asarray(events, dtype=int)
events = events[:-1]                
events = events[events[:, 2] == 1]  

In [None]:
# Define reject criteria
reject_criteria = dict(
    eeg=70e-6,  # 100 µV
)

In [None]:
event_dict = dict(target=1)
tmin, tmax = -0.1, 1.0 
baseline = (-0.1, 0.0)
epochs = mne.Epochs(raw.copy(), events, event_id=event_dict, tmin=tmin, tmax=tmax, preload=True,
                    baseline = baseline,  reject=reject_criteria)


Not setting metadata
120 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 120 events and 276 original time points ...
    Rejecting  epoch based on EEG : ['Fz', 'C3', 'Cz', 'C4', 'Pz', 'PO7', 'Oz', 'PO8']
    Rejecting  epoch based on EEG : ['Fz', 'C3', 'Cz', 'C4', 'Pz', 'PO7', 'Oz']
    Rejecting  epoch based on EEG : ['Fz', 'C3', 'Cz', 'Pz']
    Rejecting  epoch based on EEG : ['Fz', 'C3', 'Cz', 'C4', 'Pz', 'PO7', 'Oz', 'PO8']
    Rejecting  epoch based on EEG : ['Fz', 'C3', 'Cz']
    Rejecting  epoch based on EEG : ['Fz', 'C3', 'Cz']
    Rejecting  epoch based on EEG : ['Fz', 'C3', 'Cz', 'C4', 'Pz', 'PO7', 'Oz', 'PO8']
    Rejecting  epoch based on EEG : ['Fz', 'C3', 'Cz', 'C4', 'Pz', 'PO7', 'Oz', 'PO8']
    Rejecting  epoch based on EEG : ['Fz', 'C3', 'Cz', 'C4', 'Pz']
    Rejecting  epoch based on EEG : ['Fz', 'C3', 'Cz', 'C4', 'Pz', 'PO7', 'Oz']
    Rejecting  epoch based on EEG : ['Fz', 'C3', 'Cz', 'C4

In [None]:
good_idx = epochs.selection 

csv_kept = csv_data.iloc[good_idx].reset_index(drop=True)
events_kept = events[good_idx]


In [None]:
csv_data = csv_kept
events = events_kept

In [None]:
csv_data = csv_kept
events = events_kept
has_response = csv_data["resp_key"].notna().to_numpy()

epochs_resp = epochs[has_response]
events_resp = events[has_response]
csv_resp = csv_data.loc[has_response].reset_index(drop=True)

In [None]:
epochs = epochs_resp
events = events_resp
csv_data = csv_resp


In [None]:
is_ig_brand = (csv_data["brand"] == "instagram").to_numpy()
is_li_brand = (csv_data["brand"] == "linkedin").to_numpy()

is_ig_word = csv_data["target"].isin(instagram_wordlist).to_numpy()
is_li_word = csv_data["target"].isin(linkedin_wordlist).to_numpy()
is_un_word = csv_data["target"].isin(unrelated_wordlist).to_numpy()

ig_ig_idx = np.where(is_ig_brand & is_ig_word)[0]
ig_li_idx = np.where(is_ig_brand & is_li_word)[0]
ig_un_idx = np.where(is_ig_brand & is_un_word)[0]

li_li_idx = np.where(is_li_brand & is_li_word)[0]
li_ig_idx = np.where(is_li_brand & is_ig_word)[0]
li_un_idx = np.where(is_li_brand & is_un_word)[0]

ig_instagram = epochs[ig_ig_idx]
ig_linkedin  = epochs[ig_li_idx]
ig_unrelated = epochs[ig_un_idx]


li_instagram = epochs[li_ig_idx]  
li_linkedin  = epochs[li_li_idx]  
li_unrelated = epochs[li_un_idx]  


is_unrelated = csv_data["target"].isin(unrelated_wordlist).to_numpy()
is_related = (
    (csv_data["brand"] == "instagram") & csv_data["target"].isin(instagram_wordlist) |
    (csv_data["brand"] == "linkedin") & csv_data["target"].isin(linkedin_wordlist)
)

unrelated = np.where(is_unrelated)[0]
related =   np.where(is_related)[0]

In [None]:
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from scipy import stats

# ============================================================================
# STEP 1: N400 AMPLITUDE EXTRACTION
# ============================================================================

def apply_stLCMV_beamformer(epochs, template_spatial, template_temporal):
    """
    Apply spatiotemporal LCMV beamformer to extract N400 responses.
    This is a placeholder - you'll need to implement based on your stLCMV method.
    """
    # This should implement your stLCMV beamformer from reference [43]
    # Returns single-trial N400 amplitudes
    pass

# Create templates using ONLY unrelated trials
unrelated_epochs = epochs[unrelated]

# Build spatial and temporal templates from unrelated trials
# (Implementation depends on your stLCMV method)
template_spatial = np.mean(unrelated_epochs.get_data(), axis=0).mean(axis=-1)
template_temporal = np.mean(unrelated_epochs.get_data(), axis=0).mean(axis=0)

# Remove unrelated trials from further analysis
related_epochs = epochs[related]
related_brands = csv_data.iloc[related]["brand"].values
related_targets = csv_data.iloc[related]["target"].values

# Extract N400 for each trial using stLCMV
n400_responses = []
for i in range(len(related_epochs)):
    # Apply beamformer to single trial
    n400_amp = apply_stLCMV_beamformer(
        related_epochs[i], 
        template_spatial, 
        template_temporal
    )
    n400_responses.append(n400_amp)

n400_responses = np.array(n400_responses)

# ============================================================================
# STEP 2: CONVERT TO Z-SCORES PER SUBJECT
# ============================================================================

# Get subject IDs for related trials
related_subjects = csv_data.values
unique_subjects = np.unique(related_subjects)

# Z-score normalization per subject
n400_zscored = np.zeros_like(n400_responses)
for subj in unique_subjects:
    subj_mask = related_subjects == subj
    subj_n400 = n400_responses[subj_mask]
    
    # Convert to z-scores
    n400_zscored[subj_mask] = stats.zscore(subj_n400)

# ============================================================================
# STEP 3: SINGLE WORD CALCULATION (ECM for missing data)
# ============================================================================

def compute_word_means_with_ecm(n400_data, subjects, targets, words_list):
    """
    Compute expected N400 response for each word using ECM to handle missing data.
    """
    from sklearn.impute import IterativeImputer
    
    word_means = {}
    
    for word in words_list:
        word_mask = targets == word
        word_n400 = n400_data[word_mask]
        word_subjects = subjects[word_mask]
        
        # Create subject-indexed array (some may be missing)
        n_subjects = len(unique_subjects)
        subject_responses = np.full(n_subjects, np.nan)
        
        for i, subj in enumerate(unique_subjects):
            subj_trials = word_n400[word_subjects == subj]
            if len(subj_trials) > 0:
                subject_responses[i] = np.mean(subj_trials)
        
        # Use ECM (via IterativeImputer) to estimate mean with missing data
        if np.isnan(subject_responses).any():
            imputer = IterativeImputer(random_state=42)
            subject_responses_2d = subject_responses.reshape(-1, 1)
            imputed = imputer.fit_transform(subject_responses_2d)
            word_means[word] = np.mean(imputed)
        else:
            word_means[word] = np.nanmean(subject_responses)
    
    return word_means

# Get all unique target words for each brand
all_words_ig = csv_data[csv_data["brand"] == "instagram"]["target"].unique()
all_words_li = csv_data[csv_data["brand"] == "linkedin"]["target"].unique()

# Compute word means for Instagram brand
ig_mask = related_brands == "instagram"
ig_word_means = compute_word_means_with_ecm(
    n400_zscored[ig_mask],
    related_subjects[ig_mask],
    related_targets[ig_mask],
    all_words_ig
)

# Compute word means for LinkedIn brand
li_mask = related_brands == "linkedin"
li_word_means = compute_word_means_with_ecm(
    n400_zscored[li_mask],
    related_subjects[li_mask],
    related_targets[li_mask],
    all_words_li
)

# ============================================================================
# STEP 4: SINGLE-BRAND CLUSTERING ANALYSIS
# ============================================================================

def find_optimal_clusters(data, max_k=10):
    """
    Automatic determination of optimal number of clusters using silhouette score.
    """
    silhouette_scores = []
    K_range = range(2, min(max_k + 1, len(data)))
    
    for k in K_range:
        kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
        labels = kmeans.fit_predict(data.reshape(-1, 1))
        score = silhouette_score(data.reshape(-1, 1), labels)
        silhouette_scores.append(score)
    
    optimal_k = K_range[np.argmax(silhouette_scores)]
    return optimal_k

# Cluster Instagram words
ig_n400_values = np.array(list(ig_word_means.values()))
ig_words_list = list(ig_word_means.keys())

optimal_k_ig = find_optimal_clusters(ig_n400_values)
kmeans_ig = KMeans(n_clusters=optimal_k_ig, random_state=42, n_init=10)
ig_clusters = kmeans_ig.fit_predict(ig_n400_values.reshape(-1, 1))

print(f"Instagram: Optimal number of clusters = {optimal_k_ig}")
for cluster_id in range(optimal_k_ig):
    cluster_words = [ig_words_list[i] for i in range(len(ig_words_list)) 
                     if ig_clusters[i] == cluster_id]
    cluster_mean = np.mean(ig_n400_values[ig_clusters == cluster_id])
    print(f"  Cluster {cluster_id}: {cluster_words} (mean N400: {cluster_mean:.3f})")

# Cluster LinkedIn words
li_n400_values = np.array(list(li_word_means.values()))
li_words_list = list(li_word_means.keys())

optimal_k_li = find_optimal_clusters(li_n400_values)
kmeans_li = KMeans(n_clusters=optimal_k_li, random_state=42, n_init=10)
li_clusters = kmeans_li.fit_predict(li_n400_values.reshape(-1, 1))

print(f"\nLinkedIn: Optimal number of clusters = {optimal_k_li}")
for cluster_id in range(optimal_k_li):
    cluster_words = [li_words_list[i] for i in range(len(li_words_list)) 
                     if li_clusters[i] == cluster_id]
    cluster_mean = np.mean(li_n400_values[li_clusters == cluster_id])
    print(f"  Cluster {cluster_id}: {cluster_words} (mean N400: {cluster_mean:.3f})")

# ============================================================================
# STEP 5: MULTI-BRAND COMPARISON
# ============================================================================

# Find common words between brands
common_words = set(ig_words_list) & set(li_words_list)

if len(common_words) > 0:
    print(f"\n\nMulti-brand comparison for {len(common_words)} common words:")
    
    for word in common_words:
        ig_n400 = ig_word_means[word]
        li_n400 = li_word_means[word]
        difference = ig_n400 - li_n400
        
        print(f"  {word}: IG={ig_n400:.3f}, LI={li_n400:.3f}, diff={difference:.3f}")

# Visualize results
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Instagram clustering
axes[0].scatter(range(len(ig_n400_values)), ig_n400_values, c=ig_clusters, cmap='viridis', s=100)
axes[0].set_xlabel('Word Index')
axes[0].set_ylabel('N400 Amplitude (z-score)')
axes[0].set_title(f'Instagram: {optimal_k_ig} Clusters')
axes[0].grid(True, alpha=0.3)

# LinkedIn clustering
axes[1].scatter(range(len(li_n400_values)), li_n400_values, c=li_clusters, cmap='viridis', s=100)
axes[1].set_xlabel('Word Index')
axes[1].set_ylabel('N400 Amplitude (z-score)')
axes[1].set_title(f'LinkedIn: {optimal_k_li} Clusters')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

TypeError: '<' not supported between instances of 'float' and 'str'