In [7]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split 
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import GridSearchCV
import pandas as pd 
import numpy as np 
import warnings 
warnings.filterwarnings('ignore', category=FutureWarning)

In [8]:
datan = pd.read_csv('../Machine learning models/siren_data_train_no_outliers.csv')

In [10]:
# X and y 
X = datan.drop(["heard"], axis=1)
y = datan["heard"]

# Train and test splits
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

#  LDA model
model = LinearDiscriminantAnalysis()
model.fit(X_train, y_train)

# Prediction
y_pred = model.predict(X_val)

# Cross val score
scores = cross_val_score(model, X, y, cv=5)

# Mean accuracy
print(f"Mean accuracy: {scores.mean()}")

# Confusion matrix
pd.crosstab(y_val, y_pred, rownames=['Actual'], colnames=['Predicted'], margins=True)



############################################
# 
# Define the parameter grid to search
param_grid = {
    'solver': [ 'lsqr', 'eigen'],  # Different solvers
    'shrinkage': np.arange(0.0, 1.1, 0.1)  # Different shrinkage values
}

# Grid search 
grid_search = GridSearchCV(model, param_grid, cv=5)
grid_search.fit(X_train, y_train)

# Best parameters
best_params = grid_search.best_params_
print("Best Parameters for LDA:", best_params)

# LDA with best parameters
best_lda = LinearDiscriminantAnalysis(**best_params)
best_lda.fit(X_train, y_train)

# Prediction
y_pred_lda = best_lda.predict(X_val)

# Cross val score
scores_lda = cross_val_score(best_lda, X, y, cv=5)

# Mean accuracy
print(f"Mean accuracy with best parameters: {scores_lda.mean()}")

# Confusion matrix
pd.crosstab(y_val, y_pred_lda, rownames=['Actual'], colnames=['Predicted'], margins=True)

Mean accuracy: 0.9196265613008523
Best Parameters for LDA: {'shrinkage': 0.0, 'solver': 'lsqr'}
Mean accuracy: 0.9196265613008523


Predicted,0,1,All
Actual,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,154,82,236
1,12,845,857
All,166,927,1093
