In [None]:
from google.colab import drive
drive.mount('/content/drive')
#For Data folder

In [None]:
!git clone https://github.com/payalmohapatra/MAESTRO
#MAESTRO Model

In [None]:
!pip install mne
import os
import glob
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import mne
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report
import math

# ==============================================================================
# ‚öôÔ∏è CONFIGURATION (UPDATED TO 500Hz)
# ==============================================================================
# Path to the BIDS-style root folder containing 'sub-01', 'sub-02', etc.
DATA_ROOT = '/content/drive/MyDrive/sub-01'

# Where to save the processed .pt tensors and models
OUTPUT_DIR = '/content/MAESTRO_Project/processed_data'

# Signal Processing Constants
TARGET_SFREQ = 500  # <--- UPDATED to 500Hz
CLIP_LEN_SEC = 1.5
FIXED_LEN = int(TARGET_SFREQ * CLIP_LEN_SEC) # Now 750 samples (1.5 * 500)
CHANNELS = ['Fz', 'FCz', 'Pz', 'Oz', 'C3', 'C4', 'P3', 'P4', 'ECG1']

# Task-Specific File Mappings (Filename -> Class Label)
NBACK_MAP = {'zeroBACK.set': 0, 'twoBACK.set': 1}
MATB_MAP  = {'MATBeasy.set': 0, 'MATBdiff.set': 1}

# ==============================================================================
# 1. PREPROCESSING ENGINE
# ==============================================================================
class MaestroPreprocessor:
    def __init__(self, root_dir, output_dir):
        self.root_dir = root_dir
        self.output_dir = output_dir

    def run(self):
        print(f"üöÄ STARTING MAESTRO PREPROCESSING PIPELINE (500Hz)")
        print(f"üìÇ Scanning Path: {self.root_dir}")

        # Reset output directory to avoid shape mismatches
        if os.path.exists(self.output_dir):
            print("   ‚ö†Ô∏è Cleaning old output directory...")
            shutil.rmtree(self.output_dir)
        os.makedirs(self.output_dir, exist_ok=True)

        # üõ†Ô∏è SMART FIX: Check if root_dir IS the subject folder
        folder_name = os.path.basename(self.root_dir.rstrip('/'))

        if folder_name.startswith('sub-'):
            print(f"   ‚úÖ Detected target as Subject Folder: {folder_name}")
            self._process_subject(self.root_dir, folder_name)
        else:
            subjects = glob.glob(os.path.join(self.root_dir, 'sub-*'))
            print(f"   Found {len(subjects)} Subjects inside root.")
            for sub_path in subjects:
                sub_id = os.path.basename(sub_path)
                self._process_subject(sub_path, sub_id)

    def _process_subject(self, sub_path, sub_id):
        sessions = glob.glob(os.path.join(sub_path, 'ses-*'))
        if not sessions:
            print(f"   ‚ö†Ô∏è No sessions (ses-*) found in {sub_id}")
            return

        for ses_path in sessions:
            ses_id = os.path.basename(ses_path)
            eeg_path = os.path.join(ses_path, 'eeg')

            if not os.path.exists(eeg_path):
                print(f"   ‚ö†Ô∏è No 'eeg' folder in {ses_id}")
                continue

            print(f"\n   Processing {sub_id} | {ses_id}...")
            self._process_pvt(eeg_path, sub_id, ses_id)
            self._process_flanker(eeg_path, sub_id, ses_id)
            self._process_continuous(eeg_path, sub_id, ses_id, 'NBACK', NBACK_MAP)
            self._process_continuous(eeg_path, sub_id, ses_id, 'MATB', MATB_MAP)

    def _save_tensor(self, data, label, task, sub, ses, idx):
        task_dir = os.path.join(self.output_dir, task)
        os.makedirs(task_dir, exist_ok=True)

        # Transpose to [Time, Channels]
        tensor = torch.tensor(data, dtype=torch.float32).transpose(0, 1)

        # Force Fixed Length (Padding or Truncating)
        if tensor.shape[0] != FIXED_LEN:
            if tensor.shape[0] > FIXED_LEN:
                tensor = tensor[:FIXED_LEN, :]
            else:
                tensor = torch.nn.functional.pad(tensor, (0, 0, 0, FIXED_LEN - tensor.shape[0]))

        fname = f"{sub}_{ses}_{label}_{idx}.pt"
        torch.save({'data': tensor, 'label': label}, os.path.join(task_dir, fname))

    def _process_pvt(self, path, sub, ses):
        fpath = os.path.join(path, 'PVT.set')
        if not os.path.exists(fpath): return
        try:
            raw = mne.io.read_raw_eeglab(fpath, preload=True, verbose=False)
            if not self._check_channels(raw): return

            # Events logic
            events, event_id = mne.events_from_annotations(raw, verbose=False)
            stim_id, resp_id = event_id.get('13'), event_id.get('14')

            if stim_id and resp_id:
                rts, valid_idx = [], []
                for i in range(len(events)-1):
                    if events[i,2] == stim_id and events[i+1,2] == resp_id:
                        rts.append((events[i+1,0] - events[i,0]) / raw.info['sfreq'])
                        valid_idx.append(i)

                if rts:
                    median_rt = np.median(rts)
                    # Epoching
                    epochs = mne.Epochs(raw, events[valid_idx], event_id=stim_id,
                                      tmin=-1.0, tmax=0.5, baseline=None, verbose=False)
                    data = epochs.get_data()

                    for i, d in enumerate(data):
                        label = 0 if rts[i] < median_rt else 1
                        self._save_tensor(d, label, 'PVT', sub, ses, i)
                    print(f"      ‚úÖ PVT: Extracted {len(data)} trials")
        except Exception as e: print(f"      ‚ùå PVT Error: {e}")

    def _process_flanker(self, path, sub, ses):
        fpath = os.path.join(path, 'Flanker.set')
        if not os.path.exists(fpath): return
        try:
            raw = mne.io.read_raw_eeglab(fpath, preload=True, verbose=False)
            if not self._check_channels(raw): return

            events, event_id = mne.events_from_annotations(raw, verbose=False)
            mapping = {'2511': 0, '2521': 1}

            total = 0
            for marker, label in mapping.items():
                if marker in event_id:
                    epochs = mne.Epochs(raw, events, event_id=event_id[marker],
                                      tmin=-1.0, tmax=0.5, baseline=None, verbose=False)
                    data = epochs.get_data()
                    for i, d in enumerate(data):
                        self._save_tensor(d, label, 'FLANKER', sub, ses, f"{marker}_{i}")
                        total += 1
            if total > 0: print(f"      ‚úÖ FLANKER: Extracted {total} trials")
        except Exception as e: print(f"      ‚ùå FLANKER Error: {e}")

    def _process_continuous(self, path, sub, ses, task_name, file_map):
        count = 0
        for fname, label in file_map.items():
            fpath = os.path.join(path, fname)
            if not os.path.exists(fpath): continue
            try:
                raw = mne.io.read_raw_eeglab(fpath, preload=True, verbose=False)
                if not self._check_channels(raw): continue

                data = raw.get_data()
                # Slicing continuous data
                n_crops = data.shape[1] // FIXED_LEN
                for i in range(n_crops):
                    crop = data[:, i*FIXED_LEN : (i+1)*FIXED_LEN]
                    self._save_tensor(crop, label, task_name, sub, ses, i)
                    count += 1
            except Exception: pass
        if count > 0: print(f"      ‚úÖ {task_name}: Extracted {count} trials")

    def _check_channels(self, raw):
        # 1. Check if required channels exist
        if not all(ch in raw.ch_names for ch in CHANNELS):
            # Optional: Add logic here to try renaming channels if mismatch occurs
            return False

        # 2. Pick only the 9 required channels
        raw.pick_channels(CHANNELS)

        # 3. Resample if necessary (CRITICAL for 500Hz)
        if raw.info['sfreq'] != TARGET_SFREQ:
            raw.resample(TARGET_SFREQ, npad="auto")

        return True

# ==============================================================================
# 2. TRAINING ENGINE (MAESTRO)
# ==============================================================================
class MAESTRO(nn.Module):
    def __init__(self):
        super().__init__()
        # Input: [Batch, 750, 9] -> Projects to [Batch, 750, 64]
        self.input_fc = nn.Linear(9, 64)

        # Positional Encoding adjusted for 750 time steps
        self.pos_encoder = nn.Parameter(torch.randn(1, FIXED_LEN, 64))

        # Transformer Encoder
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=64, nhead=4, dim_feedforward=128,
                                     batch_first=True, dropout=0.3),
            num_layers=2
        )

        # Classifier
        self.decoder = nn.Linear(64, 2)

    def forward(self, x):
        # x shape: [Batch, 750, 9]
        x = self.input_fc(x)

        # Add position encoding (slicing safety for edge cases)
        x = x + self.pos_encoder[:, :x.size(1), :]

        x = self.transformer(x)

        # Global Average Pooling
        return self.decoder(x.mean(dim=1))

class MaestroDataset(Dataset):
    def __init__(self, folder):
        self.files = glob.glob(os.path.join(folder, "*.pt"))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        # SECURITY UPDATE: weights_only=False
        d = torch.load(self.files[idx], weights_only=False)
        x = d['data'] # Shape: [750, 9]

        # LESSON LEARNED: Z-Score Normalization
        mean, std = x.mean(dim=0, keepdim=True), x.std(dim=0, keepdim=True) + 1e-8
        x = (x - mean) / std

        return x, d['label']

def train_and_evaluate():
    tasks = ['FLANKER', 'PVT', 'NBACK', 'MATB']
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\nüöÄ STARTING TRAINING ENGINE on {device}")
    print(f"‚ÑπÔ∏è  Model Configuration: Input Length={FIXED_LEN}, Channels={len(CHANNELS)}")

    for task in tasks:
        data_dir = os.path.join(OUTPUT_DIR, task)
        if not os.path.exists(data_dir) or len(glob.glob(os.path.join(data_dir, "*.pt"))) == 0:
            print(f"‚ö†Ô∏è Skipping {task} (No data found)")
            continue

        print(f"\n" + "="*40)
        print(f"üß† TRAINING TASK: {task}")
        print("="*40)

        dataset = MaestroDataset(data_dir)
        if len(dataset) < 10:
            print("‚ö†Ô∏è Not enough data to train.")
            continue

        train_len = int(0.8 * len(dataset))
        test_len = len(dataset) - train_len
        train_set, test_set = torch.utils.data.random_split(dataset, [train_len, test_len])

        # --- ‚öñÔ∏è AUTO-CALCULATE CLASS WEIGHTS ---
        y_train = [d[1] for d in train_set] # Note: This is slow for huge datasets
        count_0 = y_train.count(0)
        count_1 = y_train.count(1)

        if count_0 > 0 and count_1 > 0:
            w0 = (count_0 + count_1) / (2.0 * count_0)
            w1 = (count_0 + count_1) / (2.0 * count_1)
            class_weights = torch.tensor([w0, w1], dtype=torch.float32).to(device)
            print(f"   ‚öñÔ∏è Class Balance: Low={count_0}, High={count_1}")
            print(f"   ‚öñÔ∏è Applying Weights: Low={w0:.2f}, High={w1:.2f}")
        else:
            class_weights = None
            print("   ‚ö†Ô∏è Warning: One class is missing in training data!")

        train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
        test_loader = DataLoader(test_set, batch_size=32, shuffle=False)

        model = MAESTRO().to(device)
        optimizer = optim.AdamW(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss(weight=class_weights)

        # --- TRAIN FOR 30 EPOCHS ---
        for epoch in range(30):
            model.train()
            total_loss = 0
            for x, y in train_loader:
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                loss = criterion(model(x), y)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            if (epoch + 1) % 10 == 0:
                print(f"   Epoch {epoch+1}/30 | Loss: {total_loss/len(train_loader):.4f}")

        model.eval()
        preds, targets = [], []
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                out = model(x)
                preds.extend(torch.argmax(out, 1).cpu().numpy())
                targets.extend(y.cpu().numpy())

        print("\n" + classification_report(targets, preds, target_names=['Low Load', 'High Load'], zero_division=0))

        save_path = os.path.join(OUTPUT_DIR, f"maestro_{task.lower()}_model.pth")
        torch.save(model.state_dict(), save_path)
        print(f"‚úÖ Model saved to: {save_path}")

# ==============================================================================
# 3. EXECUTION BLOCK
# ==============================================================================
if __name__ == "__main__":
    # 1. Run Preprocessing (Scan folders, slice data, save tensors)
    processor = MaestroPreprocessor(DATA_ROOT, OUTPUT_DIR)
    processor.run()

    # 2. Run Training (Load tensors, train 4 expert models)
    train_and_evaluate()