In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
np.random.seed(3)

In [3]:
# Number of inputs
N = 3
# Number of dimensions of each input
D = 4
all_x = []
for i in range(N):
    all_x.append(np.random.normal(size=(D,1)))

print(all_x)

[array([[ 1.78862847],
       [ 0.43650985],
       [ 0.09649747],
       [-1.8634927 ]]), array([[-0.2773882 ],
       [-0.35475898],
       [-0.08274148],
       [-0.62700068]]), array([[-0.04381817],
       [-0.47721803],
       [-1.31386475],
       [ 0.88462238]])]


In [4]:
np.random.seed(0)

omega_q = np.random.normal(size=(D, D))
omega_k = np.random.normal(size=(D, D))
omega_v = np.random.normal(size=(D, D))
beta_q = np.random.normal(size=(D,1))
beta_k = np.random.normal(size=(D,1))
beta_v = np.random.normal(size=(D,1))

## Self Attention

In [5]:
all_queries = []
all_keys = []
all_values = []

# 12.2 & 12.4

for x in all_x:
    query = beta_q + omega_q @ x
    key = beta_k + omega_k @ x
    value = beta_v + omega_v @ x

    all_queries.append(query)
    all_values.append(value)
    all_keys.append(key)

In [6]:
def softmax(items_in):
    items_out = np.exp(items_in) / np.sum(np.exp(items_in))
    return items_out

In [7]:
np.dot(all_keys[0].T, all_queries[0])

array([[-30.69506335]])

In [8]:
all_keys[0], all_queries[0]

(array([[ 3.69380505],
        [-3.9952365 ],
        [ 3.7499519 ],
        [ 3.12154293]]),
 array([[-2.36543342],
        [ 3.07476988],
        [-3.59698468],
        [ 1.22226059]]))

In [9]:
all_x_prime = []
# For each output
for n in range(N):
    # Create list for dot products of query N with all keys
    all_km_qn = []
    for key in all_keys:
        dot_product = np.dot(key.T, all_queries[n])
        all_km_qn.append(dot_product)
    
    attention = softmax(all_km_qn)
    print("Attentions for output ", n)
    # should be positive sum to one
    print(attention)

    # Compute a weighted sum of all of the values according to the attention 12.3
    x_prime = np.sum(attention * all_values, axis=0)
        
    all_x_prime.append(x_prime)

print("x_prime_0_calculated:", all_x_prime[0].transpose())
print("x_prime_0_true: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]")
print("x_prime_1_calculated:", all_x_prime[1].transpose())
print("x_prime_1_true: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]")
print("x_prime_2_calculated:", all_x_prime[2].transpose())
print("x_prime_2_true: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]")

Attentions for output  0
[[[1.24326146e-13]]

 [[9.98281489e-01]]

 [[1.71851130e-03]]]
Attentions for output  1
[[[2.79525306e-12]]

 [[5.85506360e-03]]

 [[9.94144936e-01]]]
Attentions for output  2
[[[0.00505708]]

 [[0.00654776]]

 [[0.98839516]]]
x_prime_0_calculated: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]
x_prime_0_true: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]
x_prime_1_calculated: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]
x_prime_1_true: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]
x_prime_2_calculated: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]
x_prime_2_true: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]


#### Now let's compute the same thing, but using matrix calculations. We'll store the inputs in the columns of a D x N matrix, using equations 12.6 and 12.7/8.

In [10]:
# Define softmax operation that works independently on each column
def softmax_cols(data_in):
  # Exponentiate all of the values
  exp_values = np.exp(data_in) ;
  # Sum over columns
  denom = np.sum(exp_values, axis = 0);
  # not required as broadcasting would take care of it
  # Replicate denominator to N rows
  # denom = np.matmul(np.ones((data_in.shape[0],1)), denom[np.newaxis,:])
  # Compute softmax
  softmax = exp_values / denom
  # return the answer
  return softmax

In [13]:
def self_attention(X, omega_v, omega_q, omega_k, beta_v, beta_q, beta_k):
    # 1. Compute queries, keys, and values
    # 2. Compute dot products
    # 3. Apply softmax to calculate attentions
    # 4. Weight values by attentions
    one = np.ones((N, 1))
    queries = beta_q @ one.T + omega_q @ X
    keys = beta_k @ one.T + omega_k @ X
    values = beta_v @ one.T + omega_v @ X

    dot_products = np.dot(keys.T, queries)
    attention = softmax_cols(dot_products)
    print(attention)
    X_prime = values @ attention
    
    return X_prime
     

In [14]:

# Copy data into matrix
X = np.zeros((D, N))
X[:,0] = np.squeeze(all_x[0])
X[:,1] = np.squeeze(all_x[1])
X[:,2] = np.squeeze(all_x[2])

# Run the self attention mechanism
X_prime = self_attention(X,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k)

# Print out the results
print(X_prime)

[[1.24326146e-13 2.79525306e-12 5.05707907e-03]
 [9.98281489e-01 5.85506360e-03 6.54776072e-03]
 [1.71851130e-03 9.94144936e-01 9.88395160e-01]]
[[ 0.94744244  1.64201168  1.61949281]
 [-0.24348429 -0.08470004 -0.06641533]
 [-0.91310441  4.02764044  3.96863308]
 [-0.44522983  2.18690791  2.15858316]]



#### Printing out the attention matrix we see that the values are quite extreme, one is very close to one and the others are very close to zero. 

#### Now we'll fix this problem by using scaled dot-product attention.

In [15]:
def scaled_dot_product_self_attention(X,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k):

    # 1. Compute queries, keys, and values
    # 2. Compute dot products
    # 3. Scale the dot products as in equation 12.9
    # 4. Apply softmax to calculate attentions
    # 5. Weight values by attentions
    one = np.ones((N, 1))
    queries = beta_q @ one.T + omega_q @ X
    keys = beta_k @ one.T + omega_k @ X
    values = beta_v @ one.T + omega_v @ X

    dot_products = np.dot(keys.T, queries)
    attention = softmax_cols(dot_products / np.sqrt(queries.shape[0]))
    print(attention)
    X_prime = values @ attention

    return X_prime

In [16]:

# Run the self attention mechanism
X_prime = scaled_dot_product_self_attention(X,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k)

# Print out the results
print(X_prime)

[[3.38843552e-07 1.55730194e-06 6.20418746e-02]
 [9.60161968e-01 7.12734969e-02 7.05962187e-02]
 [3.98376935e-02 9.28724946e-01 8.67361907e-01]]
[[ 0.97411966  1.59622051  1.32638014]
 [-0.23738409 -0.09516106  0.13062402]
 [-0.72333202  3.70194096  3.02371664]
 [-0.34413007  2.01339538  1.6902419 ]]


## Multihead attention

In [17]:

# Set seed so we get the same random numbers
np.random.seed(3)
# Number of inputs
N = 6
# Number of dimensions of each input
D = 8
# Create an empty list
X = np.random.normal(size=(D,N))
# Print X
print(X)

[[ 1.78862847  0.43650985  0.09649747 -1.8634927  -0.2773882  -0.35475898]
 [-0.08274148 -0.62700068 -0.04381817 -0.47721803 -1.31386475  0.88462238]
 [ 0.88131804  1.70957306  0.05003364 -0.40467741 -0.54535995 -1.54647732]
 [ 0.98236743 -1.10106763 -1.18504653 -0.2056499   1.48614836  0.23671627]
 [-1.02378514 -0.7129932   0.62524497 -0.16051336 -0.76883635 -0.23003072]
 [ 0.74505627  1.97611078 -1.24412333 -0.62641691 -0.80376609 -2.41908317]
 [-0.92379202 -1.02387576  1.12397796 -0.13191423 -1.62328545  0.64667545]
 [-0.35627076 -1.74314104 -0.59664964 -0.58859438 -0.8738823   0.02971382]]


In [18]:
# Number of heads
H = 2
# QDV dimension
H_D = int(D/H)

# Set seed so we get the same random numbers
np.random.seed(0)

# Choose random values for the parameters for the first head
omega_q1 = np.random.normal(size=(H_D,D))
omega_k1 = np.random.normal(size=(H_D,D))
omega_v1 = np.random.normal(size=(H_D,D))
beta_q1 = np.random.normal(size=(H_D,1))
beta_k1 = np.random.normal(size=(H_D,1))
beta_v1 = np.random.normal(size=(H_D,1))

# Choose random values for the parameters for the second head
omega_q2 = np.random.normal(size=(H_D,D))
omega_k2 = np.random.normal(size=(H_D,D))
omega_v2 = np.random.normal(size=(H_D,D))
beta_q2 = np.random.normal(size=(H_D,1))
beta_k2 = np.random.normal(size=(H_D,1))
beta_v2 = np.random.normal(size=(H_D,1))

# Choose random values for the parameters
omega_c = np.random.normal(size=(D,D))

In [24]:
 # Now let's compute self attention in matrix form
def multihead_scaled_self_attention(X,omega_v1, omega_q1, omega_k1, beta_v1, beta_q1, beta_k1, omega_v2, omega_q2, omega_k2, beta_v2, beta_q2, beta_k2, omega_c):

    sa1 = scaled_dot_product_self_attention(X, omega_v1, omega_q1, omega_k1, beta_v1, beta_q1, beta_k1)
    sa2 = scaled_dot_product_self_attention(X, omega_v2, omega_q2, omega_k2, beta_v2, beta_q2, beta_k2)

    X_prime = omega_c @ np.vstack((sa1, sa2))
    return X_prime

In [25]:
# Run the self attention mechanism
X_prime = multihead_scaled_self_attention(X,omega_v1, omega_q1, omega_k1, beta_v1, beta_q1, beta_k1, omega_v2, omega_q2, omega_k2, beta_v2, beta_q2, beta_k2, omega_c)

# Print out the results
np.set_printoptions(precision=3)
print("Your answer:")
print(X_prime)

print("True values:")
print("[[-21.207  -5.373 -20.933  -9.179 -11.319 -17.812]")
print(" [ -1.995   7.906 -10.516   3.452   9.863  -7.24 ]")
print(" [  5.479   1.115   9.244   0.453   5.656   7.089]")
print(" [ -7.413  -7.416   0.363  -5.573  -6.736  -0.848]")
print(" [-11.261  -9.937  -4.848  -8.915 -13.378  -5.761]")
print(" [  3.548  10.036  -2.244   1.604  12.113  -2.557]")
print(" [  4.888  -5.814   2.407   3.228  -4.232   3.71 ]")
print(" [  1.248  18.894  -6.409   3.224  19.717  -5.629]]")

[[6.451e-15 4.457e-14 2.405e-03 9.720e-01 8.643e-02 5.027e-01]
 [1.125e-14 6.564e-17 3.265e-03 2.318e-06 3.599e-04 7.835e-04]
 [1.658e-12 1.384e-08 1.481e-03 2.902e-03 1.048e-03 1.122e-03]
 [5.500e-03 8.826e-08 1.588e-01 2.611e-07 2.639e-02 4.114e-02]
 [9.945e-01 1.109e-09 8.290e-01 2.149e-07 8.553e-01 4.541e-01]
 [6.352e-06 1.000e+00 5.036e-03 2.509e-02 3.047e-02 1.798e-04]]
[[2.543e-03 3.404e-02 1.329e-03 9.340e-03 3.044e-02 7.030e-03]
 [3.154e-03 9.386e-01 1.270e-04 2.490e-01 9.600e-01 2.895e-03]
 [7.096e-07 2.436e-02 9.532e-02 7.415e-01 9.580e-03 3.025e-01]
 [6.826e-01 2.964e-03 6.408e-05 2.040e-09 2.518e-07 1.759e-05]
 [3.117e-01 6.532e-06 2.764e-04 3.434e-10 1.232e-07 1.207e-04]
 [1.807e-07 5.470e-06 9.029e-01 2.150e-04 5.442e-06 6.874e-01]]
Your answer:
[[-21.207  -5.373 -20.933  -9.179 -11.319 -17.812]
 [ -1.995   7.906 -10.516   3.452   9.863  -7.24 ]
 [  5.479   1.115   9.244   0.453   5.656   7.089]
 [ -7.413  -7.416   0.363  -5.573  -6.736  -0.848]
 [-11.261  -9.937  -4.848