In [3]:
# Illustration of self -attention for connected graphs
import numpy as np

In [53]:
class Node:
    def __init__(self, id, embed_size) -> None:
        self.embed_size = embed_size
        # define the Weigths for keys, value, and query
        self.weights_key = np.random.randn(embed_size, embed_size)
        self.weights_query = np.random.randn(embed_size, embed_size)
        self.weights_value = np.random.randn(embed_size, embed_size)
        self.id = id
        self.data = np.random.randn(embed_size, 1)
    
    def get_key(self) -> np.array:
        return self.weights_key @ self.data
    
    def get_query(self) -> np.array:
        return self.weights_query @ self.data

    def get_value(self) -> np.array:
        return self.weights_value @ self.data

embed_size=8
data = np.random.randn(embed_size,1)
node : Node = Node(id=1, embed_size=embed_size)

print(f'Key Node : {node.get_key()}')
print(f'query Node : {node.get_query()}')
print(f'value Node : {node.get_value()}')

Key Node : [[ 0.4626994 ]
 [ 4.75630941]
 [-3.40442323]
 [-0.24714526]
 [ 0.82205102]
 [ 1.20897123]
 [ 1.39797928]
 [ 1.66658014]]
query Node : [[-1.15038849]
 [-1.29709697]
 [ 0.98150268]
 [ 4.72505494]
 [ 2.47338327]
 [ 1.50709356]
 [-0.80299357]
 [-1.70404016]]
value Node : [[ 0.60668619]
 [ 3.28401293]
 [-1.58843323]
 [ 0.00711878]
 [-0.78993758]
 [-1.35890831]
 [ 5.30046708]
 [ 1.71565161]]


In [15]:
# define the graph



Node : [[ 1.87004213]
 [ 2.95803927]
 [ 0.45843419]
 [-1.08161022]
 [ 1.20145844]
 [ 2.62779471]
 [ 4.02582803]
 [-2.40531459]]


In [86]:
# Graph
from typing import List, Dict, OrderedDict
from collections import OrderedDict
class Graph:
    def __init__(self, embed_size) -> None:
        self.nodes : OrderedDict[int, Node] = OrderedDict()
        self.out_edges : OrderedDict[int, List[int]] = OrderedDict() # Node id -> neighbhors in forward pass
        self.in_edges : OrderedDict[int, List[int]] = OrderedDict()
        self.embed_size = embed_size
        
    def read_graph(self,txt_file):
        with open(txt_file, 'r') as fid:
            data = fid.readlines()
            data = [list(map(int,d.rstrip().split(','))) for d in data]   
            ids = np.unique([d[0] for d in data] )
            for id in ids:
                self.nodes[id] = Node(id,self.embed_size)
                
            # read the edges
            for d in data:
                # create the outgooing edge map
                if d[0] in self.out_edges:
                    self.out_edges[d[0]].append(d[1])
                    self.out_edges[d[0]].sort()
                else:
                    self.out_edges[d[0]] = [d[1]]
                    
                # create the incoming edge map
                if d[1] in self.in_edges:
                    self.in_edges[d[1]].append(d[0])
                    self.in_edges[d[1]].sort()
                else:
                    self.in_edges[d[1]] = [d[0]] 
                    
                    
    def print_graph(self):
        for idx,node in self.nodes.items():
            list_of_out_edges = self.out_edges[node.id]
            print(f'Node {node.id} -> {list_of_out_edges}')
           
        for idx, node in self.nodes.items(): 
            list_of_in_edges = self.in_edges[node.id]
            print(f'Node {node.id} <- {list_of_in_edges}')
    
    def compute_attention(self):
        updates = []
        for idx, node in self.nodes.items():
            # compute the query
            query_vec = node.get_query()
            
            # get the keys from the incoming edges - 2 keys
            key_vec_list = [self.nodes[idx].get_key() for idx in self.in_edges[node.id]]
            
            # compute the scores - 2 scores - 
            scores = [key_vec.T.dot(query_vec) for key_vec in key_vec_list]
            scores = np.exp(np.array(scores))
            scores /= np.sum(scores)
            print(f'Scores for Node {idx} : {scores}')
            
            # compute the values from the incoming edges -  2 value vec each with D dimension
            value_vec_list = np.array([self.nodes[idx].get_value() for idx in self.in_edges[node.id]])
            value_vec_list = np.array(value_vec_list)
            
            # compute the updates
            # scores = scores.squeeze()
            # value_vec_list = value_vec_list.squeeze()
            # if scores.ndim == 0:
            #     update = scores * value_vec_list
            # else:
            #     update = np.sum(scores[:,np.newaxis] * value_vec_list, axis=0)
            update = np.sum(scores * value_vec_list, axis=0)
            print(f'Update for Node {idx} : {update}')
            updates.append(update)
        
        # do update
        for idx, update in zip(self.nodes.keys(), updates):
            self.nodes[idx].data += update
            

embed_size = 8
txt_file = "sample_graph.txt"
net : Graph = Graph(embed_size=embed_size)
net.read_graph(txt_file)    
net.print_graph()
net.compute_attention()

Node 0 -> [1, 5]
Node 1 -> [2, 4]
Node 2 -> [3]
Node 3 -> [0, 1]
Node 4 -> [3]
Node 5 -> [4]
Node 0 <- [3]
Node 1 <- [0, 3]
Node 2 <- [1]
Node 3 <- [2, 4]
Node 4 <- [1, 5]
Node 5 <- [0]
Scores for Node 0 : [[[1.]]]
Update for Node 0 : [[-0.1369384 ]
 [ 0.21513234]
 [-0.63818914]
 [ 0.87274661]
 [-3.58753591]
 [-0.2935847 ]
 [ 0.50978782]
 [ 2.90907945]]
Scores for Node 1 : [[[0.02323996]]

 [[0.97676004]]]
Update for Node 1 : [[-0.22575942]
 [ 0.24152754]
 [-0.66833969]
 [ 0.78985401]
 [-3.52546204]
 [-0.35101307]
 [ 0.58211253]
 [ 2.8256734 ]]
Scores for Node 2 : [[[1.]]]
Update for Node 2 : [[-1.67320754]
 [-0.37502434]
 [ 0.43805411]
 [ 1.86356745]
 [ 0.77415342]
 [-0.28624065]
 [-4.72286251]
 [-1.21263537]]
Scores for Node 3 : [[[0.99359623]]

 [[0.00640377]]]
Update for Node 3 : [[ 2.46450992]
 [ 1.60815185]
 [ 1.15011059]
 [-4.38669085]
 [ 4.71130729]
 [-2.53225295]
 [-3.90145627]
 [ 5.91356638]]
Scores for Node 4 : [[[5.85335271e-23]]

 [[1.00000000e+00]]]
Update for Node 4 : [[

In [None]:
# Compute self attendrion
