In [338]:
import numpy as np
from scipy.special import softmax

In [339]:
np.random.seed(234)

In [340]:
# Create R,G,B channel with 2x2 array filled with random values between 0 and 1
r_channel = np.random.rand(2, 2).round(1)
g_channel = np.random.rand(2, 2).round(1)
b_channel = np.random.rand(2, 2).round(1)

print(r_channel)
print(g_channel)
print(b_channel)

[[0.  0.3]
 [0.9 1. ]]
[[0.2 0.7]
 [0.9 0.6]]
[[0.5 0.5]
 [0.7 0.3]]


In [341]:
# Flatten each channel
r_channel_flattened = r_channel.flatten()
g_channel_flattened = g_channel.flatten()
b_channel_flattened = b_channel.flatten()

print(r_channel_flattened)
print(g_channel_flattened)
print(b_channel_flattened)

[0.  0.3 0.9 1. ]
[0.2 0.7 0.9 0.6]
[0.5 0.5 0.7 0.3]


In [342]:
# Combine the flattened channels into a 4x3 matrix
combined_matrix = np.column_stack([r_channel_flattened, g_channel_flattened, b_channel_flattened])

print(combined_matrix.shape)
print(combined_matrix)

(4, 3)
[[0.  0.2 0.5]
 [0.3 0.7 0.5]
 [0.9 0.9 0.7]
 [1.  0.6 0.3]]


In [343]:
# Create weights for Q, K, V with 3x2 array filled with random values
q_weight = np.random.randn(3, 2).round(1)
k_weight = np.random.randn(3, 2).round(1)
v_weight = np.random.randn(3, 2).round(1)

print(q_weight)
print(k_weight)
print(v_weight)

[[ 0.   1.4]
 [-0.6  0.9]
 [-1.3  1.1]]
[[-1.9  0.5]
 [ 1.3 -0.2]
 [-0.1 -0.9]]
[[-0.3  0.9]
 [-0.6 -1.3]
 [ 1.7  0.3]]


In [344]:
# Dot product of combined_matrix with Q, K, V weights
q = np.dot(combined_matrix, q_weight)
k = np.dot(combined_matrix, k_weight)
v = np.dot(combined_matrix, v_weight)

print(q)
print(k)
print(v)

[[-0.77  0.73]
 [-1.07  1.6 ]
 [-1.45  2.84]
 [-0.75  2.27]]
[[ 0.21 -0.49]
 [ 0.29 -0.44]
 [-0.61 -0.36]
 [-1.15  0.11]]
[[ 0.73 -0.11]
 [ 0.34 -0.49]
 [ 0.38 -0.15]
 [-0.15  0.21]]


In [345]:
# Dot product of Q, K (transposed)
k_transposed = k.T  # Transpose k
attention_weight = np.dot(q, k_transposed).round(2)

print(attention_weight.shape)
print(attention_weight)

(4, 4)
[[-0.52 -0.54  0.21  0.97]
 [-1.01 -1.01  0.08  1.41]
 [-1.7  -1.67 -0.14  1.98]
 [-1.27 -1.22 -0.36  1.11]]


In [346]:
# Devided attention weight with numbers of things in K
num_k = k.size
attention_weight = (attention_weight / np.sqrt(num_k)).round(2)

print(num_k)
print(attention_weight)

8
[[-0.18 -0.19  0.07  0.34]
 [-0.36 -0.36  0.03  0.5 ]
 [-0.6  -0.59 -0.05  0.7 ]
 [-0.45 -0.43 -0.13  0.39]]


In [347]:
# Apply softmax to attention weight
attention_weight = softmax(attention_weight, axis=1).round(2)

print(attention_weight)

[[0.2  0.2  0.26 0.34]
 [0.17 0.17 0.25 0.4 ]
 [0.13 0.14 0.23 0.5 ]
 [0.18 0.18 0.24 0.41]]


In [348]:
# Dot product of attention weight with V (Perform weighted average on V)
attention_value = np.dot(attention_weight, v)

print(attention_value)

[[ 0.2618 -0.0876]
 [ 0.2169 -0.0555]
 [ 0.1549 -0.0124]
 [ 0.2223 -0.0579]]


In [349]:
# Dot product with learnable output weight (linear layer) to convert back to RGB space
output_weight = np.random.randn(2, 3)
output = np.dot(attention_value, output_weight).round(2)

print(output)

[[ 0.05  0.2  -0.01]
 [ 0.04  0.18  0.  ]
 [ 0.02  0.15  0.02]
 [ 0.04  0.18  0.  ]]


In [350]:
# Skipped connection with original RGB value to preserve original pixel value information
output = output + combined_matrix

print(combined_matrix)
print(output)

[[0.  0.2 0.5]
 [0.3 0.7 0.5]
 [0.9 0.9 0.7]
 [1.  0.6 0.3]]
[[0.05 0.4  0.49]
 [0.34 0.88 0.5 ]
 [0.92 1.05 0.72]
 [1.04 0.78 0.3 ]]
