#  Fine-tuning BERT with Tensorflow for Text Classification.

This is the classification based E-commerce text dataset for 4 categories - "Electronics", "Household", "Books" and "Clothing & Accessories" using Bert where almost cover 80% of any E-commerce websiteThe dataset is in ".csv" format with two columns - the first column is the class name and the second one is the datapoint of that class. The data point is the product and description from the e-commerce website.

The dataset has the following features :

Data Set Characteristics: Multivariate

Number of Instances: 50425



In [1]:
!pip install neattext



## Import libraries.

In [2]:
import spacy
import re
import nltk
import string
import sklearn
import neattext as nt
import neattext.functions as nfx
import pandas as pd 
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
nltk.download('stopwords')
from nltk.corpus import stopwords
from collections import Counter
from tensorflow.keras import layers
from tensorflow.keras.layers import Input, Dense
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import classification_report
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.initializers import TruncatedNormal
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import CategoricalAccuracy
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.sequence import pad_sequences

from transformers import AutoTokenizer, TFBertModel
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
bert = TFBertModel.from_pretrained('bert-base-cased')



[nltk_data] Downloading package stopwords to /usr/share/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions w

### Read data with pandas

In [3]:
import pandas as pd

data = pd.read_csv('../input/ecommerceDataset.csv')

In [4]:

# Get the feature names (column names)
feature_names = data.columns

# Print or use the feature names as needed
print("Feature Names:", feature_names)


Feature Names: Index(['Household', 'Paper Plane Design Framed Wall Hanging Motivational Office Decor Art Prints (8.7 X 8.7 inch) - Set of 4 Painting made up in synthetic frame with uv textured print which gives multi effects and attracts towards it. This is an special series of paintings which makes your wall very beautiful and gives a royal touch. This painting is ready to hang, you would be proud to possess this unique painting that is a niche apart. We use only the most modern and efficient printing technology on our prints, with only the and inks and precision epson, roland and hp printers. This innovative hd printing technique results in durable and spectacular looking prints of the highest that last a lifetime. We print solely with top-notch 100% inks, to achieve brilliant and true colours. Due to their high level of uv resistance, our prints retain their beautiful colours for many years. Add colour and style to your living space with this digitally printed painting. Some are for

In [5]:
import pandas as pd

# Assuming you've already read the CSV file into a DataFrame
 

# Get the current feature names (column names)
current_feature_names = data.columns

# Define the old feature names you want to rename
old_feature_name1 = 'Household'
old_feature_name2 =  'Paper Plane Design Framed Wall Hanging Motivational Office Decor Art Prints (8.7 X 8.7 inch) - Set of 4 Painting made up in synthetic frame with uv textured print which gives multi effects and attracts towards it. This is an special series of paintings which makes your wall very beautiful and gives a royal touch. This painting is ready to hang, you would be proud to possess this unique painting that is a niche apart. We use only the most modern and efficient printing technology on our prints, with only the and inks and precision epson, roland and hp printers. This innovative hd printing technique results in durable and spectacular looking prints of the highest that last a lifetime. We print solely with top-notch 100% inks, to achieve brilliant and true colours. Due to their high level of uv resistance, our prints retain their beautiful colours for many years. Add colour and style to your living space with this digitally printed painting. Some are for pleasure and some for eternal bliss.so bring home this elegant print that is lushed with rich colors that makes it nothing but sheer elegance to be to your friends and family.it would be treasured forever by whoever your lucky recipient is. Liven up your place with these intriguing paintings that are high definition hd graphic digital prints for home, office or any room.'
# Define the new feature names
new_feature_name1 = 'type'
new_feature_name2 = 'text'

# Rename the features using the rename method
data.rename(columns={old_feature_name1: new_feature_name1, old_feature_name2: new_feature_name2}, inplace=True)

# Print or use the DataFrame with the updated feature names
print("Updated DataFrame:")
print(data)


Updated DataFrame:
              type                                               text
0        Household  SAF 'Floral' Framed Painting (Wood, 30 inch x ...
1        Household  SAF 'UV Textured Modern Art Print Framed' Pain...
2        Household  SAF Flower Print Framed Painting (Synthetic, 1...
3        Household  Incredible Gifts India Wooden Happy Birthday U...
4        Household  Pitaara Box Romantic Venice Canvas Painting 6m...
...            ...                                                ...
50419  Electronics  Strontium MicroSD Class 10 8GB Memory Card (Bl...
50420  Electronics  CrossBeats Wave Waterproof Bluetooth Wireless ...
50421  Electronics  Karbonn Titanium Wind W4 (White) Karbonn Titan...
50422  Electronics  Samsung Guru FM Plus (SM-B110E/D, Black) Colou...
50423  Electronics                   Micromax Canvas Win W121 (White)

[50424 rows x 2 columns]


### Inspecting the data.

In [6]:
data.head()

Unnamed: 0,type,text
0,Household,"SAF 'Floral' Framed Painting (Wood, 30 inch x ..."
1,Household,SAF 'UV Textured Modern Art Print Framed' Pain...
2,Household,"SAF Flower Print Framed Painting (Synthetic, 1..."
3,Household,Incredible Gifts India Wooden Happy Birthday U...
4,Household,Pitaara Box Romantic Venice Canvas Painting 6m...


In [7]:
data.groupby('type').describe()

Unnamed: 0_level_0,text,text,text,text
Unnamed: 0_level_1,count,unique,top,freq
type,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
Books,11820,6256,Think & Grow Rich About the Author NAPOLEON HI...,30
Clothing & Accessories,8670,5674,Diverse Men's Formal Shirt Diverse is a wester...,23
Electronics,10621,5308,HP 680 Original Ink Advantage Cartridge (Black...,26
Household,19312,10563,Nilkamal Series-24 Chest of Drawers (Cream Tra...,13


In [8]:
data['type'].value_counts()

type
Household                 19312
Books                     11820
Electronics               10621
Clothing & Accessories     8671
Name: count, dtype: int64

In [9]:
data['type'].nunique()

4

In [10]:
data['type'].count()

50424

In [11]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 50424 entries, 0 to 50423
Data columns (total 2 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   type    50424 non-null  object
 1   text    50423 non-null  object
dtypes: object(2)
memory usage: 788.0+ KB


In [12]:
data.shape

(50424, 2)

### The ecommarce classes are aggregated together by type, so we need to shuffle them.

In [13]:
data.head()

Unnamed: 0,type,text
0,Household,"SAF 'Floral' Framed Painting (Wood, 30 inch x ..."
1,Household,SAF 'UV Textured Modern Art Print Framed' Pain...
2,Household,"SAF Flower Print Framed Painting (Synthetic, 1..."
3,Household,Incredible Gifts India Wooden Happy Birthday U...
4,Household,Pitaara Box Romantic Venice Canvas Painting 6m...


In [14]:
data.tail()

Unnamed: 0,type,text
50419,Electronics,Strontium MicroSD Class 10 8GB Memory Card (Bl...
50420,Electronics,CrossBeats Wave Waterproof Bluetooth Wireless ...
50421,Electronics,Karbonn Titanium Wind W4 (White) Karbonn Titan...
50422,Electronics,"Samsung Guru FM Plus (SM-B110E/D, Black) Colou..."
50423,Electronics,Micromax Canvas Win W121 (White)


In [15]:
data.isnull().sum()

type    0
text    1
dtype: int64

In [16]:
data.drop_duplicates(inplace=True)

In [17]:
data.shape

(27802, 2)

### Shuffle data and split it into training and test sets.

In [18]:
data_train, data_test = train_test_split(data, test_size = 0.3, random_state = 42, shuffle = True, stratify = data.type)

## Clean data using neattext library.

### Remove hashtags, multiple spaces and user-handles.

In [19]:
# Filter out non-string values in the 'text' column
data['text'] = data['text'].apply(lambda x: str(x) if pd.notnull(x) else '')
data_train, data_test = train_test_split(data, test_size=0.3, random_state=42, shuffle=True, stratify=data.type)

# Apply the remove_hashtags function to the 'text' column
data_train['text'] = data_train['text'].apply(nfx.remove_hashtags)
data_train.head()


Unnamed: 0,type,text
6400,Household,Cloth Fusion Olio Faux Silk Curtains with 2 Ti...
3277,Household,Rajasthani Ceramic 16 mm Ultrasonic Mist Maker...
45844,Electronics,RiaTech 2 in 1 Screen Cleaning Kit KCL-1042 fo...
50248,Electronics,"Samsung Galaxy J8 (Black, 4GB RAM, 64GB Storag..."
47866,Electronics,Rockville dB 14 4000 Watt/2000w RMS Mono Class...


In [20]:
# data_train['tweet_text'] = data_train['tweet_text'].apply(lambda x: nfx.remove_custom_pattern(x, term_pattern=r'&#\$ '))

data_train['text'] = data_train['text'].apply(nfx.remove_userhandles)

In [21]:
data_train.tail()

Unnamed: 0,type,text
27964,Books,An Introduction to Probability and Inductive L...
31726,Clothing & Accessories,Neo Flick Quick-Dry Athletic Shorts for Men Ne...
12576,Household,ATLANTIS Metal Mini 2 Lane Tea and Coffee Vend...
3912,Household,Seiko QHN006GLH Mantel Clock The Seiko QHN006G...
24314,Books,The Headspace Guide to... Mindfulness & Medita...


In [22]:
data_train[''] = data_train['text'].apply(nfx.remove_multiple_spaces)
data_train.head()

Unnamed: 0,type,text,Unnamed: 3
6400,Household,Cloth Fusion Olio Faux Silk Curtains with 2 Ti...,Cloth Fusion Olio Faux Silk Curtains with 2 Ti...
3277,Household,Rajasthani Ceramic 16 mm Ultrasonic Mist Maker...,Rajasthani Ceramic 16 mm Ultrasonic Mist Maker...
45844,Electronics,RiaTech 2 in 1 Screen Cleaning Kit KCL-1042 fo...,RiaTech 2 in 1 Screen Cleaning Kit KCL-1042 fo...
50248,Electronics,"Samsung Galaxy J8 (Black, 4GB RAM, 64GB Storag...","Samsung Galaxy J8 (Black, 4GB RAM, 64GB Storag..."
47866,Electronics,Rockville dB 14 4000 Watt/2000w RMS Mono Class...,Rockville dB 14 4000 Watt/2000w RMS Mono Class...


### We'll take only a section of our training data because we are running on CPU.

In [23]:
data_train  = data_train[:12000]
data_train.shape

(12000, 3)

In [24]:
data_train.tail()

Unnamed: 0,type,text,Unnamed: 3
18746,Household,TOOLSMART Brass Mini Portable 100 mm/inch Vern...,TOOLSMART Brass Mini Portable 100 mm/inch Vern...
46675,Electronics,Night Owl Optics Xgen Xgenpro 3X Digital Night...,Night Owl Optics Xgen Xgenpro 3X Digital Night...
42366,Electronics,zamp e commerce 15 Pin Male to Male 1.5 Meter ...,zamp e commerce 15 Pin Male to Male 1.5 Meter ...
45834,Electronics,Rnaux Smart Phone Perfumed Cleaning Gel Kit wi...,Rnaux Smart Phone Perfumed Cleaning Gel Kit wi...
23263,Books,The Art of Happiness: A Handbook for Living,The Art of Happiness: A Handbook for Living


In [25]:
data_train['text'] = data_train['text'].apply(nfx.remove_stopwords)
data_train.head()

Unnamed: 0,type,text,Unnamed: 3
6400,Household,Cloth Fusion Olio Faux Silk Curtains 2 Tie Bac...,Cloth Fusion Olio Faux Silk Curtains with 2 Ti...
3277,Household,Rajasthani Ceramic 16 mm Ultrasonic Mist Maker...,Rajasthani Ceramic 16 mm Ultrasonic Mist Maker...
45844,Electronics,RiaTech 2 1 Screen Cleaning Kit KCL-1042 LED L...,RiaTech 2 in 1 Screen Cleaning Kit KCL-1042 fo...
50248,Electronics,"Samsung Galaxy J8 (Black, 4GB RAM, 64GB Storag...","Samsung Galaxy J8 (Black, 4GB RAM, 64GB Storag..."
47866,Electronics,Rockville dB 14 4000 Watt/2000w RMS Mono Class...,Rockville dB 14 4000 Watt/2000w RMS Mono Class...


In [26]:
data_train['text'] = data_train['text'].apply(nfx.remove_urls)
data_train.head()

Unnamed: 0,type,text,Unnamed: 3
6400,Household,Cloth Fusion Olio Faux Silk Curtains 2 Tie Bac...,Cloth Fusion Olio Faux Silk Curtains with 2 Ti...
3277,Household,Rajasthani Ceramic 16 mm Ultrasonic Mist Maker...,Rajasthani Ceramic 16 mm Ultrasonic Mist Maker...
45844,Electronics,RiaTech 2 1 Screen Cleaning Kit KCL-1042 LED L...,RiaTech 2 in 1 Screen Cleaning Kit KCL-1042 fo...
50248,Electronics,"Samsung Galaxy J8 (Black, 4GB RAM, 64GB Storag...","Samsung Galaxy J8 (Black, 4GB RAM, 64GB Storag..."
47866,Electronics,Rockville dB 14 4000 Watt/2000w RMS Mono Class...,Rockville dB 14 4000 Watt/2000w RMS Mono Class...


## Check the different classes in our target variable for train and test sets.

In [27]:
data_train['type'].unique()

array(['Household', 'Electronics', 'Clothing & Accessories', 'Books'],
      dtype=object)

In [28]:
data_test['type'].unique()

array(['Clothing & Accessories', 'Household', 'Books', 'Electronics'],
      dtype=object)

In [29]:
# We'll take only a portion of the test set also.

data_test = data_test[:500]
data_test.shape

(500, 2)

### Encoding the target variable with scikit-learn label encoder. We do this for both train and test sets separately to avoid data leakage. 

In [30]:
label_enc = LabelEncoder() 


In [31]:
data_train['type'] = label_enc.fit_transform(data_train['type'])
data_train.head()

Unnamed: 0,type,text,Unnamed: 3
6400,3,Cloth Fusion Olio Faux Silk Curtains 2 Tie Bac...,Cloth Fusion Olio Faux Silk Curtains with 2 Ti...
3277,3,Rajasthani Ceramic 16 mm Ultrasonic Mist Maker...,Rajasthani Ceramic 16 mm Ultrasonic Mist Maker...
45844,2,RiaTech 2 1 Screen Cleaning Kit KCL-1042 LED L...,RiaTech 2 in 1 Screen Cleaning Kit KCL-1042 fo...
50248,2,"Samsung Galaxy J8 (Black, 4GB RAM, 64GB Storag...","Samsung Galaxy J8 (Black, 4GB RAM, 64GB Storag..."
47866,2,Rockville dB 14 4000 Watt/2000w RMS Mono Class...,Rockville dB 14 4000 Watt/2000w RMS Mono Class...


In [32]:
data_test['type'] = label_enc.transform(data_test['type'])


In [33]:
data_test['type'].unique()

array([1, 3, 0, 2])

In [34]:
data_train.dtypes

type     int64
text    object
        object
dtype: object

In [35]:
data_train['type'].unique()

array([3, 2, 1, 0])

In [36]:
data_train['type'].value_counts()

type
3    4551
0    2654
1    2491
2    2304
Name: count, dtype: int64

#### Label encoder encodes data by alphabetical order. 

### Tokenize train input text with bert's Autotokenizer that we imported earlier.

In [37]:
x_train = tokenizer(
#     text = x_train.tolist(),
    text = data_train['text'].tolist(),
    add_special_tokens = True,
    max_length = 100,
    truncation = True,
    padding = True,
    return_tensors = 'tf',
    return_token_type_ids = False,
    return_attention_mask = True,
    verbose = True
)

In [38]:
x_train['input_ids']

<tf.Tensor: shape=(12000, 100), dtype=int32, numpy=
array([[  101,   140,  7841, ...,  1116,   119,   102],
       [  101, 17988,  1182, ...,  7857,   107,   102],
       [  101,   155,  1465, ...,     0,     0,     0],
       ...,
       [  101,   195, 19471, ...,     0,     0,     0],
       [  101,   155,  1605, ...,     0,     0,     0],
       [  101,  2051, 25410, ...,     0,     0,     0]], dtype=int32)>

### Building the model and fine-tuning the model.

The first value returned by BERT model at index 0 is the last hidden state, 1 means pooler_output
We need only the hidden state, so that we can add more layers and fine-tune the model.
We'll use functional API

In [39]:
max_len = 100


input_ids = Input(shape=(max_len,), dtype=tf.int32, name='input_ids')
input_mask = Input(shape=(max_len,), dtype=tf.int32, name='attention_mask')

# 0 is the last hidden state, 1 means pooler_output
# We need only the hidden state, so that we can add more layers and fine-tune the model.
# We'll use functional API
embeddings = bert(input_ids, attention_mask=input_mask)[0]
out = tf.keras.layers.GlobalMaxPool1D()(embeddings)
out = Dense(128, activation='relu')(out)
out = tf.keras.layers.Dropout(0.1)(out)
out = Dense(32, activation='relu')(out)

y = Dense(4, activation='sigmoid')(out)

model = tf.keras.Model(inputs=[input_ids, input_mask], outputs=y)
model.layers[2].trainable = True

### Compile the model.

In [40]:
# Bert model requires a specific learning rate as stated in the huggingface website

optimizer = tf.keras.optimizers.legacy.Adam(
    learning_rate=5e-05,
    epsilon=1e-08,
    decay=0.01,
    clipnorm=1.0
)

loss = CategoricalCrossentropy(from_logits=True)
metric = CategoricalAccuracy('balanced_accuracy')

model.compile(
    optimizer=optimizer,
    loss=loss,
    metrics=metric)

model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_ids (InputLayer)      [(None, 100)]                0         []                            
                                                                                                  
 attention_mask (InputLayer  [(None, 100)]                0         []                            
 )                                                                                                
                                                                                                  
 tf_bert_model (TFBertModel  TFBaseModelOutputWithPooli   1083102   ['input_ids[0][0]',           
 )                           ngAndCrossAttentions(last_   72         'attention_mask[0][0]']      
                             hidden_state=(None, 100, 7                                       

**Tokenize test data.**

In [41]:
x_test = tokenizer(
    text = data_test['text'].tolist(), 
    add_special_tokens = True,
    max_length = 100,
    truncation = True,
    padding = True,
    return_tensors = 'tf',
    return_token_type_ids = False,
    return_attention_mask = True,
    verbose = True
)


In [42]:
bert_train = model.fit(
    x={'input_ids':x_train['input_ids'], 'attention_mask':x_train['attention_mask']},
    y=to_categorical(data_train.type),
    validation_data=(
        {'input_ids':x_test['input_ids'], 'attention_mask':x_test['attention_mask']},to_categorical(data_test.type)
    ),
    epochs=2,
    batch_size=36
)

Epoch 1/2


  output, from_logits = _get_logits(


Epoch 2/2


### We achieved balanced accuracy of approximately 95.86 % and validation balanced accuracy of approximately 95.20%.We can easily enhance the model performance by using more epochs and various thing.But now we doest not have that much configuration for running more largely.Thus here we only focused on Bert model finetuing

#### Tokenize test data.

In [49]:
 

model.save_weights('model.h5')

# To use the mode again, load it into memory
# model.load_weights('model_cyber.h5')

In [50]:
pred_raw = model.predict({'input_ids':x_test['input_ids'], 'attention_mask':x_test['attention_mask']})



We want to check the prediction of the first input in the test set. Bert gives the probability of each class. We'll use np.argmax to get the index oc the highest probability.

In [51]:
pred_raw[0]

array([0.06673667, 0.9845167 , 0.0194209 , 0.04064174], dtype=float32)

In [52]:
y_pred = np.argmax(pred_raw, axis=1)

In [53]:
data_test.type 

33602    1
36756    1
16798    3
5732     3
27364    0
        ..
23822    0
9646     3
22844    0
7328     3
26886    0
Name: type, Length: 500, dtype: int64

#### Checking the classification report.

These values in the classification report look good.

In [54]:
print(classification_report(data_test.type, y_pred))

              precision    recall  f1-score   support

           0       0.96      0.94      0.95       128
           1       0.97      0.98      0.97        91
           2       0.95      0.93      0.94       102
           3       0.94      0.96      0.95       179

    accuracy                           0.95       500
   macro avg       0.95      0.95      0.95       500
weighted avg       0.95      0.95      0.95       500

