<a href="https://colab.research.google.com/github/jburchfield76/datasharing/blob/master/MIT_all_Stats_Tree_SA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#converted tree.r from MIT-all of stats to python using grok
#converted dat file to csv, BUT DID NOT NEED TO, SEE BELOW pd.read_csv("sa.data")
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

# Read data (skip first row, comma-separated)
# Note: You'll need to adjust the file path to your actual file location
data = pd.read_csv("sa.data", skiprows=1, header=None)

# Create matrix with 11 columns, by row
X = data.values.reshape(-1, 11)

# Extract chd (11th column)
chd = X[:, 10]  # Python uses 0-based indexing, so column 11 is index 10

# Remove first and 11th columns
X = np.delete(X, [0, 10], axis=1)

# Define feature names
names = ["sbp", "tobacco", "ldl", "adiposity", "famhist", "typea", "obesity", "alcohol", "age"]

# Create DataFrame with feature columns
feature_df = pd.DataFrame(X, columns=names)

# Convert famhist to categorical
feature_df['famhist'] = feature_df['famhist'].astype('category')

# Create target variable as categorical
chd_series = pd.Series(chd, name='chd')
chd_series = chd_series.astype('category')

# Create complete dataset
d = pd.concat([chd_series, feature_df], axis=1)

print("Dataset shape:", d.shape)
print("\nDataset info:")
print(d.info())

# Decision Tree (equivalent to R's tree package)
# Fit the tree model
tree_model = DecisionTreeClassifier(random_state=42, criterion='gini')  # gini is similar to misclass
tree_model.fit(d.drop('chd', axis=1), d['chd'])

print("\nTree Model Summary:")
print(f"Number of nodes: {tree_model.tree_.node_count}")
print(f"Maximum depth: {tree_model.tree_.max_depth}")
print(f"Number of leaves: {np.sum(tree_model.tree_.children_left == -1)}")

# Plot 1: Tree visualization
fig, ax = plt.subplots(figsize=(12, 8))
plot_tree(tree_model,
          feature_names=names,
          class_names=['No CHD', 'CHD'],
          filled=True,
          ax=ax,
          fontsize=8)
plt.title('South Africa CHD Decision Tree - Full Tree')
plt.savefig('south.africa.tree.plot1.png', dpi=300, bbox_inches='tight')
plt.close()

# Cross-validation to find optimal tree size
# We'll use cross-validation to find the best max_depth (similar to pruning)
depths = range(1, 15)
cv_scores = []

for depth in depths:
    model = DecisionTreeClassifier(max_depth=depth, random_state=42, criterion='gini')
    scores = cross_val_score(model, d.drop('chd', axis=1), d['chd'],
                           cv=10, scoring='accuracy')
    cv_scores.append(1 - np.mean(scores))  # Convert to misclassification error

cv_scores = np.array(cv_scores)
size = depths
score = cv_scores

# Find optimal size (depth)
min_score_idx = np.argmin(score)
optimal_k = size[min_score_idx]

print(f"\nCross-validation results:")
print(f"Optimal tree depth: {optimal_k}")
print(f"Minimum misclassification error: {score[min_score_idx]:.3f}")

# Plot 2: CV error vs tree size
plt.figure(figsize=(10, 6))
plt.plot(size, score, 'r-', linewidth=2, color='#d62728')
plt.scatter(optimal_k, score[min_score_idx], color='red', s=100, zorder=5)
plt.xlabel('Tree Size (Depth)')
plt.ylabel('Cross-Validation Misclassification Error')
plt.title('Cross-Validation Error vs Tree Size')
plt.grid(True, alpha=0.3)
plt.savefig('south.africa.tree.plot2.png', dpi=300, bbox_inches='tight')
plt.close()

# Prune the tree to optimal size
pruned_model = DecisionTreeClassifier(max_depth=optimal_k, random_state=42, criterion='gini')
pruned_model.fit(d.drop('chd', axis=1), d['chd'])

print(f"\nPruned Tree Summary:")
print(f"Number of nodes: {pruned_model.tree_.node_count}")
print(f"Maximum depth: {pruned_model.tree_.max_depth}")
print(f"Number of leaves: {np.sum(pruned_model.tree_.children_left == -1)}")

# Plot 3: Pruned tree visualization
fig, ax = plt.subplots(figsize=(12, 8))
plot_tree(pruned_model,
          feature_names=names,
          class_names=['No CHD', 'CHD'],
          filled=True,
          ax=ax,
          fontsize=8)
plt.title(f'South Africa CHD Decision Tree - Pruned (Depth={optimal_k})')
plt.savefig('south.africa.tree.plot3.png', dpi=300, bbox_inches='tight')
plt.close()

# Print feature importances
print("\nFeature Importances:")
feature_importance = pd.DataFrame({
    'feature': names,
    'importance': pruned_model.feature_importances_
}).sort_values('importance', ascending=False)
print(feature_importance)

# Optional: Print confusion matrix and accuracy for the pruned model
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.model_selection import train_test_split

# Split data for evaluation
X_train, X_test, y_train, y_test = train_test_split(
    d.drop('chd', axis=1), d['chd'], test_size=0.2, random_state=42, stratify=d['chd']
)

# Fit on training data
pruned_model.fit(X_train, y_train)

# Predictions
y_pred = pruned_model.predict(X_test)

print(f"\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=['No CHD', 'CHD']))

print(f"\nConfusion Matrix:")
print(confusion_matrix(y_test, y_pred))

Dataset shape: (462, 10)

Dataset info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 462 entries, 0 to 461
Data columns (total 10 columns):
 #   Column     Non-Null Count  Dtype   
---  ------     --------------  -----   
 0   chd        462 non-null    category
 1   sbp        462 non-null    float64 
 2   tobacco    462 non-null    float64 
 3   ldl        462 non-null    float64 
 4   adiposity  462 non-null    float64 
 5   famhist    462 non-null    category
 6   typea      462 non-null    float64 
 7   obesity    462 non-null    float64 
 8   alcohol    462 non-null    float64 
 9   age        462 non-null    float64 
dtypes: category(2), float64(8)
memory usage: 30.1 KB
None

Tree Model Summary:
Number of nodes: 193
Maximum depth: 14
Number of leaves: 97

Cross-validation results:
Optimal tree depth: 3
Minimum misclassification error: 0.307


  plt.plot(size, score, 'r-', linewidth=2, color='#d62728')



Pruned Tree Summary:
Number of nodes: 15
Maximum depth: 3
Number of leaves: 8

Feature Importances:
     feature  importance
8        age    0.542465
1    tobacco    0.147369
4    famhist    0.128857
5      typea    0.120395
2        ldl    0.060914
0        sbp    0.000000
3  adiposity    0.000000
6    obesity    0.000000
7    alcohol    0.000000

Classification Report:
              precision    recall  f1-score   support

      No CHD       0.89      0.66      0.75        61
         CHD       0.56      0.84      0.68        32

    accuracy                           0.72        93
   macro avg       0.73      0.75      0.71        93
weighted avg       0.78      0.72      0.73        93


Confusion Matrix:
[[40 21]
 [ 5 27]]
