# Spam Text Classification

In this demonstration a BERT model is used to classify whether text messages are real or spam.

The dataset used in this notebook can be found here: https://www.kaggle.com/datasets/team-ai/spam-text-message-classification/data available under a permissive CC0: Public Domain license.

The dataset contains 5,572 rows and two columns:

| Name      | Model Role | Measurement Level | Description                          |
|:----------|:-----------|:------------------|:-------------------------------------|
| Category  | Target     | Binary            | 1 = spam, 0 = text                   |
| Message   | Input      | Text              | The text content of a given message  |

## Loading Packages and Data Exploration

In [1]:
# Imports the necessary packages

import swat
import pandas as pd

In [2]:
# Creates the connection object

conn = swat.CAS("server.demo.sas.com", 30570, 'student', 'Metadata0')

In [3]:
# Extends the session timeout

conn.session.timeout(time=60*60*12)

In [4]:
# Creates a CASlib that points to the directory where the dataset is located

conn.table.addCaslib(name='mycl', path='/workshop/winsas/VOSI', subDirectories = True, dataSource='PATH', activeOnAdd=True)

NOTE: 'mycl' is now the active caslib.
NOTE: Cloud Analytic Services added the caslib 'mycl'.


Unnamed: 0,Name,Type,Description,Path,Definition,Subdirs,Local,Active,Personal,Hidden,Transient,TableRedistUpPolicy
0,mycl,PATH,,/workshop/winsas/VOSI/,,1.0,1.0,1.0,0.0,0.0,0.0,Not Specified


In [5]:
# Loads the spam dataset into memory and then creates a CASTable object

conn.table.loadtable(caslib = "mycl", path = "spam_texts.csv", casout = dict(name = "spam", replace = True))
spam_data = conn.CASTable("spam")

NOTE: Cloud Analytic Services made the file spam_texts.csv available as table SPAM in caslib mycl.


In [6]:
# Displays the spam dataset column information

spam_data.info()

CASTable('spam')
Data columns (total 3 columns):
             N   Miss     Type
Category  5572  False   double
Message   5572  False  varchar
docid     5572  False   double
dtypes: double(2), varchar(1)
data size: 626820
vardata size: 448516
memory usage: 626936


In [7]:
# Displays a sample of the dataset

spam_data.head()

Unnamed: 0,Category,Message,docid
0,0.0,"Go until jurong point, crazy.. Available only ...",0.0
1,0.0,Ok lar... Joking wif u oni...,1.0
2,1.0,Free entry in 2 a wkly comp to win FA Cup fina...,2.0
3,0.0,U dun say so early hor... U c already then say...,3.0
4,0.0,"Nah I don't think he goes to usf, he lives aro...",4.0


In [8]:
# Displays a sample of a spam message 

spam_data["Message"].query("docid = 8").head()[0]

'WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only.'

In [9]:
# Displays the percentage of spam messages and text messages

spam_data["Category"].value_counts()*100/spam_data.shape[0]

0.0    86.593683
1.0    13.406317
dtype: float64

In [10]:
# Loads in the action sets that are going to be used

conn.loadactionset("fedSQL")
conn.loadactionset("sampling")
conn.loadactionset("percentile")
conn.loadactionset("textClassifier")

NOTE: Added action set 'fedSQL'.
NOTE: Added action set 'sampling'.
NOTE: Added action set 'percentile'.
NOTE: Added action set 'textClassifier'.


## Data Partitioning

In [11]:
# Partitions the data into training and validation

conn.sampling.stratified(table = "SPAM", 
                         seed = 42, 
                         target = "Category", 
                         samppct = 70, 
                         partind = True,
                         output= dict(casOut=dict(name = "SPAM", replace = True), copyVars = "ALL")
                         )

NOTE: Stratified sampling is in effect.
NOTE: Using SEED=42 for sampling.


Unnamed: 0,TargetName,TargetLevel,NObs,NSamp
0,Category,,5572,3900

Unnamed: 0,PartIndName,TargetName
0,_PartInd1_,Category

Unnamed: 0,casLib,Name,Label,Rows,Columns,casTable
0,mycl,SPAM,,5572,4,"CASTable('SPAM', caslib='mycl')"


In [12]:
# Displays the partitioned dataset

spam_data.head()

Unnamed: 0,Category,Message,docid,_PartInd1_
0,0.0,"Go until jurong point, crazy.. Available only ...",0.0,1.0
1,0.0,Ok lar... Joking wif u oni...,1.0,0.0
2,1.0,Free entry in 2 a wkly comp to win FA Cup fina...,2.0,1.0
3,0.0,U dun say so early hor... U c already then say...,3.0,0.0
4,0.0,"Nah I don't think he goes to usf, he lives aro...",4.0,0.0


In [13]:
# Displays the distribution of the partition indicator

spam_data["_PartInd1_"].value_counts(normalize = True)

1.0    0.699928
0.0    0.300072
dtype: float64

## Model Training

**NOTE:** Due to time constraints with the demonstration the BERT model will not be trained live. Instead, the trained model table will be loaded into memory. 

In [14]:
# # Trains a BERT text classifier

# conn.textclassifier.traintextclassifier(table = dict(name = "SPAM", where = "_PartInd1_ = 1"),
#                                         validtable = dict(name = "SPAM", where = "_PartInd1_ = 0"),
#                                         target = "Category",
#                                         text = "Message",
#                                         gpu = True,
#                                         seed = 42,
#                                         modelOut = dict(name = "bert_classifier", replace = True))

NOTE: Using GPU 0 on controller.
NOTE: Train Loss   Train Accuracy  Validation Loss    Validation Accuracy
NOTE: 0.169  93.231%  0.058  98.684%
NOTE: 0.036  99.179%  0.025  99.103%
NOTE: 0.010  99.667%  0.025  99.043%
NOTE: 0.005  99.692%  0.025  99.043%
NOTE: trainTextClassifier completed successfully.


Unnamed: 0,_epoch_,_train_loss_,_train_accuracy_,_validation_loss_,_validation_accuracy_
0,1,0.168657,93.230772,0.057536,98.68421
1,2,0.036237,99.179488,0.025145,99.102873
2,3,0.010406,99.666667,0.02461,99.043059
3,4,0.005277,99.692309,0.024757,99.043059


In [None]:
# # Saves the model table 

# conn.table.save(caslib="mycl", table = "bert_classifier", name="bert_classifier", replace = True)

In [14]:
# Loads the model table into memory

conn.table.loadtable(caslib = "mycl", path = "bert_classifier.sashdat", casout = dict(name = "bert_classifier", replace = True))

NOTE: Cloud Analytic Services made the file bert_classifier.sashdat available as table BERT_CLASSIFIER in caslib mycl.


## BERT Model Scoring

In [15]:
# Assesses the BERT classifier on the validation dataset

conn.textclassifier.scoretextclassifier(table = dict(name = "SPAM", where = "_PartInd1_ = 0"),
                                        text = "Message",
                                        docid = "docid",
                                        model = "bert_classifier",
                                        casOut = dict(name = "bert_preds", replace = True),
                                        gpu = True)

NOTE: Using GPU 0 on controller.
NOTE: scoreTextClassifier completed successfully.


In [16]:
# Creates a CASTable object of the predictions table

preds = conn.CASTable("bert_preds")
preds.shape

(1672, 3)

In [17]:
# Performs an Inner join to add the real outcome to the predictions table

conn.fedSQL.execdirect(query = 
                      """
                      CREATE TABLE JOINED AS
                      SELECT PRED.*, ACT.Category, ACT.Message
                      FROM mycl.BERT_PREDS AS PRED 
                      INNER JOIN mycl.SPAM as ACT ON
                          (PRED.docid = ACT.docid);
                      """
                      )

NOTE: Table JOINED was created in caslib mycl with 1672 rows returned.


In [18]:
# Displays the information for the new predictions table

preds = conn.CASTable("JOINED").to_frame()
preds.info()

<class 'swat.dataframe.SASDataFrame'>
Index: 1672 entries, 0 to 1671
Data columns (total 5 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   docid     1672 non-null   float64
 1   _label_   1672 non-null   object 
 2   _score_   1672 non-null   float64
 3   Category  1672 non-null   float64
 4   Message   1672 non-null   object 
dtypes: float64(3), object(2)
memory usage: 78.4+ KB


In [19]:
# Changes the type of the _label_ column

preds["_label_"] = preds["_label_"].astype("float64")

## BERT Model Assessment

**NOTE:** The percentile action from the GAN demo could be used to assess the predictions but a bit of data manipulation needs to occur. Writing your own assessments is a valid alternative.

In [20]:
# Creates a formula to compute a variety of performance metrics

def calc_performance(table, actual, pred):
    """
    Creates a function to calculate the following performance metrics:
    * True positive count
    * False positive count
    * True negative count
    * False negative count
    * Recall
    * Precision
    * Specificity
    * Balanced Accuracy
    * KS
    * F-Score
    
    Parameters
    ----------
    table: DataFrame
        Name of the pandas dataframe that contains the actual and predicted values
    actual: string
        Name of column containing actual outcomes
    pred: string
        Name of column containing predictions
        
    Returns
    ---------
    df
        Pandas DataFrame object containing the metrics listed above
    """
    # Copies the dataframe and creates count metrics
    
    data = table.copy()
    data[pred] = data[pred].astype("float64")
    data["TP"], data["FP"], data["TN"], data["FN"] = [0, 0, 0, 0]
    
    # Assigns values to the 4 count metric columns
    
    data.loc[(data[actual] == 1) & (data[pred] == 1), "TP"] = 1
    data.loc[(data[actual] == 0) & (data[pred] == 1), "FP"] = 1
    data.loc[(data[actual] == 0) & (data[pred] == 0), "TN"] = 1
    data.loc[(data[actual] == 1) & (data[pred] == 0), "FN"] = 1
    
    # Computes the column-wise totals
    
    TP = data["TP"].sum()
    FP = data["FP"].sum()
    TN = data["TN"].sum()
    FN = data["FN"].sum()   
    
    # Creates the new data frame containing the performance metrics
    
    df = pd.DataFrame(data = {"TP":[TP], "FP":[FP], "TN":[TN], "FN":[FN]})
    
    # Computes the remaining summary statistics
    
    df["recall"] = df["TP"]/(df["TP"] + df["FN"])
    df["precision"] = df["TP"]/(df["TP"] + df["FP"])
    df["specificity"] = df["TN"]/(df["TN"] + df["FP"])
    df["balanced_accuracy"] = (df["recall"] + df["specificity"])/2
    df["KS"] = df["recall"] + df["specificity"] - 1
    df["F_score"] = 2*((df["precision"]*df["recall"])/(df["precision"] + df["recall"]))
    
    return df

In [21]:
# Applies the function to calculate performance metrics

performance = calc_performance(preds, "Category", "_label_")
performance.head()

Unnamed: 0,TP,FP,TN,FN,recall,precision,specificity,balanced_accuracy,KS,F_score
0,192,7,1464,9,0.955224,0.964824,0.995241,0.975233,0.950465,0.96


In [22]:
# Displays some of the incorrect predictions

false_preds = preds.copy()
false_preds = false_preds.loc[false_preds["_label_"] != false_preds["Category"], :]
false_preds.head()

Unnamed: 0,docid,_label_,_score_,Category,Message
231,752.0,0.0,0.53485,1.0,You have an important customer service announc...
245,822.0,1.0,0.415425,0.0,On the road so cant txt
296,989.0,1.0,0.14747,0.0,Yun ah.the ubi one say if ü wan call by tomorr...
436,1430.0,0.0,0.599274,1.0,For sale - arsenal dartboard. Good condition b...
483,1612.0,1.0,1.132281,0.0,645


In [23]:
# Ends the session and frees up the resources from memory

conn.session.endsession()