# Text classification: classify spam email and non-spam email by LSTM model and BERT model in Tensorflow

Confirm that TensorFlow is using the GPU

In [3]:
import tensorflow as tf
print("The Number of GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

The Number of GPUs Available:  1


In [4]:
import numpy as np
from numpy import *
import pandas as pd
import os

import tensorflow as tf
from transformers import TFBertModel, BertTokenizer, TFBertForSequenceClassification, BertConfig
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import LSTM
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error
from keras.preprocessing.sequence import TimeseriesGenerator
from keras import callbacks
from keras.layers import Dropout
from sklearn.model_selection import RandomizedSearchCV
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import LabelEncoder
import nltk
from nltk.corpus import wordnet
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk import word_tokenize
from nltk import pos_tag
from nltk.corpus import wordnet
from nltk.stem import WordNetLemmatizer
nltk.download('stopwords')
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('wordnet')
from sklearn.model_selection import train_test_split
import torch

from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
from nltk.classify.util import accuracy
from sklearn.feature_extraction import DictVectorizer
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import VotingClassifier
from scikeras.wrappers import KerasRegressor

from keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import Embedding
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import SMOTE
import warnings
warnings.filterwarnings('ignore')
import time
from collections import Counter

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [5]:
directory = '/content/sample_data'
os.chdir(directory)

In [6]:
# Read the dataframe

df = pd.read_csv('Spam-Classification.csv')
df.iloc[:,0:2]

Unnamed: 0,v1,v2
0,ham,"Go until jurong point, crazy.. Available only ..."
1,ham,Ok lar... Joking wif u oni...
2,spam,Free entry in 2 a wkly comp to win FA Cup fina...
3,ham,U dun say so early hor... U c already then say...
4,ham,"Nah I don't think he goes to usf, he lives aro..."
...,...,...
5567,spam,This is the 2nd time we have tried 2 contact u...
5568,ham,Will Ì_ b going to esplanade fr home?
5569,ham,"Pity, * was in mood for that. So...any other s..."
5570,ham,The guy did some bitching but I acted like i'd...


In [7]:
# Label encoding

label = LabelEncoder().fit_transform(df.iloc[:,0])
label

array([0, 0, 1, ..., 0, 0, 0])

# Data cleaning

Remove or replace meaningless data using regular expressions

In [8]:
message=df.iloc[:,1].str.lower().copy()

message=message.str.replace(r'[^ ]+@[^\.]*(\.[a-z]{2,}){1,2}', 'emailaddress', regex=True)
message=message.str.replace(r'(?:http\:\/\/|www.){1}(?:http\:\/\/|www.)?(?![www])[a-zA-Z0-9\-]+(\.[a-zA-Z]{2,3}){1,2}(\/[^/& ]+)*', 'webaddress', regex=True)
message=message.str.replace(r'(£|\$|€|¥|₣|å£){1}\d+(.\d+)?', 'money', regex=True)
message=message.str.replace(r'\b\+?\d{1}[\d\s-]{5,13}\d{1}\b|\b[\d]{4}[\s-]?[\d]{3}[\s-]?[\d]{4}\b', 'phonenumber', regex=True)

# date
# day/month/year
message=message.str.replace(r'\b(3[01]|[12][0-9]|[0]?[1-9]){1}[\ |\-|\/]{1}(1[0-2]|[0]?[1-9]){1}[\ |\-|\/]{1}(\d{2}|\d{4})(?!\d{3})\b', 'date', regex=True)
# year/month/day
message=message.str.replace(r'\b(\d{2}|\d{4})(?!\d{3})[\ |\-|\/]{1}(1[0-2]|[0]?[1-9]){1}[\ |\-|\/]{1}(3[01]|[12][0-9]|[0]?[1-9]){1}\b', 'date', regex=True)
# day/month(english)/year
message=message.str.replace(r'\b(3[01]|[12][0-9]|[0]?[1-9]){1}([\ ]|st|nd|rd|th){1,2}(jan(?:uary)?|feb(?:ruary)?|mar(?:ch)?|apr(?:il)?|may|jun(?:e)?|jul(?:y)?|aug(?:ust)?|sep(?:tember)?|oct(?:ober)?|nov(?:ember)?|dec(?:ember)?){1}\b(?:([\ ]|,){0,2}(\d{4}|\d{2}))?\b', 'date', regex=True)
# month(english)/day/year
message=message.str.replace(r'\b(jan(?:uary)?|feb(?:ruary)?|mar(?:ch)?|apr(?:il)?|may|jun(?:e)?|jul(?:y)?|aug(?:ust)?|sep(?:tember)?|oct(?:ober)?|nov(?:ember)?|dec(?:ember)?){1}[\ ]?(3[01]|[12][0-9]|[0]?[1-9]){1}((?: )?st|(?: )?nd|(?: )?rd|(?: )?th){0,1}(?:([\ ]|,){0,2}(\d{4}|\d{2}))?\b', 'date', regex=True)


message=message.str.replace(r'\b\d+\b', 'number', regex=True)
message=message.str.replace(r'[^\w\d\s]+', ' ', regex=True)
message=message.str.replace(r'\s+', ' ', regex=True)
message=message.str.replace(r'^\s+|\s+?$', '', regex=True)

Remove remove common words like 'a', 'the' using stopwords

In [9]:
common_words=set(stopwords.words('english'))

message=message.apply(lambda m: ' '.join(word for word in m.split() if word not in common_words))

Change the form of words like 'running' to 'run' using Lemmatization

In [10]:
def convert_pos(tag):
    if tag in ['JJ', 'JJR', 'JJS']:
        return wordnet.ADJ
    elif tag in ['VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ']:
        return wordnet.VERB
    elif tag in ['NN', 'NNS', 'NNP', 'NNPS']:
        return wordnet.NOUN
    elif tag in ['RB', 'RBR', 'RBS']:
        return wordnet.ADV
    else:
        return wordnet.NOUN

def convert_pos_2(tag):
    if tag.startswith('J'):
        return wordnet.ADJ
    elif tag.startswith('V'):
        return wordnet.VERB
    elif tag.startswith('N'):
        return wordnet.NOUN
    elif tag.startswith('R'):
        return wordnet.ADV
    else:
        return wordnet.NOUN

In [11]:
for i in range(len(message)):
    tokens = word_tokenize(message[i])
    tagged = pos_tag(tokens)

    tokens_list_tokenized = []

    for tag in tagged:
        wordnet_pos = convert_pos(tag[1])
        tokens_list_tokenized.append(WordNetLemmatizer().lemmatize(tag[0], pos=wordnet_pos))

    tokens_to_string = ' '.join(tokens_list_tokenized)
    message[i]=tokens_to_string

In [12]:
i=0
print('Mail before cleaning: ', df.iloc[i,1].lower(), '\n')
print('Mail after cleaning: ', message[i])

Mail before cleaning:  go until jurong point, crazy.. available only in bugis n great world la e buffet... cine there got amore wat... 

Mail after cleaning:  go jurong point crazy available bugis n great world la e buffet cine get amore wat


# Tokenization

Count the total tokens of all data

In [13]:
all_words = []

for m in message:
    words = word_tokenize(m)
    for w in words:
        all_words.append(w)

all_words = nltk.FreqDist(all_words)

print('Total number of Token:', len(all_words))

Total number of Token: 7008


Shuffle data and split data into training set and testing set

In [14]:
df_message = list(zip(message, label))


np.random.shuffle(df_message)

message_split = [text[0] for text in df_message]
list_label = [text[1] for text in df_message]

x_training, x_test, y_training, y_test= train_test_split(message_split, list_label, test_size=0.2, random_state=1)

print('The number of training data: ', len(x_training))
print('The number of testing data: ', len(x_test))

The number of training data:  4457
The number of testing data:  1115


Tokenized all the words and choose the 1500 most common words

In [15]:
tokenizer = Tokenizer(num_words = 500, char_level = False, oov_token = '<OOV>')
tokenizer.fit_on_texts(x_training)

x_training_sequences = tokenizer.texts_to_sequences(x_training)
x_test_sequences = tokenizer.texts_to_sequences(x_test)

Find the max len of words of all messages

In [16]:
max_len=0

for i in range(len(x_training_sequences)):
    if len(x_training_sequences[i])>max_len:
        max_len=len(x_training_sequences[i])

for i in range(len(x_test_sequences)):
    if len(x_test_sequences[i])>max_len:
        max_len=len(x_test_sequences[i])

print('The max len of words of all messages: ', max_len)

The max len of words of all messages:  78


Using pad_sequences turns all the messages with same len of words

In [17]:
x_training_padded = pad_sequences(x_training_sequences, maxlen = max_len, padding='pre', truncating='pre')
x_test_padded = pad_sequences(x_test_sequences, maxlen = max_len, padding='pre', truncating='pre')

print('The shape of training data: ', x_training_padded.shape)
print('The shape of testing data: ', x_test_padded.shape)

The shape of training data:  (4457, 78)
The shape of testing data:  (1115, 78)


# Class imbalance problem

The dataset has a class imbalance problem, the number data with class ham is larger than the number data with class spam

In [18]:
df.iloc[:,0].value_counts()

ham     4825
spam     747
Name: v1, dtype: int64

Therefore, calculate the class weights by using sklearn.utils.class_weight.compute_class_weight as the class_weight parameter for training model to solve the class imbalance problem

In [19]:
class_weights = compute_class_weight(class_weight='balanced',classes=np.unique(y_training), y=y_training)
class_weight_dic = {0: class_weights[0], 1: class_weights[1]}
class_weight_dic

{0: 0.5768832513590474, 1: 3.7516835016835017}

# Hyperparameter optimization

Using Grid Search to optimize the two hyperparameters: neurons and batch size

In [20]:
start = time.time()

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

def create_model(neurons):
    model = Sequential()
    model.add(Embedding(input_dim=500, output_dim=neurons, input_length=max_len))
    model.add(Dropout(0.1))
    model.add(LSTM(neurons, return_sequences=True))
    model.add(Dropout(0.1))
    model.add(LSTM(neurons))
    model.add(Dropout(0.1))
    model.add(Dense(units=(neurons/2), activation='relu'))
    model.add(Dropout(0.1))
    model.add(Dense(1, activation='sigmoid'))
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

neurons = [128, 256, 512]
batch_size = [256, 512]
epochs = [30]
hypara_dict = dict(model__neurons=neurons, batch_size=batch_size, epochs=epochs)

class_weights = compute_class_weight(class_weight='balanced',classes=np.unique(y_training), y=y_training)
class_weight_dic = {0: class_weights[0], 1: class_weights[1]}

keras_model = KerasRegressor(model=create_model, verbose=0)
grid_model = GridSearchCV(estimator=keras_model, param_grid=hypara_dict, n_jobs=1, cv=3)
grid_model.fit(x_training_padded, array(y_training),
               validation_data = (x_test_padded, array(y_test)),
               class_weight=class_weight_dic)

print('time used: ' + str(round(time.time() - start, 2)) + ' seconds')
print('')
print('The optimized hyperparameters are: ')
print(grid_model.best_params_)

time used: 1092.16 seconds

The optimized hyperparameters are: 
{'batch_size': 512, 'epochs': 30, 'model__neurons': 512}


# Train lstm model using the optimal hyperparameters

1. Using callbacks techniques:
    
    callbacks.EarlyStopping: Stop model training when a monitored metric has stopped improving in 10 Epoch
        
    callbacks.ModelCheckpoint: Save a model and weights in a checkpoint file at some interval

        
2. Using class_weight parameter when doing model fitting to solve the class imbalance problem

In [21]:
#crete the lstm model using the best hyperparameters

best_parameter_lstm_model=create_model(grid_model.best_params_['model__neurons'])
best_parameter_lstm_model.summary()

Model: "sequential_19"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding_19 (Embedding)    (None, 78, 512)           256000    
                                                                 
 dropout_76 (Dropout)        (None, 78, 512)           0         
                                                                 
 lstm_38 (LSTM)              (None, 78, 512)           2099200   
                                                                 
 dropout_77 (Dropout)        (None, 78, 512)           0         
                                                                 
 lstm_39 (LSTM)              (None, 512)               2099200   
                                                                 
 dropout_78 (Dropout)        (None, 512)               0         
                                                                 
 dense_38 (Dense)            (None, 256)             

In [22]:
best_parameter_lstm_model=create_model(grid_model.best_params_['model__neurons'])

earlystopping=callbacks.EarlyStopping(monitor='val_accuracy', mode='max', patience=10, restore_best_weights=False)
filepath='bestcheckpoint_nlp_spam_classification.hdf5'
checkpoint=callbacks.ModelCheckpoint(filepath, monitor='val_accuracy', mode='max', save_best_only=True)

class_weights = compute_class_weight(class_weight='balanced',classes=np.unique(y_training), y=y_training)
class_weight_dic = {0: class_weights[0], 1: class_weights[1]}

best_parameter_lstm_model.fit(x_training_padded, array(y_training), validation_data = (x_test_padded, array(y_test)),
                         batch_size = grid_model.best_params_['batch_size'],
                         epochs = grid_model.best_params_['epochs'],
                         verbose = 2, callbacks=[earlystopping, checkpoint], class_weight=class_weight_dic)

Epoch 1/30
9/9 - 9s - loss: 0.5725 - accuracy: 0.7689 - val_loss: 0.6681 - val_accuracy: 0.7363 - 9s/epoch - 960ms/step
Epoch 2/30
9/9 - 3s - loss: 0.3277 - accuracy: 0.8517 - val_loss: 0.1631 - val_accuracy: 0.9525 - 3s/epoch - 320ms/step
Epoch 3/30
9/9 - 3s - loss: 0.1386 - accuracy: 0.9675 - val_loss: 0.1080 - val_accuracy: 0.9677 - 3s/epoch - 339ms/step
Epoch 4/30
9/9 - 3s - loss: 0.0796 - accuracy: 0.9800 - val_loss: 0.1105 - val_accuracy: 0.9650 - 3s/epoch - 325ms/step
Epoch 5/30
9/9 - 3s - loss: 0.0570 - accuracy: 0.9823 - val_loss: 0.0670 - val_accuracy: 0.9830 - 3s/epoch - 339ms/step
Epoch 6/30
9/9 - 3s - loss: 0.0418 - accuracy: 0.9841 - val_loss: 0.0571 - val_accuracy: 0.9857 - 3s/epoch - 329ms/step
Epoch 7/30
9/9 - 3s - loss: 0.0279 - accuracy: 0.9942 - val_loss: 0.1039 - val_accuracy: 0.9695 - 3s/epoch - 316ms/step
Epoch 8/30
9/9 - 3s - loss: 0.0253 - accuracy: 0.9890 - val_loss: 0.0628 - val_accuracy: 0.9857 - 3s/epoch - 332ms/step
Epoch 9/30
9/9 - 3s - loss: 0.0187 - acc

<keras.src.callbacks.History at 0x7eb1ebff9540>

Load the model again by the saved file

In [23]:
best_parameter_lstm_model.load_weights(filepath)
best_parameter_lstm_model.evaluate(x_test_padded, array(y_test))



[0.08686558902263641, 0.9901345372200012]

Doing prediction for the test data

In [24]:
predictions = best_parameter_lstm_model.predict(x_test_padded)
predicted_labels = (predictions > 0.5).astype(np.int64)

label_list=['ham','spam']
predicted_label_list=[]
for i in range(len(predicted_labels.reshape(1,-1)[0])):
    predicted_label_list.append(label_list[predicted_labels.reshape(1,-1)[0][i]])

print('The predicted label for test data: ')
print(array(predicted_label_list))

The predicted label for test data: 
['ham' 'ham' 'ham' ... 'ham' 'ham' 'ham']


Calculate the accuracy and the number of error of the prediction by the LSTM model

In [25]:
error_num=0
error_num_class_0=0
error_num_class_1=0
for i in range(len(array(y_test))):
    if predicted_labels.reshape(1,-1)[0][i]!=array(y_test)[i]:
        error_num+=1
        if array(y_test)[i]==0:
            error_num_class_0+=1
        elif array(y_test)[i]==1:
            error_num_class_1+=1


print('')
y_test_count = Counter(y_test)
for element, count in y_test_count.items():
  print(f'number of class {element} in test data: {count}')
  if element==0:
    print('The num error of class 0: ', error_num_class_0)
    print('Accuracy of class 0: ', 1-error_num_class_0/count)
  if element==1:
    print('The num error of class 1: ', error_num_class_1)
    print('Accuracy of class 1: ', 1-error_num_class_1/count)
  print('')


print('')
print('The total num of error: ', error_num)
print('The num of test data: ', len(array(y_test)))
print('Total accuracy: ', 1-error_num/len(array(y_test)))


number of class 0 in test data: 962
The num error of class 0:  2
Accuracy of class 0:  0.997920997920998

number of class 1 in test data: 153
The num error of class 1:  9
Accuracy of class 1:  0.9411764705882353


The total num of error:  11
The num of test data:  1115
Total accuracy:  0.9901345291479821


# Fine-tuning BERT model

After training the LSTM model, lets try fine-tuning BERT model for Text Classification

Download the pre-trained BERT model and Tokenizer from Huggingface

In [26]:
bert_model = TFBertModel.from_pretrained('bert-base-uncased')
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions w

Data processing again without using stopwords and Lemmatization

In [27]:
message=df.iloc[:,1].str.lower().copy()

message=message.str.replace(r'[^ ]+@[^\.]*(\.[a-z]{2,}){1,2}', 'emailaddress', regex=True)
message=message.str.replace(r'(?:http\:\/\/|www.){1}(?:http\:\/\/|www.)?(?![www])[a-zA-Z0-9\-]+(\.[a-zA-Z]{2,3}){1,2}(\/[^/& ]+)*', 'webaddress', regex=True)
message=message.str.replace(r'(£|\$|€|¥|₣|å£){1}\d+(.\d+)?', 'money', regex=True)
message=message.str.replace(r'\b\+?\d{1}[\d\s-]{5,13}\d{1}\b|\b[\d]{4}[\s-]?[\d]{3}[\s-]?[\d]{4}\b', 'phonenumber', regex=True)

# date
# day/month/year
message=message.str.replace(r'\b(3[01]|[12][0-9]|[0]?[1-9]){1}[\ |\-|\/]{1}(1[0-2]|[0]?[1-9]){1}[\ |\-|\/]{1}(\d{2}|\d{4})(?!\d{3})\b', 'date', regex=True)
# year/month/day
message=message.str.replace(r'\b(\d{2}|\d{4})(?!\d{3})[\ |\-|\/]{1}(1[0-2]|[0]?[1-9]){1}[\ |\-|\/]{1}(3[01]|[12][0-9]|[0]?[1-9]){1}\b', 'date', regex=True)
# day/month(english)/year
message=message.str.replace(r'\b(3[01]|[12][0-9]|[0]?[1-9]){1}([\ ]|st|nd|rd|th){1,2}(jan(?:uary)?|feb(?:ruary)?|mar(?:ch)?|apr(?:il)?|may|jun(?:e)?|jul(?:y)?|aug(?:ust)?|sep(?:tember)?|oct(?:ober)?|nov(?:ember)?|dec(?:ember)?){1}\b(?:([\ ]|,){0,2}(\d{4}|\d{2}))?\b', 'date', regex=True)
# month(english)/day/year
message=message.str.replace(r'\b(jan(?:uary)?|feb(?:ruary)?|mar(?:ch)?|apr(?:il)?|may|jun(?:e)?|jul(?:y)?|aug(?:ust)?|sep(?:tember)?|oct(?:ober)?|nov(?:ember)?|dec(?:ember)?){1}[\ ]?(3[01]|[12][0-9]|[0]?[1-9]){1}((?: )?st|(?: )?nd|(?: )?rd|(?: )?th){0,1}(?:([\ ]|,){0,2}(\d{4}|\d{2}))?\b', 'date', regex=True)

message=message.str.replace(r'\b\d+\b', 'number', regex=True)
message=message.str.replace(r'[^\w\d\s\']+', ' ', regex=True)
message=message.str.replace(r'\s+', ' ', regex=True)
message=message.str.replace(r'^\s+|\s+?$', '', regex=True)

df_message = list(zip(message, label))
np.random.shuffle(df_message)
message_split = [text[0] for text in df_message]
list_label = [text[1] for text in df_message]
x_training, x_test, y_training, y_test= train_test_split(message_split, list_label, test_size=0.2, random_state=1)

max_len=0

for i in range(len(x_training)):
  if len(bert_tokenizer.tokenize(x_training[i]))>max_len:
    max_len=len(bert_tokenizer.tokenize(x_training[i]))
for i in range(len(x_test)):
  if len(bert_tokenizer.tokenize(x_test[i]))>max_len:
    max_len=len(bert_tokenizer.tokenize(x_test[i]))


Tokenize the data by the BERT Tokenizer and create train and test dataset

In [28]:
# Train dataset
tokens_training = bert_tokenizer.batch_encode_plus(x_training, max_length=max_len, padding='max_length', truncation=True)
x_training_encoded = np.array(tokens_training['input_ids'])
x_training_am = np.array(tokens_training['attention_mask'])

# Test dataset
tokens_test = bert_tokenizer.batch_encode_plus(x_test, max_length=max_len, padding='max_length', truncation=True)
x_test_encoded = np.array(tokens_test['input_ids'])
x_test_am = np.array(tokens_test['attention_mask'])

Fine tune pre-trained Bert mdoel by defining model layers and model fitting

In [29]:
input_layer_1 = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name='input_ids')
input_layer_2 = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name='attention_mask')

hidden_layer_1 = bert_model([input_layer_1, input_layer_2])[1]
hidden_layer_2 = Dense(256, activation='relu')(hidden_layer_1)
dropout_layer = Dropout(0.1)(hidden_layer_2)
output_layer = Dense(1, activation='sigmoid')(dropout_layer)

model = tf.keras.Model(inputs=[input_layer_1, input_layer_2], outputs=output_layer)
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_ids (InputLayer)      [(None, 198)]                0         []                            
                                                                                                  
 attention_mask (InputLayer  [(None, 198)]                0         []                            
 )                                                                                                
                                                                                                  
 tf_bert_model (TFBertModel  TFBaseModelOutputWithPooli   1094822   ['input_ids[0][0]',           
 )                           ngAndCrossAttentions(last_   40         'attention_mask[0][0]']      
                             hidden_state=(None, 198, 7                                       

In [30]:
earlystopping_ft=callbacks.EarlyStopping(monitor='val_accuracy', mode='max', patience=3, restore_best_weights=False)
filepath_ft='bestcheckpoint_fine_tune.hdf5'
checkpoint_ft=callbacks.ModelCheckpoint(filepath_ft, monitor='val_accuracy', mode='max', save_best_only=True)

model.fit([x_training_encoded, x_training_am], np.array(y_training), validation_data = ([x_test_encoded, x_test_am], np.array(y_test)),
      batch_size = 16,
      epochs = 15,
      verbose = 2, callbacks=[earlystopping_ft, checkpoint_ft])

Epoch 1/15
279/279 - 266s - loss: 0.0784 - accuracy: 0.9773 - val_loss: 0.0549 - val_accuracy: 0.9865 - 266s/epoch - 953ms/step
Epoch 2/15
279/279 - 225s - loss: 0.0211 - accuracy: 0.9937 - val_loss: 0.0449 - val_accuracy: 0.9901 - 225s/epoch - 808ms/step
Epoch 3/15
279/279 - 194s - loss: 0.0073 - accuracy: 0.9982 - val_loss: 0.0936 - val_accuracy: 0.9865 - 194s/epoch - 696ms/step
Epoch 4/15
279/279 - 195s - loss: 0.0318 - accuracy: 0.9933 - val_loss: 0.0678 - val_accuracy: 0.9821 - 195s/epoch - 700ms/step
Epoch 5/15
279/279 - 195s - loss: 0.0107 - accuracy: 0.9969 - val_loss: 0.1064 - val_accuracy: 0.9839 - 195s/epoch - 699ms/step


<keras.src.callbacks.History at 0x7eb208ca5000>

Load the model again by the saved file

In [31]:
model.load_weights(filepath_ft)
model.evaluate([x_test_encoded, x_test_am], np.array(y_test))



[0.04493414983153343, 0.9901345372200012]

Doing prediction for the test data

In [32]:
predictions_ft = model.predict([x_test_encoded, x_test_am])
predicted_labels_ft = (predictions_ft > 0.5).astype(np.int64)

label_list=['ham','spam']
predicted_label_list_ft=[]
for i in range(len(predicted_labels_ft.reshape(1,-1)[0])):
    predicted_label_list_ft.append(label_list[predicted_labels_ft.reshape(1,-1)[0][i]])

print('The predicted label for test data: ')
print(array(predicted_label_list_ft))

The predicted label for test data: 
['ham' 'ham' 'ham' ... 'spam' 'ham' 'spam']


In [33]:
error_num=0
error_num_class_0=0
error_num_class_1=0
for i in range(len(array(y_test))):
  if predicted_labels_ft.reshape(1,-1)[0][i]!=array(y_test)[i]:
    error_num+=1
    if array(y_test)[i]==0:
      error_num_class_0+=1
    elif array(y_test)[i]==1:
        error_num_class_1+=1


print('')
y_test_count = Counter(y_test)
for element, count in y_test_count.items():
  print(f'number of class {element} in test data: {count}')
  if element==0:
    print('The num error of class 0: ', error_num_class_0)
    print('Accuracy of class 0: ', 1-error_num_class_0/count)
  if element==1:
    print('The num error of class 1: ', error_num_class_1)
    print('Accuracy of class 1: ', 1-error_num_class_1/count)
  print('')


print('')
print('The total num of error: ', error_num)
print('The num of test data: ', len(array(y_test)))
print('Total accuracy: ', 1-error_num/len(array(y_test)))


number of class 0 in test data: 953
The num error of class 0:  1
Accuracy of class 0:  0.9989506820566632

number of class 1 in test data: 162
The num error of class 1:  10
Accuracy of class 1:  0.9382716049382716


The total num of error:  11
The num of test data:  1115
Total accuracy:  0.9901345291479821
