In [5]:
import pickle
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix


In [6]:
file = open('data.pkl', 'rb')
# dump information to that file
data = pickle.load(file)

In [7]:
data

Unnamed: 0,body_basic,label
0,Fly fishing is an angling method that uses a ...,fly_fishing
1,Simms GORE TEX ExStream Cap 89 95 Simms Dockwe...,fly_fishing
2,Article NPS Photo Neal Herbert The thick yello...,fly_fishing
3,Fly Fishing is a technique for catching fish w...,fly_fishing
4,11 Tips to Help You Sell a Boat Online Fishin...,fly_fishing
...,...,...
232,We use cookies to give you a better experience...,machine_learning
233,This post is part one in a three part series o...,machine_learning
234,A leading edge research firm focused on digita...,machine_learning
235,Explore Northeastern s first international cam...,machine_learning


In [8]:
data['label'].unique()

array(['fly_fishing', 'ice_hockey', 'machine_learning'], dtype=object)

In [9]:
data['category']=data['label'].apply(lambda x: 0 if x=='fly_fishing' else(1 if x=='ice_hockey' else 2))

In [10]:
data

Unnamed: 0,body_basic,label,category
0,Fly fishing is an angling method that uses a ...,fly_fishing,0
1,Simms GORE TEX ExStream Cap 89 95 Simms Dockwe...,fly_fishing,0
2,Article NPS Photo Neal Herbert The thick yello...,fly_fishing,0
3,Fly Fishing is a technique for catching fish w...,fly_fishing,0
4,11 Tips to Help You Sell a Boat Online Fishin...,fly_fishing,0
...,...,...,...
232,We use cookies to give you a better experience...,machine_learning,2
233,This post is part one in a three part series o...,machine_learning,2
234,A leading edge research firm focused on digita...,machine_learning,2
235,Explore Northeastern s first international cam...,machine_learning,2


In [11]:
X_train, X_test, y_train, y_test = train_test_split(data['body_basic'],data['category'],test_size=0.1)

In [12]:
bert_preprocess = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
bert_encoder = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4")

In [13]:
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
preprocessed_text = bert_preprocess(text_input)
outputs = bert_encoder(preprocessed_text)

In [14]:
l = tf.keras.layers.Dropout(0.1, name="dropout")(outputs['pooled_output'])
l = tf.keras.layers.Dense(1, activation='sigmoid', name="output")(l)

In [15]:
model = tf.keras.Model(inputs=[text_input], outputs = [l])

In [16]:
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 text (InputLayer)              [(None,)]            0           []                               
                                                                                                  
 keras_layer_2 (KerasLayer)     {'input_mask': (Non  0           ['text[0][0]']                   
                                e, 128),                                                          
                                 'input_type_ids':                                                
                                (None, 128),                                                      
                                 'input_word_ids':                                                
                                (None, 128)}                                                

In [17]:
METRICS = [
      tf.keras.metrics.BinaryAccuracy(name='accuracy'),
      tf.keras.metrics.Precision(name='precision'),
      tf.keras.metrics.Recall(name='recall')
]

model.compile(optimizer='adam',
 loss='binary_crossentropy',
 metrics=METRICS)

In [18]:
model.fit(X_train, y_train, epochs=1)



<keras.callbacks.History at 0x7f0931732250>

In [None]:
model.save('saved_model/bert_model')

In [19]:
loaded_model = tf.keras.models.load_model('saved_model/bert_model')
y_predicted = loaded_model.predict(X_test)
y_predicted = y_predicted.flatten()

In [20]:
y_predicted = np.where(y_predicted > 0.5, 1, 0)
y_predicted

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1])

In [21]:
print(classification_report(y_test, y_predicted))

              precision    recall  f1-score   support

           0       0.00      0.00      0.00         9
           1       0.33      1.00      0.50         8
           2       0.00      0.00      0.00         7

    accuracy                           0.33        24
   macro avg       0.11      0.33      0.17        24
weighted avg       0.11      0.33      0.17        24



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
