In [None]:
# Streamlit Dashboard for Churn Prevention System
# Save this file as: dashboards/churn_dashboard.py
# Run with: streamlit run dashboards/churn_dashboard.py

import streamlit as st
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import joblib
from datetime import datetime

# Page configuration
st.set_page_config(
    page_title="Churn Prevention Dashboard",
    page_icon="🎯",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Custom CSS
st.markdown("""
    <style>
    .main-header {
        font-size: 2.5rem;
        font-weight: bold;
        color: #1f77b4;
        text-align: center;
        padding: 1rem;
    }
    .metric-card {
        background-color: #f0f2f6;
        padding: 1rem;
        border-radius: 0.5rem;
        box-shadow: 0 2px 4px rgba(0,0,0,0.1);
    }
    </style>
""", unsafe_allow_html=True)

# =============================================================================
# LOAD DATA AND MODELS
# =============================================================================

@st.cache_data
def load_data():
    """Load customer data"""
    df = pd.read_csv('../data/customer_churn_data.csv')
    return df

@st.cache_resource
def load_models():
    """Load trained model and scaler"""
    model = joblib.load('../models/churn_model.pkl')
    scaler = joblib.load('../models/scaler.pkl')
    model_info = joblib.load('../models/model_info.pkl')
    return model, scaler, model_info

# Load data
try:
    df = load_data()
    model, scaler, model_info = load_models()
    feature_cols = model_info['feature_columns']
except FileNotFoundError:
    st.error("⚠️ Please run Notebooks 1 and 2 first to generate data and train models!")
    st.stop()

# Feature engineering function (same as in training)
def engineer_features(df):
    df = df.copy()
    
    df['engagement_score'] = (
        df['logins_30d'] * 0.3 +
        df['features_used'] * 0.3 +
        df['session_duration_avg'] * 0.2 +
        df['power_feature_usage'] * 0.2
    )
    
    df['health_score'] = (
        (df['logins_30d'] / df['logins_30d'].max() * 30) +
        (df['session_duration_avg'] / df['session_duration_avg'].max() * 20) +
        (df['features_used'] / df['features_used'].max() * 20) +
        ((df['ticket_sentiment'] + 1) / 2 * 15) +
        (df['nps_score'] / 10 * 15)
    )
    
    df['activity_recency'] = 1 / (df['days_since_last_login'] + 1)
    df['usage_efficiency'] = df['features_used'] / (df['logins_30d'] + 1)
    df['support_intensity'] = df['support_tickets_30d'] / ((df['tenure_days'] / 30) + 1)
    
    df['high_risk_flag'] = (
        (df['days_since_last_login'] > 14) |
        (df['logins_30d'] < 5) |
        (df['support_tickets_30d'] > 3) |
        (df['payment_failures'] > 0)
    ).astype(int)
    
    # Encode categorical
    tier_map = {'free': 0, 'basic': 1, 'premium': 2}
    size_map = {'1-10': 0, '11-50': 1, '51-200': 2, '200+': 3}
    industry_map = {'tech': 0, 'finance': 1, 'healthcare': 2, 'retail': 3, 'other': 4}
    
    df['tier_encoded'] = df['subscription_tier'].map(tier_map)
    df['size_encoded'] = df['company_size'].map(size_map)
    df['industry_encoded'] = df['industry'].map(industry_map)
    
    return df

# Apply feature engineering
df_processed = engineer_features(df)

# Make predictions for all customers
X = df_processed[feature_cols]
churn_probabilities = model.predict_proba(X)[:, 1]
df_processed['churn_risk_score'] = churn_probabilities * 100
df_processed['churn_prediction'] = (churn_probabilities > 0.5).astype(int)

# =============================================================================
# HEADER
# =============================================================================

st.markdown('<div class="main-header">🎯 Customer Churn Prevention Dashboard</div>', 
            unsafe_allow_html=True)
st.markdown("### AI-Powered Early Warning System for Customer Retention")
st.markdown("---")

# =============================================================================
# SIDEBAR - FILTERS
# =============================================================================

st.sidebar.header("🔧 Filters & Settings")

# Risk threshold
risk_threshold = st.sidebar.slider(
    "Churn Risk Threshold (%)",
    min_value=0,
    max_value=100,
    value=70,
    help="Customers above this threshold are flagged as high-risk"
)

# Tier filter
tier_filter = st.sidebar.multiselect(
    "Subscription Tier",
    options=df['subscription_tier'].unique(),
    default=df['subscription_tier'].unique()
)

# Company size filter
size_filter = st.sidebar.multiselect(
    "Company Size",
    options=df['company_size'].unique(),
    default=df['company_size'].unique()
)

# Apply filters
filtered_df = df_processed[
    (df_processed['subscription_tier'].isin(tier_filter)) &
    (df_processed['company_size'].isin(size_filter))
]

st.sidebar.markdown("---")
st.sidebar.info(f"📊 Showing {len(filtered_df):,} of {len(df):,} customers")

# =============================================================================
# KEY METRICS
# =============================================================================

st.markdown("## 📊 Key Metrics")

col1, col2, col3, col4, col5 = st.columns(5)

total_customers = len(filtered_df)
at_risk_customers = (filtered_df['churn_risk_score'] >= risk_threshold).sum()
at_risk_pct = (at_risk_customers / total_customers * 100) if total_customers > 0 else 0
avg_health = filtered_df['health_score'].mean()
total_mrr = filtered_df['mrr'].sum()
mrr_at_risk = filtered_df[filtered_df['churn_risk_score'] >= risk_threshold]['mrr'].sum()

with col1:
    st.metric("Total Customers", f"{total_customers:,}")

with col2:
    st.metric("At-Risk Customers", f"{at_risk_customers:,}", 
              f"{at_risk_pct:.1f}%",
              delta_color="inverse")

with col3:
    st.metric("Avg Health Score", f"{avg_health:.1f}/100")

with col4:
    st.metric("Total MRR", f"${total_mrr:,.0f}")

with col5:
    st.metric("MRR at Risk", f"${mrr_at_risk:,.0f}",
              f"{mrr_at_risk/total_mrr*100:.1f}%" if total_mrr > 0 else "0%",
              delta_color="inverse")

st.markdown("---")

# =============================================================================
# HIGH-RISK CUSTOMERS TABLE
# =============================================================================

st.markdown("## 🚨 High-Risk Customers Requiring Immediate Attention")

high_risk_df = filtered_df[filtered_df['churn_risk_score'] >= risk_threshold].copy()
high_risk_df = high_risk_df.sort_values('churn_risk_score', ascending=False)

if len(high_risk_df) > 0:
    # Prepare display columns
    display_df = high_risk_df[[
        'customer_id', 'subscription_tier', 'mrr', 'churn_risk_score',
        'health_score', 'days_since_last_login', 'support_tickets_30d',
        'logins_30d', 'payment_failures'
    ]].head(20).copy()
    
    # Format columns
    display_df['churn_risk_score'] = display_df['churn_risk_score'].round(1)
    display_df['health_score'] = display_df['health_score'].round(1)
    display_df['mrr'] = display_df['mrr'].apply(lambda x: f'${x:.0f}')
    
    # Rename for display
    display_df.columns = [
        'Customer ID', 'Tier', 'MRR', 'Risk Score', 'Health Score',
        'Days Inactive', 'Support Tickets', 'Logins', 'Payment Fails'
    ]
    
    st.dataframe(
        display_df,
        use_container_width=True,
        hide_index=True
    )
    
    # Download button
    csv = high_risk_df.to_csv(index=False)
    st.download_button(
        label="📥 Download High-Risk Customer List (CSV)",
        data=csv,
        file_name=f"high_risk_customers_{datetime.now().strftime('%Y%m%d')}.csv",
        mime="text/csv"
    )
else:
    st.success("✅ No high-risk customers at current threshold!")

st.markdown("---")

# =============================================================================
# RISK DISTRIBUTION CHARTS
# =============================================================================

st.markdown("## 📈 Churn Risk Analysis")

col1, col2 = st.columns(2)

with col1:
    # Risk score distribution
    fig = px.histogram(
        filtered_df,
        x='churn_risk_score',
        nbins=50,
        title="Distribution of Churn Risk Scores",
        labels={'churn_risk_score': 'Churn Risk Score (%)'},
        color_discrete_sequence=['#1f77b4']
    )
    fig.add_vline(x=risk_threshold, line_dash="dash", line_color="red",
                 annotation_text="Risk Threshold",
                 annotation_position="top right")
    fig.update_layout(showlegend=False)
    st.plotly_chart(fig, use_container_width=True)

with col2:
    # Risk by subscription tier
    risk_by_tier = filtered_df.groupby('subscription_tier').agg({
        'churn_risk_score': 'mean',
        'customer_id': 'count'
    }).reset_index()
    risk_by_tier.columns = ['Tier', 'Avg Risk Score', 'Customer Count']
    
    fig = px.bar(
        risk_by_tier,
        x='Tier',
        y='Avg Risk Score',
        title="Average Risk Score by Subscription Tier",
        text='Customer Count',
        color='Avg Risk Score',
        color_continuous_scale='RdYlGn_r'
    )
    fig.update_traces(textposition='outside')
    st.plotly_chart(fig, use_container_width=True)

st.markdown("---")

# =============================================================================
# CUSTOMER SEGMENTATION
# =============================================================================

st.markdown("## 🎯 Customer Segmentation: Health vs Risk")

# Create segmentation scatter plot
fig = px.scatter(
    filtered_df,
    x='health_score',
    y='churn_risk_score',
    size='mrr',
    color='subscription_tier',
    hover_data=['customer_id', 'logins_30d', 'support_tickets_30d', 'days_since_last_login'],
    title="Customer Health Score vs Churn Risk (bubble size = MRR)",
    labels={'health_score': 'Health Score (0-100)',
           'churn_risk_score': 'Churn Risk Score (%)'},
    color_discrete_sequence=px.colors.qualitative.Set2
)

# Add quadrant lines
fig.add_hline(y=risk_threshold, line_dash="dash", line_color="red", opacity=0.5)
fig.add_vline(x=50, line_dash="dash", line_color="blue", opacity=0.5)

# Add quadrant annotations
fig.add_annotation(x=75, y=90, text="High Risk,<br>High Engagement<br>(Unexpected)", 
                  showarrow=False, bgcolor="rgba(255,200,200,0.8)", borderpad=4)
fig.add_annotation(x=25, y=90, text="High Risk,<br>Low Engagement<br>(Critical)", 
                  showarrow=False, bgcolor="rgba(255,150,150,0.8)", borderpad=4)
fig.add_annotation(x=75, y=20, text="Low Risk,<br>High Engagement<br>(Healthy)", 
                  showarrow=False, bgcolor="rgba(200,255,200,0.8)", borderpad=4)
fig.add_annotation(x=25, y=20, text="Low Risk,<br>Low Engagement<br>(Dormant)", 
                  showarrow=False, bgcolor="rgba(255,255,200,0.8)", borderpad=4)

st.plotly_chart(fig, use_container_width=True)

st.markdown("---")

# =============================================================================
# FEATURE IMPORTANCE
# =============================================================================

st.markdown("## 🔍 Key Churn Indicators")

col1, col2 = st.columns([2, 1])

with col1:
    # Feature importance (if available)
    if hasattr(model, 'feature_importances_'):
        feature_importance = pd.DataFrame({
            'Feature': feature_cols,
            'Importance': model.feature_importances_
        }).sort_values('Importance', ascending=False).head(10)
        
        fig = px.bar(
            feature_importance,
            y='Feature',
            x='Importance',
            orientation='h',
            title="Top 10 Features Predicting Churn",
            labels={'Importance': 'Importance Score', 'Feature': ''},
            color='Importance',
            color_continuous_scale='Viridis'
        )
        fig.update_layout(showlegend=False, yaxis={'categoryorder':'total ascending'})
        st.plotly_chart(fig, use_container_width=True)

with col2:
    st.markdown("### Key Insights")
    st.markdown("""
    **High Risk Indicators:**
    - 🔴 Days since last login > 14
    - 🔴 Logins < 5 per month
    - 🔴 Support tickets > 3
    - 🔴 Payment failures > 0
    - 🔴 Low engagement score
    - 🔴 Poor sentiment scores
    
    **Protective Factors:**
    - 🟢 Regular feature usage
    - 🟢 High NPS score
    - 🟢 Premium tier
    - 🟢 Long tenure
    """)

st.markdown("---")

# =============================================================================
# INDIVIDUAL CUSTOMER ANALYSIS
# =============================================================================

st.markdown("## 🔎 Individual Customer Deep Dive")

# Customer selector
customer_list = filtered_df.sort_values('churn_risk_score', ascending=False)['customer_id'].tolist()
selected_customer = st.selectbox(
    "Select Customer for Detailed Analysis",
    options=customer_list,
    format_func=lambda x: f"{x} (Risk: {filtered_df[filtered_df['customer_id']==x]['churn_risk_score'].values[0]:.1f}%)"
)

if selected_customer:
    customer_data = filtered_df[filtered_df['customer_id'] == selected_customer].iloc[0]
    
    # Customer overview
    col1, col2, col3, col4 = st.columns(4)
    
    with col1:
        risk_color = "🔴" if customer_data['churn_risk_score'] >= risk_threshold else "🟢"
        st.metric(f"{risk_color} Churn Risk", f"{customer_data['churn_risk_score']:.1f}%")
        st.metric("Health Score", f"{customer_data['health_score']:.1f}/100")
    
    with col2:
        st.metric("Subscription", customer_data['subscription_tier'].title())
        st.metric("MRR", f"${customer_data['mrr']:.0f}")
    
    with col3:
        st.metric("Tenure", f"{customer_data['tenure_days']} days")
        st.metric("Last Login", f"{customer_data['days_since_last_login']} days ago")
    
    with col4:
        st.metric("Logins (30d)", f"{customer_data['logins_30d']}")
        st.metric("Support Tickets", f"{customer_data['support_tickets_30d']}")
    
    # Customer profile radar chart
    st.markdown("### Customer Profile Analysis")
    
    profile_metrics = {
        'Engagement': min(customer_data['engagement_score'] / 50, 1),
        'Usage': customer_data['usage_vs_plan'],
        'Sentiment': (customer_data['ticket_sentiment'] + 1) / 2,
        'Activity': max(0, 1 - customer_data['days_since_last_login'] / 30),
        'NPS': customer_data['nps_score'] / 10
    }
    
    fig = go.Figure(data=go.Scatterpolar(
        r=list(profile_metrics.values()),
        theta=list(profile_metrics.keys()),
        fill='toself',
        line_color='rgb(31, 119, 180)'
    ))
    
    fig.update_layout(
        polar=dict(
            radialaxis=dict(visible=True, range=[0, 1])
        ),
        showlegend=False,
        title="Customer Health Radar (0-1 scale)"
    )
    
    st.plotly_chart(fig, use_container_width=True)
    
    # Intervention recommendations
    st.markdown("### 💡 Recommended Interventions")
    
    interventions = []
    priority_count = {'🔴 CRITICAL': 0, '🟠 HIGH': 0, '🟡 MEDIUM': 0}
    
    # Critical interventions
    if customer_data['payment_failures'] > 0:
        interventions.append(("🔴 CRITICAL", "Billing Issue Resolution", 
                            "Immediate outreach to resolve payment failure within 24 hours"))
        priority_count['🔴 CRITICAL'] += 1
    
    if customer_data['days_since_last_login'] > 30:
        interventions.append(("🔴 CRITICAL", "Dormancy Re-engagement", 
                            "Multi-touch campaign to reactivate account immediately"))
        priority_count['🔴 CRITICAL'] += 1
    
    # High priority interventions
    if customer_data['support_tickets_30d'] > 3:
        interventions.append(("🟠 HIGH", "Customer Success Check-in", 
                            "Schedule call within 3 days to address pain points"))
        priority_count['🟠 HIGH'] += 1
    
    if customer_data['logins_30d'] < 5 and customer_data['tenure_days'] < 90:
        interventions.append(("🟠 HIGH", "Onboarding Enhancement", 
                            "Provide personalized onboarding session within 1 week"))
        priority_count['🟠 HIGH'] += 1
    
    if customer_data['churn_risk_score'] > 75 and customer_data['mrr'] > 50:
        interventions.append(("🟠 HIGH", "Retention Offer", 
                            "Consider special pricing or feature upgrade within 2 days"))
        priority_count['🟠 HIGH'] += 1
    
    # Medium priority interventions
    if customer_data['usage_vs_plan'] < 0.4:
        interventions.append(("🟡 MEDIUM", "Feature Adoption Program", 
                            "14-day guided tour of underutilized features"))
        priority_count['🟡 MEDIUM'] += 1
    
    if customer_data['features_used'] < 5:
        interventions.append(("🟡 MEDIUM", "Product Education", 
                            "Share tutorials and best practices via email"))
        priority_count['🟡 MEDIUM'] += 1
    
    if interventions:
        for priority, action, description in interventions:
            st.markdown(f"**{priority} {action}**")
            st.markdown(f"→ {description}")
            st.markdown("")
    else:
        st.success("✅ Customer is healthy - no immediate interventions needed. Continue monitoring.")
    
    # Summary
    if priority_count['🔴 CRITICAL'] > 0 or priority_count['🟠 HIGH'] > 0:
        st.error(f"⚠️ Action Required: {priority_count['🔴 CRITICAL']} critical + {priority_count['🟠 HIGH']} high priority interventions")
    else:
        st.info(f"ℹ️ {priority_count['🟡 MEDIUM']} medium priority suggestions for optimization")

st.markdown("---")

# =============================================================================
# MODEL PERFORMANCE
# =============================================================================

st.markdown("## 📊 Model Performance Metrics")

col1, col2, col3 = st.columns(3)

with col1:
    st.metric("Model Type", model_info['model_type'])
    st.metric("AUC Score", f"{model_info['auc_score']:.4f}")

with col2:
    total_predicted_churn = df_processed['churn_prediction'].sum()
    actual_churn = df_processed['churned'].sum()
    st.metric("Customers Flagged", f"{total_predicted_churn:,}")
    st.metric("Actual Churners", f"{actual_churn:,}")

with col3:
    if actual_churn > 0:
        detection_rate = (df_processed['churn_prediction'] & df_processed['churned']).sum() / actual_churn
        st.metric("Detection Rate", f"{detection_rate*100:.1f}%")
    st.metric("Total Features", len(feature_cols))

# Model update info
st.info("""
📅 **Model Training Date:** September 2024  
🔄 **Recommended Retraining:** Every 30 days  
📈 **Next Update:** Monitor for data drift and performance degradation
""")

st.markdown("---")

# =============================================================================
# EXPORT & ACTIONS
# =============================================================================

st.markdown("## 📤 Export & Actions")

col1, col2, col3 = st.columns(3)

with col1:
    # Export all predictions
    export_df = filtered_df[['customer_id', 'subscription_tier', 'mrr', 
                             'churn_risk_score', 'health_score', 'churned']]
    csv_all = export_df.to_csv(index=False)
    st.download_button(
        label="📥 Export All Customer Scores",
        data=csv_all,
        file_name=f"all_customer_scores_{datetime.now().strftime('%Y%m%d')}.csv",
        mime="text/csv"
    )

with col2:
    # Export high-risk only
    if len(high_risk_df) > 0:
        csv_high_risk = high_risk_df.to_csv(index=False)
        st.download_button(
            label="⚠️ Export High-Risk Customers",
            data=csv_high_risk,
            file_name=f"high_risk_only_{datetime.now().strftime('%Y%m%d')}.csv",
            mime="text/csv"
        )

with col3:
    # Export intervention list
    if len(high_risk_df) > 0:
        intervention_df = high_risk_df[['customer_id', 'subscription_tier', 'mrr', 
                                       'churn_risk_score', 'days_since_last_login',
                                       'support_tickets_30d']].head(50)
        intervention_df['priority'] = intervention_df['churn_risk_score'].apply(
            lambda x: 'CRITICAL' if x > 85 else 'HIGH' if x > 70 else 'MEDIUM'
        )
        csv_interventions = intervention_df.to_csv(index=False)
        st.download_button(
            label="🎯 Export Intervention Queue",
            data=csv_interventions,
            file_name=f"intervention_queue_{datetime.now().strftime('%Y%m%d')}.csv",
            mime="text/csv"
        )

st.markdown("---")

# =============================================================================
# FOOTER
# =============================================================================

st.markdown("""
---
### 📖 About This Dashboard

This AI-powered churn prevention system uses machine learning to predict customer churn risk 
45-60 days in advance, enabling proactive intervention strategies.

**Key Features:**
- Real-time risk scoring for all customers
- Automated intervention recommendations
- Customer segmentation and health monitoring
- Exportable action lists for customer success teams

**Built with:** Python, Scikit-learn, XGBoost, Streamlit, Plotly

💡 **Tip:** Adjust the risk threshold slider to see how it affects the number of flagged customers.
""")

st.markdown("---")
st.markdown("*Dashboard last updated: {}*".format(datetime.now().strftime("%Y-%m-%d %H:%M:%S")))