# DATA LOAD

In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from sklearn.model_selection import GroupKFold
from sklearn.metrics import accuracy_score, mean_squared_error, confusion_matrix
from sklearn.preprocessing import StandardScaler

from utils.data_load import load_all_participants
from utils.cm_plot import plot_regression_results,plot_confusion_matrix

from models.cnn import MultiHeadConv1DModel
from models.mlp import MLPModel
from models.resnet import ResNet1DModel
from models.tcn import TCNModel
from models.transformer import FeatureGroupTransformerModel
from models.additional import *

all_participants_data = load_all_participants()




# MAIN

In [None]:
def run_model(all_participants_data,
              task_type: str = 'binary',
              model_type: str = 'transformer',
              batch_size: int = 32,
              learning_rate: float = 0.001,
              num_epochs: int = 20,
              subject_calibration: float = 0.0,
              feature_selection: float = 0.0):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device} | Model: {model_type} | Task: {task_type} | Subject calibration: {subject_calibration*100}% ")
    print(f"Batch size: {batch_size} | LR: {learning_rate} |Number of epochs: {num_epochs} | Feature selection: {feature_selection*100}%")
    X, y_binary, y_ternary, y_continuous, groups, feature_columns = prepare_data_for_pytorch(all_participants_data)
    original_feature_groups = create_token_groups(all_participants_data)
    if model_type in ('cnn', 'transformer'):
        feature_groups = original_feature_groups
    if task_type == 'binary':
        y = y_binary
    elif task_type == 'ternary':
        y = y_ternary
    else:
        y, task_type = y_continuous, 'continuous'
    unique_groups = np.unique(groups)
    if len(unique_groups) < 2:
        print(f"Only {len(unique_groups)} groups—skipping CV.")
        return
    cv = GroupKFold(n_splits=len(unique_groups))
    all_preds, all_tgts, fold_metrics = [], [], []
    for fold, (train_idx, test_idx) in enumerate(cv.split(X, y, groups), start=1):
        subject = groups[test_idx[0]]
        print(f"\nFold {fold}: test subject {subject}")
        if subject_calibration > 0:
            n_calib = int(len(test_idx) * subject_calibration)
            n_calib = min(n_calib, len(test_idx)-1)
            if n_calib > 0:
                calib = np.random.choice(test_idx, n_calib, replace=False)
                train_idx = np.concatenate([train_idx, calib])
                test_idx = np.setdiff1d(test_idx, calib)
                print(f"  + calibrated {n_calib} samples")
        X_tr, X_te = X[train_idx], X[test_idx]
        y_tr, y_te = y[train_idx], y[test_idx]
        if 0 < feature_selection < 1:
            k = max(1, int(X_tr.shape[1] * feature_selection))
            if task_type in ('binary', 'ternary'):
                selector = SelectKBest(score_func=f_classif, k=k)
            else:
                selector = SelectKBest(score_func=f_regression, k=k)
            selector.fit(X_tr, y_tr)
            mask = selector.get_support()
            X_tr = selector.transform(X_tr)
            X_te = selector.transform(X_te)
            selected_idx = np.where(mask)[0]
            feature_groups = {}
            for name, (start, size) in original_feature_groups.items():
                grp_idx = np.arange(start, start + size)
                sel_in_grp = np.intersect1d(grp_idx, selected_idx)
                if sel_in_grp.size == 0:
                    continue
                pos_map = {val: pos for pos, val in enumerate(selected_idx)}
                new_pos = [pos_map[i] for i in sel_in_grp]
                feature_groups[name] = (min(new_pos), len(new_pos))
        else:
            feature_groups = original_feature_groups
        scaler = StandardScaler()
        X_tr = scaler.fit_transform(X_tr)
        X_te = scaler.transform(X_te)
        train_ds = create_torch_dataset(X_tr, y_tr, task_type)
        test_ds  = create_torch_dataset(X_te, y_te, task_type)
        pin = torch.cuda.is_available()
        train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=pin, num_workers=3 if pin else 0)
        test_loader  = torch.utils.data.DataLoader(test_ds,  batch_size=batch_size, shuffle=False, pin_memory=pin, num_workers=3 if pin else 0)
        if task_type == 'binary':
            output_size, criterion = 1, nn.BCELoss()
        elif task_type == 'ternary':
            output_size, criterion = 3, nn.CrossEntropyLoss()
        else:
            output_size, criterion = 1, nn.MSELoss()
        if model_type == 'transformer':
            model = FeatureGroupTransformerModel(feature_groups, output_size, task_type)
        elif model_type == 'cnn':
            model = MultiHeadConv1DModel(feature_groups, output_size, task_type)
        elif model_type == 'resnet':
            model = ResNet1DModel(X_tr.shape[1], output_size, task_type)
        elif model_type == 'TCN':
            model = TCNModel(X_tr.shape[1], output_size, task_type)
        else:
            model = MLPModel(X_tr.shape[1], output_size, task_type)
        model.to(device)
        optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-3)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min' if task_type=='continuous' else 'max', factor=0.5, patience=3)
        for epoch in range(1, num_epochs+1):
            model.train()
            total_loss = 0
            for xb, yb in train_loader:
                xb, yb = xb.to(device), yb.to(device)
                optimizer.zero_grad()
                out = model(xb)
                loss = criterion(out, yb)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            avg = total_loss / len(train_loader)
            val_loss, val_metric = evaluate_model(model, test_loader, criterion, device, task_type)
            name = "Accuracy" if task_type!='continuous' else "RMSE"
            print(f" Epoch {epoch}: train_loss={avg:.4f}, val_{name}={val_metric:.4f}")
            scheduler.step(val_loss if task_type=='continuous' else val_metric)
        model.eval()
        preds, tgts = [], []
        with torch.no_grad():
            for xb, yb in test_loader:
                xb = xb.to(device)
                out = model(xb)
                if task_type=='binary':
                    p = (out>0.5).float().cpu().numpy()
                elif task_type=='ternary':
                    p = torch.argmax(out,1).cpu().numpy()
                else:
                    p = out.cpu().numpy()
                preds.extend(p.flatten())
                tgts.extend(yb.numpy().flatten())
        preds, tgts = np.array(preds), np.array(tgts)
        all_preds.extend(preds)
        all_tgts.extend(tgts)
        if task_type!='continuous':
            score = accuracy_score(tgts, preds)
            print(f" Subject {subject} acc={score:.4f}")
        else:
            score = np.sqrt(mean_squared_error(tgts, preds))
            print(f" Subject {subject} rmse={score:.4f}")
        fold_metrics.append((subject, score))
        del model, optimizer, scheduler, train_loader, test_loader
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    all_preds, all_tgts = np.array(all_preds), np.array(all_tgts)
    if task_type!='continuous':
        overall = accuracy_score(all_tgts, all_preds)
        print(f"\nOverall Accuracy: {overall:.4f}")
        print(confusion_matrix(all_tgts, all_preds))
        plot_confusion_matrix(all_tgts, all_preds, task_type, model_type)
        results_key = 'overall_accuracy'
    else:
        overall = np.sqrt(mean_squared_error(all_tgts, all_preds))
        print(f"\nOverall RMSE: {overall:.4f}")
        plot_regression_results(all_tgts, all_preds, model_type)
        results_key = 'overall_rmse'
    return {
        'model_type': model_type,
        'task_type': task_type,
        'fold_metrics': fold_metrics,
        'predictions': all_preds,
        'targets': all_tgts,
        results_key: overall
    }


# USE

In [None]:
    
group_results = run_model(
    all_participants_data, 
    task_type='binary',
    model_type='TCN',
    batch_size=64,
    learning_rate=0.0001,
    num_epochs=3,
    subject_calibration= 0.15,
    feature_selection = 0.00
)

In [None]:
    
group_results = run_model(
    all_participants_data, 
    task_type='ternary',
    model_type='transformer',
    batch_size=64,
    learning_rate=0.0001,
    num_epochs=3,
    subject_calibration= 0.15,
    feature_selection = 0.00
)

In [None]:
    
group_results = run_model(
    all_participants_data, 
    task_type='continuous',
    model_type='cnn',
    batch_size=64,
    learning_rate=0.0001,
    num_epochs=3,
    subject_calibration= 0.15,
    feature_selection = 0.00
)