In [None]:
# notebooks/01_data_exploration.ipynb
"""
Data Exploration for WGAN-GP Control Data
Purpose: Analyze control data characteristics and prepare for WGAN-GP training
"""

import sys
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

# Add project root to path
project_root = Path.cwd().parent
sys.path.append(str(project_root))

from src.data.preprocessing import DataProcessor
from src.evaluation.metrics import EvaluationMetrics

# Set plotting style
plt.style.use('seaborn')
sns.set_palette("husl")

# Load data
data_path = "data/data_combined_controls.csv"
processor = DataProcessor()
control_data = pd.read_csv(data_path)

# 1. Basic Data Overview
print("=== Dataset Overview ===")
print(f"Number of samples: {len(control_data)}")
print(f"Number of features: {control_data.shape[1]}")
print("\nFeature Statistics:")
print(control_data.describe())

# 2. Missing Value Analysis
missing_values = control_data.isnull().sum()
print("\n=== Missing Values ===")
print(missing_values[missing_values > 0] if missing_values.any() else "No missing values")

# 3. Feature Distribution Analysis
plt.figure(figsize=(15, 5))

# Feature means
plt.subplot(131)
sns.histplot(control_data.mean(), kde=True)
plt.title('Feature Means Distribution')
plt.xlabel('Mean Value')

# Feature standard deviations
plt.subplot(132)
sns.histplot(control_data.std(), kde=True)
plt.title('Feature Standard Deviations')
plt.xlabel('Standard Deviation')

# Feature correlations
plt.subplot(133)
corrmat = control_data.corr()
sns.histplot(corrmat.values.flatten(), kde=True)
plt.title('Feature Correlations')
plt.xlabel('Correlation Coefficient')

plt.tight_layout()
plt.show()

# 4. Correlation Analysis
plt.figure(figsize=(10, 8))
# Select top 50 most variable features
top_features = control_data.var().nlargest(50).index
correlation_matrix = control_data[top_features].corr()
sns.heatmap(correlation_matrix, cmap='coolwarm', center=0)
plt.title('Feature Correlation Matrix (Top 50 Features)')
plt.show()

# 5. Save exploration results
exploration_results = {
    'basic_stats': control_data.describe(),
    'missing_values': missing_values,
    'correlation_matrix': correlation_matrix,
    'top_features': top_features.tolist()
}

import pickle
with open('results/exploration_results.pkl', 'wb') as f:
    pickle.dump(exploration_results, f)