# Positional encoding using Sine and Cosine

Faisal Qureshi     
faisal.qureshi@ontariotechu.ca

In [None]:
import matplotlib.pyplot as plt
import numpy as np

## Plotting sine waves of different frequencies

Sine function is described as

$$
f(x) = A \sin \left( \frac{2 \pi}{B} x + C \right) + D
$$

Here $A$ refers to the amplitude, $B$ refers to the period, $C$ refers to the phase, and $D$ refers to the vertical shift of this function. Frequency $f = \frac{1}{B}$.

In [None]:
period = 128
x = np.linspace(0,128,1000)
y = np.sin((2*np.pi/period) * x)
plt.figure(figsize=(5,5))
plt.plot(x, y)
plt.title(f'Sine wave of period {period}')

## Positional encoding

In [None]:
p = np.arange(8)
print(p)

### Using binary vectors

Let's consider the problem of encoding positions as binary vectors.  To represent 8 positions, we need three bits.  

0 = 000  
1 = 001  
2 = 010  
3 = 011  
4 = 100  
5 = 101  
6 = 110  
7 = 111

Notice something else also.  Bit one, flips from 0 to 1 once.  Bit two, flips from 0 to 1 twice.  Bit three flips from 0 to 1 four times.  The problem of this encoding scheme is that it is discrete.  Also that it isn't immediately obvious how to compute distances between two positions.  We will use sine functions to create a continuous representation.  An added benefit of using sine functions is that we do not need to restrict ourselves to 0s and 1s.  Sine functions varies smoothly between -1 to 1. 

Below lets assume that we will use only two sine functions to encode 8 positions shown above.  In one period a sine function goes from 0 to 1 to 0 to -1 and back to 0.  This suggests that at least on sine function should have a period that is twice the number of positions that we intend to store.  So, if we intend to store 8 positions, we need atleast one sine wave of period 16.

### Using sin functions

In [None]:
# When using sine functions of periods 8 and 4

period1 = 8
period2 = 4

def f(x, period):
    return np.sin(2*np.pi*x/period)

x = np.linspace(0, 16, 1000)

plt.plot(x, f(x, period1))
plt.plot(x, f(x, period2))

You will notice that both curves meet at 4 and then again at 8, and so on.  This suggests that positions 0, 4, 8, will look the same. 

In [None]:
# When using sine functions of periods 8 and 4

period1 = 16
period2 = 8

def f(x, period):
    return np.sin(2*np.pi*x/period)

x = np.linspace(0, 16, 1000)

plt.plot(x, f(x, period1))
plt.plot(x, f(x, period2))

### Frequency/period considerations

Now both curves meet at 8, 16 and so on.  This suggests that the positon for location 0, 8, 16, and so on, will be the same.

So, when you use sine function to encode positions, you should use at least one sine function of period that is twice the number of locations that you plan to encode, i.e., if you plan to encode $n$ locations, at least one sine function should have period equal to $2n$.

In [None]:
# Say we want to encode positions from 0 to 7, so 8 positions in total

n = 1280
period1 = 128 # Set this to 8 or 16 to see of position 0 is mapped to position 4.  Confirming our intuition above.
num_sines = 2

periods = period1/(2**np.arange(num_sines))
print(periods[::-1])

freqs = 2*np.pi/periods
print(freqs[::-1])

pos = np.arange(n)
#print(pos)

enc = np.sin(pos.reshape(-1,1)*freqs.reshape(1,-1))
#print(enc)

# From stackoverflow

phi = np.linspace(0, 2*np.pi, n)
x = np.sin(phi)
y = np.cos(phi)
rgb_cycle = np.vstack((               # Three sinusoids
    .5*(1.+np.cos(phi          )),    # scaled to [0,1]
    .5*(1.+np.cos(phi+2*np.pi/3)),    # 120° phase shifted.
    .5*(1.+np.cos(phi-2*np.pi/3)))).T # Shape = (60,3)

fig, ax = plt.subplots(figsize=(5,5))
plt.plot(enc[:,0], enc[:,1])
ax.scatter(enc[:,0], enc[:,1], c=rgb_cycle[:], s=64)

Interestingly, it suggests that we need some "minimum" frequency sine function to encode positions.  

Given this minimum frequency, frequencies of other sine functions must be *monotonically increasing*.  This is often written as 

$$
M_{ij} = \sin\left(x_i \omega_0^{j / d_{\mathrm{model}}} \right)
$$

Where omega is smallest frequency (corresponding to the longest period).  

In [None]:
# Say we want to encode positions from 0 to 7, so 8 positions in total

n = 64
omega_0 = 1/64
num_sines = 2

freqs = (omega_0 ** ((np.arange(num_sines)+1)/num_sines)) 
print(freqs)

pos = np.arange(n)
#print(pos)

enc = np.sin(pos.reshape(-1,1)*freqs[::-1].reshape(1,-1))
#print(enc)

# From stackoverflow

phi = np.linspace(0, 2*np.pi, n)
x = np.sin(phi)
y = np.cos(phi)
rgb_cycle = np.vstack((               # Three sinusoids
    .5*(1.+np.cos(phi          )),    # scaled to [0,1]
    .5*(1.+np.cos(phi+2*np.pi/3)),    # 120° phase shifted.
    .5*(1.+np.cos(phi-2*np.pi/3)))).T # Shape = (60,3)

fig, ax = plt.subplots(figsize=(5,5))
plt.plot(enc[:,0], enc[:,1])
ax.scatter(enc[:,0], enc[:,1], c=rgb_cycle[:], s=64)

### Dealing with Translations - using both sin and cosine functions

One of the problem with this encoding is that it is not clear how to apply translation to this encoding.  Ideally, we would like something as follows:

$$
PE(x + \Delta x) = T(\Delta x)PE(x) 
$$

i.e., we want to express translation as linear transformation (matrix multiplication).  Realizing that sines and cosines operate on angles, we can use the following property to construct a positional encoding which allows translation to be expressed as a linear transformation.

$$
\cos(\theta+\phi) = \cos(\theta)\cos(\phi) - \sin(\theta)\sin(\phi) \\
\sin(\theta+\phi) = \cos(\theta)\sin(\phi) + \sin(\theta)\cos(\phi)
$$

and then

$$
\left[ \begin{array}{c} 
\cos(\theta+\phi)\\
\sin(\theta+\phi)
\end{array} \right]
=
\left[ \begin{array}{cc} 
\cos(\phi) & -\sin(\phi)\\
\sin(\phi) & \cos(\phi)\\
\end{array} \right]
\left[ \begin{array}{c} 
\cos(\theta)\\
\sin(\theta)
\end{array} \right]
$$

We can then use sine and cosine functions to construct our position encoding.

In [None]:
n = 1024
omega_0 = 1/10000
d_model = 64

powers = 2*(np.arange(d_model)//2)/d_model
print(powers)

print(omega_0)
freqs = omega_0**powers

enc = np.arange(n).reshape(-1,1)*freqs.reshape(1,-1)
enc[:,0::2] = np.sin(enc[:,0::2])
enc[:,1::2] = np.cos(enc[:,1::2])

plt.figure(figsize=(1,50))
plt.imshow(enc)

### Putting it all together

In [None]:
import positional_encoding as pe
from importlib import reload
reload(pe)

n = 64
pos = np.arange(n)
d_model = 64
enc = pe.positional_encoding(pos, d_model)

plt.figure(figsize=(10,4))
plt.subplot(121)
plt.imshow(enc)
plt.subplot(122)
plt.plot(enc[49,:])