In [None]:
# --- Lab 6: Decision Tree Classifier ---

# Step 0: Install necessary packages
!pip install graphviz pydotplus --quiet

# Step 1: Download the dataset
!wget -O drug200.csv https://cf-courses-data.s3.us.cloud-object-storage.appdomain.cloud/IBMDeveloperSkillsNetwork-ML0101EN-SkillsNetwork/labs/Module%203/data/drug200.csv

# Step 2: Imports
import pandas as pd
import matplotlib.pyplot as plt
import pydotplus
from sklearn import preprocessing, metrics
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from matplotlib.image import imread
from io import StringIO

# Step 3: Load dataset
data = pd.read_csv('drug200.csv')
print("✅ Dataset loaded successfully")
print(data.head())

# Step 4: Select input features and target
X = data[['Age', 'Sex', 'BP', 'Cholesterol', 'Na_to_K']].values
y = data['Drug'].values

# Step 5: Encode categorical columns
le_sex = preprocessing.LabelEncoder()
X[:,1] = le_sex.fit_transform(X[:,1])   # Sex: F/M → 0/1

le_BP = preprocessing.LabelEncoder()
X[:,2] = le_BP.fit_transform(X[:,2])    # BP: LOW/NORMAL/HIGH → 0/1/2

le_chol = preprocessing.LabelEncoder()
X[:,3] = le_chol.fit_transform(X[:,3])  # Cholesterol: NORMAL/HIGH → 0/1

# Step 6: Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=3)

# Step 7: Train Decision Tree Classifier
model = DecisionTreeClassifier(criterion='entropy', max_depth=4)
model.fit(X_train, y_train)

# Step 8: Make predictions and evaluate
predictions = model.predict(X_test)
accuracy = metrics.accuracy_score(y_test, predictions)
print("\n✅ Model Evaluation:")
print("Accuracy:", accuracy)

# Step 9: Visualize the decision tree
dot_data = StringIO()
export_graphviz(model,
                out_file=dot_data,
                feature_names=['Age', 'Sex', 'BP', 'Cholesterol', 'Na_to_K'],
                class_names=model.classes_,
                filled=True)

graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png('drug_tree.png')

# Display tree image
img = imread('drug_tree.png')
plt.figure(figsize=(18,12))
plt.imshow(img)
plt.axis('off')
plt.title("Decision Tree for Drug Classification")
plt.show()