<a href="https://colab.research.google.com/github/garryDCU/CA4022/blob/main/pyspark_fraud.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Implementation of Naive Bayes Classifier on PySpark

In this notebook, we investigated how well a naive bayes classifier could predict fraud in an anonymised bank dataset taken from [Kaggle](https://https://www.kaggle.com/volodymyrgavrysh/fraud-detection-bank-dataset-20k-records-binary). 

The notebook was created on Google Colab.

#### Necessary installs and imports

In [2]:
! pip install pyspark

Collecting pyspark
  Downloading pyspark-3.2.0.tar.gz (281.3 MB)
[K     |████████████████████████████████| 281.3 MB 37 kB/s 
[?25hCollecting py4j==0.10.9.2
  Downloading py4j-0.10.9.2-py2.py3-none-any.whl (198 kB)
[K     |████████████████████████████████| 198 kB 61.0 MB/s 
[?25hBuilding wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.2.0-py2.py3-none-any.whl size=281805912 sha256=256c35c5e545f53d59b1f3047c10cd97674499679cdde44bc08979ad60ff27d9
  Stored in directory: /root/.cache/pip/wheels/0b/de/d2/9be5d59d7331c6c2a7c1b6d1a4f463ce107332b1ecd4e80718
Successfully built pyspark
Installing collected packages: py4j, pyspark
Successfully installed py4j-0.10.9.2 pyspark-3.2.0


In [3]:
import os       #importing os to set environment variable
def install_java():
  !apt-get install -y openjdk-8-jdk-headless -qq > /dev/null      #install openjdk
  os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"     #set environment variable
  !java -version       #check java version
install_java()

openjdk version "11.0.11" 2021-04-20
OpenJDK Runtime Environment (build 11.0.11+9-Ubuntu-0ubuntu2.18.04)
OpenJDK 64-Bit Server VM (build 11.0.11+9-Ubuntu-0ubuntu2.18.04, mixed mode, sharing)


In [4]:
import pandas as pd
from google.colab import drive 
from sklearn.preprocessing import MinMaxScaler
from pyspark.sql import SparkSession 
from pyspark.sql.functions import * 
from pyspark.ml import Pipeline 
from pyspark.ml.feature import VectorAssembler 
from pyspark.ml.feature import StringIndexer 
from pyspark.ml.classification import NaiveBayes 
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.evaluation import BinaryClassificationEvaluator

In [5]:
drive.mount('/content/drive/')

Mounted at /content/drive/


#### Read in the data

In [6]:
data=pd.read_csv('/content/drive/MyDrive/CA4022FinalAssignment/fraud_detection_bank_dataset.csv')
del data['Unnamed: 0']
data.head()

Unnamed: 0,col_0,col_1,col_2,col_3,col_4,col_5,col_6,col_7,col_8,col_9,col_10,col_11,col_12,col_13,col_14,col_15,col_16,col_17,col_18,col_19,col_20,col_21,col_22,col_23,col_24,col_25,col_26,col_27,col_28,col_29,col_30,col_31,col_32,col_33,col_34,col_35,col_36,col_37,col_38,col_39,...,col_73,col_74,col_75,col_76,col_77,col_78,col_79,col_80,col_81,col_82,col_83,col_84,col_85,col_86,col_87,col_88,col_89,col_90,col_91,col_92,col_93,col_94,col_95,col_96,col_97,col_98,col_99,col_100,col_101,col_102,col_103,col_104,col_105,col_106,col_107,col_108,col_109,col_110,col_111,targets
0,9,1354,0,18,0,1,7,9,0,0,0,0,0,0,1,0,1,0,0,0,0,0,9,74,19,25,0,0,1,3,24,0,0,0,2,0,97,0,981,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,4,0,0,0,1,1,0,0,0,49,1
1,0,239,0,1,0,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,7,1,0,0,0,1,1,0,0,0,0,0,0,0,0,18,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,55,1
2,0,260,0,4,0,3,6,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,13,7,5,0,0,1,5,0,0,0,0,0,0,5,0,91,0,...,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,1,0,0,0,1,1,0,0,0,56,1
3,17,682,0,1,0,0,8,17,0,0,0,0,0,0,0,0,1,0,0,0,0,0,23,52,1,7,0,0,1,1,1,9,0,0,0,0,3,0,26,0,...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,1,0,1,1,0,0,0,65,1
4,1,540,0,2,0,1,7,1,0,0,0,0,0,0,1,0,1,0,0,0,0,0,1,20,3,11,0,0,1,4,20,0,0,0,0,0,52,0,669,0,...,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,1,0,0,0,1,1,0,0,0,175,1


Now we can check to see if we need to perform any data cleaning

In [7]:
data.isna().sum().sum()

0

There are no null values so we can proceed to our next phase

#### Scale the data to improve prediction accuracy

In [8]:
# separate the features from the label column
labels=data['targets'].to_list()
feats=data
del feats['targets']

# scale the dataframe
scaler = MinMaxScaler()
feats[feats.columns] = scaler.fit_transform(feats[feats.columns])

# put the df back together
data=feats
data['targets']=labels

#### Set up Spark and create a spark dataframe

In [9]:
#Sets the Spark master URL to run locally. 
spark = SparkSession.builder.master("local[*]").getOrCreate()

In [10]:
#Create DataFrame 
fraud_df = spark.createDataFrame(data)
fraud_df.show(5)

+--------------------+--------------------+-----+--------------------+-----+--------------------+-------------------+--------------------+-----+-----+------+------+------+------+------+------+------+------+------+------+------+------+--------------------+--------------------+--------------------+--------------------+------+------+------+--------------------+--------------------+-------------------+------+------+--------------------+------+--------------------+------+--------------------+------+------+--------------------+--------------------+--------------------+--------------------+------+------+--------------------+--------------------+--------------------+--------------------+------+------+------+--------------------+--------------------+------+--------------------+------+------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+------

In [11]:
indexers = [
StringIndexer(inputCol="col_0", outputCol = "col_0_index"),  
StringIndexer(inputCol="col_1", outputCol = "col_1_index"),  
StringIndexer(inputCol="col_2", outputCol = "col_2_index"),  
StringIndexer(inputCol="col_3", outputCol = "col_3_index"),
StringIndexer(inputCol="col_4", outputCol = "col_4_index"),  
StringIndexer(inputCol="col_5", outputCol = "col_5_index"),  
StringIndexer(inputCol="col_6", outputCol = "col_6_index"),  
StringIndexer(inputCol="col_7", outputCol = "col_7_index"),
StringIndexer(inputCol="col_8", outputCol = "col_8_index"),  
StringIndexer(inputCol="col_9", outputCol = "col_9_index"),
StringIndexer(inputCol="col_10", outputCol = "col_10_index"),  
StringIndexer(inputCol="col_11", outputCol = "col_11_index"),  
StringIndexer(inputCol="col_12", outputCol = "col_12_index"),  
StringIndexer(inputCol="col_13", outputCol = "col_13_index"),
StringIndexer(inputCol="col_14", outputCol = "col_14_index"),  
StringIndexer(inputCol="col_15", outputCol = "col_15_index"),  
StringIndexer(inputCol="col_16", outputCol = "col_16_index"),  
StringIndexer(inputCol="col_17", outputCol = "col_17_index"),
StringIndexer(inputCol="col_18", outputCol = "col_18_index"),  
StringIndexer(inputCol="col_19", outputCol = "col_19_index"),
StringIndexer(inputCol="col_20", outputCol = "col_20_index"),  
StringIndexer(inputCol="col_21", outputCol = "col_21_index"),  
StringIndexer(inputCol="col_22", outputCol = "col_22_index"),  
StringIndexer(inputCol="col_23", outputCol = "col_23_index"),
StringIndexer(inputCol="col_24", outputCol = "col_24_index"),  
StringIndexer(inputCol="col_25", outputCol = "col_25_index"),  
StringIndexer(inputCol="col_26", outputCol = "col_26_index"),  
StringIndexer(inputCol="col_27", outputCol = "col_27_index"),
StringIndexer(inputCol="col_28", outputCol = "col_28_index"),  
StringIndexer(inputCol="col_29", outputCol = "col_29_index"),
StringIndexer(inputCol="col_30", outputCol = "col_30_index"),  
StringIndexer(inputCol="col_31", outputCol = "col_31_index"),  
StringIndexer(inputCol="col_32", outputCol = "col_32_index"),  
StringIndexer(inputCol="col_33", outputCol = "col_33_index"),
StringIndexer(inputCol="col_34", outputCol = "col_34_index"),  
StringIndexer(inputCol="col_35", outputCol = "col_35_index"),  
StringIndexer(inputCol="col_36", outputCol = "col_36_index"),  
StringIndexer(inputCol="col_37", outputCol = "col_37_index"),
StringIndexer(inputCol="col_38", outputCol = "col_38_index"),  
StringIndexer(inputCol="col_39", outputCol = "col_39_index"),
StringIndexer(inputCol="col_40", outputCol = "col_40_index"),  
StringIndexer(inputCol="col_41", outputCol = "col_41_index"),  
StringIndexer(inputCol="col_42", outputCol = "col_42_index"),  
StringIndexer(inputCol="col_43", outputCol = "col_43_index"),
StringIndexer(inputCol="col_44", outputCol = "col_44_index"),  
StringIndexer(inputCol="col_45", outputCol = "col_45_index"),  
StringIndexer(inputCol="col_46", outputCol = "col_46_index"),  
StringIndexer(inputCol="col_47", outputCol = "col_47_index"),
StringIndexer(inputCol="col_48", outputCol = "col_48_index"),  
StringIndexer(inputCol="col_49", outputCol = "col_49_index"),
StringIndexer(inputCol="col_50", outputCol = "col_50_index"),  
StringIndexer(inputCol="col_51", outputCol = "col_51_index"),  
StringIndexer(inputCol="col_52", outputCol = "col_52_index"),  
StringIndexer(inputCol="col_53", outputCol = "col_53_index"),
StringIndexer(inputCol="col_54", outputCol = "col_54_index"),  
StringIndexer(inputCol="col_55", outputCol = "col_55_index"),  
StringIndexer(inputCol="col_56", outputCol = "col_56_index"),  
StringIndexer(inputCol="col_57", outputCol = "col_57_index"),
StringIndexer(inputCol="col_58", outputCol = "col_58_index"),  
StringIndexer(inputCol="col_59", outputCol = "col_59_index"),
StringIndexer(inputCol="col_60", outputCol = "col_60_index"),  
StringIndexer(inputCol="col_61", outputCol = "col_61_index"),  
StringIndexer(inputCol="col_62", outputCol = "col_62_index"),  
StringIndexer(inputCol="col_63", outputCol = "col_63_index"),
StringIndexer(inputCol="col_64", outputCol = "col_64_index"),  
StringIndexer(inputCol="col_65", outputCol = "col_65_index"),  
StringIndexer(inputCol="col_66", outputCol = "col_66_index"),  
StringIndexer(inputCol="col_67", outputCol = "col_67_index"),
StringIndexer(inputCol="col_68", outputCol = "col_68_index"),  
StringIndexer(inputCol="col_69", outputCol = "col_69_index"),
StringIndexer(inputCol="col_70", outputCol = "col_70_index"),  
StringIndexer(inputCol="col_71", outputCol = "col_71_index"),  
StringIndexer(inputCol="col_72", outputCol = "col_72_index"),  
StringIndexer(inputCol="col_73", outputCol = "col_73_index"),
StringIndexer(inputCol="col_74", outputCol = "col_74_index"),  
StringIndexer(inputCol="col_75", outputCol = "col_75_index"),  
StringIndexer(inputCol="col_76", outputCol = "col_76_index"),  
StringIndexer(inputCol="col_77", outputCol = "col_77_index"),
StringIndexer(inputCol="col_78", outputCol = "col_78_index"),  
StringIndexer(inputCol="col_79", outputCol = "col_79_index"),
StringIndexer(inputCol="col_80", outputCol = "col_80_index"),  
StringIndexer(inputCol="col_81", outputCol = "col_81_index"),  
StringIndexer(inputCol="col_82", outputCol = "col_82_index"),  
StringIndexer(inputCol="col_83", outputCol = "col_83_index"),
StringIndexer(inputCol="col_84", outputCol = "col_84_index"),  
StringIndexer(inputCol="col_85", outputCol = "col_85_index"),  
StringIndexer(inputCol="col_86", outputCol = "col_86_index"),  
StringIndexer(inputCol="col_87", outputCol = "col_87_index"),
StringIndexer(inputCol="col_88", outputCol = "col_88_index"),  
StringIndexer(inputCol="col_89", outputCol = "col_89_index"),
StringIndexer(inputCol="col_90", outputCol = "col_90_index"),  
StringIndexer(inputCol="col_91", outputCol = "col_91_index"),  
StringIndexer(inputCol="col_92", outputCol = "col_92_index"),  
StringIndexer(inputCol="col_93", outputCol = "col_93_index"),
StringIndexer(inputCol="col_94", outputCol = "col_94_index"),  
StringIndexer(inputCol="col_95", outputCol = "col_95_index"),  
StringIndexer(inputCol="col_96", outputCol = "col_96_index"),  
StringIndexer(inputCol="col_97", outputCol = "col_97_index"),
StringIndexer(inputCol="col_98", outputCol = "col_98_index"),  
StringIndexer(inputCol="col_99", outputCol = "col_99_index"),
StringIndexer(inputCol="col_100", outputCol = "col_100_index"),  
StringIndexer(inputCol="col_101", outputCol = "col_101_index"),  
StringIndexer(inputCol="col_102", outputCol = "col_102_index"),  
StringIndexer(inputCol="col_103", outputCol = "col_103_index"),
StringIndexer(inputCol="col_104", outputCol = "col_104_index"),  
StringIndexer(inputCol="col_105", outputCol = "col_105_index"),  
StringIndexer(inputCol="col_106", outputCol = "col_106_index"),  
StringIndexer(inputCol="col_107", outputCol = "col_107_index"),
StringIndexer(inputCol="col_108", outputCol = "col_108_index"),  
StringIndexer(inputCol="col_109", outputCol = "col_109_index"),
StringIndexer(inputCol="col_110", outputCol = "col_110_index"),  
StringIndexer(inputCol="col_111", outputCol = "col_111_index"),  
StringIndexer(inputCol="targets", outputCol = "label")]

In [12]:
pipeline = Pipeline(stages=indexers)

In [13]:
#Fitting a model to the input dataset. 
indexed_fraud_df = pipeline.fit(fraud_df).transform(fraud_df)

In [14]:
indexed_fraud_df.show(5,False)
#We have given False for turn off default truncation

+--------------------+--------------------+-----+---------------------+-----+--------------------+-------------------+--------------------+-----+-----+------+------+------+------+------+------+------+------+------+------+------+------+---------------------+---------------------+---------------------+--------------------+------+------+------+---------------------+---------------------+-------------------+------+------+--------------------+------+---------------------+------+---------------------+------+------+--------------------+--------------------+--------------------+--------------------+------+------+---------------------+---------------------+---------------------+---------------------+------+------+------+---------------------+--------------------+------+--------------------+------+------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+---------------------+--------

In [15]:
vectorAssembler = VectorAssembler(inputCols = [
                                               "col_0_index", 
                                               "col_1_index", 
                                               "col_2_index",
                                               "col_3_index", 
                                               "col_4_index", 
                                               "col_5_index",
                                               "col_6_index", 
                                               "col_7_index", 
                                               "col_8_index",
                                               "col_9_index",
                                               "col_10_index", 
                                               "col_11_index", 
                                               "col_12_index",
                                               "col_13_index", 
                                               "col_14_index", 
                                               "col_15_index",
                                               "col_16_index", 
                                               "col_17_index", 
                                               "col_18_index",
                                               "col_19_index",
                                               "col_20_index", 
                                               "col_21_index", 
                                               "col_22_index",
                                               "col_23_index", 
                                               "col_24_index", 
                                               "col_25_index",
                                               "col_26_index", 
                                               "col_27_index", 
                                               "col_28_index",
                                               "col_29_index",
                                               "col_30_index", 
                                               "col_31_index", 
                                               "col_32_index",
                                               "col_33_index", 
                                               "col_34_index", 
                                               "col_35_index",
                                               "col_36_index", 
                                               "col_37_index", 
                                               "col_38_index",
                                               "col_39_index",
                                               "col_40_index", 
                                               "col_41_index", 
                                               "col_42_index",
                                               "col_43_index", 
                                               "col_44_index", 
                                               "col_45_index",
                                               "col_46_index", 
                                               "col_47_index", 
                                               "col_48_index",
                                               "col_49_index",
                                               "col_50_index", 
                                               "col_51_index", 
                                               "col_52_index",
                                               "col_53_index", 
                                               "col_54_index", 
                                               "col_55_index",
                                               "col_56_index", 
                                               "col_57_index", 
                                               "col_58_index",
                                               "col_59_index",
                                               "col_60_index", 
                                               "col_61_index", 
                                               "col_62_index",
                                               "col_63_index", 
                                               "col_64_index", 
                                               "col_65_index",
                                               "col_66_index", 
                                               "col_67_index", 
                                               "col_68_index",
                                               "col_69_index",
                                               "col_70_index", 
                                               "col_71_index", 
                                               "col_72_index",
                                               "col_73_index", 
                                               "col_74_index", 
                                               "col_75_index",
                                               "col_76_index", 
                                               "col_77_index", 
                                               "col_78_index",
                                               "col_79_index",
                                               "col_80_index", 
                                               "col_81_index", 
                                               "col_82_index",
                                               "col_83_index", 
                                               "col_84_index", 
                                               "col_85_index",
                                               "col_86_index", 
                                               "col_87_index", 
                                               "col_88_index",
                                               "col_89_index",
                                               "col_90_index", 
                                               "col_91_index", 
                                               "col_92_index",
                                               "col_93_index", 
                                               "col_94_index", 
                                               "col_95_index",
                                               "col_96_index", 
                                               "col_97_index", 
                                               "col_98_index",
                                               "col_99_index",
                                               "col_100_index", 
                                               "col_101_index", 
                                               "col_102_index",
                                               "col_103_index", 
                                               "col_104_index", 
                                               "col_105_index",
                                               "col_106_index", 
                                               "col_107_index", 
                                               "col_108_index",
                                               "col_109_index",
                                               "col_110_index", 
                                               "col_111_index" 
                                               ]
                                  ,outputCol = "features") 

In [16]:
indexed_fraud_df = vectorAssembler.transform(indexed_fraud_df) 
indexed_fraud_df.show(5, False)

+--------------------+--------------------+-----+---------------------+-----+--------------------+-------------------+--------------------+-----+-----+------+------+------+------+------+------+------+------+------+------+------+------+---------------------+---------------------+---------------------+--------------------+------+------+------+---------------------+---------------------+-------------------+------+------+--------------------+------+---------------------+------+---------------------+------+------+--------------------+--------------------+--------------------+--------------------+------+------+---------------------+---------------------+---------------------+---------------------+------+------+------+---------------------+--------------------+------+--------------------+------+------------------+---------------------+--------------------+---------------------+---------------------+--------------------+---------------------+---------------------+---------------------+--------

#### Split data for training. 

We are using an 80/20 split as is pretty common.

In [17]:
splits = indexed_fraud_df.randomSplit([0.8,0.2], 42) 
# optional value 42 is seed for sampling 
train_df = splits[0] 
test_df = splits[1]

#### Apply Naive Bayes

Selected model type guassian. Other options for naive bayes are bernoulli (requiring all features to be 0 or 1) and multinomial. 

In [18]:
nb = NaiveBayes(modelType="gaussian")

#### Train the model using the training data specified above

In [19]:
nbmodel = nb.fit(train_df)

#### Predict with test data

In [20]:
predictions_df = nbmodel.transform(test_df)
predictions_df.show(5, True)

+-----+-----+-----+--------------------+-----+--------------------+-------------------+-----+-----+-----+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+--------------------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+--------------------+-------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+------------+-

#### Compute the accuracy on the test set

In [21]:
evaluator = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="prediction", metricName="areaUnderROC") 
nbaccuracy = evaluator.evaluate(predictions_df) 
print("Test accuracy = " + str(nbaccuracy))

Test accuracy = 0.655890870973493


The results of this test are poor in comparison to our implementation of the same model using sklearn which got an AUC score of 0.75.