# Summer School

## Practice

In [1]:
import jax
from jax import random
import jax.numpy as jnp

from src.attention import scaled_dot_product, MultiheadAttention, EncoderBlock, TransformerEncoder

In [2]:
main_rng = random.PRNGKey(42)

In [3]:
seq_len, d_k = 3, 2
main_rng, rand1 = random.split(main_rng)
qkv = random.normal(rand1, (3, seq_len, d_k))
q, k, v = qkv[0], qkv[1], qkv[2]
values, attention = scaled_dot_product(q, k, v)

print(f"{values=}\n{attention=}")
del rand1, qkv, q, k, v

values=Array([[ 0.376226  , -0.1465618 ],
       [-0.42778558, -0.5989566 ],
       [ 0.43624768, -0.11678301]], dtype=float32)
attention=Array([[0.27963293, 0.54049295, 0.17987415],
       [0.22194658, 0.06706189, 0.7109916 ],
       [0.27977085, 0.5837308 , 0.13649832]], dtype=float32)


In [4]:
main_rng, x_rng = random.split(main_rng)
x = random.normal(x_rng, (3, 16, 128))
mh_attn = MultiheadAttention(embed_dim=128, num_heads=4)

main_rng, init_rng = random.split(main_rng)
params = mh_attn.init(init_rng, x)['params']
out, attn = mh_attn.apply({'params': params}, x)

print(f"{out.shape=}\n{attn.shape=}")
del mh_attn, params

out.shape=(3, 16, 128)
attn.shape=(3, 4, 16, 16)


In [5]:
main_rng, x_rng = random.split(main_rng)
x = random.normal(x_rng, (3, 16, 128))

encoder_block = EncoderBlock(input_dim=128, num_heads=4, dim_feedforward=512, dropout_prob=0.1)
main_rng, init_rng, dropout_init_rng = random.split(main_rng, 3)
params = encoder_block.init({'params': init_rng, 'dropout': dropout_init_rng}, x, train=True)['params']

main_rng, dropout_apply_rng = random.split(main_rng)
output = encoder_block.apply({'params': params}, x, train=True, rngs={'dropout': dropout_apply_rng})

print('Output', output.shape)
del encoder_block, params

Output (3, 16, 128)


In [7]:
main_rng, x_rng = random.split(main_rng)
x = random.normal(x_rng, (3, 16, 128))

trans_enc = TransformerEncoder(
    num_layers = 5,
    input_dim=128,
    num_heads=4,
    dim_feedforward=256,
    dropout_prob=0.15
)

main_rng, init_rng, dropout_init_rng = random.split(main_rng, 3)
params = trans_enc.init({'params': init_rng, 'dropout': dropout_init_rng}, x, train=True)['params']

main_rng, dropout_apply_rng = random.split(main_rng)
binded_mod = trans_enc.bind({'params': params}, rngs={'dropout': dropout_apply_rng})
output = binded_mod(x, train=True)
print('Output', output.shape)

attention_maps = binded_mod.get_attention_maps(x, train=True)
print('Attention maps', len(attention_maps), attention_maps[0].shape)

del trans_enc, binded_mod, params

Output (3, 16, 128)
Attention maps 5 (3, 4, 16, 16)
