In [30]:
import sys
sys.path.append('../')
import attention
import jax.numpy as jnp
import numpy as np

In [31]:
import math
from typing import Optional, List

import torch
from torch import nn

# from labml import tracker


class PrepareForMultiHeadAttention(nn.Module):
    """
    <a id="PrepareMHA"></a>

    ## Prepare for multi-head attention

    This module does a linear transformation and splits the vector into given
    number of heads for multi-head attention.
    This is used to transform **key**, **query**, and **value** vectors.
    """

    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
        super().__init__()
        # Linear layer for linear transform
        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
        # Number of heads
        self.heads = heads
        # Number of dimensions in vectors in each head
        self.d_k = d_k

    def forward(self, x: torch.Tensor):
        # Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
        # We apply the linear transformation to the last dimension and split that into
        # the heads.
        head_shape = x.shape[:-1]

        # Linear transform
        x = self.linear(x)

        # Split last dimension into heads
        x = x.view(*head_shape, self.heads, self.d_k)

        # Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, heads, d_model]`
        return x

In [32]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
        """
        * `heads` is the number of heads.
        * `d_model` is the number of features in the `query`, `key` and `value` vectors.
        """

        super().__init__()

        # Number of features per head
        self.d_k = d_model // heads
        # Number of heads
        self.heads = heads

        # These transform the `query`, `key` and `value` vectors for multi-headed attention.
        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)

        # Softmax for attention along the time dimension of `key`
        self.softmax = nn.Softmax(dim=1)

        # Output layer
        self.output = nn.Linear(d_model, d_model)
        # Dropout
        self.dropout = nn.Dropout(dropout_prob)
        # Scaling factor before the softmax
        self.scale = 1 / math.sqrt(self.d_k)

        # We store attentions so that it can be used for logging, or other computations if needed
        self.attn = None
        self.saved_steps = {}

    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

        # Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$

        # einsum('ibhd,jbhd->ijbh', query, key)
        return torch.einsum('ibhd,jbhd->ijbh', query, key)

    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
        """
        `mask` has shape `[seq_len_q, seq_len_k, batch_size]`, where first dimension is the query dimension.
        If the query dimension is equal to $1$ it will be broadcasted.
        """

        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
        assert mask.shape[1] == key_shape[0]
        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]

        # Same mask applied to all heads.
        mask = mask.unsqueeze(-1)

        # resulting mask has shape `[seq_len_q, seq_len_k, batch_size, heads]`
        return mask

    def forward(self, *,
                query: torch.Tensor,
                key: torch.Tensor,
                value: torch.Tensor,
                mask: Optional[torch.Tensor] = None):
        """
        `query`, `key` and `value` are the tensors that store
        collection of *query*, *key* and *value* vectors.
        They have shape `[seq_len, batch_size, d_model]`.

        `mask` has shape `[seq_len, seq_len, batch_size]` and
        `mask[i, j, b]` indicates whether for batch `b`,
        query at position `i` has access to key-value at position `j`.
        """

        self.saved_steps['input_query'] = query

        # `query`, `key` and `value`  have shape `[seq_len, batch_size, d_model]`
        seq_len, batch_size, _ = query.shape

        if mask is not None:
            mask = self.prepare_mask(mask, query.shape, key.shape)

        # Prepare `query`, `key` and `value` for attention computation.
        # These will then have shape `[seq_len, batch_size, heads, d_k]`.
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)

        self.saved_steps['query'] = query
        self.saved_steps['key'] = key
        self.saved_steps['value'] = value

        # Compute attention scores $Q K^\top$.
        # This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`.
        # scores = self.get_scores(query, key)
        scores = torch.einsum('ibhd,jbhd->ijbh', query, key)
        self.saved_steps['scores'] = scores
        print(f"q * k^T scores.shape: {scores.shape}\n")

        # Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
        print(f"Scaling using self.scale: {self.scale}")
        scaled_scores = scores * self.scale

        self.saved_steps['scaled_scores'] = scaled_scores
        self.saved_steps['mask'] = mask
        # Apply mask
        if mask is not None:
            scaled_scores = scaled_scores.masked_fill(mask == 0, float('-inf'))

        self.saved_steps['masked_scaled_scores'] = scaled_scores

        # $softmax$ attention along the key sequence dimension
        # $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
        attn = self.softmax(scaled_scores)
        self.saved_steps['softmax'] = attn
        print(f"softmax attn.shape: {attn.shape}\n")

        # Save attentions if debugging
        #tracker.debug('attn', attn)

        # Apply dropout
        attn = self.dropout(attn)

        # Multiply by values
        # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
        self.saved_steps['scaled_values'] = x
        print(f"*v x.shape: {x.shape}\n")

        # Save attentions for any other calculations 
        self.attn = attn.detach()

        # Concatenate multiple heads
        x = x.reshape(seq_len, batch_size, -1)
        self.saved_steps['concat_heads'] = x
        print(f"After reshape x.shape: {x.shape}\n")

        # Output layer
        out = self.output(x)
        self.saved_steps['out'] = out
        print(f"*o out.shape: {out.shape}\n")
        return out

In [33]:
context_len = 2
batch_size = 2
d_model = 2
n_heads = 2

# For 2x2:
# [false true]
# [false false]
# indicates that pos 0 cannot attend to pos 1


# create causal mask of shape [seq_len_q, seq_len_k, batch_size]
base_mask = torch.tril(torch.ones((context_len, context_len), dtype=torch.bool))
mask = base_mask.unsqueeze(2).repeat(1, 1, batch_size)
print(f"mask.shape: {mask.shape}\n")
print(mask[:, :, 0])
x = torch.randn(context_len, batch_size, d_model, requires_grad=False)
print(f"Input x.shape: {x.shape}\n")

# q, k bias: False. v bias: True
# out bias: True
# mask is inverted!!!!!! (true = access, but in torch true = no access)
with torch.no_grad():
    mha = MultiHeadAttention(n_heads, d_model, dropout_prob=0.0, bias=False)
    out = mha(query=x, key=x, value=x, mask=mask)

mask.shape: torch.Size([2, 2, 2])

tensor([[ True, False],
        [ True,  True]])
Input x.shape: torch.Size([2, 2, 2])

q * k^T scores.shape: torch.Size([2, 2, 2, 2])

Scaling using self.scale: 1.0
softmax attn.shape: torch.Size([2, 2, 2, 2])

*v x.shape: torch.Size([2, 2, 2, 1])

After reshape x.shape: torch.Size([2, 2, 2])

*o out.shape: torch.Size([2, 2, 2])



In [34]:
mask_test = jnp.tril(jnp.ones((context_len, context_len), dtype=bool), k=0)
mask_tiled = jnp.tile(mask_test[:, :, None, None], [1, 1, batch_size, n_heads])
# print(mask_tiled.shape)
# print(mask_tiled[:, :, 0, 1])

test = jnp.ones((context_len, context_len, batch_size, n_heads))
# print(test[:, :, 0, 0])
test_masked = jnp.where(mask_tiled, test, float("-inf"))
# print(test_masked[:, :, 1, 0])

In [55]:
from states import *
x_jnp = jnp.array(x)
mha_jx = attention.MultiHeadAttention(d_model, n_heads, out_bias=True)
# rng = jax.random.PRNGKey(1337)
# mha_state = mha_jx.init_state(rng)
mha_state = MultiHeadAttentionState(
    query_state=LinearState(jnp.array(mha.query.linear.weight.detach().numpy()), None),
    key_state=LinearState(jnp.array(mha.key.linear.weight.detach().numpy()), None),
    value_state=LinearState(jnp.array(mha.value.linear.weight.detach().numpy()), jnp.array(mha.value.linear.bias.detach().numpy())),
    output_state=LinearState(jnp.array(mha.output.weight.detach().numpy()), jnp.array(mha.output.bias.detach().numpy())),
)

# mask_jax: [context_len, context_len, batch_size, n_heads]

mask_jax = mha_jx.get_causal_mask(context_len, batch_size)
print(mask_jax[:, :, 0, 0])

s2 = mha_jx.forward(mha_state, x_jnp, x_jnp, x_jnp, mask=mask_jax)

[[ True False]
 [ True  True]]


In [None]:
causal_mask_jax = mha_jx.get_causal_mask(3, 1)
print(causal_mask_jax[:, :, 0, 0])

# causal_mask_torch = torch.from_numpy(np.array(~causal_mask_jax[0, 0, :, :])).bool()
causal_mask_torch = causal_mask_jax[:, :, 0, 0]
print(causal_mask_torch)

In [52]:

torch.from_numpy(np.array(~mask_jax[:, :, 0, 0])).bool()

tensor([[False,  True],
        [False, False]])

In [36]:
print("\n###############\n")
for state, vec in mha_jx.debug_states.items():
    print(f"## STATE {state} ##")
    torch_vec = mha.saved_steps[state]
    print(f"{state}: {vec.shape}")
    print(f"torch_{state}: {torch_vec.shape}")
    allclose = np.allclose(torch_vec.detach().numpy(), vec, atol=1e-3)
    print(f"allclose: {allclose}")
    print(f"torch_{state}: {torch_vec.detach().numpy()}")
    print(f"jax_{state}: {vec}")
    print()

print(f"Output matches: {np.allclose(out.detach().numpy(), s2, atol=1e-3)}")



###############

## STATE input_query ##
input_query: (2, 2, 2)
torch_input_query: torch.Size([2, 2, 2])
allclose: True
torch_input_query: [[[ 0.89968914  0.70766306]
  [ 1.3755844   0.95841324]]

 [[ 0.36264262  0.04485898]
  [ 0.52395815 -1.3492844 ]]]
jax_input_query: [[[ 0.89968914  0.70766306]
  [ 1.3755844   0.95841324]]

 [[ 0.36264262  0.04485898]
  [ 0.52395815 -1.3492844 ]]]

## STATE query ##
query: (2, 2, 2, 1)
torch_query: torch.Size([2, 2, 2, 1])
allclose: True
torch_query: [[[[-0.24999537]
   [-0.26984102]]

  [[-0.30883878]
   [-0.42441514]]]


 [[[ 0.0420033 ]
   [-0.13179906]]

  [[ 0.90056217]
   [-0.32592243]]]]
jax_query: [[[[-0.24999537]
   [-0.26984102]]

  [[-0.30883875]
   [-0.42441514]]]


 [[[ 0.0420033 ]
   [-0.13179906]]

  [[ 0.90056217]
   [-0.32592246]]]]

## STATE key ##
key: (2, 2, 2, 1)
torch_key: torch.Size([2, 2, 2, 1])
allclose: True
torch_key: [[[[-0.6479513 ]
   [ 0.98214704]]

  [[-0.97032624]
   [ 1.4354385 ]]]


 [[[-0.22156297]
   [ 0.267061

In [53]:
# Compare with pytorch Multihead Attention
causal_mask = torch.triu(torch.ones(context_len, context_len), diagonal=1).bool()
print(causal_mask.shape)
print(causal_mask)

# input / output bias: True 
# k / v bias: False
mha_torch = nn.MultiheadAttention(d_model, n_heads, bias=False, dropout=0.0)

with torch.no_grad():
    # run mha_torch forward pass on x with causal mask
    out_torch, _ = mha_torch(x, x, x, attn_mask=causal_mask)
    print(f"out_torch.shape: {out_torch.shape}")
    
print(out_torch)

torch.Size([2, 2])
tensor([[False,  True],
        [False, False]])
out_torch.shape: torch.Size([2, 2, 2])
tensor([[[ 0.6829, -0.4547],
         [ 1.0136, -0.6750]],

        [[ 0.4493, -0.2994],
         [ 0.5127, -0.3313]]])


In [38]:
# Jax

from states import to_jax_state
jax_mha2 = attention.MultiHeadAttention(d_model, n_heads, out_bias=False, v_bias=False)
state2 = to_jax_state(mha_torch)
jax_out = jax_mha2(state2, x_jnp, x_jnp, x_jnp, mask=mask_jax)

print(f"jax_out.shape: {jax_out.shape}")
print(jax_out)

# compare
print(f"Output matches: {np.allclose(out_torch.detach().numpy(), jax_out, atol=1e-3)}")

jax_out.shape: (2, 2, 2)
[[[ 0.0831968  -0.3420485 ]
  [ 0.12281953 -0.48863164]]

 [[ 0.05498286 -0.20955701]
  [ 0.0882671  -0.16579203]]]
Output matches: True
