## Attention Mechanism

### Scaled Dot-Product Attention
$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$

### Multi-Head Attention
$$
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_h)W^O
$$
where each head is computed as:
$$
\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
$$

In [4]:
import networkx as nx
import numpy as np

In [17]:
class Node:
    def __init__(self, name):
        self.name = name
        self.data = np.random.randn(2)
        self.wkey = np.random.randn(2,2)
        self.wquery = np.random.randn(2,2)
        self.wvalue = np.random.randn(2,2)
    
    def data(self):
        return self.data
    "key is the word embedding that this node has"
    def key(self):
        return np.dot(self.data, self.wkey)
    "query is the word that we are looking for"
    def query(self):
        return np.dot(self.data, self.wquery)
    "value is the word that this node communicates"
    def value(self):
        return np.dot(self.data, self.wvalue)

In [16]:
node = Node("test")
print(node.data)
print(node.wkey)
print(node.key())                     

[0.31061503 1.21354462]
[[ 0.17328739 -0.36964248]
 [-0.42329822 -0.92023993]]
[-0.45986561 -1.23156872]


In [None]:
class Graph:
    def __init__(self):
        self.nodes = [Node(i) for i in range(10)]
        randi = lambda: np.random.randint(len(self.nodes))
        self.edges = [(randi(), randi()) for i in range(10)]
    
    def run(self):
        updates = []
        for i,n in enumerate(self.nodes):
            #what is this node looking for
            q = n.query()
            inputs = [self.nodes[src] for (src, dst) in self.edges if dst == i]
            if len(inputs) == 0:
                continue
            #what is the key of the inputs
            keys = np.array([x.key() for x in inputs])
            #calculate the compatibility
            scores = [np.dot(k, q) for k in keys]
            scores = np.exp(scores) / np.sum(np.exp(scores))