In [23]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, accuracy_score, roc_auc_score

In [24]:
file_paths = {
    "telecom_churn": "telecom_churn.csv",
    "orange_telecom": "orange_telecom.csv",
    "internet_service_churn": "internet_service_churn.csv",
    "bank_churn": "Bank_churn.csv"
}

In [25]:
telecom_churn = pd.read_csv(file_paths["telecom_churn"])
orange_telecom = pd.read_csv(file_paths["orange_telecom"])
internet_service_churn = pd.read_csv(file_paths["internet_service_churn"])
bank_churn = pd.read_csv(file_paths["bank_churn"])

In [26]:
telecom_churn_info = telecom_churn.head(), telecom_churn.columns.tolist()
orange_telecom_info = orange_telecom.head(), orange_telecom.columns.tolist()
internet_service_churn_info = internet_service_churn.head(), internet_service_churn.columns.tolist()
bank_churn_info = bank_churn.head(), bank_churn.columns.tolist()

In [28]:
telecom_churn_info

(   churn  accountweeks  contractrenewal  dataplan  datausage  custservcalls  \
 0      0           128                1         1        2.7              1   
 1      0           107                1         1        3.7              1   
 2      0           137                1         0        0.0              0   
 3      0            84                0         0        0.0              2   
 4      0            75                0         0        0.0              3   
 
    daymins  daycalls  monthlycharge  overagefee  roammins  
 0    265.1       110           89.0        9.87      10.0  
 1    161.6       123           82.0        9.78      13.7  
 2    243.4       114           52.0        6.06      12.2  
 3    299.4        71           57.0        3.10       6.6  
 4    166.7       113           41.0        7.42      10.1  ,
 ['churn',
  'accountweeks',
  'contractrenewal',
  'dataplan',
  'datausage',
  'custservcalls',
  'daymins',
  'daycalls',
  'monthlycharge',
  'over

In [29]:
orange_telecom_info

(  state  account length  area code international plan voice mail plan  \
 0    LA             117        408                 No              No   
 1    IN              65        415                 No              No   
 2    NY             161        415                 No              No   
 3    SC             111        415                 No              No   
 4    HI              49        510                 No              No   
 
    number vmail messages  total day minutes  total day calls  \
 0                      0              184.5               97   
 1                      0              129.1              137   
 2                      0              332.9               67   
 3                      0              110.4              103   
 4                      0              119.3              117   
 
    total day charge  total eve minutes  total eve calls  total eve charge  \
 0             31.37              351.6               80             29.89   
 1    

In [30]:
internet_service_churn_info

(   id  is_tv_subscriber  is_movie_package_subscriber  subscription_age  \
 0  15                 1                            0             11.95   
 1  18                 0                            0              8.22   
 2  23                 1                            0              8.91   
 3  27                 0                            0              6.87   
 4  34                 0                            0              6.39   
 
    bill_avg  reamining_contract  service_failure_count  download_avg  \
 0        25                0.14                      0           8.4   
 1         0                 NaN                      0           0.0   
 2        16                0.00                      0          13.7   
 3        21                 NaN                      1           0.0   
 4         0                 NaN                      0           0.0   
 
    upload_avg  download_over_limit  churn  
 0         2.3                    0      0  
 1         0.0    

In [31]:
bank_churn_info

(   rownumber  customerid   surname  creditscore geography  gender  age  \
 0          1    15634602  Hargrave          619    France  Female   42   
 1          2    15647311      Hill          608     Spain  Female   41   
 2          3    15619304      Onio          502    France  Female   42   
 3          4    15701354      Boni          699    France  Female   39   
 4          5    15737888  Mitchell          850     Spain  Female   43   
 
    tenure    balance  numofproducts  hascrcard  isactivemember  \
 0       2       0.00              1          1               1   
 1       1   83807.86              1          0               1   
 2       8  159660.80              3          1               0   
 3       1       0.00              2          0               0   
 4       2  125510.82              1          1               1   
 
    estimatedsalary  churn  
 0        101348.88      1  
 1        112542.58      0  
 2        113931.57      1  
 3         93826.63      0  

In [32]:
def preprocess_data(df, target_col):

    #Dropping the unnecessary columns
    unnecessary_cols = ['rownumber', 'customerid', 'surname', 'id', 'number', 'state'] # Identifiers or non-numeric
    df = df.drop(columns=[col for col in unnecessary_cols if col in df.columns], errors='ignore')

    # Handle missing values with median value
    df = df.fillna(df.median(numeric_only=True))

    # Encode categorical variables
    categorical_cols = df.select_dtypes(include=['object', 'category']).columns
    for col in categorical_cols:
        df[col] = LabelEncoder().fit_transform(df[col].astype(str))

    # Separate features and target variable
    X = df.drop(columns=[target_col])
    y = df[target_col]

    # Scale features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    return X_scaled, y

In [33]:
datasets = {
    "telecom_churn": preprocess_data(telecom_churn, 'churn'),
    "orange_telecom": preprocess_data(orange_telecom, 'churn'),
    "internet_service_churn": preprocess_data(internet_service_churn, 'churn'),
    "bank_churn": preprocess_data(bank_churn, 'churn')
}

**Random Forest Model**

In [34]:
rf_model = RandomForestClassifier(random_state=42)

In [35]:
results = {}
for name, (X, y) in datasets.items():
    # Split data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

    # Train the model
    rf_model.fit(X_train, y_train)

    # Make predictions
    y_pred = rf_model.predict(X_test)
    y_proba = rf_model.predict_proba(X_test)[:, 1]

    # Evaluate the model
    accuracy = accuracy_score(y_test, y_pred)
    roc_auc = roc_auc_score(y_test, y_proba)
    report = classification_report(y_test, y_pred, zero_division=0)

    results[name] = {
        "accuracy": accuracy,
        "roc_auc": roc_auc,
        "classification_report": report
    }

In [36]:
for dataset_name, metrics in results.items():
    print(f"Results for {dataset_name}:\n")
    print(f"Accuracy: {metrics['accuracy']:.2f}")
    print(f"ROC AUC: {metrics['roc_auc']:.2f}")
    print("Classification Report:\n", metrics['classification_report'])
    print("\n" + "="*50 + "\n")

Results for telecom_churn:

Accuracy: 0.94
ROC AUC: 0.93
Classification Report:
               precision    recall  f1-score   support

           0       0.94      0.99      0.96       857
           1       0.90      0.62      0.74       143

    accuracy                           0.94      1000
   macro avg       0.92      0.81      0.85      1000
weighted avg       0.93      0.94      0.93      1000



Results for orange_telecom:

Accuracy: 0.94
ROC AUC: 0.95
Classification Report:
               precision    recall  f1-score   support

           0       0.93      0.99      0.96       173
           1       0.94      0.57      0.71        28

    accuracy                           0.94       201
   macro avg       0.94      0.78      0.84       201
weighted avg       0.94      0.94      0.93       201



Results for internet_service_churn:

Accuracy: 0.94
ROC AUC: 0.98
Classification Report:
               precision    recall  f1-score   support

           0       0.93      0.95 

In [38]:
!pip install fastapi

Collecting fastapi
  Downloading fastapi-0.115.0-py3-none-any.whl.metadata (27 kB)
Collecting starlette<0.39.0,>=0.37.2 (from fastapi)
  Downloading starlette-0.38.5-py3-none-any.whl.metadata (6.0 kB)
Downloading fastapi-0.115.0-py3-none-any.whl (94 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m94.6/94.6 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading starlette-0.38.5-py3-none-any.whl (71 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.4/71.4 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: starlette, fastapi
Successfully installed fastapi-0.115.0 starlette-0.38.5


In [40]:
!pip install lime

Collecting lime
  Downloading lime-0.2.0.1.tar.gz (275 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/275.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━[0m [32m204.8/275.7 kB[0m [31m6.5 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m275.7/275.7 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: lime
  Building wheel for lime (setup.py) ... [?25l[?25hdone
  Created wheel for lime: filename=lime-0.2.0.1-py3-none-any.whl size=283834 sha256=35c3aa8560dce95d42cd8af3937c3c70461116914a597c6baf052c4cbdd01711
  Stored in directory: /root/.cache/pip/wheels/fd/a2/af/9ac0a1a85a27f314a06b39e1f492bee1547d52549a4606ed89
Successfully built lime
Installing collected packages: lime
Successfully installed lime-0.2.0.1


In [42]:
!pip install shap

Collecting shap
  Downloading shap-0.46.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (24 kB)
Collecting slicer==0.0.8 (from shap)
  Downloading slicer-0.0.8-py3-none-any.whl.metadata (4.0 kB)
Downloading shap-0.46.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (540 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m540.1/540.1 kB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading slicer-0.0.8-py3-none-any.whl (15 kB)
Installing collected packages: slicer, shap
Successfully installed shap-0.46.0 slicer-0.0.8


In [43]:
import pandas as pd
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from xgboost import XGBClassifier
from sklearn.metrics import classification_report, accuracy_score, roc_auc_score
from lime.lime_tabular import LimeTabularExplainer
import shap
from sklearn.tree import DecisionTreeClassifier

In [44]:
app = FastAPI()

In [83]:
def preprocess_data(df, target_col):
    """
    Preprocesses the data by handling missing values, encoding categorical variables,
    and scaling numerical features.
    """
    unnecessary_cols = ['rownumber', 'customerid', 'surname', 'id', 'number', 'state']
    df = df.drop(columns=[col for col in unnecessary_cols if col in df.columns], errors='ignore')

    # Fill missing values
    df = df.fillna(df.median(numeric_only=True))

    # Encode categorical variables
    categorical_cols = df.select_dtypes(include=['object', 'category']).columns
    for col in categorical_cols:
        df[col] = LabelEncoder().fit_transform(df[col].astype(str))

    # Split features and target
    X = df.drop(columns=[target_col])
    y = df[target_col]

    # Scale features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    return X_scaled, y, X.columns, scaler


In [84]:
file_paths = {
    "telecom_churn": "telecom_churn.csv",
    "orange_telecom": "orange_telecom.csv",
    "internet_service_churn": "internet_service_churn.csv",
    "bank_churn": "Bank_churn.csv"
}

In [85]:
datasets = {
    "telecom_churn": pd.read_csv(file_paths["telecom_churn"]),
    "orange_telecom": pd.read_csv(file_paths["orange_telecom"]),
    "internet_service_churn": pd.read_csv(file_paths["internet_service_churn"]),
    "bank_churn": pd.read_csv(file_paths["bank_churn"])
}

In [86]:
models = {}

In [87]:
for name, df in datasets.items():
    X, y, feature_names, scaler = preprocess_data(df, 'churn')
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

    # Train multiple models
    rf_model = RandomForestClassifier(random_state=42)
    rf_model.fit(X_train, y_train)

    lr_model = LogisticRegression(max_iter=1000, random_state=42)
    lr_model.fit(X_train, y_train)

    xgb_model = XGBClassifier(random_state=42)
    xgb_model.fit(X_train, y_train)

    # Store the models
    models[name] = {
        "rf_model": rf_model,
        "lr_model": lr_model,
        "xgb_model": xgb_model,
        "scaler": scaler,
        "X_test": X_test,
        "y_test": y_test,
        "feature_names": feature_names
    }

In [91]:
from pydantic import BaseModel, Field
from typing import Dict, Union

class ChurnPredictionInput(BaseModel):
    dataset_name: str = Field(..., min_length=1)  # Ensure non-empty string
    model_name: str = Field(..., pattern='^(rf_model|lr_model|xgb_model)$')  # Validate model names using 'pattern' instead of 'regex'
    data: Dict[str, Union[int, float, str]]  # Allows for int, float, or string features


Field "model_name" in ChurnPredictionInput has conflict with protected namespace "model_".




In [92]:
# FastAPI Endpoint for Prediction
@app.post("/predict")
def predict(input_data: ChurnPredictionInput):
    """
    Endpoint to predict churn based on input features.
    """
    if input_data.dataset_name not in models:
        raise HTTPException(status_code=404, detail="Dataset not found.")

    if input_data.model_name not in ["rf_model", "lr_model", "xgb_model"]:
        raise HTTPException(status_code=400, detail="Model not found.")

    model_data = models[input_data.dataset_name]
    model = model_data[input_data.model_name]
    scaler = model_data["scaler"]
    feature_names = model_data["feature_names"]

    # Ensure the input data matches the feature names expected
    input_df = pd.DataFrame([input_data.data], columns=feature_names)

    # Scale the input data using the stored scaler
    input_df_scaled = scaler.transform(input_df)

    # Predict churn probability and class
    prediction = model.predict(input_df_scaled)[0]
    probability = model.predict_proba(input_df_scaled)[0][1]

    return {"prediction": int(prediction), "probability": probability}

In [93]:
# FastAPI Endpoint for Explainability
@app.post("/explain")
def explain(input_data: ChurnPredictionInput):
    """
    Endpoint to provide explanations for predictions using LIME and SHAP.
    """
    if input_data.dataset_name not in models:
        raise HTTPException(status_code=404, detail="Dataset not found.")

    if input_data.model_name not in ["rf_model", "lr_model", "xgb_model"]:
        raise HTTPException(status_code=400, detail="Model not found.")

    model_data = models[input_data.dataset_name]
    model = model_data[input_data.model_name]
    X_test = model_data["X_test"]
    feature_names = model_data["feature_names"]

    # Use LIME for local explanations
    explainer = LimeTabularExplainer(
        X_test,
        feature_names=feature_names,
        class_names=['No Churn', 'Churn'],
        mode='classification'
    )
    explanation = explainer.explain_instance(pd.Series(input_data.data).values, model.predict_proba)

    # Use SHAP for global explanation
    shap.initjs()
    shap_explainer = shap.Explainer(model, X_test)
    shap_values = shap_explainer(X_test)

    # Surrogate Model - Decision Tree for global interpretability
    surrogate_model = DecisionTreeClassifier(max_depth=3)
    surrogate_model.fit(X_test, model.predict(X_test))

    return {
        "lime_explanation": explanation.as_list(),
        "shap_summary_plot": "Generated (SHAP Summary Plot - handle this separately)",
        "surrogate_tree": surrogate_model.tree_.__str__()  # Simplified tree representation
    }

In [68]:
!pip install uvicorn

Collecting uvicorn
  Downloading uvicorn-0.30.6-py3-none-any.whl.metadata (6.6 kB)
Collecting h11>=0.8 (from uvicorn)
  Downloading h11-0.14.0-py3-none-any.whl.metadata (8.2 kB)
Downloading uvicorn-0.30.6-py3-none-any.whl (62 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.8/62.8 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading h11-0.14.0-py3-none-any.whl (58 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.3/58.3 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: h11, uvicorn
Successfully installed h11-0.14.0 uvicorn-0.30.6


In [None]:
if __name__ == "__main__":
    import uvicorn
    import asyncio

    # Use nest_asyncio to run the server within the existing event loop
    import nest_asyncio
    nest_asyncio.apply()

    uvicorn.run(app, host="0.0.0.0", port=8000)

INFO:     Started server process [231]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)


In [95]:
!pip install uvicorn fastapi pyngrok


Collecting pyngrok
  Downloading pyngrok-7.2.0-py3-none-any.whl.metadata (7.4 kB)
Downloading pyngrok-7.2.0-py3-none-any.whl (22 kB)
Installing collected packages: pyngrok
Successfully installed pyngrok-7.2.0


In [96]:
code = """
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import pandas as pd
import uvicorn
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from lime.lime_tabular import LimeTabularExplainer
import shap
from sklearn.tree import DecisionTreeClassifier
from typing import Dict, Union

app = FastAPI()

# Your model training code here...

# Define the input data model
class ChurnPredictionInput(BaseModel):
    dataset_name: str
    model_name: str
    data: dict

@app.post("/predict")
def predict(input_data: ChurnPredictionInput):
    # Your prediction code here...
    pass

@app.post("/explain")
def explain(input_data: ChurnPredictionInput):
    # Your explanation code here...
    pass

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)
"""

with open("app.py", "w") as f:
    f.write(code)


In [97]:
!uvicorn app:app --host 0.0.0.0 --port 8000 --reload


[32mINFO[0m:     Will watch for changes in these directories: ['/content']
[32mINFO[0m:     Uvicorn running on [1mhttp://0.0.0.0:8000[0m (Press CTRL+C to quit)
[32mINFO[0m:     Started reloader process [[36m[1m17981[0m] using [36m[1mStatReload[0m

[32mINFO[0m:     Started server process [[36m17983[0m]
[32mINFO[0m:     Waiting for application startup.
[32mINFO[0m:     Application startup complete.
[32mINFO[0m:     Shutting down
[32mINFO[0m:     Waiting for application shutdown.
[32mINFO[0m:     Application shutdown complete.
[32mINFO[0m:     Finished server process [[36m17983[0m]
[32mINFO[0m:     Stopping reloader process [[36m[1m17981[0m]


In [99]:
!pip install pyngrok




In [105]:
from pyngrok import ngrok

ngrok.set_auth_token("2mHAO9RDan3kWDkWCU7N2Pk09YX_6VQy5wSs6gCbGE4ZVJgvg")

# Set up a tunnel to the FastAPI server
# Use addr instead of port and specify the port number as part of the address string
public_url = ngrok.connect(addr='localhost:8000')
print(f"Public URL: {public_url}")

Public URL: NgrokTunnel: "https://2537-35-196-129-43.ngrok-free.app" -> "http://localhost:8000"


In [None]:
curl -X POST "http://example.ngrok.io/predict" -H "Content-Type: application/json" -d '{
  "dataset_name": "telecom_churn",
  "model_name": "rf_model",
  "data": {"feature1": value1, "feature2": value2}
}'
