In [1]:
import sys
sys.path.append('../')
import jax.numpy as jnp
import numpy as np

from attention import *
from utils import *

In [None]:
n_in, n_out = 3, 2

x = jax.random.normal(jax.random.PRNGKey(0), (n_in,))
print(x.shape)

In [None]:
context_len = 3
batch_size = 2
emb_size = 32
n_heads = 1
d_k = emb_size // n_heads

z = jnp.array(np.random.normal(size=(context_len, batch_size, emb_size)))
print(z.shape)
print(z.shape[:-1])
v = z.transpose(1, 0, 2)
print(v.shape)

preattn = PreAttention(emb_size=emb_size, n_heads=n_heads, d_k=d_k, bias=False)
rng = jax.random.PRNGKey(0)
state = preattn.init_state(rng)

q = preattn(state, z)
print(q.shape)

# scores borde vara (3, 3, 2, 1)

In [None]:
s = jnp.einsum('ibhd,jbhd->ijbh', q, q)
print(s.shape)



In [None]:
jnp.matmul(q, jnp.transpose(q, axes=(0, 1, 3, 2))).shape

In [None]:
import torch
qt = torch.randn(size=(context_len, batch_size, n_heads, d_k))
kt = torch.randn(size=(context_len, batch_size, n_heads, d_k))

qt_1 = torch.einsum('ibhd,jbhd->ijbh', qt, kt)
print(qt_1.shape)

qt_2 = jnp.einsum('ibhd,jbhd->ijbh', qt.detach().numpy(), kt.detach().numpy(), optimize='optimal')
print(qt_2.shape)

allclose = np.allclose(qt_1.detach().numpy(), qt_2)
print(allclose)

In [None]:
qtj = qt.detach().numpy()
ktj = kt.detach().numpy()

res = jnp.einsum("...id,...jd->...ij", qtj, ktj)
print(res.shape)

# Example 2: do the same operation using matmul

res2 = jnp.matmul(qtj, ktj.transpose(0, 1, 3, 2))

print(res2.shape)

allclose = np.allclose(res, res2)
print(allclose)


In [None]:
res3 = jnp.einsum("cbhd,Cbhd->cCbh", qtj, ktj)
print(res3.shape)

In [None]:
# Transposing a matrix is just dimension permutation

x = np.random.normal(size=(2, 3))
print(x)
print(x.transpose(1, 0))

In [None]:
v = np.random.normal(size=(1, 3))
jnp.einsum('ij,kj->ik', x, v)

In [None]:
x = np.random.normal(size=(2, 2))
# Matrix matrix
print(jnp.einsum('ij,kj->ik', x, x))
print(x @ x.T)
print(jnp.einsum('ik,kj->ij', x, x))

In [None]:
from attention import softmax
r = 0
x = np.random.normal(size=(2, 2))
print(x)
xs = softmax(x, dim=r)
print(xs)
print(sum(xs, 0))

In [None]:
context_len = 3
batch_size = 16
n_heads = 4

single_mask = jnp.tril(jnp.ones((context_len, context_len)), k=0)
print(single_mask)

attention_weights = jax.random.normal(jax.random.PRNGKey(0), (context_len, context_len))
print(attention_weights)

filled = jnp.where(single_mask == 0, float('-inf'), attention_weights)
print(filled)

softmaxed = softmax(filled, dim=-1)
print(softmaxed)

In [None]:
# (context_len, context_len, batch_size, n_heads)

single_mask = jnp.tril(jnp.ones((context_len, context_len)), k=0)
print(f"single_mask: {single_mask.shape}")
# expand last dim twice
single_mask = jnp.expand_dims(single_mask, axis=-1)
single_mask = jnp.expand_dims(single_mask, axis=-1)
print(f"single_mask: {single_mask.shape}")

batch_mask = jnp.repeat(single_mask, batch_size, axis=2)
print(f"batch_mask: {batch_mask.shape}")

head_mask = jnp.repeat(batch_mask, n_heads, axis=3)
print(f"head_mask: {head_mask.shape}")



In [None]:
single_mask = jnp.tril(jnp.ones((context_len, context_len)), k=0)
single_mask = single_mask.reshape((context_len, context_len, 1, 1))
mask = jnp.tile(single_mask, (1, 1, batch_size, n_heads))
print(mask.shape)

In [2]:
##### COMPARE MHA MASK #####

context_len = 2
batch_size = 1
emb_size = 1
n_heads = 1

# should be of size (context_len, context_len)
import torch
torch_mask = torch.tril(torch.ones((context_len, context_len)), diagonal=0)
print(torch_mask.numpy())

torch_mha = torch.nn.MultiheadAttention(emb_size, n_heads, bias=False)

torch_x = torch.randn(size=(context_len, batch_size, emb_size))
print(f"input: {torch_x}, shape: {torch_x.shape}")
with torch.no_grad():
    torch_res, torch_weights = torch_mha(torch_x, torch_x, torch_x, attn_mask=torch_mask.T)
print(torch_res.shape)
print(f"Output: {torch_res.numpy()}")

[[1. 0.]
 [1. 1.]]
input: tensor([[[0.4739]],

        [[0.0339]]]), shape: torch.Size([2, 1, 1])
torch.Size([2, 1, 1])
Output: [[[0.10363345]]

 [[0.06198265]]]


In [3]:
# print all weighs of torch_attn
jax_mha_state = to_jax_state(torch_mha)

In [4]:
x = jnp.array(torch_x.numpy())
print(f"Input: {x}, shape: {x.shape}")

attn = MultiHeadAttention(emb_size=emb_size, n_heads=n_heads, v_bias=False)

mask = attn.get_causal_mask(context_len, batch_size)
print(f"mask.shape: {mask.shape}")
print(mask[:,:, 0, 0])

res = attn(jax_mha_state, x, x, x, mask)
print(f"out.shape: {res.shape}")
print(res)

print(f"Mask close: {np.allclose(torch_mask.numpy(), mask[:,:, 0, 0])}")
print(f"Output Close: {np.allclose(torch_res.numpy(), res, atol=1e-5)}")

# Not getting masked outputs to match...

Input: [[[0.4738906 ]]

 [[0.03385897]]], shape: (2, 1, 1)
mask.shape: (2, 2, 1, 1)
[[1. 0.]
 [1. 1.]]
out.shape: (2, 1, 1)
[[[0.19293994]]

 [[0.10338199]]]
Mask close: True
Output Close: False
