# How Positional Embeddings Work? Part 2


The attention mechanism computes dot products between projections of embeddings. 

$$
\sum_i W_i cos(\theta_{t + k} - \theta{t})
$$

In [13]:
sequence_length = 15
embedding_length = 9

import jax.numpy as jnp

def pos(t, i):

    x = t / (100.0**(i/embedding_length))
    return x, jnp.where(i%2 == 0, jnp.sin(x), jnp.cos(x))

In [14]:
import jax
import jax.numpy as jnp

rng = jax.random.PRNGKey(42)

def generate_embeddings(rng, shape=(sequence_length, embedding_length), min_val=-0.1, max_val=0.1):
    embeddings = jax.random.uniform(rng, shape, minval=min_val, maxval=max_val)
    return embeddings

sem_embeddings = generate_embeddings(rng)

In [15]:
import jax.numpy as jnp

def generate_pos_embeddings(shape=(sequence_length, embedding_length)):

    embeddings = jnp.empty(shape)
    for row in range(shape[0]):
        for col in range(shape[1]):
            embeddings = embeddings.at[(row,col)].set(pos(row,col)[1])

    return embeddings

pos_embeddings = generate_pos_embeddings()

In [16]:
def generate_dataset(num_samples, dataset):

    i_values = jax.random.randint(rng, (num_samples,), 0, 10)
    j_values = i_values + jax.random.randint(rng + 1, (num_samples,), 0, 10 - i_values)
    j_values = jnp.clip(j_values, 0, 9)

    emb_i = dataset[i_values,:]
    emb_j = dataset[j_values, :]

    diff = j_values - i_values

    return jnp.stack([emb_i, emb_j], axis=1), diff



In [None]:
import jax
import flax.nnx as nnx

class AttnBlock(nnx.Module):
    def __init__(self, h: int, rngs: nnx.Rngs):
        self.q = nnx.Linear(h, h, rngs=rngs)
        self.k = nnx.Linear(h, h, rngs=rngs)
        self.v = nnx.Linear(h, h, rngs=rngs)

    def __call__(self, x):
        q = self.q(x) # B x T x E
        k = self.k(x)
        v = self.v(x)

        att = q @ k / jnp.sqrt(x.shape[-1]) # B x T x T 
        x = att @ v # B x T x E
        return x

k = jax.random.key(42)
rngs = nnx.Rngs(k)

x = pos_embeddings[:3, :]
attnblk = AttnBlock(embedding_length, rngs)
x = attnblk(x)






(9,)


ValueError: matmul input operand 0 must have ndim at least 1, but it has ndim 0

In [22]:
import torch

src = torch.tensor([[10, 11, 12, 13],
                    [20, 21, 22, 23]])
# shape = (2 rows, 4 columns)

idx = torch.tensor([[3, 1],   # for output row 0, cols → [src[0,3], src[0,1]]
                    [2, 0]])  # for output row 1, cols → [src[1,2], src[1,0]]
# shape = (2 rows, 2 “gathered” columns)

out = torch.gather(src, dim=1, index=idx)
print(out)

tensor([[13, 11],
        [22, 20]])
