In [1]:
#churning prdiction revenue
import argparse
import joblib
import os
import sys
import numpy as np
import pandas as pd
import boto3
import sagemaker
from sagemaker.sklearn.estimator import SKLearn
from sagemaker.inputs import TrainingInput
from sagemaker import get_execution_role

def model_fn(model_dir):
    """
    Deserialization function for SageMaker.
    Loads model artifacts from disk and returns a model object for inference.
    """
    model_path = os.path.join(model_dir, "model.joblib")
    model = joblib.load(model_path)
    return model

def run_training(train_dir, model_dir):
    """
    Reads 'train.csv' from train_dir, trains a RandomForest model,
    then saves 'model.joblib' to model_dir.
    """
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.metrics import accuracy_score
    
    # Load training data
    train_path = os.path.join(train_dir, "train.csv")
    data = pd.read_csv(train_path)

    # Extract features and target
    X = data[['monthly_charges', 'tenure_months', 'support_tickets']]
    y = data['is_churn']

    # Train Random Forest
    model = RandomForestClassifier(n_estimators=100, random_state=42)
    model.fit(X, y)
    
    # Evaluate (on training data for demonstration)
    y_pred = model.predict(X)
    accuracy = accuracy_score(y, y_pred)
    print(f"Training Accuracy: {accuracy:.2f}")

    # Save the trained model
    model_path = os.path.join(model_dir, "model.joblib")
    os.makedirs(model_dir, exist_ok=True)
    joblib.dump(model, model_path)
    print(f"Model saved to {model_path}")

def run_local_orchestration():
    """
    Generates synthetic data, uploads to S3, starts a SageMaker training job,
    deploys the model, and runs a test inference. Used when running locally.
    """
    session = sagemaker.Session()
    role = get_execution_role()
    region = session.boto_region_name
    bucket = session.default_bucket()
    prefix = "churn-filter-args-example"

    # 1. Generate synthetic data
    np.random.seed(42)
    num_customers = 1000
    monthly_charges = np.random.randint(20, 150, num_customers)
    tenure_months = np.random.randint(1, 36, num_customers)
    support_tickets = np.random.randint(0, 5, num_customers)

    churn_probability = (
        0.001 * monthly_charges
        + 0.03 * support_tickets
        + 0.0005 * (36 - tenure_months)
    )
    is_churn = (np.random.rand(num_customers) < churn_probability).astype(int)

    data = pd.DataFrame({
        'monthly_charges': monthly_charges,
        'tenure_months': tenure_months,
        'support_tickets': support_tickets,
        'is_churn': is_churn
    })

    # Save locally
    os.makedirs("project_data", exist_ok=True)
    train_file_path = os.path.join("project_data", "train.csv")
    data.to_csv(train_file_path, index=False)
    print(f"Local CSV saved at: {train_file_path}")

    # 2. Upload data to S3
    train_s3_path = session.upload_data(
        path=train_file_path,
        bucket=bucket,
        key_prefix=prefix
    )
    print(f"Training data uploaded to: {train_s3_path}")

    # 3. Create scikit-learn Estimator
    sklearn_estimator = SKLearn(
        entry_point=__file__,  # Use this same script as our "entry point"
        role=role,
        instance_count=1,
        instance_type="ml.m5.large",
        framework_version="1.2-1",
        py_version="py3",
        sagemaker_session=session
    )

    # 4. Fit (train) the model
    sklearn_estimator.fit({"train": TrainingInput(train_s3_path, content_type="text/csv")})

    # 5. Deploy the model to an endpoint
    endpoint_name = "churn-filter-args-endpoint"
    predictor = sklearn_estimator.deploy(
        initial_instance_count=1,
        instance_type="ml.m5.large",
        endpoint_name=endpoint_name
    )

    # 6. Test the endpoint
    sample_data = data.sample(5)
    input_for_inference = sample_data[['monthly_charges','tenure_months','support_tickets']].values.tolist()
    prediction = predictor.predict(input_for_inference)
    print("Sample input:\n", sample_data)
    print("Predicted churn labels (0=no churn, 1=churn):", prediction)

    # Uncomment below to delete the endpoint when finished
    # predictor.delete_endpoint()

if __name__ == "__main__":
    """
    When run with 'python sagemaker_churn_pipeline.py run-local', 
    it executes run_local_orchestration().
    
    Otherwise, it expects to receive arguments consistent with SageMaker 
    (e.g., --train /opt/ml/input/data/train --model-dir /opt/ml/model),
    or uses default project directories if no arguments are provided.
    
    Any extraneous Jupyter arguments (-f /path/to/kernel.json) are filtered out.
    """
    if len(sys.argv) > 1 and sys.argv[1] == "run-local":
        # Local run: orchestrate data generation, training job, deployment, test
        run_local_orchestration()
    else:
        # Filter out Jupyter's '-f /path/to/kernel.json' or other unknown args
        filtered_args = [
            arg for arg in sys.argv[1:] 
            if arg.startswith("--") or arg in ("-h", "--help")
        ]
        new_sys_argv = [sys.argv[0]] + filtered_args

        parser = argparse.ArgumentParser()
        parser.add_argument(
            "--train",
            type=str,
            default=os.environ.get("SM_CHANNEL_TRAIN", os.path.abspath("project_data")),
            help="Path to training directory containing train.csv"
        )
        parser.add_argument(
            "--model-dir",
            type=str,
            default=os.environ.get("SM_MODEL_DIR", os.path.abspath("model_output")),
            help="Where to save the trained model artifacts"
        )

        parsed_args, _ = parser.parse_known_args(new_sys_argv[1:])
        train_dir = parsed_args.train
        model_dir = parsed_args.model_dir

        run_training(train_dir, model_dir)

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/studio-lab-user/.config/sagemaker/config.yaml
Training Accuracy: 1.00
Model saved to /home/studio-lab-user/model_output/model.joblib
