# Assignment 3: Machine Learning for Huntington's Disease Prediction

---

**Objective:** Build and evaluate machine learning models to predict disease stage in Huntington's Disease patients using clinical, genetic, and molecular features.

**Dataset:** Huntington's Disease Dataset (48,536 patients, 13 clinical features)

**Target Variable:** Disease_Stage (5-class classification: Pre-symptomatic, Early Stage, Mid Stage, Late Stage, Advanced)

---

## Why This Matters

Accurate prediction of disease stage in Huntington's Disease enables:
- **Early Intervention:** Identify patients who would benefit from early treatment
- **Treatment Planning:** Tailor therapeutic strategies based on disease progression
- **Clinical Trials:** Stratify patients for more effective trial enrollment
- **Patient Counseling:** Provide evidence-based prognosis for personalized care

---

## Success Criteria

- High classification accuracy (>85%)
- Balanced precision and recall across all disease stages
- Interpretable models that align with clinical knowledge
- Robust generalization to unseen patient data

---

## 1. Introduction & Setup

### 1.1 Import Libraries

In [None]:
#import core libraries for data manipulation and visualization
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

#utilities
import warnings
warnings.filterwarnings('ignore')

#setting random seed ensures reproducible results across runs
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)

#configure display settings for better readability
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 100)
pd.set_option('display.float_format', '{:.4f}'.format)

#plotting style for visualizations
plt.style.use('default')
sns.set_palette("husl")

print("✓ Core libraries imported successfully")
print(f"✓ Random seed set to {RANDOM_STATE} for reproducibility")
print("\nNote: Additional libraries will be imported in relevant sections as needed")

## 2. Load Data

In [None]:
#load cleaned data from assinment 2
#data preprocessing (removing irrelevant columns, handling duplicates, etc) was completed in assignment 2
df = pd.read_csv('data/Huntington_Disease_Cleaned.csv')

print(f"Data loaded: {df.shape[0]:,} patients, {df.shape[1]} features")
print(f"Target variable: Disease_Stage (multi-class classification)")

In [None]:
#quick overview
df.head()

In [None]:
#check target variable distribution
#check for class imbalance
print("Disease Stage Distribution:")
print(df['Disease_Stage'].value_counts())
print(f"\nClass balance:")
print(df['Disease_Stage'].value_counts(normalize=True) * 100)

In [None]:
#visualize class distribution
plt.figure(figsize=(10, 5))
df['Disease_Stage'].value_counts().plot(kind='bar', color='steelblue')
plt.title('Distribution of Disease Stages')
plt.xlabel('Disease Stage')
plt.ylabel('Number of Patients')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()