In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import dropout_edge

class AttentionResidualGNNLayer(MessagePassing):
    def __init__(self, in_ch, out_ch, dropout=0.3, heads=4, dropedge_rate=0.1):
        super().__init__(aggr='add', node_dim=0)
        self.heads = heads
        self.head_dim = out_ch // heads
        self.linear_q = nn.Linear(in_ch, out_ch)
        self.linear_k = nn.Linear(in_ch, out_ch)
        self.linear_v = nn.Linear(in_ch, out_ch)
        self.residual = nn.Linear(in_ch, out_ch) if in_ch != out_ch else nn.Identity()
        self.norm = nn.LayerNorm(out_ch)
        self.act = nn.LeakyReLU(0.2)
        self.drop = nn.Dropout(dropout)
        self.dropedge_rate = dropedge_rate

    def forward(self, x, edge_index):
        residual = self.residual(x)
        edge_index, _ = dropout_edge(edge_index, p=self.dropedge_rate, force_undirected=True, training=self.training)
        Q = self.linear_q(x).view(-1, self.heads, self.head_dim)
        K = self.linear_k(x).view(-1, self.heads, self.head_dim)
        V = self.linear_v(x).view(-1, self.heads, self.head_dim)
        x_attn = self.propagate(edge_index, Q=Q, K=K, V=V)
        x_attn = x_attn.view(-1, self.heads * self.head_dim)
        x = x_attn + residual
        x = self.norm(x)
        x = self.act(x)
        x = self.drop(x)
        return x

    def message(self, Q_i, K_j, V_j):
        alpha = (Q_i * K_j).sum(dim=-1) / math.sqrt(self.head_dim)
        alpha = F.softmax(alpha, dim=1)
        return (alpha.unsqueeze(-1) * V_j).view(-1, self.heads, self.head_dim)

class EnhancedGNNWithAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, output_dim=2, dropout=0.3):
        super().__init__()
        self.input_proj = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout)
        )
        self.gnn_layers = nn.ModuleList([
            AttentionResidualGNNLayer(hidden_dim, hidden_dim, dropout, heads=4, dropedge_rate=0.1),
            AttentionResidualGNNLayer(hidden_dim, hidden_dim//2, dropout, heads=4, dropedge_rate=0.1),
            AttentionResidualGNNLayer(hidden_dim//2, hidden_dim//4, dropout, heads=2, dropedge_rate=0.1),
        ])
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_dim//4, hidden_dim//8),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim//8, output_dim)
        )

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.input_proj(x)
        for layer in self.gnn_layers:
            x = layer(x, edge_index)
        return F.log_softmax(self.output_layer(x), dim=1)
