# Attention Scoring Functions
:label:`sec_attention-scoring-functions`

In :numref:`sec_nadaraya-waston`,
we used a Gaussian kernel to model
interactions between queries and keys.
Treating the exponent of the Gaussian kernel
in :eqref:`eq_nadaraya-waston-gaussian`
as an *attention scoring function* (or *scoring function* for short),
the results of this function were
essentially fed into
a softmax operation.
As a result,
we obtained
a probability distribution (attention weights)
over values that are paired with keys.
In the end,
the output of the attention pooling
is simply a weighted sum of the values
based on these attention weights.

At a high level,
we can use the above algorithm
to instantiate the framework of attention mechanisms
in :numref:`fig_qkv`.
Denoting an attention scoring function by $a$,
:numref:`fig_attention_output`
illustrates how the output of attention pooling
can be computed as a weighted sum of values.
Since attention weights are
a probability distribution,
the weighted sum is essentially
a weighted average.

![Computing the output of attention pooling as a weighted average of values.](https://raw.githubusercontent.com/d2l-ai/d2l-en/master/img/attention-output.svg)
:label:`fig_attention_output`



Mathematically,
suppose that we have
a query $\mathbf{q} \in \mathbb{R}^q$
and $m$ key-value pairs $(\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)$, where any $\mathbf{k}_i \in \mathbb{R}^k$ and any $\mathbf{v}_i \in \mathbb{R}^v$.
The attention pooling $f$
is instantiated as a weighted sum of the values:

$$f(\mathbf{q}, (\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)) = \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i \in \mathbb{R}^v,$$
:eqlabel:`eq_attn-pooling`

where
the attention weight (scalar) for the query $\mathbf{q}$
and key $\mathbf{k}_i$
is computed by
the softmax operation of
an attention scoring function $a$ that maps two vectors to a scalar:

$$\alpha(\mathbf{q}, \mathbf{k}_i) = \mathrm{softmax}(a(\mathbf{q}, \mathbf{k}_i)) = \frac{\exp(a(\mathbf{q}, \mathbf{k}_i))}{\sum_{j=1}^m \exp(a(\mathbf{q}, \mathbf{k}_j))} \in \mathbb{R}.$$
:eqlabel:`eq_attn-scoring-alpha`

As we can see,
different choices of the attention scoring function $a$
lead to different behaviors of attention pooling.
In this section,
we introduce two popular scoring functions
that we will use to develop more
sophisticated attention mechanisms later.


In [None]:
%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/PlotUtils.java

In [None]:
NDManager manager = NDManager.newBaseManager(Functions.tryGpu(0));

## Masked Softmax Operation

As we just mentioned,
a softmax operation is used to
output a probability distribution as attention weights.
In some cases,
not all the values should be fed into attention pooling.
For instance,
for efficient minibatch processing in :numref:`sec_machine_translation`,
some text sequences are padded with
special tokens that do not carry meaning.
To get an attention pooling
over
only meaningful tokens as values,
we can specify a valid sequence length (in number of tokens)
to filter out those beyond this specified range
when computing softmax.
In this way,
we can implement such a *masked softmax operation*
in the following `masked_softmax` function,
where any value beyond the valid length
is masked as zero.


In [None]:
public static NDArray maskedSoftmax(NDArray X, NDArray validLens) {
    /* Perform softmax operation by masking elements on the last axis. */
    // `X`: 3D NDArray, `validLens`: 1D or 2D NDArray
    if (validLens == null) {
        return X.softmax(-1);
    } else {
        Shape shape = X.getShape();
        if (validLens.getShape().dimension() == 1) {
            validLens = validLens.repeat(shape.get(1));
        } else {
            validLens = validLens.reshape(-1);
        }
        // On the last axis, replace masked elements with a very large negative
        // value, whose exponentiation outputs 0
        X =
                X.reshape(new Shape(-1, shape.get(shape.dimension() - 1)))
                        .sequenceMask(validLens, (float) -1E6);
        return X.softmax(-1).reshape(shape);
    }
}

To demonstrate how this function works,
consider a minibatch of two $2 \times 4$ matrix examples,
where the valid lengths for these two examples
are two and three, respectively.
As a result of the masked softmax operation,
values beyond the valid lengths
are all masked as zero.


In [None]:
System.out.println(
        maskedSoftmax(
                manager.randomUniform(0, 1, new Shape(2, 2, 4)),
                manager.create(new float[] {2, 3})));

Similarly, we can also
use a two-dimensional NDArray
to specify valid lengths
for every row in each matrix example.


In [None]:
System.out.println(
        maskedSoftmax(
                manager.randomUniform(0, 1, new Shape(2, 2, 4)),
                manager.create(new float[][] {{1, 3}, {2, 4}})));

## Additive Attention
:label:`subsec_additive-attention`

In general,
when queries and keys are vectors of different lengths,
we can use additive attention
as the scoring function.
Given a query $\mathbf{q} \in \mathbb{R}^q$
and a key $\mathbf{k} \in \mathbb{R}^k$,
the *additive attention* scoring function

$$a(\mathbf q, \mathbf k) = \mathbf w_v^\top \text{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R},$$
:eqlabel:`eq_additive-attn`

where
learnable parameters
$\mathbf W_q\in\mathbb R^{h\times q}$, $\mathbf W_k\in\mathbb R^{h\times k}$, and $\mathbf w_v\in\mathbb R^{h}$.
Equivalent to :eqref:`eq_additive-attn`,
the query and the key are concatenated
and fed into an MLP with a single hidden layer
whose number of hidden units is $h$, a hyperparameter.
By using $\tanh$ as the activation function and disabling
bias terms,
we implement additive attention in the following.


In [None]:
/* Additive attention. */
public class AdditiveAttention extends AbstractBlock {
    private static final byte VERSION = 1;
    private Linear W_k;
    private Linear W_q;
    private Linear W_v;
    private Dropout dropout;
    public NDArray attentionWeights;

    public AdditiveAttention(int numHiddens, float dropout) {
        super(VERSION);
        this.W_k = Linear.builder().setUnits(numHiddens).optBias(false).build();
        this.addChildBlock("W_k", this.W_k);

        this.W_q = Linear.builder().setUnits(numHiddens).optBias(false).build();
        this.addChildBlock("W_q", this.W_q);

        this.W_v = Linear.builder().setUnits(1).optBias(false).build();
        this.addChildBlock("W_v", this.W_v);

        this.dropout = Dropout.builder().optRate(dropout).build();
        this.addChildBlock("dropout", this.dropout);
    }

    @Override
    protected NDList forwardInternal(
            ParameterStore parameterStore,
            NDList inputs,
            boolean training,
            PairList<String, Object> params) {
        // Shape of the output `queries` and `attentionWeights`:
        // (no. of queries, no. of key-value pairs)
        NDArray queries = inputs.get(0);
        NDArray keys = inputs.get(1);
        NDArray values = inputs.get(2);
        NDArray validLens = inputs.get(3);

        queries = this.W_q.forward(parameterStore, new NDList(queries), training, params).head();
        keys = this.W_k.forward(parameterStore, new NDList(keys), training, params).head();
        // After dimension expansion, shape of `queries`: (`batchSize`, no. of
        // queries, 1, `numHiddens`) and shape of `keys`: (`batchSize`, 1,
        // no. of key-value pairs, `numHiddens`). Sum them up with
        // broadcasting
        NDArray features = queries.expandDims(2).add(keys.expandDims(1));
        features = features.tanh();
        // There is only one output of `this.W_v`, so we remove the last
        // one-dimensional entry from the shape. Shape of `scores`:
        // (`batchSize`, no. of queries, no. of key-value pairs)
        NDArray result =
                this.W_v.forward(parameterStore, new NDList(features), training, params).head();
        NDArray scores = result.squeeze(-1);
        this.attentionWeights = maskedSoftmax(scores, validLens);
        // Shape of `values`: (`batchSize`, no. of key-value pairs, value
        // dimension)
        return new NDList(
                this.dropout
                        .forward(
                                parameterStore, new NDList(this.attentionWeights), training, params)
                        .head()
                        .batchDot(values));
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputShapes) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {}
}

Let us demonstrate the above `AdditiveAttention` class
with a toy example,
where shapes (batch size, number of steps or sequence length in tokens, feature size)
of queries, keys, and values
are ($2$, $1$, $20$), ($2$, $10$, $2$),
and ($2$, $10$, $4$), respectively.
The attention pooling output
has a shape of (batch size, number of steps for queries, feature size for values).


In [None]:
NDArray queries = manager.randomNormal(0, 1, new Shape(2, 1, 20), DataType.FLOAT32);
NDArray keys = manager.ones(new Shape(2, 10, 2));
// The two value matrices in the `values` minibatch are identical
NDArray values = manager.arange(40f).reshape(1, 10, 4).repeat(0, 2);
NDArray validLens = manager.create(new float[] {2, 6});

AdditiveAttention attention = new AdditiveAttention(8, 0.1f);
attention
        .forward(
                new ParameterStore(manager, false),
                new NDList(queries, keys, values, validLens),
                false)
        .head();

Although additive attention contains learnable parameters,
since every key is the same in this example,
the attention weights are uniform,
determined by the specified valid lengths.


In [None]:
PlotUtils.showHeatmaps(
            attention.attentionWeights.reshape(1, 1, 2, 10),
            "Keys",
            "Queries",
            new String[] {""},
            500,
            700);

## Scaled Dot-Product Attention

A more computationally efficient
design for the scoring function can be
simply dot product.
However,
the dot product operation
requires that both the query and the key
have the same vector length, say $d$.
Assume that
all the elements of the query and the key
are independent random variables
with zero mean and unit variance.
The dot product of
both vectors has zero mean and a variance of $d$.
To ensure that the variance of the dot product
still remains one regardless of vector length,
the *scaled dot-product attention* scoring function


$$a(\mathbf q, \mathbf k) = \mathbf{q}^\top \mathbf{k}  /\sqrt{d}$$

divides the dot product by $\sqrt{d}$.
In practice,
we often think in minibatches
for efficiency,
such as computing attention
for
$n$ queries and $m$ key-value pairs,
where queries and keys are of length $d$
and values are of length $v$.
The scaled dot-product attention
of queries $\mathbf Q\in\mathbb R^{n\times d}$,
keys $\mathbf K\in\mathbb R^{m\times d}$,
and values $\mathbf V\in\mathbb R^{m\times v}$
is


$$ \mathrm{softmax}\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \in \mathbb{R}^{n\times v}.$$
:eqlabel:`eq_softmax_QK_V`

In the following implementation of the scaled dot product attention, we use dropout for model regularization.


In [None]:
/* Scaled dot product attention. */
public class DotProductAttention extends AbstractBlock {
    private static final byte VERSION = 1;
    private Dropout dropout;
    public NDArray attentionWeights;

    public DotProductAttention(float dropout) {
        super(VERSION);

        this.dropout = Dropout.builder().optRate(dropout).build();
        this.addChildBlock("dropout", this.dropout);
        this.dropout.setInitializer(new UniformInitializer(0.07f), Parameter.Type.WEIGHT);
    }

    @Override
    protected NDList forwardInternal(
            ParameterStore parameterStore,
            NDList inputs,
            boolean training,
            PairList<String, Object> params) {
        // Shape of `queries`: (`batchSize`, no. of queries, `d`)
        // Shape of `keys`: (`batchSize`, no. of key-value pairs, `d`)
        // Shape of `values`: (`batchSize`, no. of key-value pairs, value
        // dimension)
        // Shape of `valid_lens`: (`batchSize`,) or (`batchSize`, no. of queries)
        NDArray queries = inputs.head();
        NDArray keys = inputs.get(1);
        NDArray values = inputs.get(2);
        NDArray validLens = inputs.get(3);

        Long d = queries.getShape().get(queries.getShape().dimension() - 1);
        // Swap the last two dimensions of `keys` and perform batchDot
        NDArray scores = queries.batchDot(keys.swapAxes(1, 2)).div(Math.sqrt(2));
        attentionWeights = maskedSoftmax(scores, validLens);
        return new NDList(
                this.dropout
                        .forward(
                                parameterStore, new NDList(this.attentionWeights), training, params)
                        .head()
                        .batchDot(values));
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputShapes) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {}
}

To demonstrate the above `DotProductAttention` class,
we use the same keys, values, and valid lengths from the earlier toy example
for additive attention.
For the dot product operation,
we make the feature size of queries
the same as that of keys.


In [None]:
queries = manager.randomNormal(0, 1, new Shape(2, 1, 2), DataType.FLOAT32);
DotProductAttention productAttention = new DotProductAttention(0.5f);
productAttention
        .forward(
                new ParameterStore(manager, false),
                new NDList(queries, keys, values, validLens),
                false)
        .head();

Same as in the additive attention demonstration,
since `keys` contains the same element
that cannot be differentiated by any query,
uniform attention weights are obtained.


In [None]:
PlotUtils.showHeatmaps(
        productAttention.attentionWeights.reshape(1, 1, 2, 10),
        "Keys",
        "Queries",
        new String[] {""},
        500,
        700);

## Summary

* We can compute the output of attention pooling as a weighted average of values, where different choices of the attention scoring function lead to different behaviors of attention pooling.
* When queries and keys are vectors of different lengths, we can use the additive attention scoring function. When they are the same, the scaled dot-product attention scoring function is more computationally efficient.



## Exercises

1. Modify keys in the toy example and visualize attention weights. Do additive attention and scaled dot-product attention still output the same attention weights? Why or why not?
1. Using matrix multiplications only, can you design a new scoring function for queries and keys with different vector lengths?
1. When queries and keys have the same vector length, is vector summation a better design than dot product for the scoring function? Why or why not?
