In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, TimeSeriesSplit
from typing import Tuple, Optional
import logging
import yaml
from pathlib import Path
import sys

In [None]:
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class DataSplitter:
    """
    Advanced data splitting utility optimized for car price prediction
    
    Features:
    - Multiple splitting strategies (random, temporal, stratified)
    - Automatic target distribution preservation
    - Memory-efficient chunked processing
    - Configuration via YAML file
    - Comprehensive logging
    """
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize splitter with configuration
        
        Args:
            config_path: Path to YAML configuration file
        """
        self.config = self._load_config(config_path)
        self._validate_config()
        
    def _load_config(self, config_path: Optional[str]) -> dict:
        """Load configuration from YAML file"""
        default_config = {
            'splitting': {
                'test_size': 0.2,
                'random_state': 42,
                'stratify': False,
                'temporal_split': False,
                'time_column': None,
                'chunksize': None
            },
            'paths': {
                'input_data': 'data/processed/cleaned_data.csv',
                'output_dir': 'data/splits/'
            }
        }
        
        if config_path:
            try:
                with open(config_path) as f:
                    user_config = yaml.safe_load(f)
                return {**default_config, **user_config}
            except Exception as e:
                logger.warning(f"Failed to load config: {e}. Using defaults")
                return default_config
        return default_config
    
    def _validate_config(self) -> None:
        """Validate configuration parameters"""
        if self.config['splitting']['temporal_split'] and not self.config['splitting']['time_column']:
            raise ValueError("time_column must be specified for temporal split")
        if not 0 < self.config['splitting']['test_size'] < 1:
            raise ValueError("test_size must be between 0 and 1")
    
    def split_data(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.Series, pd.Series]:
        """
        Execute the data splitting process
        
        Returns:
            X_train, X_test, y_train, y_test DataFrames/Series
        """
        # Load data
        df = self._load_data()
        
        # Prepare features and target
        X = df.drop(columns=[self.config['target_column']])
        y = df[self.config['target_column']]
        
        # Select splitting strategy
        if self.config['splitting']['temporal_split']:
            return self._temporal_split(X, y)
        else:
            return self._random_split(X, y)
    
    def _load_data(self) -> pd.DataFrame:
        """Load and validate input data"""
        input_path = self.config['paths']['input_data']
        chunksize = self.config['splitting']['chunksize']
        
        logger.info(f"Loading data from {input_path}")
        
        try:
            if chunksize:
                chunks = []
                for chunk in pd.read_csv(input_path, chunksize=chunksize):
                    chunks.append(chunk)
                df = pd.concat(chunks, axis=0)
            else:
                df = pd.read_csv(input_path)
                
            # Validate required columns
            required_cols = self.config.get('required_columns', [])
            if required_cols and not all(col in df.columns for col in required_cols):
                missing = set(required_cols) - set(df.columns)
                raise ValueError(f"Missing required columns: {missing}")
                
            return df
            
        except Exception as e:
            logger.error(f"Failed to load data: {e}")
            sys.exit(1)
    
    def _random_split(self, X: pd.DataFrame, y: pd.Series) -> Tuple:
        """Standard randomized train-test split"""
        stratify = y if self.config['splitting']['stratify'] else None
        
        X_train, X_test, y_train, y_test = train_test_split(
            X, y,
            test_size=self.config['splitting']['test_size'],
            random_state=self.config['splitting']['random_state'],
            stratify=stratify
        )
        
        self._save_splits(X_train, X_test, y_train, y_test)
        return X_train, X_test, y_train, y_test
    
    def _temporal_split(self, X: pd.DataFrame, y: pd.Series) -> Tuple:
        """Time-based splitting (chronological order)"""
        time_col = self.config['splitting']['time_column']
        cutoff = self.config['splitting'].get('cutoff_date')
        
        if not cutoff:
            # Auto-determine cutoff based on test_size
            sorted_dates = X[time_col].sort_values()
            cutoff_idx = int(len(X) * (1 - self.config['splitting']['test_size']))
            cutoff = sorted_dates.iloc[cutoff_idx]
        
        mask = X[time_col] <= cutoff
        X_train, y_train = X[mask], y[mask]
        X_test, y_test = X[~mask], y[~mask]
        
        self._save_splits(X_train, X_test, y_train, y_test)
        return X_train, X_test, y_train, y_test
    
    def _save_splits(self, 
                    X_train: pd.DataFrame, 
                    X_test: pd.DataFrame,
                    y_train: pd.Series,
                    y_test: pd.Series) -> None:
        """Persist split data to disk"""
        output_dir = Path(self.config['paths']['output_dir'])
        output_dir.mkdir(parents=True, exist_ok=True)
        
        try:
            X_train.to_csv(output_dir / 'X_train.csv', index=False)
            X_test.to_csv(output_dir / 'X_test.csv', index=False)
            y_train.to_csv(output_dir / 'y_train.csv', index=False)
            y_test.to_csv(output_dir / 'y_test.csv', index=False)
            logger.info(f"Split data saved to {output_dir}")
        except Exception as e:
            logger.error(f"Failed to save splits: {e}")

In [None]:
if __name__ == "__main__":
    # Initialize with configuration
    splitter = DataSplitter(config_path="config/split_config.yaml")
    
    # Execute splitting
    X_train, X_test, y_train, y_test = splitter.split_data()
    
    # Log summary statistics
    logger.info("\n=== Split Summary ===")
    logger.info(f"Training set: {len(X_train)} samples ({len(X_train)/len(X_train)+len(X_test)):.1%})")
    logger.info(f"Test set: {len(X_test)} samples ({len(X_test)/len(X_train)+len(X_test)):.1%})")
    logger.info(f"Target mean - Train: {y_train.mean():.2f}, Test: {y_test.mean():.2f}")