# Visualization of functions

In [52]:
WEIGHT = 10
D_MODEL = 8

In [53]:
WEIGHT**(1/D_MODEL)

1.333521432163324

## postition encoding

In [15]:
import numpy as np
import torch
import torch.nn as nn
import math
import matplotlib.pyplot as plt
%matplotlib inline

In [63]:
def position_encoding(t):
    result = []
    
    for i in range(D_MODEL//2):
        we = 1/(WEIGHT**(2*i/D_MODEL))
        result.append(math.sin(we*t))
        result.append(math.cos(we*t))

    return result

In [64]:
position_encoding(0)

[0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]

In [66]:
np.meshgrid(3, D_MODEL//2)

[array([[3]]), array([[4]])]

In [37]:
# pe[:, 0::2]
pe = torch.zeros([1,1,4,5])
pe

tensor([[[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]]])

In [38]:
pe[..., 0::2] = torch.ones(pe[..., 0::2].shape)
print(pe)

tensor([[[[1., 0., 1., 0., 1.],
          [1., 0., 1., 0., 1.],
          [1., 0., 1., 0., 1.],
          [1., 0., 1., 0., 1.]]]])


In [39]:
pe = torch.zeros([1,1,4,5])
pe[..., 1::2] = torch.ones(pe[..., 1::2].shape)
print(pe)

tensor([[[[0., 1., 0., 1., 0.],
          [0., 1., 0., 1., 0.],
          [0., 1., 0., 1., 0.],
          [0., 1., 0., 1., 0.]]]])


In [65]:
class PositionEmbeddingSine(nn.Module):
    def __init__(self,d_model,max_rows = 5000,weight = 10000.0) -> None:
        super().__init__()
        position = torch.zeros(max_rows,d_model)
        pe = torch.arange(0, max_rows, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0,d_model,2).float()*(-math.log(weight)/d_model))
        position[..., 0::2] = torch.sin(pe * div_term)
        position[..., 1::2] = torch.cos(pe * div_term)
        position = position.unsqueeze(0)
        # .transpose(0, 1)
        self.register_buffer('position', position)
    
    def forward(self,x):
        print(x.shape)
        x = x + self.position[:,:x.size(1),:]
        return x

In [67]:
# batch,rows,d_model
model = PositionEmbeddingSine(6)
x = torch.ones([9,3,6])
print(model(x).shape)

torch.Size([5000, 6])
torch.Size([5000, 1])
torch.Size([3])
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0000e+00, 4.6416e-02, 2.1544e-03],
        [2.0000e+00, 9.2832e-02, 4.3089e-03],
        ...,
        [4.9970e+03, 2.3194e+02, 1.0766e+01],
        [4.9980e+03, 2.3199e+02, 1.0768e+01],
        [4.9990e+03, 2.3203e+02, 1.0770e+01]])
torch.Size([9, 3, 6])
torch.Size([9, 3, 6])
