<a href="https://colab.research.google.com/github/efo-anopa/nlp/blob/main/ag_news_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [37]:
# Import libraries ------- [keras, numpy, tensorflow, tensorflow_datasets, matplotlib]
import keras
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

In [38]:
# Import Dataset
dataset = tfds.load('ag_news_subset')

In [39]:
# Preview Data
dataset

{'train': <_PrefetchDataset element_spec={'description': TensorSpec(shape=(), dtype=tf.string, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None), 'title': TensorSpec(shape=(), dtype=tf.string, name=None)}>,
 'test': <_PrefetchDataset element_spec={'description': TensorSpec(shape=(), dtype=tf.string, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None), 'title': TensorSpec(shape=(), dtype=tf.string, name=None)}>}

In [40]:
type(dataset)

dict

In [41]:
# Separate Training part and Testing part of Dataset

In [42]:
train_data = dataset['train']
test_data = dataset['test']

classNames = ['World', 'Sport', 'Business', 'Sci/Tech']
for i, x in zip(range(10), train_data):
    print(f"{x['label']}:{classNames[x['label']]} --> {x['title']} --> {x['description']}")
    print('-'*90)

3:Sci/Tech --> b'AMD Debuts Dual-Core Opteron Processor' --> b'AMD #39;s new dual-core Opteron chip is designed mainly for corporate computing applications, including databases, Web services, and financial transactions.'
------------------------------------------------------------------------------------------
1:Sport --> b"Wood's Suspension Upheld (Reuters)" --> b'Reuters - Major League Baseball\\Monday announced a decision on the appeal filed by Chicago Cubs\\pitcher Kerry Wood regarding a suspension stemming from an\\incident earlier this season.'
------------------------------------------------------------------------------------------
2:Business --> b'Bush reform may have blue states seeing red' --> b'President Bush #39;s  quot;revenue-neutral quot; tax reform needs losers to balance its winners, and people claiming the federal deduction for state and local taxes may be in administration planners #39; sights, news reports say.'
-----------------------------------------------------

In [43]:
# Build vectorizer

In [44]:
#from keras.layers.experimental.preprocessing import TextVectorization
#vectorizer = TextVectorization(max_tokens=5000)

vectorizer = tf.keras.layers.TextVectorization(max_tokens = 10000, output_sequence_length=300)
vectorizer.adapt(train_data.take(500).map(lambda x: x['title']+ " "+ x['description']))

In [45]:
vocabs = vectorizer.get_vocabulary()

In [46]:
vocab_size = len(vocabs)
print(vocab_size)

5335


In [47]:
print(vocabs)



In [48]:
vectorizer('Ghana Black Stars is playing Comoros today').numpy()

array([4171,  493,    1,   17,    1,    1,   96,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,   

In [50]:
train_data_tk = train_data.map(lambda x:(vectorizer(x['title']+ " " + x['description']), x['label'])).batch(128)
#train_data_ps = train_data_tk.map(lambda x: pad_input(x))
#train_data_tk = train_data.map(lambda x: ps(vectorizer(x['title'] + x['description']).numpy(), maxlen = 300, padding = 'post',truncating = 'post')).batch(128)

In [51]:
train_data_tk

<_BatchDataset element_spec=(TensorSpec(shape=(None, None), dtype=tf.int64, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))>

In [52]:
test_data_tk = test_data.map(lambda x:(vectorizer(x['title'] + " " + x['description']), x['label'])).batch(128)

In [54]:
# Build model

In [55]:
from keras.models import Sequential
from keras.layers import Dense, Embedding, Flatten, Bidirectional, LSTM

In [56]:
from keras.layers import Dropout

In [66]:
model = Sequential()
model.add(Embedding(vocab_size, 512))
model.add(Dropout(0.1))
model.add(Bidirectional(LSTM(32)))
model.add(Flatten())
model.add(Dropout(0.3))
model.add(Dense(32, activation = 'relu'))
model.add(Dense(4, activation='softmax'))

In [67]:
# Compile Model
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

In [71]:
# Fit Model
model.fit(train_data_tk, validation_data=test_data_tk, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.src.callbacks.History at 0x780c3075b820>

In [62]:
# Make Predictions

In [72]:
def prediction_report(news):
    classNames = ['World', 'Sport', 'Business', 'Sci/Tech']
    x = vectorizer(news)
    x_ = model.predict(x)
    print(x_)
    y2 = np.argmax(x_)
    return classNames[y2]

In [74]:
news = ['Impending alien invasion! Is this the end of humanity?']
prediction_report(news)

[[0.39584708 0.0133239  0.29989192 0.29093713]]


'World'

In [75]:
news = ["Humans discover alien life on an exoplanet"]
prediction_report(news)

[[8.4448420e-02 3.7903083e-05 8.5497975e-02 8.3001566e-01]]


'Sci/Tech'

In [76]:
n1 = ["Australia loses war with the emus... Again"]
prediction_report(n1)

[[0.8530336  0.05248579 0.05824916 0.03623144]]


'World'

In [77]:
n2 = ["Australian mosquitos may feed frog nostrils for blood"]
prediction_report(n2)

[[9.4550538e-01 5.2333862e-04 6.2824958e-03 4.7688752e-02]]


'World'

In [78]:
model.evaluate(test_data_tk)



[0.49840715527534485, 0.8882894515991211]

In [79]:
model.save('ag_news_model.keras')