In [1]:
import argparse
import os
import glob
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
import pprint

import pyspark
import pyspark.sql.functions as F
from pyspark.sql.functions import col

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score

import xgboost as xgb
from sklearn.model_selection import RandomizedSearchCV

import mlflow
import mlflow.sklearn
from mlflow.models import infer_signature

In [2]:
# Initialize SparkSession optimized for 32GB RAM
spark = pyspark.sql.SparkSession.builder \
    .appName("model_training") \
    .master("local[*]") \
    .config("spark.driver.memory", "20g") \
    .config("spark.driver.maxResultSize", "8g") \
    .config("spark.sql.shuffle.partitions", "32") \
    .config("spark.default.parallelism", "16") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/11/10 10:38:39 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/11/10 10:38:40 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [3]:
# Set MLflow tracking URI
mlflow_tracking_uri = os.getenv('MLFLOW_TRACKING_URI', './mlruns')
mlflow.set_tracking_uri(mlflow_tracking_uri)
mlflow.set_experiment("loan_default_prediction")

print(f"MLflow tracking URI: {mlflow_tracking_uri}")

2025/11/10 10:38:42 INFO mlflow.tracking.fluent: Experiment with name 'loan_default_prediction' does not exist. Creating a new experiment.


MLflow tracking URI: http://mlflow:5000


In [4]:
snapshotdate = '2024-09-01'

In [5]:
FEATURE_DIR = "/app/datamart/gold/feature_store/"
LABEL_DIR = "/app/datamart/gold/label_store/"
SILVER_DIR = "/app/datamart/silver/"
APP_STORE_PATH = "/app/datamart/gold/application_store/"

In [6]:
# --- set up config ---
model_train_date_str = snapshotdate
train_test_period_months = 12
oot_period_months = 2
train_test_ratio = 0.8

config = {}
config["model_train_date_str"] = model_train_date_str
config["train_test_period_months"] = train_test_period_months
config["oot_period_months"] = oot_period_months
config["model_train_date"] = datetime.strptime(model_train_date_str, "%Y-%m-%d")
config["oot_end_date"] = config['model_train_date'] - timedelta(days=1)
config["oot_start_date"] = config['model_train_date'] - relativedelta(months=oot_period_months)
config["train_test_end_date"] = config["oot_start_date"] - timedelta(days=1)
config["train_test_start_date"] = config["oot_start_date"] - relativedelta(months=train_test_period_months)
config["train_test_ratio"] = train_test_ratio

print("\n" + "="*80)
print("CONFIGURATION")
print("="*80)
pprint.pprint(config)
print("="*80 + "\n")


CONFIGURATION
{'model_train_date': datetime.datetime(2024, 9, 1, 0, 0),
 'model_train_date_str': '2024-09-01',
 'oot_end_date': datetime.datetime(2024, 8, 31, 0, 0),
 'oot_period_months': 2,
 'oot_start_date': datetime.datetime(2024, 7, 1, 0, 0),
 'train_test_end_date': datetime.datetime(2024, 6, 30, 0, 0),
 'train_test_period_months': 12,
 'train_test_ratio': 0.8,
 'train_test_start_date': datetime.datetime(2023, 7, 1, 0, 0)}



In [7]:
# --- get labels ---
print("Loading labels from gold layer...")
label_files_pattern = LABEL_DIR + "gold_label_store_*.parquet"
label_store_sdf = spark.read.parquet(label_files_pattern)
print(f"Total labels loaded: {label_store_sdf.count():,}")

# Filter labels by date range
labels_sdf = label_store_sdf.filter(
    (col("snapshot_date") >= config["train_test_start_date"]) & 
    (col("snapshot_date") <= config["oot_end_date"])
)
print(f"Filtered labels: {labels_sdf.count():,} (from {config['train_test_start_date']} to {config['oot_end_date']})")

Loading labels from gold layer...


                                                                                

Total labels loaded: 8,974




Filtered labels: 6,961 (from 2023-07-01 00:00:00 to 2024-08-31 00:00:00)


                                                                                

In [17]:
labels_sdf.columns

['loan_id', 'Customer_ID', 'label', 'label_def', 'snapshot_date']

In [8]:
FEATURE_DIR = "/app/datamart/gold/feature_store"
# ✅ This automatically reads all partitions (all application_date values)
features_sdf = spark.read.parquet(FEATURE_DIR)
features_sdf.show(5)

+-----------+-------------+----------------+--------------------+----------+--------------------+------------------+-----------+----------------------+-------------------+-------------------+------------------+-------------------------------------------------+------------------------------------------------+-------------------------------------------------+------------------------------------------------+--------------------------------------------------+-------------------------------------------------+-------------------+---------------+--------------+----------------------+--------------------------------+--------------------------+-----------------------------+--------------------------+-------------------------+------------------------------------+------------------------+-------------+----------------+--------------+--------------+--------------+--------------+-----------+------------+------------+------------+------------+------------+------------+------------+------------+-----

In [14]:
features_sdf.columns

['Customer_ID',
 'Annual_Income',
 'Outstanding_Debt',
 'Payment_Behaviour',
 'Credit_Mix',
 'Type_of_Loan',
 'Credit_History_Age',
 'Num_of_Loan',
 'Num_of_Delayed_Payment',
 'Delay_from_due_date',
 'DTI',
 'log_Annual_Income',
 'Payment_Behaviour_High_spent_Small_value_payments',
 'Payment_Behaviour_Low_spent_Large_value_payments',
 'Payment_Behaviour_Low_spent_Medium_value_payments',
 'Payment_Behaviour_Low_spent_Small_value_payments',
 'Payment_Behaviour_High_spent_Medium_value_payments',
 'Payment_Behaviour_High_spent_Large_value_payments',
 'Credit_Mix_Standard',
 'Credit_Mix_Good',
 'Credit_Mix_Bad',
 'Type_of_Loan_Auto_Loan',
 'Type_of_Loan_Credit_Builder_Loan',
 'Type_of_Loan_Personal_Loan',
 'Type_of_Loan_Home_Equity_Loan',
 'Type_of_Loan_Mortgage_Loan',
 'Type_of_Loan_Student_Loan',
 'Type_of_Loan_Debt_Consolidation_Loan',
 'Type_of_Loan_Payday_Loan',
 'is_occu_known',
 'age_band_Unknown',
 'age_band_18_24',
 'age_band_25_34',
 'age_band_35_44',
 'age_band_45_54',
 'age_band

In [16]:
features_sdf.columns

['Customer_ID',
 'Annual_Income',
 'Outstanding_Debt',
 'Payment_Behaviour',
 'Credit_Mix',
 'Type_of_Loan',
 'Credit_History_Age',
 'Num_of_Loan',
 'Num_of_Delayed_Payment',
 'Delay_from_due_date',
 'DTI',
 'log_Annual_Income',
 'Payment_Behaviour_High_spent_Small_value_payments',
 'Payment_Behaviour_Low_spent_Large_value_payments',
 'Payment_Behaviour_Low_spent_Medium_value_payments',
 'Payment_Behaviour_Low_spent_Small_value_payments',
 'Payment_Behaviour_High_spent_Medium_value_payments',
 'Payment_Behaviour_High_spent_Large_value_payments',
 'Credit_Mix_Standard',
 'Credit_Mix_Good',
 'Credit_Mix_Bad',
 'Type_of_Loan_Auto_Loan',
 'Type_of_Loan_Credit_Builder_Loan',
 'Type_of_Loan_Personal_Loan',
 'Type_of_Loan_Home_Equity_Loan',
 'Type_of_Loan_Mortgage_Loan',
 'Type_of_Loan_Student_Loan',
 'Type_of_Loan_Debt_Consolidation_Loan',
 'Type_of_Loan_Payday_Loan',
 'is_occu_known',
 'age_band_Unknown',
 'age_band_18_24',
 'age_band_25_34',
 'age_band_35_44',
 'age_band_45_54',
 'age_band

In [9]:
# Filter features by date range
features_sdf = features_sdf.filter(
    (col("snapshot_date") >= config["train_test_start_date"]) & 
    (col("snapshot_date") <= config["oot_end_date"])
)
print(f"Filtered features: {features_sdf.count():,} (from {config['train_test_start_date']} to {config['oot_end_date']})")

Filtered features: 6,937 (from 2023-07-01 00:00:00 to 2024-08-31 00:00:00)


                                                                                

In [10]:
features_sdf.toPandas()

                                                                                

Unnamed: 0,Customer_ID,Annual_Income,Outstanding_Debt,Payment_Behaviour,Credit_Mix,Type_of_Loan,Credit_History_Age,Num_of_Loan,Num_of_Delayed_Payment,Delay_from_due_date,...,fe_12,fe_13,fe_14,fe_15,fe_16,fe_17,fe_18,fe_19,fe_20,snapshot_date
0,CUS_0x102e,50807.441406,869.590027,High_spent_Medium_value_payments,_,"Mortgage Loan, Not Specified, Home Equity Loan...",274.0,4,9,12,...,191,90,131,52,28,19,129,266,159,2024-04-01
1,CUS_0x1226,11421.509766,1278.119995,Unknown,_,"Student Loan, Payday Loan, and Home Equity Loan",362.0,3,15,19,...,221,294,29,8,46,74,159,239,104,2024-04-01
2,CUS_0x1343,125280.359375,1388.430054,High_spent_Large_value_payments,Good,"Debt Consolidation Loan, Student Loan, Home Eq...",267.0,4,5,2,...,55,158,-88,178,224,-34,184,126,-57,2024-04-01
3,CUS_0x13c7,69341.523438,302.750000,Unknown,Standard,Not Specified,190.0,1,9,13,...,-78,88,90,134,-2,26,234,99,18,2024-04-01
4,CUS_0x13d5,124183.562500,247.960007,High_spent_Large_value_payments,Good,"Payday Loan, and Student Loan",376.0,2,17,12,...,195,103,137,141,242,166,28,78,272,2024-04-01
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6932,CUS_0xc4d7,35296.320312,902.229980,Low_spent_Large_value_payments,Standard,"Credit-Builder Loan, Debt Consolidation Loan, ...",276.0,4,12,23,...,0,0,0,0,0,0,0,0,0,2024-07-01
6933,CUS_0xc620,42769.781250,308.910004,High_spent_Large_value_payments,_,"Auto Loan, Mortgage Loan, Debt Consolidation L...",282.0,4,11,10,...,0,0,0,0,0,0,0,0,0,2024-07-01
6934,CUS_0xc6c2,49056.089844,1336.770020,Low_spent_Medium_value_payments,Bad,"Payday Loan, Student Loan, and Payday Loan",83.0,3,20,27,...,0,0,0,0,0,0,0,0,0,2024-07-01
6935,CUS_0xc80,48706.101562,1083.760010,Unknown,Standard,"Student Loan, Payday Loan, and Personal Loan",209.0,3,10,24,...,0,0,0,0,0,0,0,0,0,2024-07-01


In [11]:
labels_sdf.toPandas()

                                                                                

Unnamed: 0,loan_id,Customer_ID,label,label_def,snapshot_date
0,CUS_0x1037_2023_01_01,CUS_0x1037,0,30dpd_6mob,2023-07-01
1,CUS_0x1069_2023_01_01,CUS_0x1069,0,30dpd_6mob,2023-07-01
2,CUS_0x114a_2023_01_01,CUS_0x114a,0,30dpd_6mob,2023-07-01
3,CUS_0x1184_2023_01_01,CUS_0x1184,0,30dpd_6mob,2023-07-01
4,CUS_0x1297_2023_01_01,CUS_0x1297,1,30dpd_6mob,2023-07-01
...,...,...,...,...,...
6956,CUS_0xdf6_2023_09_01,CUS_0xdf6,0,30dpd_6mob,2024-03-01
6957,CUS_0xe23_2023_09_01,CUS_0xe23,0,30dpd_6mob,2024-03-01
6958,CUS_0xe4e_2023_09_01,CUS_0xe4e,0,30dpd_6mob,2024-03-01
6959,CUS_0xedd_2023_09_01,CUS_0xedd,0,30dpd_6mob,2024-03-01


In [12]:
# --- prepare data for modeling ---
print("\nJoining features and labels...")
data_sdf = features_sdf.join(labels_sdf,["Customer_ID", "snapshot_date"],"inner")

print(f"Joined data: {data_sdf.count():,}")

# Convert to pandas
data_pdf = data_sdf.toPandas()

# Identify feature columns (exclude identifiers and label)
exclude_cols = ['loan_id', 'Customer_ID', 'application_date', 'snapshot_date', 'label', 'label_def']
feature_cols = [col for col in data_pdf.columns if col not in exclude_cols]

print(f"\nTotal feature columns: {len(feature_cols)}")
print(f"Feature columns: {feature_cols[:10]}... (showing first 10)")


Joining features and labels...


                                                                                

Joined data: 0





Total feature columns: 75
Feature columns: ['Annual_Income', 'Outstanding_Debt', 'Payment_Behaviour', 'Credit_Mix', 'Type_of_Loan', 'Credit_History_Age', 'Num_of_Loan', 'Num_of_Delayed_Payment', 'Delay_from_due_date', 'DTI']... (showing first 10)


                                                                                

In [13]:
data_pdf.head()

Unnamed: 0,Customer_ID,snapshot_date,Annual_Income,Outstanding_Debt,Payment_Behaviour,Credit_Mix,Type_of_Loan,Credit_History_Age,Num_of_Loan,Num_of_Delayed_Payment,...,fe_14,fe_15,fe_16,fe_17,fe_18,fe_19,fe_20,loan_id,label,label_def
