In [20]:
import torch
from torch_geometric.nn import GATConv
from torch_geometric.utils import to_undirected
from torch_geometric.utils.hetero import group_hetero_graph
from torch_geometric.datasets import OGB_MAG
#from torch_geometric.utils.hetero import to_hetero
from torch_geometric.utils import degree

In [3]:
def compute_saliency_map(model, data, label):
    model.eval()
    out = model(data.x_dict, data.edge_index_dict)
    loss = torch.nn.functional.mse_loss(out[label], data.y_dict[label])
    loss.backward(retain_graph=True)
    saliency_map = torch.abs(data.x_dict[label].grad)  # Compute absolute gradient of input features
    return saliency_map

In [22]:
# Example usage:
import torch_geometric.datasets as datasets
from torch_geometric.data import DataLoader
from torch_geometric.nn import GATConv

# Load AMiner dataset
dataset2 = OGB_MAG(root='./data')

Using existing file mag.zip
Extracting data/mag/raw/mag.zip


BadZipFile: File is not a zip file

In [5]:
# Define GAT model
class GAT(torch.nn.Module):
    def __init__(self, in_channels, out_channels, num_heads):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels, out_channels, heads=num_heads)
        self.conv2 = GATConv(out_channels*num_heads, dataset.num_classes, heads=1)
        
    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = torch.nn.functional.elu(x_dict)
        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict

In [19]:
# Convert dataset to heterogenous graph
hetero_data = dataset[0]
hetero_data

HeteroData(
  [1mauthor[0m={
    y=[246678],
    y_index=[246678],
    num_nodes=1693531
  },
  [1mvenue[0m={
    y=[134],
    y_index=[134],
    num_nodes=3883
  },
  [1mpaper[0m={ num_nodes=3194405 },
  [1m(paper, written_by, author)[0m={ edge_index=[2, 9323605] },
  [1m(author, writes, paper)[0m={ edge_index=[2, 9323605] },
  [1m(paper, published_in, venue)[0m={ edge_index=[2, 3194405] },
  [1m(venue, publishes, paper)[0m={ edge_index=[2, 3194405] }
)

In [None]:

# Define GAT model and optimizer
model = GAT(hetero_data.num_node_features, 16, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Train the model for a few epochs
model.train()
for epoch in range(10):
    optimizer.zero_grad()
    out = model(grouped_data.x_dict, grouped_data.edge_index_dict)
    loss = torch.nn.functional.mse_loss(out['author'], grouped_data.y_dict['author'])
    loss.backward()
    optimizer.step()

# Compute saliency map for author label
label = 'author'
saliency_map = compute_saliency_map(model, grouped_data, label)

# Compute degree-normalized saliency scores
degree = degree(grouped_data.edge_index_dict[label][0])
degree_saliency = saliency_map / degree.unsqueeze(1)

# Print the top 5 most salient nodes
_, indices = degree_saliency.sort(descending=True)
top_nodes = indices[:5]
print("Top 5 most salient nodes:", top_nodes)