In [16]:
import mlx.nn as nn
import mlx.core as mx

## Understanding and Coding Self-Attention 

This is an MLX version of the awesome article created by Sebastian Raschka: [Understanding and Coding Self-Attention, Multi-Head Attention, Cross-Attention, and Causal-Attention in LLMs](https://magazine.sebastianraschka.com/p/understanding-and-coding-self-attention). 

I just wanted to run this on MLX to learn and get the practive of using the lib. 


### **Embedding an Input Sentence**

In [6]:
sentence = 'Life is short, eat dessert first'

dc = { s:i for i, s in enumerate(sorted(sentence.replace(',', '').split(' '))) }

dc


{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}

In [8]:
sentence_int = mx.array([ dc[s] for s in sentence.replace(',', '').split(' ') ])

print(sentence_int)

array([0, 4, 5, 2, 1, 3], dtype=int32)


In [19]:
vocab_size = 50000

mx.random.seed(123)

embed = nn.Embedding(vocab_size, 3)
embedded_sentence = embed(sentence_int)

print(embedded_sentence, embedded_sentence.shape)

array([[-0.696892, 0.251667, -0.973329],
       [0.185253, -0.116112, 0.462164],
       [-0.169591, 0.474225, -0.639736],
       [0.0158762, 0.00966178, -0.00937334],
       [-0.372818, 0.371361, -0.115717],
       [-0.222577, 0.155811, 0.154389]], dtype=float32) [6, 3]


### **Defining the Weight Matrices**

In [23]:
mx.random.seed(123)

d = embedded_sentence.shape[1]

d_q, d_k, d_v = 2, 2, 4

W_query = mx.array(mx.random.normal((d, d_q)))
W_key = mx.array(mx.random.normal((d, d_k)))
W_value = mx.array(mx.random.normal((d, d_v)))


In [57]:
kk = embedded_sentence @ W_key
qq = embedded_sentence @ W_query
vv = embedded_sentence @ W_value

print(kk.shape, qq.shape, vv.shape)

[6, 2] [6, 2] [6, 4]


In [68]:
att

array([[-0.375703, 0.122179, -0.385584, -0.00105096, -0.421755, -0.217912],
       [0.375137, -0.21757, 0.589012, 0.0162334, 0.246473, 0.00777121],
       [-0.667561, 0.493398, -1.27491, -0.0457645, -0.244485, 0.219374],
       [-0.0348911, 0.0283096, -0.072017, -0.00279252, -0.00817113, 0.0170009],
       [0.314138, -0.200112, 0.531488, 0.0164408, 0.17365, -0.0328317],
       [0.491097, -0.326443, 0.859923, 0.0278635, 0.246611, -0.0811913]], dtype=float32)

In [69]:
att = qq @ kk.T

mx.softmax(att / sentence_int.shape[0]**0.5, -1)

array([[0.155411, 0.190439, 0.154785, 0.181096, 0.152517, 0.165752],
       [0.180194, 0.141466, 0.196635, 0.155635, 0.170973, 0.155098],
       [0.136975, 0.220029, 0.106896, 0.176558, 0.1628, 0.196742],
       [0.165107, 0.169423, 0.162624, 0.167285, 0.166918, 0.168642],
       [0.178534, 0.144726, 0.1951, 0.158103, 0.168583, 0.154954],
       [0.18508, 0.132559, 0.215155, 0.153189, 0.167499, 0.146518]], dtype=float32)

### **Computing the Unnormalized Attention Weights**

In [26]:
x_2 = embedded_sentence[1]
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

[2]
[2]
[4]


In [30]:
keys = embedded_sentence @ W_key
values = embedded_sentence @ W_value

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: [6, 2]
values.shape: [6, 4]


In [46]:
omega_24 = (query_2 * keys[4]).sum(axis=0)
print(omega_24)

array(0.246473, dtype=float32)


In [47]:
omega_2 = query_2 @ keys.T
print(omega_2)

array([0.375137, -0.21757, 0.589012, 0.0162334, 0.246473, 0.00777121], dtype=float32)


### **Computing the Attention Weights**

In [49]:
omega_2 / d_k**0.5

array([0.265262, -0.153845, 0.416495, 0.0114788, 0.174283, 0.00549508], dtype=float32)

In [53]:

attention_weights_2 = mx.softmax(omega_2 / d_k**0.5, 0)
print(attention_weights_2)

array([0.189357, 0.124528, 0.220273, 0.146915, 0.17289, 0.146038], dtype=float32)


### **Self-Attention**

In [90]:
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq = d_out_kq
        self.W_query = nn.Linear(d_in, d_out_kq, bias=False)
        self.W_key = nn.Linear(d_in, d_out_kq, bias=False)
        self.W_value = nn.Linear(d_in, d_out_v, bias=False)
    
    def __call__(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T # unnormalized attention weights

        attn_weights = mx.softmax(attn_scores / self.d_out_kq**0.5, -1)

        context_vec = attn_weights @ values
        return context_vec

In [91]:
mx.random.seed(123)

# reduce d_out_v from 4 to 1, because we have 4 heads
d_in, d_out_kq, d_out_v = 3, 2, 4

sa = SelfAttention(d_in, d_out_kq, d_out_v)
print(sa(embedded_sentence))

array([[-0.0313195, 0.104693, 0.132542, -0.0298887],
       [-0.0297479, 0.106741, 0.163512, -0.043829],
       [-0.0305369, 0.106373, 0.150997, -0.0381344],
       [-0.0302048, 0.106577, 0.15645, -0.0406181],
       [-0.0302526, 0.105162, 0.148332, -0.0372819],
       [-0.0300828, 0.105523, 0.152456, -0.0390983]], dtype=float32)
