In [6]:
# Import models and utility functions
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import AdaBoostClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

#inbult module
from data_processing import preprocess_data

# Setting SEED for reproducibility
SEED = 23
 
# Load the dataset
data = pd.read_csv('dementia_dataset.csv')

# Preprocess the data (excluding any non-feature columns like IDs or target variable)
X = preprocess_data(data.iloc[:, 3:])  # Assuming the first two columns are non-features
y = data.iloc[:, 2]  
 
# Splitting dataset
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.25, random_state = SEED)
 

adb = AdaBoostClassifier(
    DecisionTreeClassifier(max_depth=1), 
    n_estimators=100, 
    learning_rate=1.0, 
    random_state=42
)

# Train the model on the training data
adb.fit(X_train, y_train)

y_pred = adb.predict(X_test)

# Get unique classes from y for target_names, converted to string
target_names = [str(label) for label in np.unique(y)]

# Print the classification report
print(classification_report(y_test, y_pred, target_names=target_names, labels=np.unique(y)))
display(X)



              precision    recall  f1-score   support

   Converted       0.19      0.55      0.28        11
    Demented       0.95      0.95      0.95        38
 Nondemented       0.88      0.47      0.61        45

    accuracy                           0.67        94
   macro avg       0.67      0.65      0.61        94
weighted avg       0.82      0.67      0.71        94



Unnamed: 0,Visit,MR Delay,M/F,Age,EDUC,SES,MMSE,CDR,eTIV,nWBV,ASF
0,1,0,1,87,14,2.000000,27.0,0.0,1987,0.696,0.883
1,2,457,1,88,14,2.000000,30.0,0.0,2004,0.681,0.876
2,1,0,1,75,12,2.460452,23.0,0.5,1678,0.736,1.046
3,2,560,1,76,12,2.460452,28.0,0.5,1738,0.713,1.010
4,3,1895,1,80,12,2.460452,22.0,0.5,1698,0.701,1.034
...,...,...,...,...,...,...,...,...,...,...,...
368,2,842,1,82,16,1.000000,28.0,0.5,1693,0.694,1.037
369,3,2297,1,86,16,1.000000,26.0,0.5,1688,0.675,1.040
370,1,0,0,61,13,2.000000,30.0,0.0,1319,0.801,1.331
371,2,763,0,63,13,2.000000,30.0,0.0,1327,0.796,1.323
