In [1]:
import jax
import jax.numpy as jnp

In [2]:
# Input
X = jnp.array([[1.,2.,1.], [0.,3.,2.], [1.,0.,1.]])

# Parameters
K = jnp.array([[1.,0.], [0.,-1.]])
W_Q = jnp.array([[1.,0.], [-1.,0.]])
W_K = jnp.array([[1.,1.], [1.,1.]])
W_V = jnp.array([[1.,0.], [0., 1.]])
W = jnp.array([-1.,1.,0.,1.])
b = 2.0
y = 4

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [3]:
def forward_pass(K, W_Q, W_K, W_V, W, b, p=False):
    # Input
    X = jnp.array([[1,2,1], [0,3,2], [1,0,1]])
    y = 4

    conv_out = jax.scipy.signal.correlate2d(X, K, 'valid')
    if p: print('conv_out:', conv_out)

    Q = jnp.dot(conv_out, W_Q)
    if p: print('Q:', Q)
    K = jnp.dot(conv_out, W_K)
    if p: print('K:', K)
    V = jnp.dot(conv_out, W_V)
    if p: print('V:', V)

    att_scores = jnp.dot(Q, jnp.transpose(K))
    if p: print('Attention scores:', att_scores)
    
    att_weights = jax.nn.softmax(att_scores, axis=0)
    if p: print('Attention weights:', att_weights)
    
    att_out = jnp.dot(att_weights, V)
    if p: print('Attention out:', att_out)

    fc_in = att_out.ravel()
    if p: print('FC in:', fc_in)

    fc_out = jnp.dot(W, fc_in) + b
    if p: print('FC out:', fc_out)

    loss = 0.5 * (y - fc_out) ** 2
    if p: print('Loss:', loss)
    return loss    

In [4]:
forward_pass(K, W_Q, W_K, W_V, W, b, p=True)

conv_out: [[-2.  0.]
 [ 0.  2.]]
Q: [[-2.  0.]
 [-2.  0.]]
K: [[-2. -2.]
 [ 2.  2.]]
V: [[-2.  0.]
 [ 0.  2.]]
Attention scores: [[ 4. -4.]
 [ 4. -4.]]
Attention weights: [[0.5 0.5]
 [0.5 0.5]]
Attention out: [[-1.  1.]
 [-1.  1.]]
FC in: [-1.  1. -1.  1.]
FC out: 5.0
Loss: 0.5


Array(0.5, dtype=float32)

In [5]:
jax.grad(forward_pass, (0,1,2,3,4,5))(K, W_Q, W_K, W_V, W, b)

(Array([[ 2.5,  0.5],
        [ 6.5, -0.5]], dtype=float32),
 Array([[2., 2.],
        [2., 2.]], dtype=float32),
 Array([[0., 0.],
        [0., 0.]], dtype=float32),
 Array([[ 1., -2.],
        [-1.,  2.]], dtype=float32),
 Array([-1.,  1., -1.,  1.], dtype=float32),
 Array(1., dtype=float32, weak_type=True))

## Forward pass

In [6]:
jax.jacobian(jax.nn.softmax, 0)(jnp.array([[4.,-4.], [4., -4.]]), axis=0).reshape(4,4)

Array([[ 0.25,  0.  , -0.25,  0.  ],
       [ 0.  ,  0.25,  0.  , -0.25],
       [-0.25,  0.  ,  0.25,  0.  ],
       [ 0.  , -0.25,  0.  ,  0.25]], dtype=float32)

#### CNN

In [7]:
# Forward pass
conv_out = jax.scipy.signal.correlate2d(X, K, 'valid')
conv_out

Array([[-2.,  0.],
       [ 0.,  2.]], dtype=float32)

In [8]:
# # Conv2D
# conv_out = np.zeros((2,2))
# for i in range(2):
#     for j in range(2):
#         conv_out[i, j] = sum(X[i + m, j + n] * K[m, n] for m in range(K.shape[0]) for n in range(K.shape[1]))

#### Attention

In [9]:
Q = conv_out @ W_Q
Q

Array([[-2.,  0.],
       [-2.,  0.]], dtype=float32)

In [10]:
K = conv_out @ W_K
K

Array([[-2., -2.],
       [ 2.,  2.]], dtype=float32)

In [11]:
V = conv_out @ W_V
V

Array([[-2.,  0.],
       [ 0.,  2.]], dtype=float32)

In [12]:
def softmax(x):
    return jnp.exp(x) / sum(jnp.exp(x))

In [13]:
Q @ K.T

Array([[ 4., -4.],
       [ 4., -4.]], dtype=float32)

In [14]:
att_out = softmax(Q @ K.T) @ V
att_out

Array([[-1.,  1.],
       [-1.,  1.]], dtype=float32)

#### Fully connected

In [15]:
# Flatten
fc_in = att_out.flatten()
fc_in

Array([-1.,  1., -1.,  1.], dtype=float32)

In [16]:
fc_out = W @ fc_in + b
fc_out

Array(5., dtype=float32)

In [17]:
# Loss
loss = 0.5 * (y - fc_out) ** 2
loss

Array(0.5, dtype=float32)

## Backward pass

#### Fully connected

In [18]:
dL_dfc = - (y - fc_out)
dL_dfc

Array(1., dtype=float32)

In [19]:
dL_dw = dL_dfc * fc_in
dL_dw

Array([-1.,  1., -1.,  1.], dtype=float32)

In [20]:
dL_db = dL_dfc
dL_db

Array(1., dtype=float32)

#### Attention

In [21]:
dL_dat = (dL_dfc * W.T).reshape(2,2)
dL_dat

Array([[-1.,  1.],
       [ 0.,  1.]], dtype=float32)

In [22]:
P = softmax(Q @ K.T)
dL_dV = P @ dL_dat
dL_dW_V = conv_out @ dL_dV
dL_dW_V

Array([[ 1., -2.],
       [-1.,  2.]], dtype=float32)

In [23]:
dL_daw = dL_dat @ V.T

In [24]:
grad_softmax = jax.jacobian(jax.nn.softmax, 0)(Q @ K.T, axis=0).reshape(4,4)
dL_das = (grad_softmax @ dL_daw.ravel()).reshape(2,2)

In [25]:
dL_das

Array([[ 0.5,  0. ],
       [-0.5,  0. ]], dtype=float32)

In [26]:
grad_softmax

Array([[ 0.25,  0.  , -0.25,  0.  ],
       [ 0.  ,  0.25,  0.  , -0.25],
       [-0.25,  0.  ,  0.25,  0.  ],
       [ 0.  , -0.25,  0.  ,  0.25]], dtype=float32)

In [27]:
# grad_softmax = jnp.array([[[0.25, -0.25], [-0.25, 0.25]], [[0., 0.], [0., 0.]], [[0., 0.], [0., 0.]], [[0.25, -0.25], [-0.25, 0.25]]]).reshape(4,4)
# dL_das = (grad_softmax @ dL_daw.ravel()).reshape(2,2)

In [28]:
# grad_softmax = jnp.array([[0.25, -0.25], [-0.25, 0.25]])
# dL_das = grad_softmax @ dL_daw

In [29]:
dL_dK = dL_das.T @ Q
dL_dW_K = conv_out @ dL_dK
dL_dW_K

Array([[0., 0.],
       [0., 0.]], dtype=float32)

In [30]:
dL_dQ = dL_das @ K
dL_dW_Q = conv_out @ dL_dQ
dL_dW_Q

Array([[2., 2.],
       [2., 2.]], dtype=float32)

#### CNN

In [31]:
dL_dconv = dL_dQ @ W_Q.T + dL_dK @ W_K.T + dL_dV @ W_V.T
dL_dconv

Array([[-1.5,  2. ],
       [ 0.5,  0. ]], dtype=float32)

In [34]:
dL_dK = jnp.zeros_like(K)
output_size = 2
filter_size = 2

for i in range(output_size):
    for j in range(output_size):
        dL_dK += dL_dconv[i, j] * X[i:i+filter_size, j:j+filter_size]

In [37]:
jax.scipy.signal.correlate2d(X, dL_dconv, 'valid')

Array([[ 2.5,  0.5],
       [ 6.5, -0.5]], dtype=float32)

In [35]:
dL_dK

Array([[ 2.5,  0.5],
       [ 6.5, -0.5]], dtype=float32)