# Topic classification with NN

Steps:
1. Web scrape information from coronanet.org
2. Clean texts - get from coronanet project
3. Learn RNN process to fit model (https://www.youtube.com/watch?v=dkpS2g4K08s)
4. Get feedback

In [1]:
from selenium import webdriver
from parsel import Selector
import time
import tensorflow as tf
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import tensorflow_hub as hub

import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
import re
import pycountry

## 1. Webscrape info from github with selenium 

In [2]:
driver = webdriver.Chrome('/Users/irenechang/Downloads/chromedriver')
driver.get('https://github.com/CoronaNetDataScience/corona_tscs/tree/master/data/CoronaNet/data_country/coronanet_release')

In [3]:
#Germany, USA, Spain, Australia
country_names = ['Germany', 'United States of America', 'Spain', 'Australia', 'India']
countries_to_scrape = []

#parse string together
for country in country_names:
    countries_to_scrape.append("coronanet_release_" + country + ".csv")


# go into each link - hit download
# '/a[@title="coronanet_release_United States of America.csv"]/@href'
urls = []
for country in countries_to_scrape:
    xPath = "//a[@title='"+ country + "']"
    print(xPath)
    link = driver.find_element_by_xpath(xPath).get_attribute("href")
    urls.append(link)

# get the csv
csv_urls = []
for url in urls:
    driver.get(url)
    sel = Selector(text=driver.page_source) # what is this for again?...
    raw_csv = driver.find_element_by_id('raw-url').get_attribute("href")
    csv_urls.append(raw_csv)

//a[@title='coronanet_release_Germany.csv']
//a[@title='coronanet_release_United States of America.csv']
//a[@title='coronanet_release_Spain.csv']
//a[@title='coronanet_release_Australia.csv']
//a[@title='coronanet_release_India.csv']


In [4]:
dfs = []
for csv in csv_urls:
    # read each csv into a separate dataframe
    dfs.append(pd.read_csv(csv))

big_frame = pd.concat(dfs, ignore_index=True)

In [5]:
big_frame.shape

(15252, 40)

In [6]:
big_frame.columns

Index(['record_id', 'policy_id', 'entry_type', 'correct_type', 'update_type',
       'update_level', 'description', 'date_announced', 'date_start',
       'date_end', 'country', 'ISO_A3', 'ISO_A2', 'init_country_level',
       'domestic_policy', 'province', 'ISO_L2', 'city', 'type', 'type_sub_cat',
       'type_text', 'institution_status', 'target_country',
       'target_geog_level', 'target_region', 'target_province', 'target_city',
       'target_other', 'target_who_what', 'target_direction',
       'travel_mechanism', 'compliance', 'enforcer', 'dist_index_high_est',
       'dist_index_med_est', 'dist_index_low_est', 'dist_index_country_rank',
       'link', 'date_updated', 'recorded_date'],
      dtype='object')

In [7]:
big_frame.head(3)

Unnamed: 0,record_id,policy_id,entry_type,correct_type,update_type,update_level,description,date_announced,date_start,date_end,...,travel_mechanism,compliance,enforcer,dist_index_high_est,dist_index_med_est,dist_index_low_est,dist_index_country_rank,link,date_updated,recorded_date
0,R_DT9wJ6cfpACiXyVNA,1475054,new_entry,original,,,In Baden-Württemberg (Germany) the city of Stu...,2020-03-17,2020-03-17,2020-03-17,...,,Mandatory (Unspecified/Implied),Provincial/State Government,57.590903,54.734064,51.729718,102.0,https://www.stuttgarter-nachrichten.de/inhalt....,2020-08-22,2020-08-22T14:59:54Z
1,R_5p95bNFstNDs9UdNA,5717782,new_entry,original,,,Bremen (Germany) informs about fake news about...,2020-03-26,2020-03-26,2020-03-26,...,,Mandatory (Unspecified/Implied),Provincial/State Government,68.460588,65.451526,62.56566,90.0,https://www.bremen-innovativ.de/2020/03/fake-n...,2020-08-18,2020-08-18T06:45:40Z
2,R_2Qgl6LnVGJvdYioNA,6291168,new_entry,original,,,The Thuringia government introduces citizens t...,2020-03-30,2020-03-30,,...,,Voluntary/Recommended but No Penalties,Provincial/State Government,69.580439,66.680645,64.0928,80.0,https://corona.thueringen.de/buerger-soziales/...,2020-09-03,2020-09-03T13:35:20Z


In [8]:
import gc
gc.collect()
del(dfs)

In [23]:
df = big_frame[["description", "type"]].drop_duplicates(subset=['description']).reset_index(drop=True)

In [24]:
df

Unnamed: 0,description,type
0,In Baden-Württemberg (Germany) the city of Stu...,Anti-Disinformation Measures
1,Bremen (Germany) informs about fake news about...,Anti-Disinformation Measures
2,The Thuringia government introduces citizens t...,Anti-Disinformation Measures
3,"On 4 May, The NRW State Criminal Police Office...",Anti-Disinformation Measures
4,Germany offers a website with information aro...,Anti-Disinformation Measures
...,...,...
10086,India's Delhi metro largest subway reopens,Social Distancing
10087,In the state of Assam in India wearing face ma...,Social Distancing
10088,The state of Assam in India establishes a sop ...,Social Distancing
10089,The government of Goa in India on 20th Oct ann...,Social Distancing


In [25]:
df['type'].value_counts()

Restriction and Regulation of Businesses                  1635
Restrictions of Mass Gatherings                           1048
Health Resources                                          1027
Closure and Regulation of Schools                          739
Restriction and Regulation of Government Services          726
Social Distancing                                          704
Quarantine                                                 702
Other Policy Not Listed Above                              608
Lockdown                                                   477
Public Awareness Measures                                  430
Health Testing                                             385
Declaration of Emergency                                   339
New Task Force, Bureau or Administrative Configuration     281
External Border Restrictions                               234
Internal Border Restrictions                               181
Health Monitoring                                      

## 2. Preprocess text data

In [26]:
# Special characters and punctuation
punc_list = list("?:!.,;()")
df["description_1"] = df["description"].str.replace("\n", " ")
df["description_1"] = df["description_1"].str.replace('"', " ")
df["description_1"] = df["description_1"].str.replace("'s", " ")
for punc in punc_list:
    df["description_1"] = df["description_1"].str.replace(punc, " ")
    
#lowering cases
df["description_1"] = df["description_1"].str.lower()

In [27]:
df

Unnamed: 0,description,type,description_1
0,In Baden-Württemberg (Germany) the city of Stu...,Anti-Disinformation Measures,in baden-württemberg germany the city of stu...
1,Bremen (Germany) informs about fake news about...,Anti-Disinformation Measures,bremen germany informs about fake news about...
2,The Thuringia government introduces citizens t...,Anti-Disinformation Measures,the thuringia government introduces citizens t...
3,"On 4 May, The NRW State Criminal Police Office...",Anti-Disinformation Measures,on 4 may the nrw state criminal police office...
4,Germany offers a website with information aro...,Anti-Disinformation Measures,germany offers a website with information aro...
...,...,...,...
10086,India's Delhi metro largest subway reopens,Social Distancing,india delhi metro largest subway reopens
10087,In the state of Assam in India wearing face ma...,Social Distancing,in the state of assam in india wearing face ma...
10088,The state of Assam in India establishes a sop ...,Social Distancing,the state of assam in india establishes a sop ...
10089,The government of Goa in India on 20th Oct ann...,Social Distancing,the government of goa in india on 20th oct ann...


In [29]:
#stemming and lemmatization
nltk.download('punkt')
nltk.download('wordnet')

wordnet_lemmatizer = WordNetLemmatizer()
nrows = len(df)
lemmatized_text_list = []

for row in range(0, nrows):
    
    # Create an empty list containing lemmatized words
    lemmatized_list = []
    
    # Save the text and its words into an object
    text = df.loc[row]['description_1']
    text_words = text.split(" ")

    # Iterate through every word to lemmatize
    for word in text_words:
        lemmatized_list.append(wordnet_lemmatizer.lemmatize(word, pos="v"))
        
    # Join the list
    lemmatized_text = " ".join(lemmatized_list)
    
    # Append to the list containing the texts
    lemmatized_text_list.append(lemmatized_text)


In [30]:
#stopwords
nltk.download('stopwords')
stop_words = list(stopwords.words('english'))

# include country names in stopwords
country_text = []
for text in df["description_1"].tolist():
    for c in pycountry.countries:
        if c.name.lower() in text:
            text = re.sub(c.name.lower(), '', text)
    country_text.append(text)

df["description_2"] = country_text

for stop_word in stop_words:

    regex_stopword = r"\b" + stop_word + r"\b"
    df['description_2'] = df['description_2'].str.replace(regex_stopword, '')

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/irenechang/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [31]:
#remove numbers
pattern = r'[0-9]'
remove_number = []
for text in df["description_2"].tolist():
    remove_number.append(re.sub(pattern, '', text))
    
df["description_3"] = remove_number
df

Unnamed: 0,description,type,description_1,description_2,description_3
0,In Baden-Württemberg (Germany) the city of Stu...,Anti-Disinformation Measures,in baden-württemberg germany the city of stu...,baden-württemberg city stuttgart warns ...,baden-württemberg city stuttgart warns ...
1,Bremen (Germany) informs about fake news about...,Anti-Disinformation Measures,bremen germany informs about fake news about...,bremen informs fake news covid-19 march 26,bremen informs fake news covid- march
2,The Thuringia government introduces citizens t...,Anti-Disinformation Measures,the thuringia government introduces citizens t...,thuringia government introduces citizens ov...,thuringia government introduces citizens ov...
3,"On 4 May, The NRW State Criminal Police Office...",Anti-Disinformation Measures,on 4 may the nrw state criminal police office...,4 may nrw state criminal police office lka...,may nrw state criminal police office lka ...
4,Germany offers a website with information aro...,Anti-Disinformation Measures,germany offers a website with information aro...,offers website information around fake new...,offers website information around fake new...
...,...,...,...,...,...
10086,India's Delhi metro largest subway reopens,Social Distancing,india delhi metro largest subway reopens,delhi metro largest subway reopens,delhi metro largest subway reopens
10087,In the state of Assam in India wearing face ma...,Social Distancing,in the state of assam in india wearing face ma...,state assam wearing face masks social di...,state assam wearing face masks social di...
10088,The state of Assam in India establishes a sop ...,Social Distancing,the state of assam in india establishes a sop ...,state assam establishes sop cinemas pu...,state assam establishes sop cinemas pu...
10089,The government of Goa in India on 20th Oct ann...,Social Distancing,the government of goa in india on 20th oct ann...,government goa 20th oct announced partie...,government goa th oct announced parties ...


In [32]:
list_columns = ["description_3", "type"]
df2 = df[list_columns]

df2 = df2.rename(columns={'description_3': 'description'})

In [33]:
df2

Unnamed: 0,description,type
0,baden-württemberg city stuttgart warns ...,Anti-Disinformation Measures
1,bremen informs fake news covid- march,Anti-Disinformation Measures
2,thuringia government introduces citizens ov...,Anti-Disinformation Measures
3,may nrw state criminal police office lka ...,Anti-Disinformation Measures
4,offers website information around fake new...,Anti-Disinformation Measures
...,...,...
10086,delhi metro largest subway reopens,Social Distancing
10087,state assam wearing face masks social di...,Social Distancing
10088,state assam establishes sop cinemas pu...,Social Distancing
10089,government goa th oct announced parties ...,Social Distancing


#### Label encoding

In [47]:
category_codes = {
    'Anti-Disinformation Measures': 0,
    'Hygiene': 1,
    'Curfew': 2,
    'Closure and Regulation of Schools': 3,
    'Declaration of Emergency': 4,
    'External Border Restrictions': 5,
    'Health Monitoring': 6,
    'Health Resources': 7,
    'Health Testing': 8,
    'Internal Border Restrictions': 9,
    'Lockdown': 10,
    'New Task Force, Bureau or Administrative Configuration': 11,
    'COVID-19 Vaccines': 12,
    'Public Awareness Measures': 13,
    'Quarantine': 14,
    'Restriction and Regulation of Businesses': 15,
    'Restriction and Regulation of Government Services': 16,
    'Restrictions of Mass Gatherings':17,
    'Social Distancing':18
}

# Category mapping
df2['type_code'] = df2['type']
df2 = df2.replace({'type_code':category_codes})

df2.head()

Unnamed: 0,description,type,type_code
0,baden-württemberg city stuttgart warns ...,Anti-Disinformation Measures,0
1,bremen informs fake news covid- march,Anti-Disinformation Measures,0
2,thuringia government introduces citizens ov...,Anti-Disinformation Measures,0
3,may nrw state criminal police office lka ...,Anti-Disinformation Measures,0
4,offers website information around fake new...,Anti-Disinformation Measures,0


#### Split train-test sets

In [35]:
# imbalanced -- have to assign class weights
df2['type'].value_counts()

# split train - test
X_train, X_test, y_train, y_test = train_test_split(df2['description'], 
                                                    df2['type'], 
                                                    test_size=0.15, 
                                                    random_state=8)

In [62]:
df2

Unnamed: 0,description,type,type_code
0,baden-württemberg city stuttgart warns ...,Anti-Disinformation Measures,0
1,bremen informs fake news covid- march,Anti-Disinformation Measures,0
2,thuringia government introduces citizens ov...,Anti-Disinformation Measures,0
3,may nrw state criminal police office lka ...,Anti-Disinformation Measures,0
4,offers website information around fake new...,Anti-Disinformation Measures,0
...,...,...,...
10086,delhi metro largest subway reopens,Social Distancing,18
10087,state assam wearing face masks social di...,Social Distancing,18
10088,state assam establishes sop cinemas pu...,Social Distancing,18
10089,government goa th oct announced parties ...,Social Distancing,18


## 3. Fit the model

In [63]:
from sklearn.utils import class_weight
class_weights = list(class_weight.compute_class_weight('balanced', 
                                                       np.unique(df2['type']), 
                                                       df2['type']))
class_weights.sort()
class_weights

[0.3085932721712538,
 0.4814408396946565,
 0.49128529698149953,
 0.6827469553450609,
 0.6949724517906336,
 0.7166903409090909,
 0.7187321937321938,
 0.8298519736842105,
 1.0577568134171909,
 1.1733720930232558,
 1.3105194805194804,
 1.4883480825958701,
 1.7955516014234876,
 2.156196581196581,
 2.787569060773481,
 2.9505847953216375,
 3.1933544303797468,
 3.2977124183006534,
 10.091,
 11.733720930232558]

In [64]:
weights = {}
for index, weight in enumerate(class_weights) :
    weights[index] = weight
    
weights

{0: 0.3085932721712538,
 1: 0.4814408396946565,
 2: 0.49128529698149953,
 3: 0.6827469553450609,
 4: 0.6949724517906336,
 5: 0.7166903409090909,
 6: 0.7187321937321938,
 7: 0.8298519736842105,
 8: 1.0577568134171909,
 9: 1.1733720930232558,
 10: 1.3105194805194804,
 11: 1.4883480825958701,
 12: 1.7955516014234876,
 13: 2.156196581196581,
 14: 2.787569060773481,
 15: 2.9505847953216375,
 16: 3.1933544303797468,
 17: 3.2977124183006534,
 18: 10.091,
 19: 11.733720930232558}

In [65]:
dataset_train = tf.data.Dataset.from_tensor_slices((X_train.values, y_train.values))
dataset_test = tf.data.Dataset.from_tensor_slices((X_test.values, y_test.values))

In [66]:
for text, target in dataset_train.take(5):
    print('Desc: {}, label: {}'.format(text, target))

Desc: b' us embassy montevideo consular section  closed   routine consular services   notice    emergency situations   considered   time ', label: b'Restriction and Regulation of Government Services'
Desc: b' pennsylvania governor signed  senate bill    waives  requirement  schools    session  least  days  provides  continuity  education plans  ensures school employees  paid   closure   provides  secretary  education  authority  waive student teacher  standardized assessments   march  ', label: b'Closure and Regulation of Schools'
Desc: b"dumka   district   n state  jharkhand   defined  government services  would remain operational   lockdown   follows    law  order agencies -  function without  restrictions   officers attendance - compulsory  grade ''  'b' officers  reduced  %  grade 'c'      district administration  treasury officials -   function  restricted staff    wildlife  forest officers -  function  taking necessary precautions ", label: b'Restriction and Regulation of Governm

In [67]:
table = tf.lookup.StaticHashTable(
    initializer=tf.lookup.KeyValueTensorInitializer(
        keys=tf.constant(list(category_codes.keys())),
        values = tf.constant(list(category_codes.values()))
    ),
    default_value=tf.constant(-1),
    name="target_encoding"
)

@tf.function
def target(x):
    return table.lookup(x)

In [89]:
def fetch(text, labels):
    return text, tf.one_hot(target(labels), 19)

train_data_fetch = dataset_train.map(fetch)
test_data_fetch = dataset_test.map(fetch)

#### Start creating a model

In [140]:
embedding = "https://tfhub.dev/google/nnlm-en-dim128/2"
hub_layer = hub.KerasLayer(embedding, output_shape=[128],input_shape=[], dtype=tf.string,
                          trainable=True)
hub_layer(train_data[:1])

<tf.Tensor: shape=(1, 128), dtype=float32, numpy=
array([[ 1.85923278e-01,  3.82673025e-01,  8.69123638e-02,
        -2.36745372e-01, -1.19763926e-01, -5.65516986e-02,
         2.45870352e-01,  5.02816178e-02, -2.10541233e-01,
        -4.42932360e-02,  1.28366366e-01,  1.47269592e-01,
         1.41175740e-04,  4.45434526e-02,  2.13784329e-03,
         1.61750317e-01, -2.32903764e-01, -2.10702419e-01,
        -2.09106982e-01,  1.55449033e-01,  4.53584678e-02,
         4.31233309e-02,  1.48296393e-02, -1.68935359e-01,
         1.12579502e-01, -1.03304483e-01,  1.61703452e-01,
         2.13061482e-01, -4.74388264e-02,  1.27027377e-01,
        -3.04564610e-02, -1.92816645e-01, -3.22420187e-02,
         2.94271410e-01,  2.97213867e-02,  1.13602817e-01,
         8.43360722e-02, -1.42353237e-01,  1.92280009e-01,
         4.26607989e-02, -2.84466296e-02, -2.83433974e-01,
        -1.92027800e-02,  1.16621844e-01, -1.83381909e-03,
         2.30389148e-01,  1.43880561e-01,  1.20757513e-01,
      

In [91]:
#build the basic model without dropout
model_wo_dropout = tf.keras.Sequential()
model_wo_dropout.add(hub_layer)
for units in [128, 128, 64, 32]:
    model_wo_dropout.add(tf.keras.layers.Dense(units, activation='relu'))
model_wo_dropout.add(tf.keras.layers.Dense(19, activation='softmax'))

model_wo_dropout.summary()

Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
keras_layer_3 (KerasLayer)   (None, 128)               124642688 
_________________________________________________________________
dense_25 (Dense)             (None, 128)               16512     
_________________________________________________________________
dense_26 (Dense)             (None, 128)               16512     
_________________________________________________________________
dense_27 (Dense)             (None, 64)                8256      
_________________________________________________________________
dense_28 (Dense)             (None, 32)                2080      
_________________________________________________________________
dense_29 (Dense)             (None, 19)                627       
Total params: 124,686,675
Trainable params: 124,686,675
Non-trainable params: 0
________________________________________

In [92]:
#compile the model
model_wo_dropout.compile(optimizer='adam', 
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [93]:
train_data_fetch = train_data_fetch.shuffle(70000).batch(512)
test_data_fetch = test_data_fetch.batch(512)

In [None]:
#fit the model
from keras import callbacks
earlystopping = callbacks.EarlyStopping(monitor="val_loss", mode="min", patience=5, restore_best_weights=True, verbose=1)
text_classifier_wo_dropout = model_wo_dropout.fit(train_data_fetch, epochs=25, validation_data = test_data_fetch,
                   verbose=1, class_weight=weights, callbacks =[earlystopping])

In [80]:
test_length = len(list(dataset_test))
results = model_wo_dropout.evaluate(dataset_test.map(fetch).batch(test_length), verbose=2)
print(results)

1/1 - 0s - loss: 1.1629 - accuracy: 0.6526
[1.1628882884979248, 0.6525759696960449]


In [81]:
test_data, test_labels = next(iter(dataset_test.map(fetch).batch(test_length)))
y_pred = model_wo_dropout.predict(test_data)

print(classification_report(test_labels.numpy().argmax(axis=1), y_pred.argmax(axis=1)))

              precision    recall  f1-score   support

           0       0.00      0.00      0.00       110
           1       0.00      0.00      0.00        29
           2       1.00      0.05      0.09        21
           3       0.92      0.79      0.85       126
           4       0.79      0.75      0.77        36
           5       0.63      0.41      0.50        29
           6       0.25      0.05      0.08        22
           7       0.51      0.60      0.55       161
           8       0.44      0.63      0.52        60
           9       0.80      0.19      0.31        21
          10       0.57      0.79      0.66        66
          11       0.46      0.45      0.46        53
          12       0.00      0.00      0.00         4
          13       0.37      0.69      0.49        59
          14       0.67      0.87      0.76       108
          15       0.79      0.88      0.83       243
          16       0.54      0.74      0.62       109
          17       0.88    

  _warn_prf(average, modifier, msg_start, len(result))


#### Model with Dropout layers

In [143]:
train_data_fetch = dataset_train.map(fetch)
test_data_fetch = dataset_test.map(fetch)

In [144]:
print(len(train_data_fetch))
print(len(test_data_fetch))

8577
1514


In [145]:
# build the model
model = tf.keras.Sequential()
model.add(hub_layer)
for units in [128, 128, 64, 32]:
    model.add(tf.keras.layers.Dense(units, activation='relu'))
    model.add(tf.keras.layers.Dropout(0.3))
model.add(tf.keras.layers.Dense(19, activation='softmax'))

model.summary()

Model: "sequential_12"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
keras_layer_5 (KerasLayer)   (None, 128)               124642688 
_________________________________________________________________
dense_60 (Dense)             (None, 128)               16512     
_________________________________________________________________
dropout_32 (Dropout)         (None, 128)               0         
_________________________________________________________________
dense_61 (Dense)             (None, 128)               16512     
_________________________________________________________________
dropout_33 (Dropout)         (None, 128)               0         
_________________________________________________________________
dense_62 (Dense)             (None, 64)                8256      
_________________________________________________________________
dropout_34 (Dropout)         (None, 64)              

In [146]:
#compile the model
model.compile(optimizer='adam', 
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [147]:
train_data_fetch = train_data_fetch.shuffle(10000).batch(1500)
test_data_fetch = test_data_fetch.batch(512)

In [148]:
#fit the model
text_classifier = model.fit(train_data_fetch, epochs=100, validation_data = test_data_fetch,
                   verbose=1, class_weight=weights, callbacks =[earlystopping])

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Restoring model weights from the end of the best epoch.
Epoch 00036: early stopping


In [149]:
test_length = len(list(dataset_test))
results = model.evaluate(dataset_test.map(fetch).batch(test_length), verbose=2)
print(results)

1/1 - 0s - loss: 1.4425 - accuracy: 0.6011
[1.442507028579712, 0.6010568141937256]


In [150]:
test_data, test_labels = next(iter(dataset_test.map(fetch).batch(test_length)))
y_pred = model.predict(test_data)

print(classification_report(test_labels.numpy().argmax(axis=1), y_pred.argmax(axis=1)))

              precision    recall  f1-score   support

           0       0.00      0.00      0.00       110
           1       0.00      0.00      0.00        29
           2       0.00      0.00      0.00        21
           3       0.62      0.79      0.70       126
           4       0.00      0.00      0.00        36
           5       0.00      0.00      0.00        29
           6       0.00      0.00      0.00        22
           7       0.53      0.61      0.57       161
           8       0.00      0.00      0.00        60
           9       0.00      0.00      0.00        21
          10       0.65      0.76      0.70        66
          11       0.18      0.40      0.25        53
          12       0.00      0.00      0.00         4
          13       0.26      0.59      0.36        59
          14       0.70      0.91      0.79       108
          15       0.76      0.88      0.81       243
          16       0.64      0.72      0.68       109
          17       0.77    

In [107]:
y_train_pred = model.predict(train_data)

print(classification_report(train_labels.numpy().argmax(axis=1), y_train_pred.argmax(axis=1)))

              precision    recall  f1-score   support

           3       1.00      1.00      1.00         1
          10       1.00      1.00      1.00         1
          15       1.00      1.00      1.00         1
          16       1.00      1.00      1.00         2

    accuracy                           1.00         5
   macro avg       1.00      1.00      1.00         5
weighted avg       1.00      1.00      1.00         5



In [108]:
train_labels

<tf.Tensor: shape=(5, 19), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
        0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 0., 0.]], dtype=float32)>

In [109]:
train_data

<tf.Tensor: shape=(5,), dtype=string, numpy=
array([b' us embassy montevideo consular section  closed   routine consular services   notice    emergency situations   considered   time ',
       b' pennsylvania governor signed  senate bill    waives  requirement  schools    session  least  days  provides  continuity  education plans  ensures school employees  paid   closure   provides  secretary  education  authority  waive student teacher  standardized assessments   march  ',
       b"dumka   district   n state  jharkhand   defined  government services  would remain operational   lockdown   follows    law  order agencies -  function without  restrictions   officers attendance - compulsory  grade ''  'b' officers  reduced  %  grade 'c'      district administration  treasury officials -   function  restricted staff    wildlife  forest officers -  function  taking necessary precautions ",
       b'texas     reopening  non-essential businesses starting may      per executive order ga-    ha