Deployment for Classification and Regression Model using Gradio

In [7]:
import gradio as gr
import joblib
import pandas as pd

# --- Load data (only for min/max bounds) ---
df = pd.read_csv('cleaned_data.csv')

# --- Load models and preprocessing artifacts SAVED DURING TRAINING ---
model = joblib.load("results/model_RF_class.joblib")
scaler = joblib.load("results/scaler.joblib")
encoders = joblib.load("results/input_encoders.joblib")
feature_cols_class = joblib.load("results/feature_cols.joblib")

# Regression features
feature_cols_reg = ['avg_session_length', 'sessions_per_month', 'engagement_index', 'total_session_time']

# Categorical columns
categorical_input_cols = list(encoders.keys())

# --- Human-readable labels for features ---
FEATURE_LABELS = {
    'age': 'Age',
    'tenure_months': 'Tenure (Months)',
    'avg_session_length': 'Avg Session Length (min)',
    'sessions_per_month': 'Sessions per Month',
    'support_tickets': 'Support Tickets (Last 30 Days)',
    'num_devices': 'Number of Registered Devices',
    'email_click_rate': 'Email Click Rate',
    'referral_count': 'Referral Count (Lifetime)',
    'discount_rate': 'Discount Rate',
    'satisfaction_score': 'Satisfaction Score (1-5)',
    'monthly_spend': 'Monthly Spend (RM)',         
    'complaint_rate': 'Complaint Rate',  
    'discount_sensitivity': 'Discount Sensitivity',
    'engagement_index': 'Engagement Index',
    'country': 'Country',
    'city': 'City',
    'gender': 'Gender',
    'membership_tier': 'Membership Tier',
    'last_payment_method': 'Last Payment Method',
    'is_mobile_user': 'Is Mobile User?',
    'signup_month': 'Signup Month',
    'signup_dayofweek': 'Signup Day of Week',
    'total_session_time': 'Total Session Time (min)',
    'clicks_per_session': 'Avg Clicks per Session',          
    'device_per_session': 'Avg Devices per Session', 
}

# --- Discrete (integer) features ---
DISCRETE_FEATURES = {
    'age', 'tenure_months', 'sessions_per_month', 'total_session_time', 'support_tickets',
    'num_devices', 'referral_count', 'satisfaction_score', 'signup_month', 
    'signup_dayofweek',
    'avg_session_length',    
    'engagement_index'         
}

# --- Helper for min/max (with custom bounds for cyclical features) ---
def get_min_max(col, fallback_min=0.0, fallback_max=100.0):
    if col == 'signup_month':
        return 1, 12
    elif col == 'signup_dayofweek':
        return 0, 6
    elif col in df.columns:
        return float(df[col].min()), float(df[col].max())
    else:
        print(f"[WARNING] Column '{col}' not found. Using fallback.")
        return fallback_min, fallback_max

# --- Churn Prediction Function ---
def predict_churn(*inputs):
    try:
        data = {col: [val] for col, val in zip(feature_cols_class, inputs)}
        X_new = pd.DataFrame(data)

        for col in categorical_input_cols:
            if col in X_new.columns:
                known = set(encoders[col].classes_)
                X_new[col] = X_new[col].apply(lambda x: x if x in known else encoders[col].classes_[0])
                X_new[col] = encoders[col].transform(X_new[col].astype(str))

        X_input = X_new[feature_cols_class].astype(float)
        X_scaled = scaler.transform(X_input)
        proba = model.predict_proba(X_scaled)[0, 1]
        label = "Yes" if proba >= 0.5 else "No"
        return label, round(float(proba), 2)

    except Exception as e:
        print(f"[PREDICTION ERROR] {e}")
        return "Error", 0.0

# --- Spend Prediction Function ---
def predict_spend(avg_session_length, sessions_per_month, engagement_index, total_session_time):
    try:
        x = pd.DataFrame([{
            'avg_session_length': float(avg_session_length),
            'sessions_per_month': float(sessions_per_month),
            'engagement_index': float(engagement_index),
            'total_session_time': float(total_session_time),
        }])
        x = x.reindex(columns=feature_cols_reg, fill_value=0.0)
        model_reg = joblib.load("results/model_RF_reg.joblib")
        pred = model_reg.predict(x)[0]
        return f"Predicted Average Monthly Spend: RM{pred:.2f}"
    except Exception as e:
        print(f"[SPEND PREDICTION ERROR] {e}")
        return "Error in spend prediction"

# --- Wrapper with Notifications ---
def predict_churn_with_format(*inputs):
    label, proba = predict_churn(*inputs)
    if label == "Error":
        gr.Warning("❌ Churn prediction failed! Please check your inputs.")
        return label, "0.00"
    gr.Info("✅ Churn prediction completed!")
    return label, f"{proba:.2f}"

def predict_spend_with_status(avg_session_length, sessions_per_month, engagement_index, total_session_time):
    result = predict_spend(avg_session_length, sessions_per_month, engagement_index, total_session_time)
    if "Error" in result:
        gr.Warning("❌ Spend prediction failed! Please check your inputs.")
    else:
        gr.Info("✅ Spend prediction completed!")
    return result

# --- Build Gradio UI ---
with gr.Blocks(title="Customer Churn & Spend Predictor") as demo:
    gr.Markdown('<div style="text-align: center; font-size: 2rem; font-weight: bold;">Customer Churn & Spend Predictor</div>')

    with gr.Tabs():
        # === CHURN PREDICTION TAB ===
        with gr.Tab("📉 Churn Prediction"):
            gr.Markdown("Enter customer details to predict if they will churn.")

            with gr.Row():
                with gr.Column():
                    input_components = []
                    for col in feature_cols_class:
                        label = FEATURE_LABELS.get(col, col.replace('_', ' ').title())
                        if col in categorical_input_cols:
                            choices = sorted(encoders[col].classes_.tolist())
                            comp = gr.Dropdown(choices=choices, label=label)
                        elif col == 'is_mobile_user':
                            comp = gr.Checkbox(label=label)
                        else:
                            min_val, max_val = get_min_max(col, 0, 100)
                            if col in DISCRETE_FEATURES:
                                step = 1
                                min_val, max_val = int(min_val), int(max_val)
                            else:
                                step = 0.01 if max_val > 10 else 0.001
                            comp = gr.Slider(min_val, max_val, step=step, label=label)
                        input_components.append(comp)

                    btn_cls = gr.Button("Predict Churn", variant="primary")

                with gr.Column():
                    out_label = gr.Textbox(label="Churn Prediction", interactive=False)
                    out_proba = gr.Textbox(label="Churn Probability", interactive=False)

            btn_cls.click(
                fn=predict_churn_with_format,
                inputs=input_components,
                outputs=[out_label, out_proba]
            )

        # === SPEND PREDICTION TAB ===
        with gr.Tab("💰 Spend Prediction"):
            gr.Markdown("Predict average monthly spend based on engagement metrics.")
            with gr.Row():
                with gr.Column():
                    # Avg Session Length (integer)
                    min_asl, max_asl = get_min_max('avg_session_length', 5, 60)
                    avg_session_length = gr.Slider(
                        int(min_asl), int(max_asl),
                        step=1,
                        label=FEATURE_LABELS.get('avg_session_length', 'Avg Session Length (min)')
                    )
                    
                    # Total Session Time (integer)
                    min_tst, max_tst = get_min_max('total_session_time', 0, 5000)
                    total_session_time = gr.Slider(
                        int(min_tst), int(max_tst),
                        step=1,
                        label=FEATURE_LABELS.get('total_session_time', 'Total Session Time (min)')
                    )
                    
                    # Sessions per Month (integer)
                    min_spm, max_spm = get_min_max('sessions_per_month', 1, 50)
                    sessions_per_month = gr.Slider(
                        int(min_spm), int(max_spm),
                        step=1,
                        label=FEATURE_LABELS.get('sessions_per_month', 'Sessions per Month')
                    )
                    
                    # Engagement Index (integer)
                    min_eng, max_eng = get_min_max('engagement_index', 0, 100)
                    engagement_index = gr.Slider(
                        int(min_eng), int(max_eng),
                        step=1,
                        label=FEATURE_LABELS.get('engagement_index', 'Engagement Index')
                    )
                    
                    btn_reg = gr.Button("Predict Spend", variant="primary")
                with gr.Column():
                    out_spend = gr.Textbox(label="Prediction Result", interactive=False)
            btn_reg.click(
                fn=predict_spend_with_status,
                inputs=[avg_session_length, sessions_per_month, engagement_index, total_session_time],
                outputs=out_spend
            )

# Launch app
if __name__ == "__main__":
    demo.launch()

* Running on local URL:  http://127.0.0.1:7866
* To create a public link, set `share=True` in `launch()`.


