In [7]:
# Import necessary libraries
from preprocess import get_data
import pickle
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report

In [8]:
# Load and preprocess data
data = get_data()

# Extract features
X = data["values"]
X

In [10]:
# Function to rename labels
def rename_labels(labels: np.array) -> np.array:
    transformed_labels = np.zeros_like(labels)
    
    # Apply the transformation logic
    transformed_labels[(labels == 1) | (labels == 2) | (labels == 3)] = 1  # seizure or risk
    return transformed_labels # (labels == 4) | (labels == 5)                without risk

# Extract and transform labels
Y = rename_labels(data["labels"])
Y

array([0, 0, 0, ..., 1, 1, 1], dtype=int32)

In [11]:
# Split for train and test data
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.1, random_state=42)

# Scale and normalize data
scaler =  StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled  = scaler.transform(X_test)

In [12]:
# Train SVM model
svm_model = SVC(kernel='rbf', C=69)
svm_model.fit(X_train_scaled, y_train)

# Predict and evaluate
y_pred = svm_model.predict(X_test_scaled)
print("Accuracy:", accuracy_score(y_test, y_pred))
print("\nClassification Report:\n", classification_report(y_test, y_pred))

Accuracy: 0.9780509218612818

Classification Report:
               precision    recall  f1-score   support

           0       0.98      0.99      0.99       915
           1       0.97      0.92      0.94       224

    accuracy                           0.98      1139
   macro avg       0.97      0.96      0.96      1139
weighted avg       0.98      0.98      0.98      1139



In [13]:
# Save the trained model and scaler
def save_model():
    with open('../trained/svm_model.pkl', 'wb') as file:
        pickle.dump(svm_model, file)

    with open("../trained/scaler.pkl", "wb") as file:
        pickle.dump(scaler, file)

save_model()