# DATA LOAD

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from sklearn.model_selection import GroupKFold
from sklearn.metrics import accuracy_score, mean_squared_error, confusion_matrix
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from utils.data_load import load_all_participants

all_participants_data = load_all_participants()




# FUNCTIONS

In [None]:

class FeatureGroupTransformerModel(nn.Module):
    def __init__(self, feature_groups, output_size, task_type='binary'):
        super(FeatureGroupTransformerModel, self).__init__()

        self.feature_groups = feature_groups
        self.group_names = list(feature_groups.keys())
        self.num_groups = len(feature_groups)
        self.task_type = task_type

        self.embedding_dim = 64
        self.num_heads = 8
        self.num_encoder_layers = 3

        self.feature_encoders = nn.ModuleDict()
        for group_name in self.group_names:
            _, group_size = feature_groups[group_name]
            self.feature_encoders[group_name] = nn.Sequential(
                nn.Linear(group_size, 64),
                nn.ReLU(),
                nn.Linear(64, self.embedding_dim)
            )

        self.pos_encoding = nn.Parameter(torch.randn(1, self.num_groups, self.embedding_dim) * 0.1)
        self.group_type_embedding = nn.Parameter(torch.randn(1, self.num_groups, self.embedding_dim) * 0.1)

        encoder_layers = nn.TransformerEncoderLayer(
            d_model=self.embedding_dim,
            nhead=self.num_heads,
            dim_feedforward=128,
            dropout=0.5,
            activation='relu',
            batch_first=True,
            norm_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layers,
            num_layers=self.num_encoder_layers,
            norm=nn.LayerNorm(self.embedding_dim)
        )

        self.attention_scorer = nn.Sequential(
            nn.Linear(self.embedding_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

        self.output_layers = nn.Sequential(
            nn.Linear(self.embedding_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, output_size)
        )

        if task_type == 'binary':
            self.output_activation = nn.Sigmoid()
        elif task_type == 'ternary':
            self.output_activation = None
        else:
            self.output_activation = nn.Sigmoid()


    def forward(self, x, return_attention_weights=False): # Added flag for visualization later
        batch_size = x.size(0)
        group_embeddings = []

        for group_name in self.group_names:
            start_idx, group_size = self.feature_groups[group_name]
            group_features = x[:, start_idx : start_idx + group_size]
            group_embedding = self.feature_encoders[group_name](group_features)
            group_embeddings.append(group_embedding)

        encoded_groups = torch.stack(group_embeddings, dim=1)
        encoded_groups = encoded_groups + self.pos_encoding + self.group_type_embedding
        transformer_output = self.transformer_encoder(encoded_groups)

        attn_scores = self.attention_scorer(transformer_output)
        attn_weights = torch.softmax(attn_scores, dim=1)
        pooled_output = torch.sum(transformer_output * attn_weights, dim=1)

        final_output = self.output_layers(pooled_output)
        if self.output_activation is not None:
            final_output = self.output_activation(final_output)

        if return_attention_weights:
            return final_output, attn_weights.squeeze(-1)
        else:
            return final_output

    @staticmethod
    def create_token_groups(all_participants_data):
        feature_groups = {}
        columns_to_drop = ["perclos", "quantized_perclos", "Participant Name"]
        columns_to_process = [col for col in all_participants_data.columns if col not in columns_to_drop]

        col_to_X_index = {col: i for i, col in enumerate(columns_to_process)}

        eeg_types = ["EEG_2Hz", "EEG_5Bands"]
        forehead_types = ["Forehead_EEG_2Hz", "Forehead_EEG_5Bands"]


        for feature_type in eeg_types:
            channel_features = [col for col in columns_to_process if col.startswith(feature_type)]
            channels = sorted(list(set(
                col.split('_')[2].replace('ch', '')
                for col in channel_features if '_ch' in col and len(col.split('_')) > 2 and col.split('_')[2].startswith('ch')
            )))
            for ch in channels:
                group_cols = [col for col in channel_features if f"_ch{ch}_" in col]
                if not group_cols:
                    continue
                indices = [col_to_X_index[col] for col in group_cols]
                if not indices: continue
                feature_groups[f"{feature_type}_Channel_{ch}"] = (min(indices), len(indices))

        for feature_type in forehead_types:
            channel_features = [col for col in columns_to_process if col.startswith(feature_type)]
            channels = sorted(list(set(
                col.split('_')[3].replace('ch', '')
                for col in channel_features if '_ch' in col and len(col.split('_')) > 3 and col.split('_')[3].startswith('ch')
            )))
            for ch in channels:
                group_cols = [col for col in channel_features if f"_ch{ch}_" in col]
                if not group_cols:
                    continue
                indices = [col_to_X_index[col] for col in group_cols]
                if not indices: continue
                feature_groups[f"{feature_type}_Channel_{ch}"] = (min(indices), len(indices))

        eog_features = [col for col in columns_to_process if col.startswith("EOG")]
        if eog_features:
            eog_indices = [col_to_X_index[col] for col in eog_features]
            if eog_indices:
                 feature_groups["EOG"] = (min(eog_indices), len(eog_indices))

        return feature_groups


def prepare_data_for_pytorch(all_participants_data):
    columns_to_drop = ["perclos", "quantized_perclos", "Participant Name"]
    feature_columns = [col for col in all_participants_data.columns if col not in columns_to_drop]

    X = all_participants_data[feature_columns].values

    y_binary = (all_participants_data["perclos"] > 0.5).astype(int).values
    y_ternary = all_participants_data["quantized_perclos"].values
    y_continuous = all_participants_data["perclos"].values

    groups = all_participants_data["Participant Name"].values

    return X, y_binary, y_ternary, y_continuous, groups

def create_torch_dataset(X, y, task_type):
    X_tensor = torch.FloatTensor(X)
    if task_type == 'ternary':
        y_tensor = torch.LongTensor(y)
    elif task_type == 'binary':
        y_tensor = torch.FloatTensor(y).view(-1, 1)
    else:
        y_tensor = torch.FloatTensor(y).view(-1, 1)
    return torch.utils.data.TensorDataset(X_tensor, y_tensor)


def evaluate_model(model, test_loader, criterion, device, task_type):
    model.eval()
    running_loss = 0.0
    all_predictions = []
    all_targets = []
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            targets_device = targets.to(device) # Keep targets on device for loss calc

            outputs = model(inputs)
            loss = criterion(outputs, targets_device)
            running_loss += loss.item()

            if task_type == 'binary':
                preds = (outputs > 0.5).float().cpu().numpy()
            elif task_type == 'ternary':
                preds = torch.argmax(outputs, dim=1).cpu().numpy()
            else:
                preds = outputs.cpu().numpy()

            all_predictions.extend(preds.flatten())
            # Targets moved to CPU only when needed for extending list
            all_targets.extend(targets.cpu().numpy().flatten())

    all_predictions_np = np.array(all_predictions)
    all_targets_np = np.array(all_targets)

    if task_type == 'binary' or task_type == 'ternary':
        # Ensure predictions are appropriate type/shape if necessary before accuracy_score
        accuracy = accuracy_score(all_targets_np, all_predictions_np)
        return running_loss / len(test_loader), accuracy
    else:
        rmse = np.sqrt(mean_squared_error(all_targets_np, all_predictions_np))
        return running_loss / len(test_loader), rmse


def plot_confusion_matrix(all_targets, all_predictions, task_type, model_name):
    plt.figure(figsize=(8, 6))
    cm = confusion_matrix(all_targets, all_predictions)
    if task_type == 'binary':
        labels = ["Class 0", "Class 1"]
        cmap = "Blues"
        title = f"{model_name.upper()} Binary Classification Confusion Matrix"
    else: # ternary
        # Determine actual unique labels present in targets and predictions
        unique_labels = sorted(list(np.unique(np.concatenate((all_targets, all_predictions)))))
        # Ensure labels list matches the unique classes found
        if len(unique_labels) == 3:
            labels = ["Class 0", "Class 1", "Class 2"]
        elif len(unique_labels) == 2: # Handle case where only 2 classes appear
             labels = [f"Class {i}" for i in unique_labels]
             # Adjust confusion matrix size if needed, though confusion_matrix handles this
        else: # Fallback for unexpected number of classes
             labels = [f"Class {i}" for i in unique_labels]

        cmap = "Oranges"
        title = f"{model_name.upper()} Ternary Classification Confusion Matrix"

    sns.heatmap(cm, annot=True, fmt="d", cmap=cmap, xticklabels=labels, yticklabels=labels)
    plt.title(title)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.tight_layout()
    plt.show()

def plot_regression_results(all_targets, all_predictions, model_name):
    overall_rmse = np.sqrt(mean_squared_error(all_targets, all_predictions))
    print(f"\nOverall Continuous Regression RMSE: {overall_rmse:.4f}")
    plt.figure(figsize=(8, 6))
    plt.scatter(all_targets, all_predictions, alpha=0.5)
    min_val = min(np.min(all_targets), np.min(all_predictions)) if len(all_targets) > 0 else 0
    max_val = max(np.max(all_targets), np.max(all_predictions)) if len(all_targets) > 0 else 1
    plt.plot([min_val, max_val], [min_val, max_val], 'r--')
    plt.title(f"{model_name.upper()} Regression: Predicted vs Actual PERCLOS")
    plt.xlabel("Actual PERCLOS")
    plt.ylabel("Predicted PERCLOS")
    plt.tight_layout()
    plt.show()
    return overall_rmse




# MAIN

In [None]:
def run_feature_group_model(all_participants_data, task_type='binary',
                            batch_size=32, learning_rate=0.001, num_epochs=20):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create feature groups using the static method from the model class
    feature_groups = FeatureGroupTransformerModel.create_token_groups(all_participants_data)
    print("Feature groups created (total {}):".format(len(feature_groups)))
    for group_name in feature_groups:
        print(group_name)

    X, y_binary, y_ternary, y_continuous, groups = prepare_data_for_pytorch(all_participants_data)

    if task_type == 'binary':
        y = y_binary
    elif task_type == 'ternary':
        y = y_ternary
    else:
        y = y_continuous
        task_type = 'continuous'

    unique_groups = np.unique(groups)

    n_splits = len(unique_groups)
    if n_splits < 2:
        print(f"Warning: Only {n_splits} unique groups found. Cannot perform GroupKFold with n_splits < 2.")
        return None # 
    cv_loso = GroupKFold(n_splits=n_splits)


    all_predictions = []
    all_targets = []
    fold_metrics = []

    # Define group_strategy or model name for reporting
    group_strategy = "channel_eeg_eog_attn_pool" # Example name
    model_name_for_plotting = f"feature_group_transformer_{group_strategy}"

    # Use X (unscaled) for splitting, scale inside the loop
    for fold, (train_idx, test_idx) in enumerate(cv_loso.split(X, y, groups)):
        test_subject_identifier = groups[test_idx[0]]
        print(f"\nFold {fold+1}/{n_splits} - Testing on subject: {test_subject_identifier}")

        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        # Scale data within the fold
        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_test = scaler.transform(X_test) # Use transform only on test

        train_dataset = create_torch_dataset(X_train, y_train, task_type)
        test_dataset = create_torch_dataset(X_test, y_test, task_type)

        pin_memory_flag = torch.cuda.is_available()
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True,
            pin_memory=pin_memory_flag, num_workers=2 if pin_memory_flag else 0
        )
        test_loader = torch.utils.data.DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False,
            pin_memory=pin_memory_flag, num_workers=2 if pin_memory_flag else 0
        )

        if task_type == 'binary':
            output_size = 1
            criterion = nn.BCELoss()
        elif task_type == 'ternary':
            output_size = 3 # Ensure this matches the number of classes in y_ternary
            criterion = nn.CrossEntropyLoss()
        else: # continuous
            output_size = 1
            criterion = nn.MSELoss()

        model = FeatureGroupTransformerModel(feature_groups, output_size, task_type).to(device)
        optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.001)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min' if task_type == 'continuous' else 'max',
            factor=0.5, patience=3
        )

        for epoch in range(num_epochs):
            model.train()
            running_loss = 0.0
            for inputs, targets in train_loader:
                inputs = inputs.to(device, non_blocking=pin_memory_flag)
                targets = targets.to(device, non_blocking=pin_memory_flag)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
            train_loss = running_loss / len(train_loader)

            # Evaluate every epoch
            val_loss, metric = evaluate_model(model, test_loader, criterion, device, task_type)
            if task_type == 'continuous':
                scheduler.step(val_loss)
            else:
                scheduler.step(metric)

            metric_name = "Accuracy" if task_type in ['binary', 'ternary'] else "RMSE"
            print(f"  Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, "
                  f"Val Loss: {val_loss:.4f}, Val {metric_name}: {metric:.4f}, "
                  f"LR: {optimizer.param_groups[0]['lr']:.6f}")

        model.eval()
        fold_preds = []
        fold_targets = []
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs = inputs.to(device, non_blocking=pin_memory_flag)
                outputs = model(inputs)
                if task_type == 'binary':
                    preds = (outputs > 0.5).float().cpu().numpy()
                elif task_type == 'ternary':
                    preds = torch.argmax(outputs, dim=1).cpu().numpy()
                else: # continuous
                    preds = outputs.cpu().numpy()
                fold_preds.extend(preds.flatten())
                fold_targets.extend(targets.cpu().numpy().flatten())

        fold_preds_np = np.array(fold_preds)
        fold_targets_np = np.array(fold_targets)

        if task_type in ['binary', 'ternary']:
            fold_acc = accuracy_score(fold_targets_np, fold_preds_np)
            fold_metrics.append((test_subject_identifier, fold_acc))
            print(f"  Subject {test_subject_identifier} - Final Accuracy: {fold_acc:.4f}")
        else:
            fold_rmse = np.sqrt(mean_squared_error(fold_targets_np, fold_preds_np))
            fold_metrics.append((test_subject_identifier, fold_rmse))
            print(f"  Subject {test_subject_identifier} - Final RMSE: {fold_rmse:.4f}")

        all_predictions.extend(fold_preds_np)
        all_targets.extend(fold_targets_np)

        del model, optimizer, scheduler, train_dataset, test_dataset, train_loader, test_loader
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    all_predictions = np.array(all_predictions)
    all_targets = np.array(all_targets)

    results = {
        'model_type': model_name_for_plotting,
        'task_type': task_type,
        'fold_metrics': fold_metrics,
        'predictions': all_predictions,
        'targets': all_targets
    }

    if task_type in ['binary', 'ternary']:
        overall_acc = accuracy_score(all_targets, all_predictions)
        cm = confusion_matrix(all_targets, all_predictions)
        print(f"\nOverall {task_type.capitalize()} Classification Accuracy: {overall_acc:.4f}")
        print("Overall Confusion Matrix:")
        print(cm)
        plot_confusion_matrix(all_targets, all_predictions, task_type, model_name_for_plotting)
        results['overall_accuracy'] = overall_acc
        results['confusion_matrix'] = cm
    else:
        overall_rmse = plot_regression_results(all_targets, all_predictions, model_name_for_plotting)
        results['overall_rmse'] = overall_rmse

    return results

# USE

In [None]:
group_results = run_feature_group_model(
    all_participants_data, 
    task_type='binary',
    batch_size=64,
    learning_rate=0.0001,
    num_epochs=2
)

In [None]:
group_results = run_feature_group_model(
    all_participants_data, 
    task_type='ternary',
    batch_size=64,
    learning_rate=0.0001,
    num_epochs=2
)

In [None]:
group_results = run_feature_group_model(
    all_participants_data, 
    task_type='continous',
    batch_size=64,
    learning_rate=0.0001,
    num_epochs=2
)