In [1]:
import sys
sys.path.append('../')
import attention
import jax.numpy as jnp
import numpy as np

In [2]:
from torch.nn import MultiheadAttention
import torch

attn = MultiheadAttention(embed_dim=512, num_heads=1, bias=False)

# print weights of attn
print(attn.in_proj_weight.shape)
print(attn.in_proj_weight.shape[0]/3)

# Torch uses a concatenated weight matrix of size (embed_dim * 3, embed_dim)
x = torch.randn(1, 512)

W_q = torch.randn(512, 512)
W_k = torch.randn(512, 512)
W_v = torch.randn(512, 512)

W = torch.cat([W_q, W_k, W_v], dim=0)
attn.in_proj_weight = torch.nn.Parameter(W)

out, _ = attn(x, x, x)

# print all learnable parameters

for name, param in attn.named_parameters():
    if param.requires_grad:
        print(name, param.data.shape)

torch.Size([1536, 512])
512.0
in_proj_weight torch.Size([1536, 512])
out_proj.weight torch.Size([512, 512])


In [3]:
from attention import *

x = jnp.array(torch.randn(10, 2))
jnp.sum(x, axis=0)

Array([-0.03500485,  2.0492291 ], dtype=float32)

In [4]:
import math
from typing import Optional, List

import torch
from torch import nn

# from labml import tracker


class PrepareForMultiHeadAttention(nn.Module):
    """
    <a id="PrepareMHA"></a>

    ## Prepare for multi-head attention

    This module does a linear transformation and splits the vector into given
    number of heads for multi-head attention.
    This is used to transform **key**, **query**, and **value** vectors.
    """

    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
        super().__init__()
        # Linear layer for linear transform
        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
        # Number of heads
        self.heads = heads
        # Number of dimensions in vectors in each head
        self.d_k = d_k

    def forward(self, x: torch.Tensor):
        # Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
        # We apply the linear transformation to the last dimension and split that into
        # the heads.
        head_shape = x.shape[:-1]

        # Linear transform
        x = self.linear(x)

        # Split last dimension into heads
        x = x.view(*head_shape, self.heads, self.d_k)

        # Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, heads, d_model]`
        return x

In [5]:
d_model = 1 # embedding dimension, split among all heads
n_heads = 1
d_k = 1 
seq_len = 10
batch_size = 2

# q, k, v have shape (seq_len, batch_size, d_model) = (context_len, batch_size, emb_size)

forward1 = PrepareForMultiHeadAttention(d_model, n_heads, d_k, bias=False)

x = torch.randn(seq_len, d_model, requires_grad=False)
y_torch = forward1(x);

In [6]:
from attention import PreAttention

preattn = PreAttention(emb_size=d_model, n_heads=n_heads, d_k=d_k, bias=False)
xjnp = jnp.array(x)
W = jnp.array(forward1.linear.weight.detach())
state = DenseState(W, None)
print(f"W.shape: {W.shape}")
y_jax = preattn(state, xjnp);


np.allclose(y_torch.detach().numpy(), y_jax, atol=1e-6)

W.shape: (1, 1)


True

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
        """
        * `heads` is the number of heads.
        * `d_model` is the number of features in the `query`, `key` and `value` vectors.
        """

        super().__init__()

        # Number of features per head
        self.d_k = d_model // heads
        # Number of heads
        self.heads = heads

        # These transform the `query`, `key` and `value` vectors for multi-headed attention.
        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)

        # Softmax for attention along the time dimension of `key`
        self.softmax = nn.Softmax(dim=1)

        # Output layer
        self.output = nn.Linear(d_model, d_model)
        # Dropout
        self.dropout = nn.Dropout(dropout_prob)
        # Scaling factor before the softmax
        self.scale = 1 / math.sqrt(self.d_k)

        # We store attentions so that it can be used for logging, or other computations if needed
        self.attn = None
        self.saved_steps = {}

    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

        # Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$

        # einsum('ibhd,jbhd->ijbh', query, key)
        return torch.einsum('ibhd,jbhd->ijbh', query, key)

    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
        """
        `mask` has shape `[seq_len_q, seq_len_k, batch_size]`, where first dimension is the query dimension.
        If the query dimension is equal to $1$ it will be broadcasted.
        """

        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
        assert mask.shape[1] == key_shape[0]
        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]

        # Same mask applied to all heads.
        mask = mask.unsqueeze(-1)

        # resulting mask has shape `[seq_len_q, seq_len_k, batch_size, heads]`
        return mask

    def forward(self, *,
                query: torch.Tensor,
                key: torch.Tensor,
                value: torch.Tensor,
                mask: Optional[torch.Tensor] = None):
        """
        `query`, `key` and `value` are the tensors that store
        collection of *query*, *key* and *value* vectors.
        They have shape `[seq_len, batch_size, d_model]`.

        `mask` has shape `[seq_len, seq_len, batch_size]` and
        `mask[i, j, b]` indicates whether for batch `b`,
        query at position `i` has access to key-value at position `j`.
        """
        print(f"query.shape: {query.shape}")
        print(f"key.shape: {key.shape}")
        print(f"value.shape: {value.shape}\n")

        self.saved_steps['input_query'] = query

        # `query`, `key` and `value`  have shape `[seq_len, batch_size, d_model]`
        seq_len, batch_size, _ = query.shape

        if mask is not None:
            mask = self.prepare_mask(mask, query.shape, key.shape)

        # Prepare `query`, `key` and `value` for attention computation.
        # These will then have shape `[seq_len, batch_size, heads, d_k]`.
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)

        self.saved_steps['transformed_query'] = query

        print("# Shapes after linear transform and split into heads")
        print(f"query.shape: {query.shape}")
        print(f"key.shape: {key.shape}")
        print(f"value.shape: {value.shape}\n")

        # Compute attention scores $Q K^\top$.
        # This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`.
        scores = self.get_scores(query, key)
        # einsum('ibhd,jbhd->ijbh', query, key)
        self.saved_steps['scores'] = scores
        print(f"q * k^T scores.shape: {scores.shape}\n")

        # Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
        scores *= self.scale

        self.saved_steps['scaled_scores'] = scores
        # Apply mask
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # $softmax$ attention along the key sequence dimension
        # $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
        attn = self.softmax(scores)
        self.saved_steps['softmax'] = attn
        print(f"softmax attn.shape: {attn.shape}\n")

        # Save attentions if debugging
        #tracker.debug('attn', attn)

        # Apply dropout
        attn = self.dropout(attn)
        self.saved_steps['dropout'] = attn

        # Multiply by values
        # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
        self.saved_steps['post_softmax'] = x
        print(f"*v x.shape: {x.shape}\n")

        # Save attentions for any other calculations 
        self.attn = attn.detach()

        # Concatenate multiple heads
        x = x.reshape(seq_len, batch_size, -1)
        self.saved_steps['concat'] = x
        print(f"After reshape x.shape: {x.shape}\n")

        # Output layer
        out = self.output(x)
        self.saved_steps['output'] = out
        print(f"*o out.shape: {out.shape}\n")
        return out

In [8]:
context_len = 2
batch_size = 2
d_model = 4
n_heads = 2

x = torch.randn(context_len, batch_size, d_model, requires_grad=False)
print(f"Input x.shape: {x.shape}\n")

with torch.no_grad():
    mha = MultiHeadAttention(n_heads, d_model, dropout_prob=0.0, bias=False)
    out = mha(query=x, key=x, value=x)

Input x.shape: torch.Size([2, 2, 4])

query.shape: torch.Size([2, 2, 4])
key.shape: torch.Size([2, 2, 4])
value.shape: torch.Size([2, 2, 4])

# Shapes after linear transform and split into heads
query.shape: torch.Size([2, 2, 2, 2])
key.shape: torch.Size([2, 2, 2, 2])
value.shape: torch.Size([2, 2, 2, 2])

q * k^T scores.shape: torch.Size([2, 2, 2, 2])

softmax attn.shape: torch.Size([2, 2, 2, 2])

*v x.shape: torch.Size([2, 2, 2, 2])

After reshape x.shape: torch.Size([2, 2, 4])

*o out.shape: torch.Size([2, 2, 4])



In [11]:

x_jnp = jnp.array(x)
mha_jx = attention.MultiHeadAttention(n_heads, d_model, bias=False)
# rng = jax.random.PRNGKey(1337)
# mha_state = mha_jx.init_state(rng)
mha_state = MultiHeadAttentionState(
    query_state=DenseState(jnp.array(mha.query.linear.weight.detach().numpy()), None),
    key_state=DenseState(jnp.array(mha.key.linear.weight.detach().numpy()), None),
    value_state=DenseState(jnp.array(mha.value.linear.weight.detach().numpy()), jnp.array(mha.value.linear.bias.detach().numpy())),
    output_state=DenseState(jnp.array(mha.output.weight.detach().numpy()), None),
)

s2 = mha_jx.forward(mha_state, x_jnp, x_jnp, x_jnp)

print("\n###############\n")
for state, vec in mha_jx.saved_steps.items():
    print(f"## STATE {state} ##")
    torch_vec = mha.saved_steps[state]
    print(f"{state}: {vec.shape}")
    print(f"torch_{state}: {torch_vec.shape}")
    allclose = np.allclose(torch_vec.detach().numpy(), vec, atol=1e-6)
    print(f"allclose: {allclose}")
    print()
    if True:
        # print(f"State {state} not close")
        print(f"\ntorch_{state}: {torch_vec.detach().numpy()}")
        print(f"\njax_{state}: {vec}")

    print()

# Oklart om implementationen jag jämför med är rätt, failar fortfarance på scores-delen.
# Borde jag försöka klura ut det på egen hand?

q.shape = (2, 2, 4), k.shape = (2, 2, 4), v.shape = (2, 2, 4)
# Shapes after linear transform and split into heads
query.shape = (2, 2, 2, 2), key.shape = (2, 2, 2, 2), value.shape = (2, 2, 2, 2)
q * k^T = s.shape = (2, 2, 2, 2)
Softmax attn.shape = (2, 2, 2, 2)

###############

## STATE input_query ##
input_query: (2, 2, 4)
torch_input_query: torch.Size([2, 2, 4])
allclose: True


torch_input_query: [[[ 1.7285914   0.0934689  -1.3468537  -0.56231266]
  [-0.26797503  0.19827268  0.56406194 -1.5388839 ]]

 [[ 0.08391768 -0.74255586  0.9938217  -0.30552605]
  [ 1.0010856   1.1944772   0.23038462  0.974807  ]]]

jax_input_query: [[[ 1.7285914   0.0934689  -1.3468537  -0.56231266]
  [-0.26797503  0.19827268  0.56406194 -1.5388839 ]]

 [[ 0.08391768 -0.74255586  0.9938217  -0.30552605]
  [ 1.0010856   1.1944772   0.23038462  0.974807  ]]]

## STATE transformed_query ##
transformed_query: (2, 2, 2, 2)
torch_transformed_query: torch.Size([2, 2, 2, 2])
allclose: True


torch_transformed_query