In [None]:
import numpy as np
from sklearn.linear_model import Ridge
from sklearn.svm import SVR
from sklearn.model_selection import cross_val_score, KFold
from sklearn.multioutput import MultiOutputRegressor
import matplotlib.pyplot as plt
from preprocessing import preprocess_dataset

# Load the preprocessed features dataset
X_path = r'C:\Users\gianm\Documents\Uni\Big Data\F422\project\data\guided\guided_dataset_X.npy'
Y_path = r'C:\Users\gianm\Documents\Uni\Big Data\F422\project\data\guided\guided_dataset_Y.npy'

X_train_features, Y_train, X_val_features, Y_val = preprocess_dataset(X_path, Y_path, mode='features')

# Reshape (4 sessions, windows, electrodes, features) to (samples, features)
# and (4 sessions, windows, joints) to (samples, joints)
X_train = X_train_features.reshape(-1, X_train_features.shape[2] * X_train_features.shape[3])
Y_train = Y_train.reshape(-1, Y_train.shape[2])

print(f"X_train shape: {X_train.shape}")
print(f"Y_train shape: {Y_train.shape}")

# Set up cross-validation strategy
kf = KFold(n_splits=4, shuffle=True, random_state=42)

# --- Ridge Regression ---
print("\nTesting Ridge Regression")

# Model and parameters
ridge = MultiOutputRegressor(Ridge(alpha=1.0))

# Cross-validation RMSE
ridge_scores = cross_val_score(ridge, X_train, Y_train, scoring='neg_root_mean_squared_error', cv=kf)
print(f"Ridge Regression Cross-Validated RMSE: {-ridge_scores.mean():.4f} (+/- {ridge_scores.std():.4f})")

# --- Support Vector Regression (SVR) ---
print("\nTesting Support Vector Regression (SVR)")

# Model and parameters
svr = MultiOutputRegressor(SVR(C=1.0, epsilon=0.2, kernel='rbf'))

# Cross-validation RMSE
svr_scores = cross_val_score(svr, X_train, Y_train, scoring='neg_root_mean_squared_error', cv=kf)
print(f"SVR Cross-Validated RMSE: {-svr_scores.mean():.4f} (+/- {svr_scores.std():.4f})")

# --- Hyperparameters explanation ---
print("\n\nHyperparameter explanation:")
print("\nRidge Regression parameters:")
print("  - alpha: regularization strength (higher = more regularization, prevents overfitting)")

print("\nSVR parameters:")
print("  - C: regularization parameter (higher = less regularization)")
print("  - epsilon: no penalty if prediction is within epsilon of true value")
print("  - kernel: defines transformation (e.g., 'rbf' = radial basis function)")