# DGL Walkthrough

## Message Passing

In [2]:
import dgl
import dgl.function as fn
import torch
import torch.nn.functional as F

g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
g.ndata['x'] = torch.randn(5, 2)

# we can send and receive messages using edge (u --> v)
# we cannot send messages to edges that do not exist, however
g.send_and_recv(([0, 1], [1, 2]), fn.copy_u('x', 'm'), fn.sum('m', 'h'))
g.ndata['h']

Using backend: pytorch


tensor([[ 0.0000,  0.0000],
        [ 0.1392,  1.3757],
        [ 0.0146, -1.5229],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000]])

In [31]:
# we can send and receive messages using edge ids
g.send_and_recv([0], fn.copy_u('x', 'm'), fn.sum('m', 'h'))
g.ndata['h']

tensor([[ 0.0000,  0.0000],
        [ 0.2271,  0.0784],
        [-0.3758,  0.2759],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000]])

In [26]:
# **local scope**
# we can enter a local scope for a graph such that any mutations in the graph are not 
# reflected in the original graph
g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
with g.local_scope():
    g.ndata['x'] = torch.randn(5, 2)
    print(g.ndata)
print(g.ndata)

{'x': tensor([[ 0.2903, -0.1730],
        [ 0.1695, -0.0738],
        [ 0.6497, -1.1092],
        [ 0.6976, -0.7749],
        [-0.6042,  0.7806]])}
{}


In [27]:
"""
message passaging using dgl resembles functional programming. We can make user-define function
creation easier, by implementing a compose function
"""
import numpy as np
import torch as th

def compose(*funcs):
    def composed(edges):
        for _func in funcs:
            _func(edges)
    return composed

g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
g.ndata['k'] = torch.randn(5, 2)
g.ndata['q'] = torch.randn(5, 2)
g.ndata['v'] = torch.randn(5, 2)


def src_dot_dst(src_field, dst_field, out_field):
    def func(edges):
        return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)}

    return func

def scaled_exp(field, scale_constant):
    def func(edges):
        # clamp for softmax numerical stability
        return {field: th.exp((edges.data[field] / scale_constant).clamp(-5, 5))}

    return func


g.apply_edges(src_dot_dst('k', 'q', 'score'))
g.apply_edges(scaled_exp('score', np.sqrt(2)))
g.update_all(fn.u_mul_e('v', 'score', 'v'), fn.sum('v', 'wv'))
g.update_all(fn.copy_edge('score', 'score'), fn.sum('score', 'z'))



# # class Encoder(nn.Module):
    
# #     def __init__(self, layer: nn.Module, N: int):
# #         super().__init__()
# #         self.N = N
# #         self.layers = clones(layer, N)
# #         self.norm = nn.LayerNorm(layer.size)
        
# #     def 

# class MultiHeadAttention(nn.Module):
    
#     def __init__(self):
#         super().__init__()
    
# class Encoder(nn.Module):
    
#     def __init__(self):
#         super().__init__()
        
# class Decoder(nn.Module):
    
#     def __init__(self):
#         super().__init__()
        
        

{'k': tensor([[-0.3580,  0.3352],
        [-0.6294, -1.2133],
        [ 2.5045,  0.4580],
        [-0.1749,  0.9500],
        [ 0.5554, -1.2297]]), 'q': tensor([[ 0.6454,  1.0083],
        [-0.1463, -0.9820],
        [-0.1117,  1.7270],
        [ 1.0037,  2.5795],
        [ 1.1025,  0.6385]]), 'v': tensor([[ 0.1606, -0.5321],
        [ 0.9059, -1.0296],
        [ 0.3193, -1.5199],
        [ 1.0204,  0.5048],
        [ 1.4193, -0.1029]]), 'wv': tensor([[  0.0000,   0.0000],
        [  0.1320,  -0.4376],
        [  0.2164,  -0.2459],
        [  4.3543, -20.7302],
        [  1.3672,   0.6764]]), 'z': tensor([[ 0.0000],
        [ 0.8223],
        [ 0.2388],
        [13.6388],
        [ 1.3399]])}
