In [2]:
"hello"

'hello'

In [None]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
from torch_geometric.nn import GATConv

https://github.com/xjtuwgt/GNN-MAGNA

```GATConv(in_channels, out_channels, heads=1, concat=True, negative_slope=0.2, dropout=0.0, add_self_loops=True, edge_dim=None, fill_value='mean', bias=True, residual=False)``` 

- The graph attentional operator from the Graph Attention Networks paper 
- in_channels - size of each input sample to dervice the size from teh first input to the forward method 
- out_channel - size of each output samle 
- heads - number of multi head attentions 
- concat - if set to false, the multi head attentions are averaged instead of concatenated 
- negative slope - leaky ReLU angle of the negative slope 
- dropout - dropout probability of the normalized attention coefficients which exposes each node to stochastically sampled neighborhood during training 
- add_selfloops - add self loops to input graph 
- edge_dim - edge feature dimensionality in case there are any 
- fill value - 



```nn.ModuleList```
```nn.LayerNorm```
```nn.Linear```
```nn.Embedding```
```nn.Parameter```
```torch.bmm```


# 2.2 Multi Hop Attention Diffusion

Attention diffusion to compute the multi-hop attention directly, which operates on the MAGNA's attention scores at each layer. 

Input ($v_i, r_k, v_j$) 
- $v_i, v_j$ are nodes 
- $r_k$ is edge type 

1. MAGNA computes the attention scores on all edges 
2. Attention diffusion module then computes the attention values between pairs of nodes not directly connected by an edge via diffusion process 

## Step 1. Edge Attention Computation 

At each layer 


In [None]:
class MAGNA(nn.Module):
    def __init__(self, in_channels, out_channels, num_heads=8, alpha=0.1, num_layers=3):
        super(MAGNA, self).__init__()
        self.num_heads = num_heads
        self.alpha = alpha
        self.num_layers = num_layers
        
        self.layers = nn.ModuleList()
        
        for _ in range(num_layers):
            self.layers.append(GATConv(in_channels, out_channels, heads=num_heads, concat=True, dropout=0.6))
            in_channels = out_channels * num_heads
        
        self.norms = nn.ModuleList([nn.LayerNorm(out_channels * num_heads) for _ in range(num_layers)])
        self.feed_forwards = nn.ModuleList([nn.Linear(out_channels * num_heads, out_channels * num_heads) for _ in range(num_layers)])
        
#========IMPORTANT====================================================
    def attention_diffusion(self, A, H):
        # Approximating attention diffusion with recursive updates
        Z = H.clone()
        for _ in range(6):  # Assuming a fixed hop count K = 6
            Z = (1 - self.alpha) * torch.matmul(A, Z) + self.alpha * H
        return Z
#========IMPORTANT====================================================

    def forward(self, x, edge_index):
        for i, layer in enumerate(self.layers):
            x = layer(x, edge_index)
            A = F.softmax(torch.matmul(x, x.T), dim=-1)  # Compute the attention matrix
            x = self.attention_diffusion(A, x)
            x = self.norms[i](x + x)  # Residual connection
            x = F.relu(self.feed_forwards[i](x)) + x  # Feed forward with residual connection
        
        return x
