# Cancer Classification with TabNet

This notebook implements a TabNet deep learning model for cancer type classification. TabNet is a deep learning architecture designed specifically for tabular data, using sequential attention to choose which features to reason from at each decision step.

## Overview
1. Data loading and preprocessing
2. Handling class imbalance with SMOTE
3. Model training and evaluation
4. Performance analysis with various metrics

In [None]:
# Import required libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import LabelEncoder, StandardScaler
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from pytorch_tabnet.tab_model import TabNetClassifier
import torch

## 1. Data Loading and Preprocessing

We start by loading the cancer classification dataset and preparing it for model training.

In [None]:
# Load dataset
print("Loading dataset...")
file_path = "cancer_classification_dataset.csv"
df = pd.read_csv(file_path)

# Display the first few rows to understand the data structure
print("Dataset shape:", df.shape)
df.head()

In [None]:
# Check target variable distribution
plt.figure(figsize=(12, 6))
cancer_counts = df['cancer_type'].value_counts()
sns.barplot(x=cancer_counts.index, y=cancer_counts.values)
plt.title('Cancer Type Distribution')
plt.xlabel('Cancer Type')
plt.ylabel('Count')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

In [None]:
# Separate features and target
X = df.drop(columns=['cancer_type'])
y = df['cancer_type']

# Encode categorical target variable
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)

# Feature scaling
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Train-test split (Keep test set original)
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_encoded, test_size=0.2, random_state=42)

print(f"Training set shape: {X_train.shape}")
print(f"Testing set shape: {X_test.shape}")

## 2. Handling Class Imbalance with SMOTE

We'll use the Synthetic Minority Over-sampling Technique (SMOTE) to address class imbalance in our training data. Note that we're only applying SMOTE to the training set to maintain the integrity of our test evaluation.

In [None]:
# Check class distribution before SMOTE
print("Class distribution before SMOTE:")
for label, count in zip(*np.unique(y_train, return_counts=True)):
    print(f"Class {label_encoder.inverse_transform([label])[0]}: {count} samples")

In [None]:
# Apply SMOTE only on the training set
smote = SMOTE(random_state=42)
X_train_resampled, y_train_resampled = smote.fit_resample(X_train, y_train)

# Check class distribution after SMOTE
print("Class distribution after SMOTE:")
for label, count in zip(*np.unique(y_train_resampled, return_counts=True)):
    print(f"Class {label_encoder.inverse_transform([label])[0]}: {count} samples")

# Convert to NumPy arrays
X_train_resampled, X_test = np.array(X_train_resampled), np.array(X_test)
y_train_resampled, y_test = np.array(y_train_resampled), np.array(y_test)

## 3. TabNet Model Training

Now we'll initialize and train the TabNet classifier on our balanced training data.

In [None]:
# Initialize TabNet model
tabnet_clf = TabNetClassifier(
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=2e-2),
    scheduler_params={"step_size":10, "gamma":0.9},
    scheduler_fn=torch.optim.lr_scheduler.StepLR,
    mask_type='entmax'  # Can be 'sparsemax' or 'entmax'
)

In [None]:
# Train model
print("Training TabNet model...")
tabnet_clf.fit(
    X_train_resampled, y_train_resampled,
    eval_set=[(X_train_resampled, y_train_resampled), (X_test, y_test)],
    eval_name=['train', 'valid'],
    eval_metric=['accuracy'],
    max_epochs=100,
    patience=10,
    batch_size=64,
    virtual_batch_size=64,
    num_workers=0
)

In [None]:
# Plot training history
plt.figure(figsize=(10, 6))
plt.plot(tabnet_clf.history['train_accuracy'], label='Train')
plt.plot(tabnet_clf.history['valid_accuracy'], label='Valid')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('TabNet Training History')
plt.legend()
plt.grid(True)
plt.show()

## 4. Model Evaluation

Let's evaluate the trained TabNet model using various metrics to get a comprehensive understanding of its performance.

In [None]:
# Make predictions on original test set
y_pred = tabnet_clf.predict(X_test)
y_proba = tabnet_clf.predict_proba(X_test)

# Basic performance metrics
accuracy = accuracy_score(y_test, y_pred)
print(f"TabNet Accuracy on Original Test Set: {accuracy:.4f}")
print("\nClassification Report on Original Test Set:\n", 
      classification_report(y_test, y_pred, target_names=label_encoder.classes_))

In [None]:
# Function to compute top-k accuracy
def top_k_accuracy(y_true, y_pred_proba, k=2):
    top_k = np.argsort(y_pred_proba, axis=1)[:, -k:]  # Get top-k predictions
    correct = np.array([y_true[i] in top_k[i] for i in range(len(y_true))])
    return np.mean(correct)

# Compute top-2 and top-3 accuracy
top_2_acc = top_k_accuracy(y_test, y_proba, k=2)
top_3_acc = top_k_accuracy(y_test, y_proba, k=3)

print(f"TabNet Top-2 Accuracy: {top_2_acc:.4f}")
print(f"TabNet Top-3 Accuracy: {top_3_acc:.4f}")

In [None]:
# Compute per-cancer-type accuracy
print("\nPer-Cancer-Type Accuracy:")
for i, label in enumerate(label_encoder.classes_):
    correct = np.sum((y_pred == i) & (y_test == i))  # Correctly classified instances
    total = np.sum(y_test == i)  # Total instances of this cancer type
    class_accuracy = correct / total if total > 0 else 0  # Avoid division by zero
    print(f"Accuracy for cancer type '{label}': {class_accuracy:.4f} ({correct}/{total})")

## 5. Confusion Matrix Visualization

In [None]:
# Confusion Matrix Visualization
conf_matrix = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(12, 10))
sns.heatmap(conf_matrix, annot=True, cmap='Greens', fmt='d',
            xticklabels=label_encoder.classes_, yticklabels=label_encoder.classes_)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix for TabNet on Original Test Set")
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

## 6. Analysis of Misclassifications

Let's analyze which cancer types are most frequently confused with each other.

In [None]:
# Find the most common misclassifications
print("Cancer Types Most Confused with Each Other:")
misclassifications = []
for i in range(len(label_encoder.classes_)):
    for j in range(len(label_encoder.classes_)):
        if i != j and conf_matrix[i, j] > 0:
            misclassifications.append({
                'true': label_encoder.classes_[i],
                'predicted': label_encoder.classes_[j],
                'count': conf_matrix[i, j]
            })

# Sort by count in descending order
misclassifications.sort(key=lambda x: x['count'], reverse=True)

# Display top misclassifications
for item in misclassifications[:10]:  # Show top 10
    print(f"{item['true']} is confused with {item['predicted']} with count {item['count']}")

## 7. Conclusion and Next Steps

In this notebook, we've implemented a TabNet deep learning model for cancer type classification. Here's a summary of what we've accomplished:

1. Processed and standardized the cancer classification dataset
2. Addressed class imbalance using SMOTE on the training data
3. Trained a TabNet model with optimized parameters
4. Evaluated the model using multiple metrics including per-cancer-type accuracy
5. Analyzed misclassifications to identify commonly confused cancer types

Potential next steps:
- Hyperparameter tuning to further optimize the TabNet model
- Feature selection to identify the most discriminative genes for cancer classification
- Ensemble with other models (SVM, XGBoost) to improve overall performance
- Investigate specific misclassifications and potential biological reasons