## Bankruptcy Prediction with LightGBM Classification

#### Read dataset

In [None]:
dataset = spark.read.format("csv")\
  .option("header", True)\
  .load("wasbs://publicwasb@mmlspark.blob.core.windows.net/company_bankruptcy_prediction_data.csv")
# print dataset size
print("records read: " + str(dataset.count()))

In [None]:
# convert features to double type
from pyspark.sql.functions import col
from pyspark.sql.types import DoubleType
for colName in dataset.columns:
  dataset = dataset.withColumn(colName, col(colName).cast(DoubleType()))
print("Schema: ")
dataset.printSchema()

In [None]:
dataset.show(n=3, truncate=False, vertical=True)

#### Split the dataset into train and test

In [None]:
train, test = dataset.randomSplit([0.85, 0.15], seed=1)

#### Add featurizer to convert features to vector

In [None]:
from pyspark.ml.feature import VectorAssembler
feature_cols = dataset.columns[1:]
featurizer = VectorAssembler(
    inputCols=feature_cols,
    outputCol='features'
)
train_data = featurizer.transform(train)['Bankrupt?', 'features']
test_data = featurizer.transform(test)['Bankrupt?', 'features']

In [None]:
train_data.show(10)

#### Check if the data is unbalanced

In [None]:
train_data.groupBy("Bankrupt?").count().show()

#### Model Training

In [None]:
from mmlspark.lightgbm import LightGBMClassifier
model = LightGBMClassifier(objective="binary", featuresCol="features", labelCol="Bankrupt?", isUnbalance=True)

In [None]:
model = model.fit(train_data)

In [None]:
from mmlspark.lightgbm import LightGBMClassificationModel
model.saveNativeModel("/lgbmcmodel")
model = LightGBMClassificationModel.loadNativeModelFromFile("/lgbmcmodel")

In [None]:
print(model.getFeatureImportances())

#### Model Prediction

In [None]:
predictions = model.transform(test_data)
predictions.limit(10).toPandas()

In [None]:
ComputeModelStatistics(evaluationMetric="classification", labelCol="Bankrupt?", scoredLabelsCol="prediction").transform(predictions)

In [None]:
from mmlspark.train import ComputeModelStatistics
metrics = ComputeModelStatistics(evaluationMetric="classification", labelCol='Bankrupt?', scoredLabelsCol='prediction').transform(predictions)
display(metrics)