<a href="https://colab.research.google.com/github/kkokay07/genomicclass/blob/master/ML_decision_tree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Cancer Classification using Decision Trees - Mathematical Approach

## 1. Introduction to Decision Trees

### Mathematical Foundation

Decision Trees use the following mathematical concepts:

1) **Information Entropy**:
   $$H(S) = -\sum_{i=1}^{c} p_i \log_2(p_i)$$
   where $p_i$ is the probability of class i in set S

2) **Information Gain**:
   $$IG(S, A) = H(S) - \sum_{v \in Values(A)} \frac{|S_v|}{|S|} H(S_v)$$
   where $S_v$ is the subset of S where attribute A has value v

3) **Gini Impurity**:
   $$Gini(S) = 1 - \sum_{i=1}^{c} p_i^2$$

### Why Decision Trees for SNP Analysis?
- Non-linear relationships
- Handle multiple classes naturally
- Feature importance ranking
- Easy interpretation
- Handle categorical data well

In [None]:
# Import required libraries
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
import seaborn as sns
import matplotlib.pyplot as plt

# Plot example entropy and gini functions
p = np.linspace(0, 1, 100)
entropy = -p * np.log2(p + 1e-10) - (1-p) * np.log2(1-p + 1e-10)
gini = 1 - (p**2 + (1-p)**2)

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(p, entropy)
plt.title('Binary Entropy Function')
plt.xlabel('Probability')
plt.ylabel('Entropy')
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(p, gini)
plt.title('Gini Impurity Function')
plt.xlabel('Probability')
plt.ylabel('Gini Impurity')
plt.grid(True)

plt.tight_layout()
plt.show()

## 2. Data Loading and Understanding

### SNP Data Structure
SNP values are typically encoded as:
- 1: Homozygous reference (AA)
- 2: Heterozygous (AB)
- 3: Homozygous alternate (BB)

In [None]:
# Load the data
data = pd.read_csv('common_cancers.csv')

print("Dataset Information:")
print(f"Number of samples: {len(data)}")
print(f"Number of SNPs: {len(data.columns)-1}")

# Display first few rows
print("\nFirst few rows of the data:")
display(data.head())

# Plot distribution of cancer types
plt.figure(figsize=(12, 6))
cancer_counts = data.iloc[:, 0].value_counts()
sns.barplot(x=cancer_counts.index, y=cancer_counts.values)
plt.title('Distribution of Cancer Types')
plt.xlabel('Cancer Type')
plt.ylabel('Count')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

## 3. Data Preprocessing

### Split and Prepare Data
Unlike linear models, Decision Trees don't require feature scaling because they use threshold-based splitting rules:

$$\text{Split Rule}: X_i \leq t$$

where:
- $X_i$ is feature i
- $t$ is the threshold value

In [None]:
# Split features and target
X = data.iloc[:, 1:]  # SNP features
y = data.iloc[:, 0]   # Cancer types

# Split into training (80%) and testing (20%) sets
X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                    test_size=0.2,
                                                    random_state=42,
                                                    stratify=y)

print("Training set shape:", X_train.shape)
print("Testing set shape:", X_test.shape)

# Show unique values in features
print("\nUnique values in SNP features:")
print(pd.DataFrame(X_train).nunique().value_counts())

## 4. Model Training

### Decision Tree Algorithm

The tree is built recursively using these steps:

1. For each feature $f$ and threshold $t$, calculate information gain:
   $$IG(S, f, t) = H(S) - \frac{|S_{left}|}{|S|}H(S_{left}) - \frac{|S_{right}|}{|S|}H(S_{right})$$

2. Choose the split that maximizes information gain:
   $$(f^*, t^*) = \arg\max_{f,t} IG(S, f, t)$$

3. Repeat recursively for child nodes until stopping criteria are met

In [None]:
# Create and train the model
model = DecisionTreeClassifier(
    criterion='entropy',     # Use information gain
    max_depth=5,            # Limit tree depth to prevent overfitting
    min_samples_split=20,   # Minimum samples required to split
    min_samples_leaf=10     # Minimum samples in leaf nodes
)

print("Training the model...")
model.fit(X_train, y_train)

# Visualize the tree structure
plt.figure(figsize=(20,10))
plot_tree(model,
          feature_names=X.columns,
          class_names=model.classes_,
          filled=True,
          rounded=True)
plt.title('Decision Tree Structure')
plt.show()

## 5. Model Evaluation

### Key Metrics

1. **Accuracy**:
   $$\text{Accuracy} = \frac{\text{Correct Predictions}}{\text{Total Predictions}}$$

2. **Class-wise Precision**:
   $$\text{Precision}_i = \frac{\text{True Positives}_i}{\text{True Positives}_i + \text{False Positives}_i}$$

3. **Class-wise Recall**:
   $$\text{Recall}_i = \frac{\text{True Positives}_i}{\text{True Positives}_i + \text{False Negatives}_i}$$

In [None]:
# Make predictions
y_pred = model.predict(X_test)
y_pred_prob = model.predict_proba(X_test)

# Calculate metrics
accuracy = accuracy_score(y_test, y_pred)
print(f"Model Accuracy: {accuracy:.2%}")

print("\nDetailed Classification Report:")
print(classification_report(y_test, y_pred))

# Create confusion matrix
conf_matrix = confusion_matrix(y_test, y_pred)

# Plot confusion matrix
plt.figure(figsize=(12, 10))
sns.heatmap(conf_matrix,
            annot=True,
            fmt='d',
            cmap='Blues',
            xticklabels=model.classes_,
            yticklabels=model.classes_)
plt.title('Confusion Matrix')
plt.xlabel('Predicted Cancer Type')
plt.ylabel('Actual Cancer Type')
plt.tight_layout()
plt.show()

## 6. Feature Importance Analysis

### Mathematical Interpretation

Feature importance in decision trees is calculated based on the weighted impurity decrease:

$$\text{Importance}(f) = \sum_{n \in \text{nodes using } f} w_n \Delta I_n$$

where:
- $w_n$ is the weighted number of samples reaching node n
- $\Delta I_n$ is the impurity decrease at node n

In [None]:
# Calculate feature importance
feature_importance = pd.DataFrame({
    'SNP': X.columns,
    'Importance': model.feature_importances_
})

# Sort by importance
feature_importance = feature_importance.sort_values('Importance', ascending=False)

# Plot top 20 features
plt.figure(figsize=(12, 6))
sns.barplot(data=feature_importance.head(20), x='Importance', y='SNP')
plt.title('Top 20 Most Important SNP Markers')
plt.xlabel('Feature Importance')
plt.ylabel('SNP Marker')
plt.tight_layout()
plt.show()

# Print top 10 SNPs
print("\nTop 10 Most Important SNP Markers:")
print(feature_importance.head(10))