In [None]:
# models/GAT.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool

class GATClassifier(nn.Module):
    def __init__(self, n_class, n_features=768, hidden_dim=64, heads=4, dropout=0.1):
        super(GATClassifier, self).__init__()
        
        self.gat1 = GATConv(n_features, hidden_dim, heads=heads, dropout=dropout)
        self.gat2 = GATConv(hidden_dim * heads, hidden_dim, heads=heads, dropout=dropout)
        self.gat3 = GATConv(hidden_dim * heads, hidden_dim, heads=1, concat=False, dropout=dropout)
        
        self.classifier = nn.Linear(hidden_dim, n_class)
        self.criterion = nn.CrossEntropyLoss()
        
    def forward(self, node_feat, labels, adj, mask, return_attention=False):
        # Convert dense adj to edge_index
        edge_index = adj.nonzero(as_tuple=False).t()
        
        x = node_feat.squeeze(0)
        
        x, attn1 = self.gat1(x, edge_index, return_attention_weights=True)
        x = F.elu(x)
        
        x, attn2 = self.gat2(x, edge_index, return_attention_weights=True)
        x = F.elu(x)
        
        x, attn3 = self.gat3(x, edge_index, return_attention_weights=True)
        
        # Global pooling
        out = x.mean(dim=0, keepdim=True)
        out = self.classifier(out)
        
        loss = self.criterion(out, labels)
        pred = out.argmax(dim=1)
        
        if return_attention:
            return pred, labels, loss, (attn1, attn2, attn3)
        
        return pred, labels, loss