In [1]:
import sys
import mlflow
import pandas as pd
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Tuple
from imblearn.over_sampling import SMOTE
from dataclasses import dataclass

sys.path.append('..')
from src.mlflow_utils import configure_mlflow, find_latest_run_id_by_experiment_and_stage, get_targets, get_data, load_config

In [2]:
CONFIG = load_config()

In [3]:
@dataclass
class ResampledData:
    X_train: pd.DataFrame
    y_train: pd.Series
    pre_stats: dict
    post_stats: dict

In [4]:
class DataBalancer:
    """Handles data resampling pipeline"""
    def __init__(self, config: dict):
        self.config = config
        self.raw_data = None
        self.resampled = None
        self.smote = None

    def load_data(self) -> None:
        """Load data from feature selection and preprocessing stages"""
        fs_run_id = find_latest_run_id_by_experiment_and_stage(
            self.config["experiment_names"]["feature_selection"],
            self.config["run_names"]["feature_selection"]
        )
        preprocessing_run_id = find_latest_run_id_by_experiment_and_stage(
            self.config["experiment_names"]["preprocessing"],
            self.config["run_names"]["preprocessing"]
        )

        self.raw_data = {
            "X_train": get_data(fs_run_id, self.config["dataset"], self.config["artifacts"]["data"]["selected"])["X_train"],
            "y_train": get_targets(preprocessing_run_id, self.config["dataset"], "processed")["y_train"]
        }

        mlflow.log_params({
            "feature_selection_run_id": fs_run_id,
            "preprocessing_run_id": preprocessing_run_id
        })

    def validate_data(self) -> None:
        """Validate input data quality"""
        if self.raw_data["X_train"].empty or self.raw_data["y_train"].empty:
            raise ValueError("Empty training data received")
            
        if len(self.raw_data["X_train"]) != len(self.raw_data["y_train"]):
            raise ValueError("Mismatch between features and target row counts")

    def _calculate_class_stats(self, y: pd.Series) -> dict:
        """Calculate class distribution statistics"""
        counts = y.value_counts()
        return {
            "class_0": counts.get(0, 0),
            "class_1": counts.get(1, 0),
            "imbalance_ratio": counts[0]/counts[1] if 1 in counts else 0
        }

    def apply_smote(self) -> None:
        """Apply SMOTE resampling with configurable parameters"""
        params = self.config["resampling"]["smote"]
        self.smote = SMOTE(
            random_state=params["random_state"],
            sampling_strategy=params["sampling_strategy"],
            k_neighbors=params["k_neighbors"]
        )

        X_res, y_res = self.smote.fit_resample(
            self.raw_data["X_train"], 
            self.raw_data["y_train"].values.ravel()
        )

        self.resampled = ResampledData(
            X_train=pd.DataFrame(X_res, columns=self.raw_data["X_train"].columns),
            y_train=pd.Series(y_res),
            pre_stats=self._calculate_class_stats(self.raw_data["y_train"]),
            post_stats=self._calculate_class_stats(pd.Series(y_res))
        )

    def log_artifacts(self) -> None:
        """Log resampled data and artifacts to MLflow"""
        # Log dataset statistics
        mlflow.log_metrics({
            "original_samples": self.resampled.pre_stats["class_0"] + self.resampled.pre_stats["class_1"],
            "resampled_samples": self.resampled.post_stats["class_0"] + self.resampled.post_stats["class_1"],
            "original_imbalance_ratio": self.resampled.pre_stats["imbalance_ratio"],
            "new_imbalance_ratio": self.resampled.post_stats["imbalance_ratio"]
        })

        # Log resampled datasets
        with TemporaryDirectory() as tmp_dir:
            # Features
            x_path = Path(tmp_dir) / "X_train_resampled.parquet"
            self.resampled.X_train.to_parquet(x_path)
            mlflow.log_artifact(x_path, f"{self.config['artifacts']['data']['resampled']}/training")
            
            # Targets
            y_path = Path(tmp_dir) / "y_train_resampled.parquet"
            self.resampled.y_train.to_frame().to_parquet(y_path)
            mlflow.log_artifact(y_path, f"{self.config['artifacts']['data']['resampled']}/training")

        # Log SMOTE parameters
        mlflow.log_params({
            "smote_k_neighbors": self.smote.k_neighbors,
            "smote_sampling_strategy": str(self.smote.sampling_strategy)
        })

In [5]:
class ResamplingPipeline:
    """Orchestrates the full resampling workflow"""
    def __init__(self, config: dict):
        self.config = config
        self.balancer = DataBalancer(config)

    def run(self) -> None:
        """Execute full resampling pipeline"""
        self.balancer.load_data()
        self.balancer.validate_data()
        self.balancer.apply_smote()
        self.balancer.log_artifacts()

In [6]:
if __name__ == "__main__":
    experiment_name = CONFIG["experiment_names"]["resampling"]
    run_name = CONFIG["run_names"]["resampling"]
    
    configure_mlflow(experiment_name)
    
    try:
        with mlflow.start_run(run_name=run_name):
            mlflow.set_tags({
                "stage": "resampling",
                "resampling_method": "SMOTE",
                "task": "classification"
            })
            
            # Log full configuration
            mlflow.log_dict(CONFIG, "resampling_config.yaml")
            
            # Execute pipeline
            pipeline = ResamplingPipeline(CONFIG)
            pipeline.run()
            
            mlflow.set_tag("status", "completed")
            print(f"Resampling completed. Run ID: {mlflow.active_run().info.run_id}")
            
    except Exception as e:
        mlflow.log_param("error", str(e))
        mlflow.set_tag("status", "failed")
        mlflow.end_run()
        raise

Resampling completed. Run ID: 9cde454b91d04ffa94db2fb1c7e02ed1
