# Positional Embeddings

This is a small visualization of time embeddings.

Good ressources: Time embeddings added instead of concatenated
* https://ai.stackexchange.com/questions/35990/why-are-embeddings-added-not-concatenated
* https://mathoverflow.net/questions/248466/why-are-two-random-vectors-in-mathbb-rn-approximately-orthogonal-for-large
* https://www.reddit.com/r/MachineLearning/comments/cttefo/comment/exs7d08/?context=3
* https://medium.com/@waelrashwan/demystifying-transformer-architecture-the-magic-of-positional-encoding-5fe8154d4a64
* Desmos: https://www.desmos.com/calculator/nvpbogxcnd?lang=de


In [None]:
import torch
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np

def get_time_embedding(time_steps, temb_dim):
    r"""
    Convert time steps tensor into an embedding using the
    sinusoidal time embedding formula
    :param time_steps: 1D tensor of length batch size
    :param temb_dim: Dimension of the embedding
    :return: BxD embedding representation of B time steps
    """
    assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
    
    # factor = 10000^(2i/d_model)
    factor = 10000 ** ((torch.arange(
        start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
    )
    
    # pos / factor
    # timesteps B -> B, 1 -> B, temb_dim
    t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
    t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
    return t_emb


temb_dim = 128
all_time_steps_considered = torch.arange(0,500)
t_emb = get_time_embedding(all_time_steps_considered, temb_dim)

x = np.arange(t_emb.shape[1])
cmap = 'magma'

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(20, 10), gridspec_kw={'height_ratios': [3, 1], 'hspace': 0.3})

# Plot heatmap
cax = ax1.imshow(t_emb.cpu().numpy(), aspect='auto', cmap=cmap, extent=[x.min(), x.max(), 0, t_emb.shape[0]])
ax1.set_xlabel('x-axis')
ax1.set_ylabel('Embeddings')
ax1.set_title('Embeddings Visualization')

# Plot each line with a different color from the colormap
cmap = cm.get_cmap(cmap, t_emb.shape[0])
plot_every = 50
for i in range(0, t_emb.shape[0], plot_every):
    ax2.plot(x, t_emb[i].cpu().numpy(), color=cmap(i))
ax2.set_xlabel('x-axis')
ax2.set_ylabel(f'Embedding values of every {plot_every}th value')
ax2.set_title('Embeddings Visualization')

# Adjust layout to make room for colorbar
fig.subplots_adjust(right=0.85)
cbar_ax = fig.add_axes([0.88, 0.15, 0.02, 0.7])
fig.colorbar(cax, cax=cbar_ax, label='Embedding values')

plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from tqdm.notebook import tqdm

# Enable inline plotting in Jupyter notebook
%matplotlib inline

def get_time_embedding(time_steps, temb_dim):
    r"""
    Convert time steps tensor into an embedding using the
    sinusoidal time embedding formula
    :param time_steps: 1D tensor of length batch size
    :param temb_dim: Dimension of the embedding
    :return: BxD embedding representation of B time steps
    """
    assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
    
    # factor = 10000^(2i/d_model)
    factor = 10000 ** ((torch.arange(
        start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
    )
    
    # pos / factor
    # timesteps B -> B, 1 -> B, temb_dim
    t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
    t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
    return t_emb

# Parameters
temb_dim = 64
time_steps = torch.arange(0, 100)
t_emb = get_time_embedding(time_steps, temb_dim)

x = np.arange(t_emb.shape[1])
cmap = 'magma'
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(20, 10), gridspec_kw={'height_ratios': [3, 1]})

# heatmap of embedding
cax = ax1.imshow(t_emb.cpu().numpy(), aspect='auto', cmap=cmap, extent=[x.min(), x.max(), 0, t_emb.shape[0]])
cbar = fig.colorbar(cax, ax=[ax1, ax2], location='right', pad=0.05)
cbar.set_label('Embedding values')
ax1.set_xlabel('Embedding dimensions')
ax1.set_ylabel('Timesteps')
ax1.set_title('Embeddings Visualization')

line = ax1.axhline(0, color='white', linewidth=2)
plot_line, = ax2.plot([], [], lw=2)
ax2.set_xlim(0, temb_dim)
ax2.set_ylim(torch.min(t_emb).item(), torch.max(t_emb).item())
ax2.set_xlabel('Embedding dimensions')
ax2.set_ylabel('Embedding value')
ax2.set_title('Embedding Vector for Current Timestep')

progress_bar = tqdm(total=len(time_steps), desc="Generating Frame for Timestep")

def update(frame):
    # Update horizontal line position on heatmap
    line.set_ydata([frame, frame])
    plot_line.set_data(x, t_emb[frame].cpu().numpy())
    progress_bar.update(1)
    progress_bar.set_description(f"Generating Frame for Timestep [{frame} / {len(time_steps)}")
    
    return line, plot_line

# Animation (takes a while)
ani = FuncAnimation(fig, update, frames=np.arange(0, len(time_steps)), blit=True, interval=100, repeat=False)
plt.close(fig)
HTML(ani.to_jshtml())
