# Copyright 2018 MLBenchmark Group. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of multiheaded attention and self-attention layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from mlperf_compliance import mlperf_log
class Attention(tf.layers.Layer):
"""Multi-headed attention layer."""
def __init__(self, hidden_size, num_heads, attention_dropout, train):
if hidden_size % num_heads != 0:
raise ValueError("Hidden size must be evenly divisible by the number of "
super(Attention, self).__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.attention_dropout = attention_dropout
self.train = train
# Layers for linearly projecting the queries, keys, and values.
use_bias = False
self.q_dense_layer = tf.layers.Dense(hidden_size, use_bias=use_bias, name="q")
self.k_dense_layer = tf.layers.Dense(hidden_size, use_bias=use_bias, name="k")
self.v_dense_layer = tf.layers.Dense(hidden_size, use_bias=use_bias, name="v")
self.output_dense_layer = tf.layers.Dense(hidden_size, use_bias=use_bias,
"hidden_size": hidden_size,
"use_bias": use_bias,
"num_heads": num_heads
def split_heads(self, x):
"""Split x into different heads, and transpose the resulting value.
The tensor is transposed to insure the inner dimensions hold the correct
values during the matrix multiplication.
x: A tensor with shape [batch_size, length, hidden_size]
A tensor with shape [batch_size, num_heads, length, hidden_size/num_heads]
with tf.name_scope("split_heads"):
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
# Calculate depth of last dimension after it has been split.
depth = (self.hidden_size // self.num_heads)
# Split the last dimension
x = tf.reshape(x, [batch_size, length, self.num_heads, depth])
# Transpose the result
return tf.transpose(x, [0, 2, 1, 3])
def combine_heads(self, x):
"""Combine tensor that has been split.
x: A tensor [batch_size, num_heads, length, hidden_size/num_heads]
A tensor with shape [batch_size, length, hidden_size]
with tf.name_scope("combine_heads"):
batch_size = tf.shape(x)[0]
length = tf.shape(x)[2]
x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth]
return tf.reshape(x, [batch_size, length, self.hidden_size])
def call(self, x, y, bias, cache=None):
"""Apply attention mechanism to x and y.
x: a tensor with shape [batch_size, length_x, hidden_size]
y: a tensor with shape [batch_size, length_y, hidden_size]
bias: attention bias that will be added to the result of the dot product.
cache: (Used during prediction) dictionary with tensors containing results
of previous attentions. The dictionary must have the items:
{"k": tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]}
where i is the current decoded length.
Attention layer output with shape [batch_size, length_x, hidden_size]
# Linearly project the query (q), key (k) and value (v) using different
# learned projections. This is in preparation of splitting them into
# multiple heads. Multi-head attention uses multiple queries, keys, and
# values rather than regular attention (which uses a single q, k, v).
q = self.q_dense_layer(x)
k = self.k_dense_layer(y)
v = self.v_dense_layer(y)
if cache is not None:
# Combine cached keys and values with new keys and values.
k = tf.concat([cache["k"], k], axis=1)
v = tf.concat([cache["v"], v], axis=1)
# Update cache
cache["k"] = k
cache["v"] = v
# Split q, k, v into heads.
q = self.split_heads(q)
k = self.split_heads(k)
v = self.split_heads(v)
# Scale q to prevent the dot product between q and k from growing too large.
depth = (self.hidden_size // self.num_heads)
q *= depth ** -0.5
# Calculate dot product attention
logits = tf.matmul(q, k, transpose_b=True)
logits += bias
weights = tf.nn.softmax(logits, name="attention_weights")
if self.train:
weights = tf.nn.dropout(weights, 1.0 - self.attention_dropout)
attention_output = tf.matmul(weights, v)
# Recombine heads --> [batch_size, length, hidden_size]
attention_output = self.combine_heads(attention_output)
# Run the combined outputs through another linear projection layer.
attention_output = self.output_dense_layer(attention_output)
return attention_output
class SelfAttention(Attention):
"""Multiheaded self-attention layer."""
def call(self, x, bias, cache=None):
return super(SelfAttention, self).call(x, x, bias, cache)
