In [1]:
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import imdb
from keras.preprocessing import sequence
from IPython.display import display, Markdown

%matplotlib inline

Using TensorFlow backend.


In [None]:
from keras import models
from keras import optimizers as opt
from keras.layers import Input
from keras.layers import Embedding
from keras.layers import LSTM
from keras.layers import Dense
from attention.layers import AttentionLayer

# Please notice that this dataset is quite simple, so it's easy to overfit a model on it.
# We are using it because it comes bundled with Keras and the goal isto showcase the layer only

In [3]:
VOCAB_SIZE = 10000
MAX_LEN = 100
HIDDEN_SIZE = 16
DROPOUT = 0.5

(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=VOCAB_SIZE)

x_train = sequence.pad_sequences(x_train, MAX_LEN, padding='post', truncating='post')
x_test = sequence.pad_sequences(x_test, MAX_LEN, padding='post', truncating='post')

In [4]:
def build_model():
    sentence_in = Input((MAX_LEN, ),
                        name='sentence-in')

    embedded = Embedding(VOCAB_SIZE,
                         HIDDEN_SIZE,
                         mask_zero=True,
                         name='embedding')(sentence_in)

    vectors  = LSTM(HIDDEN_SIZE,
                    return_sequences=True,
                    dropout=DROPOUT,
                    recurrent_dropout=DROPOUT,
                    name='ff-lstm')(embedded)

    sentence = AttentionLayer(name='attention')(vectors)

    output   = Dense(1,
                     activation='sigmoid',
                     name='output')(sentence)

    model = models.Model(inputs=[sentence_in], outputs=[output])
    return model

In [None]:
# (b, t, d)
model = build_model()
model.summary()
    
model.compile('adam', 'binary_crossentropy', metrics=['binary_accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=5, validation_split=0.2)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sentence-in (InputLayer)     (None, 100)               0         
_________________________________________________________________
embedding (Embedding)        (None, 100, 16)           160000    
_________________________________________________________________
ff-lstm (LSTM)               (None, 100, 16)           2112      
_________________________________________________________________
attention (AttentionLayer)   (None, 16)                288       
_________________________________________________________________
output (Dense)               (None, 1)                 17        
Total params: 162,417
Trainable params: 162,417
Non-trainable params: 0
_________________________________________________________________
Train on 20000 samples, validate on 5000 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5

In [None]:
def build_viz_model(trained_model):
    """Architecture: input -> embedding -> lstm -> attention -> sigmoid"""
    sentence_in = Input((MAX_LEN, ),
                        name='sentence-in')

    embedded = Embedding(VOCAB_SIZE,
                         HIDDEN_SIZE,
                         mask_zero=True,
                         weights=trained_model.layers[1].get_weights(),
                         name='embedding')(sentence_in)

    vectors  = LSTM(HIDDEN_SIZE,
                    return_sequences=True,
                    dropout=DROPOUT,
                    recurrent_dropout=DROPOUT,
                    weights=trained_model.layers[2].get_weights(),
                    name='ff-lstm')(embedded)

    alphas  = AttentionLayer(weights=trained_model.layers[3].get_weights(),
                              return_attention=True,
                              name='attention')(vectors)

    model = models.Model(inputs=[sentence_in], outputs=[alphas])
    return model

In [None]:
def get_index2word():
    """Computes the table that maps ids to their words."""
    INDEX_FROM = 3   # word index offset

    word2index = imdb.get_word_index()
    index2word = {v + INDEX_FROM: k for k, v in word2index.items()}
    index2word[0] = '[PAD]'
    index2word[1] = '[START]'
    index2word[2] = '[UNK]'

    return index2word


def reconstruct(sample, index2word):
    """Given a list of word ids, returns a list of words."""
    return [index2word[word] for word in sample]

In [None]:
def _weight2color(brightness):
    """Converts a single (positive) attention weight to a shade of blue."""
    brightness = brightness.item()

    brightness = int(round(255 * brightness)) # convert from 0.0-1.0 to 0-255
    ints = (255 - brightness, 255 - brightness, 255)
#     return '#%02x%02x%02x' % ints
    return 'rgba({}, {}, {}, 0.6)'.format(*ints)


def print_sentence(label, predicted, sentence, weights):
    """Prints a sample (sequence) making the most attended words background darker."""

    parts = list()
    parts.append('<span style="padding:2px;">[actual: %10s >< pred: %10s]</span> ' % (label, predicted))
    for word, weight in zip(sentence, weights):
        if word == '[PAD]':
            break
        parts.append('<span style="background: {}; color:#000; padding:2px; font-weight=\'bold\'">{}</span>'.format(_weight2color(weight), word))
    
    text = ' '.join(parts)
    display(Markdown(text))


def plot_sentence(words, weights):
    words = [f'{i}_{word}' for i, word in enumerate(words)]
    
    plt.figure(figsize=(20, 2))
    plt.plot(words, z)
    plt.xticks(rotation=90)
    plt.grid(alpha=0.4)
    plt.ylabel('Attention')
    plt.show()

In [None]:
index2word = get_index2word()
index2label = {
    0: 'negative',
    1: 'positive'
}

viz_model = build_viz_model(model)

In [None]:
index2word = get_index2word()
index2label = {
    0: 'negative',
    1: 'positive'
}

for index in range(5):
    # reconstructing sample from word ids to actual words
    words = reconstruct(x_train[index], index2word)
    sample = x_train[index:index+1]

    # getting prediction and alphas
    pred = int(model.predict(sample)[0] >= 0.5)
    z = viz_model.predict(x_train[index:index+1])[0]
    
    # reescaling for visualization purposes
    w = (z - np.min(z)) / (np.max(z) - np.min(z))

    # ta-da
    print_sentence(index2label[y_train[index]], index2label[pred], words, w)
    plot_sentence(words, w)