# DATA LOAD

In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from sklearn.model_selection import GroupKFold
from sklearn.metrics import accuracy_score, mean_squared_error, confusion_matrix
from sklearn.preprocessing import StandardScaler

from utils.data_load import load_all_participants
from utils.cm_plot import plot_regression_results,plot_confusion_matrix

from models.cnn import MultiHeadConv1DModel
from models.mlp import MLPModel
from models.transformer import FeatureGroupTransformerModel
from models.additional import *

all_participants_data = load_all_participants()




# MAIN

In [None]:
def run_model(all_participants_data, task_type='binary', model_type='transformer',
              batch_size=32, learning_rate=0.001, num_epochs=20):

    # 1) Device setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    print(f"Model type: {model_type}, Task type: {task_type}")

    # 2) Data preparation
    if model_type == 'cnn':
        X, y_binary, y_ternary, y_continuous, groups, feature_columns = prepare_data_for_pytorch(all_participants_data)
        feature_groups = create_token_groups(all_participants_data)
        model_name_for_plotting = "cnn"
    elif model_type == 'transformer':
        X, y_binary, y_ternary, y_continuous, groups, feature_columns = prepare_data_for_pytorch(all_participants_data)
        feature_groups = create_token_groups(all_participants_data)
        model_name_for_plotting = "transformer"
        print("Feature groups created (total {}):".format(len(feature_groups)))
        for name in feature_groups:
            print(f"  {name}")
    else: #MLP
        X, y_binary, y_ternary, y_continuous, groups, feature_columns = prepare_data_for_pytorch(all_participants_data)
        model_name_for_plotting = "mlp"
        print(f"Using MLP model with {X.shape[1]} input features")

    # 3) Select target
    if task_type == 'binary':
        y = y_binary
    elif task_type == 'ternary':
        y = y_ternary
    else:
        y = y_continuous
        task_type = 'continuous'

    # 4) LOSO cross‑validation setup
    unique_groups = np.unique(groups)
    n_splits = len(unique_groups)
    if n_splits < 2:
        print(f"Warning: Only {n_splits} unique groups; skipping CV.")
        return None
    cv_loso = GroupKFold(n_splits=n_splits)

    all_predictions = []
    all_targets = []
    fold_metrics = []

    # 5) Loop over folds
    for fold, (train_idx, test_idx) in enumerate(cv_loso.split(X, y, groups), start=1):
        test_subject = groups[test_idx[0]]
        print(f"\nFold {fold}/{n_splits} — testing on subject: {test_subject}")

        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_test = scaler.transform(X_test)

        train_ds = create_torch_dataset(X_train, y_train, task_type)
        test_ds  = create_torch_dataset(X_test, y_test, task_type)
        pin_mem = torch.cuda.is_available()
        train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size,
                                                   shuffle=True, pin_memory=pin_mem,
                                                   num_workers=2 if pin_mem else 0)
        test_loader  = torch.utils.data.DataLoader(test_ds,  batch_size=batch_size,
                                                   shuffle=False, pin_memory=pin_mem,
                                                   num_workers=2 if pin_mem else 0)

        if task_type == 'binary':
            output_size = 1
            criterion = nn.BCELoss()
        elif task_type == 'ternary':
            output_size = 3
            criterion = nn.CrossEntropyLoss()
        else:
            output_size = 1
            criterion = nn.MSELoss()

        if model_type == 'transformer':
            model = FeatureGroupTransformerModel(feature_groups, output_size, task_type)
        elif model_type == 'cnn':
            model = MultiHeadConv1DModel(feature_groups, output_size, task_type)
           
        else:
            model = MLPModel(X_train.shape[1], output_size, task_type)

        model.to(device)

        optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.001)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min' if task_type == 'continuous' else 'max',
            factor=0.5,
            patience=3
        )

        # 6) Training loop
        for epoch in range(1, num_epochs + 1):
            model.train()
            running_loss = 0.0
            for inputs, targets in train_loader:
                inputs = inputs.to(device, non_blocking=pin_mem)
                targets = targets.to(device, non_blocking=pin_mem)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
            avg_loss = running_loss / len(train_loader)

            val_loss, val_metric = evaluate_model(model, test_loader, criterion, device, task_type)
            metric_name = "Accuracy" if task_type in ['binary','ternary'] else "RMSE"
            print(f"  Epoch {epoch}/{num_epochs}, Loss: {avg_loss:.4f}, Val {metric_name}: {val_metric:.4f}")

            if task_type == 'continuous':
                scheduler.step(val_loss)
            else:
                scheduler.step(val_metric)

        # 7) Final evaluation
        model.eval()
        preds, tgts = [], []
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs = inputs.to(device, non_blocking=pin_mem)
                outputs = model(inputs)
                if task_type == 'binary':
                    batch_preds = (outputs > 0.5).float().cpu().numpy()
                elif task_type == 'ternary':
                    batch_preds = torch.argmax(outputs, dim=1).cpu().numpy()
                else:
                    batch_preds = outputs.cpu().numpy()
                preds.extend(batch_preds.flatten())
                tgts.extend(targets.cpu().numpy().flatten())

        preds = np.array(preds)
        tgts  = np.array(tgts)
        all_predictions.extend(preds)
        all_targets.extend(tgts)

        if task_type in ['binary', 'ternary']:
            score = accuracy_score(tgts, preds)
            print(f"  Subject {test_subject} — Accuracy: {score:.4f}")
        else:
            score = np.sqrt(mean_squared_error(tgts, preds))
            print(f"  Subject {test_subject} — RMSE: {score:.4f}")
        fold_metrics.append((test_subject, score))

        del model, optimizer, scheduler, train_loader, test_loader
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    all_preds = np.array(all_predictions)
    all_tgts  = np.array(all_targets)

    if task_type in ['binary', 'ternary']:
        overall = accuracy_score(all_tgts, all_preds)
        print(f"\nOverall Accuracy: {overall:.4f}")
        print("Confusion Matrix:")
        cm = confusion_matrix(all_tgts, all_preds)
        print(cm)
        plot_confusion_matrix(all_tgts, all_preds, task_type, model_name_for_plotting)
    else:
        overall = np.sqrt(mean_squared_error(all_tgts, all_preds))
        print(f"\nOverall RMSE: {overall:.4f}")
        plot_regression_results(all_tgts, all_preds,model_name_for_plotting)

    results = {
        'model_type': model_name_for_plotting,
        'task_type': task_type,
        'fold_metrics': fold_metrics,
        'predictions': all_preds,
        'targets': all_tgts,
        'overall_accuracy' if task_type in ['binary','ternary'] else 'overall_rmse': overall
    }
    return results


# USE

In [None]:
    
group_results = run_model(
    all_participants_data, 
    task_type='binary',
    model_type='cnn',
    batch_size=128,
    learning_rate=0.00005,
    num_epochs=3
)

In [None]:
    
group_results = run_model(
    all_participants_data, 
    task_type='ternary',
    model_type='transformer',
    batch_size=128,
    learning_rate=0.00005,
    num_epochs=3
)

In [None]:
    
group_results = run_model(
    all_participants_data, 
    task_type='continous',
    model_type='mlp',
    batch_size=128,
    learning_rate=0.00005,
    num_epochs=3
)