# Linear Discriminant Analysis

In [6]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split 
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import GridSearchCV
import pandas as pd 
import numpy as np 
import warnings 
warnings.filterwarnings('ignore', category=FutureWarning)

# Reading the data

In [7]:
data = pd.read_csv('siren_data_train_no_outliers.csv')
data.head()

Unnamed: 0,heard,building,noise,in_vehicle,asleep,no_windows,age,distance
0,1,0,0,0,0,0,59,901.283517
1,1,0,0,0,0,0,29,972.00626
2,1,0,0,0,0,0,32,872.340924
3,1,0,0,0,0,0,36,257.804449
4,1,0,0,0,0,0,55,529.686791


# Train the model

In [8]:
X = data.drop(["heard"], axis=1)
y = data["heard"]

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

model = LinearDiscriminantAnalysis()
model.fit(X_train, y_train)

y_pred = model.predict(X_val)


score = cross_val_score(model, X, y, cv=5)
print(f"Mean accuracy: {score.mean()}")

pd.crosstab(y_val, y_pred, rownames=['True'], colnames=['Predicted'], margins=True)

Mean accuracy: 0.9196265613008523


Predicted,0,1,All
True,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,154,82,236
1,12,845,857
All,166,927,1093


# Tune the model

In [9]:
param_grid = {
    'solver': ['lsqr', 'eigen'],  
    'shrinkage': np.arange(0.0, 1.1, 0.1)  
}

grid_search = GridSearchCV(model, param_grid, cv=5)
grid_search.fit(X_train, y_train)

best_params = grid_search.best_params_
print("Best Parameters for LDA:", best_params)

best_lda = LinearDiscriminantAnalysis(**best_params)
best_lda.fit(X_train, y_train)

y_pred_lda = best_lda.predict(X_val)

scores_lda = cross_val_score(best_lda, X, y, cv=5)
print(f"Mean accuracy with tuning: {scores_lda.mean()}")

pd.crosstab(y_val, y_pred_lda, rownames=['True'], colnames=['Predicted'], margins=True)

Best Parameters for LDA: {'shrinkage': 0.0, 'solver': 'lsqr'}
Mean accuracy with tuning: 0.9196265613008523


Predicted,0,1,All
True,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,154,82,236
1,12,845,857
All,166,927,1093
