# Self-Attention

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch

def scaled_self_attention(queries, keys, values):
    print("queries.shape", queries.shape)
    print("keys.shape", keys.shape)
    print("values.shape", values.shape)
    
    # Q dot KT
    product = torch.matmul(queries, keys, transpose_b=True)
    print("product.shape", product.shape)
    
    # Divide by keys dimension.
    keys_dim = torch.cast(torch.shape(keys)[0], torch.float32)
    scaled_product = product / torch.math.sqrt(keys_dim)
    print("scaled_product.shape", scaled_product.shape)
    
    # Attention score.
    attention_score = tf.nn.softmax(scaled_product, axis=-1)
    print("attention_score.shape", attention_score.shape)
    
    # Multiply.
    attention = tf.matmul(attention_score, values)
    print("attention.shape", attention.shape)
    
    return attention, attention_score

In [2]:
queries = np.random.random((80, 8)).astype("float32")
keys = np.random.random((80, 8)).astype("float32")
values = np.random.random((80, 8)).astype("float32")

attention, attention_score = scaled_self_attention(queries, keys, values)

plt.subplot(1, 3, 1)
plt.imshow(queries, cmap="inferno")
plt.subplot(1, 3, 2)
plt.imshow(keys, cmap="inferno")
plt.subplot(1, 3, 3)
plt.imshow(values, cmap="inferno")
plt.show()
plt.close()

plt.imshow(attention_score, cmap="inferno")
plt.show()
plt.close()

queries.shape (80, 8)
keys.shape (80, 8)
values.shape (80, 8)


NameError: name 'tf' is not defined

In [None]:
from tensorflow.keras.datasets import imdb
from tensorflow.keras import preprocessing

vocabulary_size = 10000
sequence_length = 80
embedding_size = 32

(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=10000)
x_train = preprocessing.sequence.pad_sequences(
    x_train,
    maxlen=sequence_length,
    padding="post",
    truncating="post"
)

In [None]:
from tensorflow.keras import models, layers

model_input = layers.Input(shape=(sequence_length,))
embedding = layers.Embedding(vocabulary_size, embedding_size)(model_input)

queries = layers.Dense(8, activation="relu")(embedding)
keys = layers.Dense(8, activation="relu")(embedding)
values = layers.Dense(8, activation="relu")(embedding)

attention, attention_score = layers.Attention()(
    [queries, keys, values], 
    return_attention_scores=True
)
flatten = layers.Flatten()(attention)
dense = layers.Dense(1, activation="sigmoid")(flatten)

model = models.Model(model_input, dense)
model.summary()

attention_model = models.Model(model_input, attention)
attention_score_model = models.Model(model_input, attention_score)

In [None]:
model.compile(
    optimizer="adam",
    loss="binary_crossentropy",
    metrics=["accuracy"]
)

history = model.fit(
    x_train, y_train,
    epochs=20,
    batch_size=32,
    validation_split=0.2

In [None]:
def plot_attention(sample):
    sample = np.array([sample])

    plt.subplot(1, 2, 1)
    attention_score = attention_score_model.predict(sample)[0]
    plt.imshow(attention_score, cmap="inferno")

    plt.subplot(1, 2, 2)
    attention = attention_model.predict(sample)[0]
    plt.imshow(attention, cmap="inferno")
    
    plt.show()
    plt.close()
    
sample = random.choice(x_train)
plot_attention(sample)