In [51]:
import pandas as pd
import numpy as np
import os

from nltk.tokenize import wordpunct_tokenize
from keras.utils.np_utils import to_categorical
from sklearn.metrics import *

from keras.utils.np_utils import to_categorical
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model, load_model
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import *
import tensorflow.keras

import warnings
warnings.filterwarnings('ignore')

# Implementing Fastformers
In this project, I'm going to be implementing Fastformers, as introduced in a research paper titled <b> Fastformer: Additive Attention Can Be All You Need </b> which was published on September 5th, 2021. The full paper can be found [here](https://arxiv.org/pdf/2108.09084v6.pdf).

Transformers have proven to be very successful in many state-of-the-art pre-trained language models in NLP and in vision-related tasks. Transformers implement a concept called self-attention, which allows them to model the contexts within an input sequence. One limitation of today's transformers is that, because they compute the dot-product between the input representations at each pair of positions, their complexity is quadratic to the input sequence length. This makes it difficult for standard Transformer models to efficiently handle long input sequences. The concept proposed in this research paper to solve the quadratic complexity problem is <b> additive attention</b>, which can achieve effective context modeling in linear complexity.

## The Dataset
I'm going to be using a News Category Dataset which can be found [here](https://www.kaggle.com/rmisra/news-category-dataset). This dataset contains around 200k news headlines and short descriptions from the year 2012 to 2018 obtained from HuffPost. It contains 41 separate categories

In [2]:
df = pd.read_json('archive/News_Category_Dataset_v2.json', lines=True)

In [3]:
df.head()

Unnamed: 0,category,headline,authors,link,short_description,date
0,CRIME,There Were 2 Mass Shootings In Texas Last Week...,Melissa Jeltsen,https://www.huffingtonpost.com/entry/texas-ama...,She left her husband. He killed their children...,2018-05-26
1,ENTERTAINMENT,Will Smith Joins Diplo And Nicky Jam For The 2...,Andy McDonald,https://www.huffingtonpost.com/entry/will-smit...,Of course it has a song.,2018-05-26
2,ENTERTAINMENT,Hugh Grant Marries For The First Time At Age 57,Ron Dicker,https://www.huffingtonpost.com/entry/hugh-gran...,The actor and his longtime girlfriend Anna Ebe...,2018-05-26
3,ENTERTAINMENT,Jim Carrey Blasts 'Castrato' Adam Schiff And D...,Ron Dicker,https://www.huffingtonpost.com/entry/jim-carre...,The actor gives Dems an ass-kicking for not fi...,2018-05-26
4,ENTERTAINMENT,Julianna Margulies Uses Donald Trump Poop Bags...,Ron Dicker,https://www.huffingtonpost.com/entry/julianna-...,"The ""Dietland"" actress said using the bags is ...",2018-05-26


In [4]:
cates = df.groupby('category')
print("total categories:", cates.ngroups)
print(cates.size())

total categories: 41
category
ARTS               1509
ARTS & CULTURE     1339
BLACK VOICES       4528
BUSINESS           5937
COLLEGE            1144
COMEDY             5175
CRIME              3405
CULTURE & ARTS     1030
DIVORCE            3426
EDUCATION          1004
ENTERTAINMENT     16058
ENVIRONMENT        1323
FIFTY              1401
FOOD & DRINK       6226
GOOD NEWS          1398
GREEN              2622
HEALTHY LIVING     6694
HOME & LIVING      4195
IMPACT             3459
LATINO VOICES      1129
MEDIA              2815
MONEY              1707
PARENTING          8677
PARENTS            3955
POLITICS          32739
QUEER VOICES       6314
RELIGION           2556
SCIENCE            2178
SPORTS             4884
STYLE              2254
STYLE & BEAUTY     9649
TASTE              2096
TECH               2082
THE WORLDPOST      3664
TRAVEL             9887
WEDDINGS           3651
WEIRD NEWS         2670
WELLNESS          17827
WOMEN              3490
WORLD NEWS         2177
WORLDPOST 

## Preprocessing

In [5]:
# as shown above, THE WORLDPOST and WORLDPOST should be the same category, so merge them
df.category = df.category.map(lambda x: "WORLDPOST" if x == "THE WORLDPOST" else x)

In [6]:
# using headlines and short_description as input X
df['text'] = df.headline + " " + df.short_description

In [7]:
df['text'][0]

'There Were 2 Mass Shootings In Texas Last Week, But Only 1 On TV She left her husband. He killed their children. Just another day in America.'

## Label Encoding and Tokenization

In [8]:
#creating a dictionary to map the category types to numbers and the numbers to category types.
# category to id
categories = df.groupby('category').size().index.tolist()

category_int = {}
int_category = {}
for i, k in enumerate(categories):
    category_int.update({k:i})
    int_category.update({i:k})
    
df['label'] = df['category'].apply(lambda x: category_int[x])

In [9]:
df = df[['text','label']]

In [10]:
average_length = df['text'].apply(len).mean()
max_length = df['text'].apply(len).max()
print(f'The average number of characters is: {round(average_length,2)}\nThe max number of characters is: {round(max_length,2)}')

The average number of characters is: 173.25
The max number of characters is: 1487


This takes every string in the 'text' column and converts it to tokens. The tokenizer `wordpunct_tokenize` is based on a simple regexp tokenization. It will split pretty much all special symbols and treat them as separate units.

In [11]:
label = list(df['label'])

text = []
for row in df['text']:
    text.append(wordpunct_tokenize(row.lower()))

In [12]:
# giving all unique tokens a number in a word_dict dictionary
word_dict={'PADDING':0}
for sent in text:
    for token in sent:
        if token not in word_dict:
            word_dict[token] = len(word_dict)
        

In [13]:
# finding the average number of tokens in each sentence
tokens_per_sentence = []
for sent in text:
    tokens_per_sentence.append(len(sent))

average_length = round(sum(tokens_per_sentence)/len(tokens_per_sentence),2)
average_length   

35.01

## Making all Sequences into Uniform Length

Here I am converting all of the tokenized sentences to their numerical representations. Then I am trimming / padding each sentence to be of the same uniform length.

In [14]:
MAX_SENT_LENGTH = 32
news_words = []
for sent in text:
    sample = []
    for token in sent:
        sample.append(word_dict[token])
    sample = sample[:MAX_SENT_LENGTH]
    news_words.append(sample+[0]*(MAX_SENT_LENGTH-len(sample)))

In [15]:
# converting the numerical tokens and labels to dtype int32
news_words = np.array(news_words,dtype='int32')
label = np.array(label,dtype='int32')

## Splitting into a Training and Test Set

In [16]:
train_size = int(len(label) * 0.9)
train_size

180767

In [17]:
# getting an array with numbers for the train and test index values
index = np.arange(len(label))
train_size = int(len(label) * 0.9)
train_index = index[:train_size]
np.random.shuffle(train_index)
test_index=index[train_size:]

# Modeling
## Creating the Fastformer Class

Here is the structure of a Fastformer as shown in the research paper: <img src = "Images/Fastformer_structure.png" width=400 height=400>

Here is a breakdown of the image:
<ol>
    <li> The input embedding matrix is transformed into the query, key, and value sequences using three independent linear transformation layers</li>
    <li> The additive attention mechanism is used to summarize the input attention query matrix into a global query vector </li>
    <li> The interaction between attention key and the global query vector is modeled via element-wise product to learn a global context-aware key matrix. </li>
        <i> Element-wise product is used instead of addition in order to model the non-linear relations between the two vectors because addition cannot differ the influence of the global query on different keys (which is not beneficial for context understanding) </i>
    <li> The global context-aware key matrix is summarized into a global key vector via additive attention</li>
    <li> The element-wise product is used to aggregate the global key and attention value</li>
    <li> The global key and attention value are further processed by a linear transformation to compute the global context-aware attention value</li>
    <li> The original attention query and the global context-aware attention value are added to form the final output</li>
</ol>

In [18]:
class Fastformer(Layer):

    def __init__(self, nb_head, size_per_head, **kwargs):
        self.nb_head = nb_head
        self.size_per_head = size_per_head
        self.output_dim = nb_head*size_per_head
        self.now_input_shape=None
        super(Fastformer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.now_input_shape=input_shape
        self.WQ = self.add_weight(name='WQ', 
                                  shape=(input_shape[0][-1], self.output_dim),
                                  initializer='glorot_uniform',
                                  trainable=True)
        self.WK = self.add_weight(name='WK', 
                                  shape=(input_shape[1][-1], self.output_dim),
                                  initializer='glorot_uniform',
                                  trainable=True) 
        self.Wq = self.add_weight(name='Wq', 
                                  shape=(self.output_dim,self.nb_head),
                                  initializer='glorot_uniform',
                                  trainable=True)
        self.Wk = self.add_weight(name='Wk', 
                                  shape=(self.output_dim,self.nb_head),
                                  initializer='glorot_uniform',
                                  trainable=True)
        
        self.WP = self.add_weight(name='WP', 
                                  shape=(self.output_dim,self.output_dim),
                                  initializer='glorot_uniform',
                                  trainable=True)
        
        
        super(Fastformer, self).build(input_shape)
        
    def call(self, x):
        if len(x) == 2:
            Q_seq,K_seq = x
        elif len(x) == 4:
            Q_seq,K_seq,Q_mask,K_mask = x #different mask lengths, reserved for cross attention

        Q_seq = K.dot(Q_seq, self.WQ)        
        Q_seq_reshape = K.reshape(Q_seq, (-1, self.now_input_shape[0][1], self.nb_head*self.size_per_head))

        Q_att=  K.permute_dimensions(K.dot(Q_seq_reshape, self.Wq),(0,2,1))/ self.size_per_head**0.5

        if len(x)  == 4:
            Q_att = Q_att-(1-K.expand_dims(Q_mask,axis=1))*1e8

        Q_att = K.softmax(Q_att)
        Q_seq = K.reshape(Q_seq, (-1,self.now_input_shape[0][1], self.nb_head, self.size_per_head))
        Q_seq = K.permute_dimensions(Q_seq, (0,2,1,3))
        
        K_seq = K.dot(K_seq, self.WK)
        K_seq = K.reshape(K_seq, (-1,self.now_input_shape[1][1], self.nb_head, self.size_per_head))
        K_seq = K.permute_dimensions(K_seq, (0,2,1,3))

        Q_att = Lambda(lambda x: K.repeat_elements(K.expand_dims(x,axis=3),self.size_per_head,axis=3))(Q_att)
        global_q = K.sum(multiply([Q_att, Q_seq]),axis=2)
        
        global_q_repeat = Lambda(lambda x: K.repeat_elements(K.expand_dims(x,axis=2), self.now_input_shape[1][1],axis=2))(global_q)

        QK_interaction = multiply([K_seq, global_q_repeat])
        QK_interaction_reshape = K.reshape(QK_interaction, (-1, self.now_input_shape[0][1], self.nb_head*self.size_per_head))
        K_att = K.permute_dimensions(K.dot(QK_interaction_reshape, self.Wk),(0,2,1))/ self.size_per_head**0.5
        
        if len(x)  == 4:
            K_att = K_att-(1-K.expand_dims(K_mask,axis=1))*1e8
            
        K_att = K.softmax(K_att)

        K_att = Lambda(lambda x: K.repeat_elements(K.expand_dims(x,axis=3),self.size_per_head,axis=3))(K_att)

        global_k = K.sum(multiply([K_att, QK_interaction]),axis=2)
     
        global_k_repeat = Lambda(lambda x: K.repeat_elements(K.expand_dims(x,axis=2), self.now_input_shape[0][1],axis=2))(global_k)
        #Q=V
        QKQ_interaction = multiply([global_k_repeat, Q_seq])
        QKQ_interaction = K.permute_dimensions(QKQ_interaction, (0,2,1,3))
        QKQ_interaction = K.reshape(QKQ_interaction, (-1,self.now_input_shape[0][1], self.nb_head*self.size_per_head))
        QKQ_interaction = K.dot(QKQ_interaction, self.WP)
        QKQ_interaction = K.reshape(QKQ_interaction, (-1,self.now_input_shape[0][1], self.nb_head,self.size_per_head))
        QKQ_interaction = K.permute_dimensions(QKQ_interaction, (0,2,1,3))
        QKQ_interaction = QKQ_interaction+Q_seq
        QKQ_interaction = K.permute_dimensions(QKQ_interaction, (0,2,1,3))
        QKQ_interaction = K.reshape(QKQ_interaction, (-1,self.now_input_shape[0][1], self.nb_head*self.size_per_head))

        #many operations can be optimized if higher versions are used. 
        
        return QKQ_interaction
        
    def compute_output_shape(self, input_shape):
        return (input_shape[0][0], input_shape[0][1], self.output_dim)

In [52]:
tensorflow.keras.backend.clear_session() 

text_input = Input(shape=(MAX_SENT_LENGTH,), dtype='int32')
qmask=Lambda(lambda x:  K.cast(K.cast(x,'bool'),'float32'))(text_input)

word_emb = Embedding(len(word_dict),256, trainable=True)(text_input)
word_emb = Dropout(0.2)(word_emb)

hidden_word_emb = Fastformer(16,16)([word_emb,word_emb,qmask,qmask])
hidden_word_emb = Dropout(0.2)(hidden_word_emb)
hidden_word_emb = LayerNormalization()(add([word_emb,hidden_word_emb])) 

hidden_word_emb_layer2 = Fastformer(16,16)([hidden_word_emb,hidden_word_emb,qmask,qmask])
hidden_word_emb_layer2 = Dropout(0.2)(hidden_word_emb_layer2)
hidden_word_emb_layer2 = LayerNormalization()(add([hidden_word_emb,hidden_word_emb_layer2]))

word_att = Flatten()(Dense(1)(hidden_word_emb_layer2))
word_att = Activation('softmax')(word_att)
text_emb = Dot((1, 1))([hidden_word_emb_layer2 , word_att])
classifier = Dense(40, activation='softmax')(text_emb)
                                      
model = Model([text_input], [classifier])
model.compile(loss=['categorical_crossentropy'],optimizer=Adam(lr=0.001), metrics=['acc'])

for i in range(2):
    model.fit(news_words[train_index],to_categorical(label)[train_index], shuffle=True, batch_size=64, epochs=2, verbose=1)
    y_pred = model.predict([news_words[test_index] ], batch_size=128, verbose=1)
    y_pred = np.argmax(y_pred, axis=1)
    y_true = label[test_index]
    # 'weighted' means that it calculates metrics for each label and finds their average weighted by the
    # number of true instances for each label
    report = f1_score(y_true, y_pred, average='weighted')  
    print(report)

Train on 180767 samples
Epoch 1/2
Epoch 2/2


2021-09-14 22:35:53.515126: E tensorflow/core/grappler/optimizers/dependency_optimizer.cc:697] Iteration = 0, topological sort failed with message: The graph couldn't be sorted in topological order.
2021-09-14 22:35:53.515687: E tensorflow/core/grappler/optimizers/dependency_optimizer.cc:697] Iteration = 1, topological sort failed with message: The graph couldn't be sorted in topological order.
2021-09-14 22:35:53.516842: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:533] model_pruner failed: Invalid argument: MutableGraphView::MutableGraphView error: node 'fastformer_1/lambda_5/concat' has self cycle fanin 'fastformer_1/lambda_5/concat'.
2021-09-14 22:35:53.519253: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:533] remapper failed: Invalid argument: MutableGraphView::MutableGraphView error: node 'fastformer_1/lambda_5/concat' has self cycle fanin 'fastformer_1/lambda_5/concat'.
2021-09-14 22:35:53.519638: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:533]

0.7085544144846351
Train on 180767 samples
Epoch 1/2
Epoch 2/2
0.6706256158320152
