<a href="https://colab.research.google.com/github/lab-jianghao/spark_ml_sample/blob/main/05_feature_weights_SVM_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!apt-get install openjdk-17-jdk-headless

!wget https://dlcdn.apache.org/spark/spark-3.5.0/spark-3.5.0-bin-hadoop3.tgz /content
!tar xf spark-3.5.0-bin-hadoop3.tgz

In [47]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-17-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.5.0-bin-hadoop3"

In [None]:
!pip install pyspark==3.5.0

In [49]:
from pyspark.sql import SparkSession

spark = SparkSession.builder\
        .master("local[*]")\
        .appName("Colab")\
        .getOrCreate()

In [50]:
from functools import wraps

def spark_sql_initializer(func):
    @wraps(func)
    def wrapper(*args, **kwargs):

        spark = SparkSession.builder\
            .appName("Colab_DT")\
            .master("local[*]")\
            .getOrCreate()

        spark.sparkContext.setLogLevel("WARN")

        func(spark,*args, **kwargs)

        spark.stop()

    return wrapper

In [None]:
!pip install openml

In [56]:
import openml
import pandas as pd

dataset = openml.datasets.get_dataset(31, version=1.0)

print("Name:", dataset.name)
print("Description:", dataset.description)
print("Number of features:", dataset.features)
print("Number of instances:", dataset.qualities["NumberOfInstances"])


X, y, _, _ = dataset.get_data(target=dataset.default_target_attribute, dataset_format='dataframe')
credit_risk_df = pd.concat([X, y], axis=1)

# print("DataFrame with features and labels:")
print(credit_risk_df.head())


Name: credit-g
Description: **Author**: Dr. Hans Hofmann  
**Source**: [UCI](https://archive.ics.uci.edu/ml/datasets/statlog+(german+credit+data)) - 1994    
**Please cite**: [UCI](https://archive.ics.uci.edu/ml/citation_policy.html)

**German Credit dataset**  
This dataset classifies people described by a set of attributes as good or bad credit risks.

This dataset comes with a cost matrix: 
``` 
Good  Bad (predicted)  
Good   0    1   (actual)  
Bad    5    0  
```

It is worse to class a customer as good when they are bad (5), than it is to class a customer as bad when they are good (1).  

### Attribute description  

1. Status of existing checking account, in Deutsche Mark.  
2. Duration in months  
3. Credit history (credits taken, paid back duly, delays, critical accounts)  
4. Purpose of the credit (car, television,...)  
5. Credit amount  
6. Status of savings account/bonds, in Deutsche Mark.  
7. Present employment, in number of years.  
8. Installment rate in percentage of 



In [78]:
from functools import reduce

from pyspark.ml import Pipeline
from pyspark.ml.classification import LinearSVC
from pyspark.ml.feature import StringIndexer, VectorAssembler, IndexToString
from pyspark.ml.evaluation import MulticlassClassificationEvaluator


from pyspark.sql.functions import col

@spark_sql_initializer
def train_print_weights(spark, df):
    credit_risk_df = spark.createDataFrame(df)
    credit_risk_df.show()

    index_columns = credit_risk_df.columns[:-1]
    credit_risk_df = credit_risk_df.withColumnRenamed("class", "label")

    feature_indexes = [StringIndexer(inputCol=col, outputCol=f"indexed_{col}", handleInvalid="keep") for col in index_columns]
    for indexer in feature_indexes:
        credit_risk_df = indexer.fit(credit_risk_df).transform(credit_risk_df)

    vector_assembler = VectorAssembler(
        inputCols=list(map(lambda idx_col: f"indexed_{idx_col}", index_columns)), outputCol="features")
    # credit_risk_df = vector_assembler.transform(credit_risk_df)
    credit_risk_df = credit_risk_df.drop(*list(map(lambda idx_col: f"indexed_{idx_col}", index_columns)))
    # credit_risk_df.show()

    label_indexer = StringIndexer(inputCol="label", outputCol="indexed_label").fit(credit_risk_df)
    label_converter = IndexToString(inputCol="prediction", outputCol="predictedLabel", labels=label_indexer.labels)

    svm_classifier = LinearSVC(labelCol="indexed_label", featuresCol="features")

    training_data, test_data = credit_risk_df.randomSplit([0.7, 0.3], seed=1122)

    svm_pipeline = Pipeline(
        stages=feature_indexes + [vector_assembler, label_indexer, svm_classifier, label_converter])
    svm_model = svm_pipeline.fit(training_data)

    svm_prediction = svm_model.transform(test_data)
    evaluator = MulticlassClassificationEvaluator(labelCol="indexed_label", predictionCol="prediction", metricName="accuracy")
    accuracy = evaluator.evaluate(svm_prediction)
    print("Test Accuracy = {:.2%}".format(accuracy))

    svm_prediction.select("label", "predictedLabel").show()


    feature_weights = svm_model.stages[-2].coefficients
    feature_weights_dict = dict(zip(index_columns, feature_weights))

    sorted_weights = sorted(feature_weights_dict.items(), key=lambda x: abs(x[1]), reverse=True)

    print("Feature Weights:")
    for feature, weight in sorted_weights:
        print(f"{feature}: {weight}")

    # dt_model.write().overwrite().save("file:///content/model/RandomForest")


In [79]:
train_print_weights(credit_risk_df)

+---------------+--------+--------------------+-------------------+-------------+----------------+----------+----------------------+------------------+-------------+---------------+------------------+---+-------------------+--------+----------------+--------------------+--------------+-------------+--------------+-----+
|checking_status|duration|      credit_history|            purpose|credit_amount|  savings_status|employment|installment_commitment|   personal_status|other_parties|residence_since|property_magnitude|age|other_payment_plans| housing|existing_credits|                 job|num_dependents|own_telephone|foreign_worker|class|
+---------------+--------+--------------------+-------------------+-------------+----------------+----------+----------------------+------------------+-------------+---------------+------------------+---+-------------------+--------+----------------+--------------------+--------------+-------------+--------------+-----+
|             <0|       6|critical