In [1]:
import sys
sys.path.append('../')
import attention
import jax.numpy as jnp
import numpy as np

from attention import *

In [11]:
n_in, n_out = 3, 2

x = jax.random.normal(jax.random.PRNGKey(0), (n_in,))
print(x.shape)

(3,)


In [2]:
context_len = 3
batch_size = 2
emb_size = 32
n_heads = 1
d_k = emb_size // n_heads

z = jnp.array(np.random.normal(size=(context_len, batch_size, emb_size)))
print(z.shape)
print(z.shape[:-1])
v = z.transpose(1, 0, 2)
print(v.shape)

preattn = PreAttention(emb_size=emb_size, n_heads=n_heads, d_k=d_k, bias=False)
rng = jax.random.PRNGKey(0)
state = preattn.init_state(rng)

q = preattn(state, z)
print(q.shape)

# scores borde vara (3, 3, 2, 1)

(3, 2, 32)
(3, 2)
(2, 3, 32)
(3, 2, 1, 32)


In [3]:
s = jnp.einsum('ibhd,jbhd->ijbh', q, q)
print(s.shape)



(3, 3, 2, 1)


In [4]:
jnp.matmul(q, jnp.transpose(q, axes=(0, 1, 3, 2))).shape

(3, 2, 1, 1)

In [7]:
import torch
qt = torch.randn(size=(context_len, batch_size, n_heads, d_k))
kt = torch.randn(size=(context_len, batch_size, n_heads, d_k))

qt_1 = torch.einsum('ibhd,jbhd->ijbh', qt, kt)
print(qt_1.shape)

qt_2 = jnp.einsum('ibhd,jbhd->ijbh', qt.detach().numpy(), kt.detach().numpy(), optimize='optimal')
print(qt_2.shape)

allclose = np.allclose(qt_1.detach().numpy(), qt_2)
print(allclose)

torch.Size([3, 3, 2, 1])
(3, 3, 2, 1)
True


In [None]:
qtj = qt.detach().numpy()
ktj = kt.detach().numpy()

res = jnp.einsum("...id,...jd->...ij", qtj, ktj)
print(res.shape)

# Example 2: do the same operation using matmul

res2 = jnp.matmul(qtj, ktj.transpose(0, 1, 3, 2))

print(res2.shape)

allclose = np.allclose(res, res2)
print(allclose)


In [None]:
res3 = jnp.einsum("cbhd,Cbhd->cCbh", qtj, ktj)
print(res3.shape)

In [None]:
# Transposing a matrix is just dimension permutation

x = np.random.normal(size=(2, 3))
print(x)
print(x.transpose(1, 0))

In [None]:
v = np.random.normal(size=(1, 3))
jnp.einsum('ij,kj->ik', x, v)

In [None]:
x = np.random.normal(size=(2, 2))
# Matrix matrix
print(jnp.einsum('ij,kj->ik', x, x))
print(x @ x.T)
print(jnp.einsum('ik,kj->ij', x, x))

In [None]:
from attention import softmax
r = 0
x = np.random.normal(size=(2, 2))
print(x)
xs = softmax(x, dim=r)
print(xs)
print(sum(xs, 0))

In [15]:
context_len = 3
batch_size = 16
n_heads = 4

single_mask = jnp.tril(jnp.ones((context_len, context_len)), k=0)
print(single_mask)

attention_weights = jax.random.normal(jax.random.PRNGKey(0), (context_len, context_len))
print(attention_weights)

filled = jnp.where(single_mask == 0, float('-inf'), attention_weights)
print(filled)

softmaxed = softmax(filled, dim=-1)
print(softmaxed)

[[1. 0. 0.]
 [1. 1. 0.]
 [1. 1. 1.]]
[[-0.3721109   0.26423115 -0.18252768]
 [-0.7368197   0.44973662 -0.1521442 ]
 [-0.67135346 -0.5908641   0.73168886]]
[[-0.3721109         -inf        -inf]
 [-0.7368197   0.44973662        -inf]
 [-0.67135346 -0.5908641   0.73168886]]
[[1.         0.         0.        ]
 [0.23387541 0.7661246  0.        ]
 [0.16256534 0.17619112 0.66124356]]


In [16]:
# (context_len, context_len, batch_size, n_heads)

single_mask = jnp.tril(jnp.ones((context_len, context_len)), k=0)
print(f"single_mask: {single_mask.shape}")
# expand last dim twice
single_mask = jnp.expand_dims(single_mask, axis=-1)
single_mask = jnp.expand_dims(single_mask, axis=-1)
print(f"single_mask: {single_mask.shape}")

batch_mask = jnp.repeat(single_mask, batch_size, axis=2)
print(f"batch_mask: {batch_mask.shape}")

head_mask = jnp.repeat(batch_mask, n_heads, axis=3)
print(f"head_mask: {head_mask.shape}")



single_mask: (3, 3)
single_mask: (3, 3, 1, 1)
batch_mask: (3, 3, 16, 1)
head_mask: (3, 3, 16, 4)


In [17]:
single_mask = jnp.tril(jnp.ones((context_len, context_len)), k=0)
single_mask = single_mask.reshape((context_len, context_len, 1, 1))
mask = jnp.tile(single_mask, (1, 1, batch_size, n_heads))
print(mask.shape)

(3, 3, 16, 4)
