In [None]:
#| default_exp components
%load_ext autoreload
%autoreload 2


# Transformers components to be assembled in a model

Work in progress

In [None]:
#\ export
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F

device = torch.device('mps')
torch.backends.mps.is_available()

## Embeddings

- Token embeddings, using torch embedding lookup
- Positional embedding, which can be fixed or learned.

In [None]:
#\ export

class TokenEmbeddings(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int) -> None:
        super().__init__()
        self.embedding_dim = embedding_dim
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, embedding_dim=embedding_dim)
        
    def forward(self, x):
        return self.embedding(x)

In [None]:
vocab_size = 6
embedding_dim = 50
emb = TokenEmbeddings(vocab_size, embedding_dim=embedding_dim)
emb.to(device)
x = torch.tensor([0,1,2,1], dtype=torch.long, device=device)
emb_x = emb(x)
assert emb_x.shape==(x.shape[0],embedding_dim)
assert torch.equal(emb_x[1],emb_x[3])

### Positional Encoding

As each word in a sentence simultaneously flows through the Transformer’s encoder/decoder stack, The model itself doesn’t have any sense of position/order for each word. Consequently, there’s still the need for a way to incorporate the order of the words into our model.
So we give the model some sense of position of the token in the sequence. 

Either we give the position as an input to the model or the model learns it.

#### Non learned positional embeddings

##### Potential solutions:

The first idea that might come to mind is to assign a number to each time-step within the [0, 1] range in which 0 means the first word and 1 is the last time-step. One of the problems it will introduce is that you can’t figure out how many words are present within a specific range. In other words, time-step delta doesn’t have consistent meaning across different sentences.

Another idea is to assign a number to each time-step linearly. That is, the first word is given “1”, the second word is given “2”, and so on. The problem with this approach is that not only the values could get quite large, but also our model can face sentences longer than the ones in training.

Ideally, the following criteria should be satisfied:

- It should output a unique encoding for each time-step (word’s position in a sentence)
- Distance between any two time-steps should be consistent across sentences with different lengths.
- Our model should generalize to longer sentences without any efforts. Its values should be bounded.
- It must be deterministic.

##### Proposed solutions:

The initial solution that was proposed isn’t a single number. Instead, it’s a d-dimensional vector that contains information about a specific position in a sentence. This vector, if not learned, is not integrated in the model.

$\begin{align}
  \vec{p_t}^{(i)} = f(t)^{(i)} & := 
  \begin{cases}
      \sin({\omega_k} . t),  & \text{if}\  i = 2k \\
      \cos({\omega_k} . t),  & \text{if}\  i = 2k + 1
  \end{cases}
\end{align} $

where $\omega_k = \frac{1}{n^{2k / d}}$

where : 
- L: sequence length
- t: position of token in input sequence
- d: dimension of embedding
- P(t,j): position function to map a position t in sequence to index (t,i) in positional matrix
- n: user defined scalar (ex: 10'000)
- i: column indice in positional matrix
- k: floor division (partie entière) of i by 2, so i = 2k (i is even - pair) or i=2k+1 (i is odd - impair), because we add a sin/cos pair on every two embedding slot

The frequency of sinusoidal is decreasing with the vector dimension (as i grows)

So we got a vector with pairs of sin/cosines for each frequency.

$\vec{p_t} = \begin{bmatrix} 
\sin({\omega_1}.t)\\ 
\cos({\omega_1}.t)\\ 
\\
\sin({\omega_2}.t)\\ 
\cos({\omega_2}.t)\\ 
\\
\vdots\\ 
\\
\sin({\omega_{d/2}}.t)\\ 
\cos({\omega_{d/2}}.t) 
\end{bmatrix}_{d \times 1}$

It's like the encoding of numbers in binary format: 

$
\begin{align}
  0: \ \ \ \ \color{orange}{\texttt{0}} \ \ \color{green}{\texttt{0}} \ \ \color{blue}{\texttt{0}} \ \ \color{red}{\texttt{0}} & & 
  8: \ \ \ \ \color{orange}{\texttt{1}} \ \ \color{green}{\texttt{0}} \ \ \color{blue}{\texttt{0}} \ \ \color{red}{\texttt{0}} \\
  1: \ \ \ \ \color{orange}{\texttt{0}} \ \ \color{green}{\texttt{0}} \ \ \color{blue}{\texttt{0}} \ \ \color{red}{\texttt{1}} & & 
  9: \ \ \ \ \color{orange}{\texttt{1}} \ \ \color{green}{\texttt{0}} \ \ \color{blue}{\texttt{0}} \ \ \color{red}{\texttt{1}} \\ 
  2: \ \ \ \ \color{orange}{\texttt{0}} \ \ \color{green}{\texttt{0}} \ \ \color{blue}{\texttt{1}} \ \ \color{red}{\texttt{0}} & & 
  10: \ \ \ \ \color{orange}{\texttt{1}} \ \ \color{green}{\texttt{0}} \ \ \color{blue}{\texttt{1}} \ \ \color{red}{\texttt{0}} \\ 
  3: \ \ \ \ \color{orange}{\texttt{0}} \ \ \color{green}{\texttt{0}} \ \ \color{blue}{\texttt{1}} \ \ \color{red}{\texttt{1}} & & 
  11: \ \ \ \ \color{orange}{\texttt{1}} \ \ \color{green}{\texttt{0}} \ \ \color{blue}{\texttt{1}} \ \ \color{red}{\texttt{1}} \\ 
  4: \ \ \ \ \color{orange}{\texttt{0}} \ \ \color{green}{\texttt{1}} \ \ \color{blue}{\texttt{0}} \ \ \color{red}{\texttt{0}} & & 
  12: \ \ \ \ \color{orange}{\texttt{1}} \ \ \color{green}{\texttt{1}} \ \ \color{blue}{\texttt{0}} \ \ \color{red}{\texttt{0}} \\
  5: \ \ \ \ \color{orange}{\texttt{0}} \ \ \color{green}{\texttt{1}} \ \ \color{blue}{\texttt{0}} \ \ \color{red}{\texttt{1}} & & 
  13: \ \ \ \ \color{orange}{\texttt{1}} \ \ \color{green}{\texttt{1}} \ \ \color{blue}{\texttt{0}} \ \ \color{red}{\texttt{1}} \\
  6: \ \ \ \ \color{orange}{\texttt{0}} \ \ \color{green}{\texttt{1}} \ \ \color{blue}{\texttt{1}} \ \ \color{red}{\texttt{0}} & & 
  14: \ \ \ \ \color{orange}{\texttt{1}} \ \ \color{green}{\texttt{1}} \ \ \color{blue}{\texttt{1}} \ \ \color{red}{\texttt{0}} \\
  7: \ \ \ \ \color{orange}{\texttt{0}} \ \ \color{green}{\texttt{1}} \ \ \color{blue}{\texttt{1}} \ \ \color{red}{\texttt{1}} & & 
  15: \ \ \ \ \color{orange}{\texttt{1}} \ \ \color{green}{\texttt{1}} \ \ \color{blue}{\texttt{1}} \ \ \color{red}{\texttt{1}} \\
\end{align}$

where last bit is alternating on every number and previous on every 2 numbers, and so on. So instead of using bits (which would be a waste of space), we use continuous float variant with sinuosidal functions..


In [None]:
def positional_encoding(length, depth):
  depth = depth/2

  positions = np.arange(length)[:, np.newaxis]     # (seq, 1)
  depths = np.arange(depth)[np.newaxis, :]/depth   # (1, depth)
  
  angle_rates = 1 / (10000**depths)         # (1, depth)
  angle_rads = positions * angle_rates      # (pos, depth)

  pos_encoding = np.concatenate(
      [np.sin(angle_rads), np.cos(angle_rads)],
      axis=-1) 

  return pos_encoding


pos_encoding = positional_encoding(length=50, depth=130)

# Check the shape.
print(pos_encoding.shape)

# Plot the dimensions.
plt.pcolormesh(pos_encoding.T, cmap='RdBu')
plt.ylabel('Depth')
plt.xlabel('Position')
plt.colorbar()
plt.show()

In [None]:
#\ export

class PositionalEncoder(nn.Module):
    
    def get_angles(self, t: torch.Tensor, i:torch.Tensor, n: torch.Tensor=torch.tensor(10000)) -> torch.Tensor:
        '''Compute angles so we can apply sin/cosine

        Args:
            t (int): a position in sequence or a list of positions
            i (int): a position in embedding dimension or a list of positions
            n (int, optional): Defaults to 10000.

        Returns:
            Tensor: an angle if t and i are single positions or a matrix of seq_len * embedding_dim matrix. 
        '''
        k = torch.div(i, torch.tensor(2), rounding_mode='floor') # floor division (partie entière) so i = 2k (i is even - pair) or i=2k+1 (i is odd - impair)
        s = 1 / torch.pow(n, ((torch.tensor(2)*k)/self.embedding_dim))
        angle_rates = s
        angles = t.unsqueeze(1) * angle_rates
        return angles
        
    def positional_encoding(self):
        seq_positions = torch.arange(self.max_seq_len)
        embedding_positions = torch.arange(self.embedding_dim)
        angles = self.get_angles(seq_positions, embedding_positions)
        # we start from col 0 and we do steps of 2 for sin(even), and cos for start at 1 with steps of 2 (odd)
        angles[:,0::2]=torch.sin(angles[:,0::2])
        angles[:,1::2]=torch.cos(angles[:,1::2])
        angles.requires_grad_(False)
        return angles
    
    def __init__(self, max_seq_len: int, embedding_dim: int, is_learned: bool = True, n:int=10000) -> None:
        super().__init__()
        self.max_seq_len = max_seq_len
        self.embedding_dim = embedding_dim
        self.is_learned = is_learned
        if self.is_learned:
            self.pos_embedding = nn.Embedding(max_seq_len, embedding_dim)
        else:
            pos_encodings = self.positional_encoding()
            # we add the example dimension first
            pos_encodings = pos_encodings.unsqueeze(dim=0)
            self.register_buffer('pos_encodings',pos_encodings)
    

            
    def forward(self,x):
        if self.is_learned:
            ...
        else:
            return x + self.pos_encodings[:,:x.size(1)]   
                        
        

seq_len = 50
embedding_dim = 130
pos_enc = PositionalEncoder(max_seq_len=seq_len, embedding_dim=embedding_dim, is_learned=False) 
pos_enc.positional_encoding()       
x = torch.zeros(1, seq_len, embedding_dim)
pos_encoding = pos_enc.forward(x)


In [None]:
plt.figure( dpi=100)
# we transpose depth and position axis
plt.pcolormesh(pos_encoding[0].numpy(), cmap='RdBu')
plt.xlabel('Depth')
plt.ylabel('Position')
plt.colorbar()

plt.show()

In [None]:
pos_encoding/=torch.linalg.norm(pos_encoding, dim=1, keepdims=True)
p = pos_encoding[0][1000]
dots = torch.einsum('pd,d -> p', pos_encoding[0], p)
plt.subplot(2,1,1)
plt.plot(dots)
plt.ylim([0,1])
plt.plot([950, 950, float('nan'), 1050, 1050],
         [0,1,float('nan'),0,1], color='k', label='Zoom')
plt.legend()
plt.subplot(2,1,2)
plt.plot(dots)
plt.xlim([950, 1050])
plt.ylim([0,1])

In [None]:

def get_angles(t: torch.Tensor, i:torch.Tensor, embedding_dim: torch.Tensor, n: torch.Tensor=torch.tensor(10000)):
    '''Compute angles so we can apply sin/cosine

    Args:
        t (int): a position in sequence or a list of positions
        i (int): a position in embedding dimension or a list of positions
        embeddng_dim(int): embedding dimensions of model.
        n (int, optional): _description_. Defaults to 10000.

    Returns:
        _type_: an angle if t and i are single positions or a matrix of seq_len * embedding_dim matrix. 
    '''
    k = torch.div(i, torch.tensor(2), rounding_mode='floor') # floor division (partie entière) so i = 2k (i is even - pair) or i=2k+1 (i is odd - impair)
    s = 1 / torch.pow(n, ((torch.tensor(2)*k)/embedding_dim))
    angle_rates = s
    angles = t.unsqueeze(1) * angle_rates
    return angles

torch.sin(get_angles(t=torch.tensor([0]), i=torch.tensor([0]), embedding_dim=torch.tensor(embedding_dim)))


In [None]:
i = torch.arange(0,embedding_dim, dtype=torch.float32)
t = torch.arange(0,seq_len, dtype=torch.float32)
angles = get_angles(t, i , embedding_dim=torch.tensor(embedding_dim))
# it returns a seq_len X embedding_dim matrix


In [None]:
i, embedding_dim,t

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()