In [2]:
import numpy as np
import torch
from torch import Tensor
import math

# Initializing data...
x_i = np.array([[ 0.12, -0.87,  0.33,  0.45],
                [ 0.76,  0.21, -0.34,  0.67],
                [-0.55,  0.18,  0.29, -0.73],
                [ 0.03, -0.99,  0.42,  0.11],
                [ 0.76,  0.21, -0.34,  0.67],
                [-0.31,  0.66, -0.74,  0.09],
                [-0.92,  0.37,  0.28, -0.50]])

h_e = np.array([[ 0.64, -0.27,  0.89, -0.12],
                [-0.45,  0.33,  0.71,  0.08],
                [ 0.19, -0.94,  0.56,  0.37],
                [ 0.03,  0.85, -0.41,  0.76],
                [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])

h_d = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0],
                [0.58, -0.13,  0.94, 0.22],
                [0.45,  0.11, -0.88, 0.67],
                [0, 0, 0, 0]])

# Casting to Tensors...
x_i_tensor = torch.from_numpy(x_i)
h_d_tensor = torch.from_numpy(h_d)
h_e_tensor = torch.from_numpy(h_e)

In [3]:
'''
Calculate the score of two vectors.
'''
def score(h_d: Tensor, h_e: Tensor):
    num = torch.dot(h_d, h_e)
    dim = h_d.size(dim=0)
    denom = torch.sqrt(torch.as_tensor([dim]))

    return torch.div(num, denom)

'''
Calculate the attention matrix.
'''
def calculate_attn(x_i_tensor: Tensor, h_d_tensor: Tensor, h_e_tensor: Tensor):
    x_i_size = x_i_tensor.size(dim=0)
    a = torch.zeros((x_i_size, x_i_size))
    for i in range(1, x_i_size):
        for j in range(x_i_size):
            h_d_index = i - 1
            h_e_index = j
            a[i][j] = score(h_d_tensor[h_d_index], h_e_tensor[h_e_index])
    
    return torch.softmax(a, dim=1)

In [4]:
a = calculate_attn(x_i_tensor, h_d_tensor, h_e_tensor)
print(a)

tensor([[0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429],
        [0.2181, 0.1436, 0.1807, 0.1016, 0.1187, 0.1187, 0.1187],
        [0.1065, 0.0997, 0.1264, 0.2349, 0.1441, 0.1441, 0.1441]])


In [5]:
def contribution(i: int, a: Tensor, h_e_tensor: Tensor):
    c_i_dim = h_e_tensor.size(dim=1)
    x_i_size = h_e_tensor.size(dim=0)

    c_i = torch.zeros((1, c_i_dim))

    for j in range(x_i_size):
        c_i += a[i][j] * h_e_tensor[j]
        
    return c_i

c_i_traduction = contribution(6, a, h_e_tensor)
print(c_i_traduction)

tensor([[0.0544, 0.0850, 0.1400, 0.2205]])


In [6]:
W_Q = np.array([[ 0.12, -0.87,  0.33,  0.45],
                [ 0.76,  0.21, -0.34,  0.67],
                [-0.55,  0.18,  0.29, -0.73],
                [ 0.03, -0.99,  0.42,  0.11]])

W_K = np.array([[ 0.64, -0.27,  0.89, -0.12],
                [-0.45,  0.33,  0.71,  0.08],
                [ 0.19, -0.94,  0.56,  0.37],
                [ 0.03,  0.85, -0.41,  0.76]])

W_V = np.array([[ 0.58, -0.13,  0.94,  0.22 ],
                [-0.31,  0.66, -0.74,  0.09],
                [ 0.45,  0.11, -0.88,  0.67],
                [-0.92,  0.37,  0.28, -0.50]])

W_Q_tensor = torch.from_numpy(W_Q)
W_K_tensor = torch.from_numpy(W_K)
W_V_tensor = torch.from_numpy(W_V)

In [7]:
Q_tensor = torch.matmul(x_i_tensor, W_Q_tensor)
K_tensor = torch.matmul(x_i_tensor, W_K_tensor)
V_tensor = torch.matmul(x_i_tensor, W_V_tensor)

def scaled_dot_product_attention(Q: Tensor, K: Tensor, V: Tensor, d: int):
    scaling_factor = 1 / math.sqrt(d)
    print(scaling_factor)

    attn_weight = torch.matmul(Q, K.T) * scaling_factor
    attn_weight = torch.softmax(attn_weight, dim=1)
    return torch.matmul(attn_weight, V)

print(scaled_dot_product_attention(Q_tensor, K_tensor, V_tensor, 4))


0.5
tensor([[-0.0278,  0.0325, -0.3108,  0.0412],
        [ 0.0488, -0.1117,  0.0405,  0.0265],
        [-0.2811,  0.1322,  0.1575, -0.2035],
        [-0.0825,  0.0691, -0.2955,  0.0031],
        [ 0.0488, -0.1117,  0.0405,  0.0265],
        [-0.0947, -0.0424,  0.3061, -0.1067],
        [-0.2932,  0.1361,  0.1879, -0.2166]], dtype=torch.float64)
