In [1]:
from graph_encode import base_graph_edges,move_to_index,index_to_move,adjacency_list, node_feature_matrix,edge_feature_matrix 

created 64 nodes
created 1792 edge
current state (FEN): r1bqkbnr/pppp1ppp/2n5/1B2p3/4P3/5N2/PPPP1PPP/RNBQK2R b KQkq - 3 3

encode result:
node matrix shape: (64, 21)
edge matrix shape: (1792, 11)


In [None]:
import numpy as np
import torch

Wu = torch.randn(21, 3)
We = torch.randn(11, 3)
Wv = torch.randn(21, 3)
W0 = torch.randn(21, 21)
Wh = torch.randn(21, 21)
Wg = torch.randn(11, 21)
a = torch.randn(3)

node_feature_matrix = torch.tensor(node_feature_matrix)
edge_feature_matrix = torch.tensor(edge_feature_matrix)

import torch
import torch.nn as nn
import torch.nn.functional as F
leaky_relu = nn.LeakyReLU(negative_slope=0.2)


hew_i = None
for i, neighborhood in enumerate(adjacency_list):
    gilist = []
    value_list = []
    for j, (node_index, edge_index) in enumerate(neighborhood):
        gijnew = (
            node_feature_matrix[i] @ Wu +
            edge_feature_matrix[edge_index] @ We +
            node_feature_matrix[node_index] @ Wv
        )

        gilist.append(leaky_relu(a @ gijnew))
        value_list.append(node_feature_matrix[node_index] @ Wh + edge_feature_matrix[edge_index] @ Wg)
    gilist = torch.tensor(gilist)
    value_list = torch.stack(value_list)
    hew_i = node_feature_matrix[i] @ W0 + F.softmax(gilist, dim=0) @ value_list

    print(hew_i)

    break


  node_feature_matrix = torch.tensor(node_feature_matrix)
  edge_feature_matrix = torch.tensor(edge_feature_matrix)


In [30]:

source_nodes = []
target_nodes = []
edge_feature_indices = []

for i, neighborhood in enumerate(adjacency_list):
    for neighbor_node, edge_index_val in neighborhood:
        target_nodes.append(i)
        source_nodes.append(neighbor_node)
        edge_feature_indices.append(edge_index_val)

edge_index = torch.tensor([target_nodes, source_nodes], dtype=torch.long)
edge_map = torch.tensor(edge_feature_indices, dtype=torch.long)
num_nodes = len(node_feature_matrix)



In [28]:


class GATEAULayer(nn.Module):

    def __init__(self, node_in_features, edge_in_features, node_out_features, attention_dim):

        super(GATEAULayer, self).__init__()
        self.node_in_features = node_in_features
        self.edge_in_features = edge_in_features
        self.node_out_features = node_out_features
        self.attention_dim = attention_dim

  
        self.Wv = nn.Parameter(torch.randn(node_in_features, attention_dim))
        self.Wu = nn.Parameter(torch.randn(node_in_features, attention_dim))
        self.We = nn.Parameter(torch.randn(edge_in_features, attention_dim))
        
        self.Wh = nn.Parameter(torch.randn(node_in_features, node_out_features))
        self.Wg = nn.Parameter(torch.randn(edge_in_features, node_out_features))
        self.W0 = nn.Parameter(torch.randn(node_in_features, node_out_features))

        self.a = nn.Parameter(torch.randn(attention_dim))

        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)
        
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.Wv)
        nn.init.xavier_uniform_(self.Wu)
        nn.init.xavier_uniform_(self.We)
        nn.init.xavier_uniform_(self.Wh)
        nn.init.xavier_uniform_(self.Wg)
        nn.init.xavier_uniform_(self.W0)
        nn.init.zeros_(self.a)


    def forward(self, node_feature_matrix, edge_feature_matrix, edge_index, edge_map):

        num_nodes = node_feature_matrix.shape[0]
        target_node_idx, source_node_idx = edge_index[0], edge_index[1]

        h_nodes_v = node_feature_matrix @ self.Wv
        h_nodes_u = node_feature_matrix @ self.Wu
        h_nodes_0 = node_feature_matrix @ self.W0
        h_nodes_h = node_feature_matrix @ self.Wh
        h_edges_e = edge_feature_matrix @ self.We
        h_edges_g = edge_feature_matrix @ self.Wg
        

        target_node_feats_for_attention = h_nodes_u[target_node_idx]
        source_node_feats_for_attention = h_nodes_v[source_node_idx]
        edge_feats_for_attention = h_edges_e[edge_map]

        attention_scores_raw = target_node_feats_for_attention + source_node_feats_for_attention + edge_feats_for_attention
        attention_scores = self.leaky_relu(attention_scores_raw @ self.a) 
        max_scores = torch.full((num_nodes,), -1e9, device=attention_scores.device, dtype=attention_scores.dtype)
        max_scores.scatter_reduce_(0, target_node_idx, attention_scores, reduce="amax", include_self=False)
        
        scores_max_per_edge = max_scores[target_node_idx]
        attention_scores_exp = torch.exp(attention_scores - scores_max_per_edge)

        sum_exp_scores = torch.zeros(num_nodes, device=attention_scores.device, dtype=attention_scores.dtype)
        sum_exp_scores.index_add_(0, target_node_idx, attention_scores_exp)
        
        sum_exp_per_edge = sum_exp_scores[target_node_idx]

        alpha = attention_scores_exp / (sum_exp_per_edge + 1e-10)


        source_node_values = h_nodes_h[source_node_idx]
        edge_values = h_edges_g[edge_map]
        values = source_node_values + edge_values
        
        weighted_values = values * alpha.unsqueeze(-1)

        aggregated_messages = torch.zeros_like(h_nodes_0)
        aggregated_messages.index_add_(0, target_node_idx, weighted_values)


        new_final = h_nodes_0 + aggregated_messages

        return new_final, attention_scores_raw

In [None]:
class BNR(nn.Module):
    def __init__(self, num_features):
        super(BNR, self).__init__()
        self.norm = nn.BatchNorm1d(num_features)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.norm(x)
        x = self.relu(x)
        return x

In [None]:
class ResGATEAU(nn.Module):

    def __init__(self, node_in_features, edge_in_features, node_out_features, attention_dim):
        super(ResGATEAU, self).__init__()

        self.bnr1 = BNR(node_in_features)
        self.gateau1 = GATEAULayer(node_in_features, edge_in_features, node_out_features, attention_dim)
        

        self.bnr2 = BNR(node_out_features)
        self.gateau2 = GATEAULayer(node_out_features, edge_in_features, node_out_features, attention_dim)
        
        if node_in_features != node_out_features:
            self.residual_transform = nn.Linear(node_in_features, node_out_features)
        else:
            self.residual_transform = nn.Identity()

    def forward(self, node_feature_matrix, edge_feature_matrix, edge_index, edge_map):
        residual = self.residual_transform(node_feature_matrix)


        x = self.bnr1(node_feature_matrix)
        

        x, e = self.gateau1(x, edge_feature_matrix, edge_index, edge_map)
        
        x = self.bnr2(x)

        x, e = self.gateau2(x, e, edge_index, edge_map)


        output_node_features = residual + x
        
        return output_node_features, e

In [35]:
gat_layer = GATEAULayer(
    node_in_features=21,
    edge_in_features=11,
    node_out_features=21,
    attention_dim=3
)


with torch.no_grad():
    gat_layer.Wu.copy_(Wu)
    gat_layer.We.copy_(We)
    gat_layer.Wv.copy_(Wv)
    gat_layer.W0.copy_(W0)
    gat_layer.Wh.copy_(Wh)
    gat_layer.Wg.copy_(Wg)
    gat_layer.a.copy_(a)

hew_module_output,_ = gat_layer(
    node_feature_matrix=node_feature_matrix,
    edge_feature_matrix=edge_feature_matrix,
    edge_index=edge_index,
    edge_map=edge_map
)



are_close = torch.allclose(hew_i, hew_module_output[0])

print(f"original output shape: {hew_i.shape}")
print(f"nn.Module output shape: {hew_module_output[0].shape}")
print(f"is output same? -> {are_close}")



original output shape: torch.Size([21])
nn.Module output shape: torch.Size([21])
is output same? -> True
