<a href="https://colab.research.google.com/github/omkar-salunke/The-Elements-of-Statistical-Learning-Python-Notebooks/blob/master/Graph_data_explore.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch_geometric.data import Data

# 1. Define the Graph Data
# Let's create a graph with 4 nodes and some edges (undirected)
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3],
                           [1, 0, 2, 1, 3, 2]], dtype=torch.long)

# Define initial node features (e.g., 3 features per node)
x = torch.tensor([[1.0, 0.0, 0.0],  # Node 0
                  [0.0, 1.0, 0.0],  # Node 1
                  [0.0, 0.0, 1.0],  # Node 2
                  [0.5, 0.5, 0.0]],  # Node 3
                 dtype=torch.float)

# Create a PyG Data object
data = Data(x=x, edge_index=edge_index)
print("Graph Data:")
print(data)
print("\nInitial Node Features (data.x):\n", data.x)
print("\nEdge Indices (data.edge_index):\n", data.edge_index)

# 2. Define a Simple GAT Model
class SimpleGAT(torch.nn.Module):
    def __init__(self, in_channels, out_channels, heads=1):
        super(SimpleGAT, self).__init__()
        self.conv1 = GATConv(in_channels, out_channels, heads=heads)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)  # Apply ReLU activation
        return x

# 3. Instantiate the Model
# Input feature dimension = 3 (from our 'x' tensor)
# Output feature dimension = 2 (we can choose this)
# Number of attention heads = 2 (we can also choose this)
input_dim = data.num_node_features
output_dim = 2
num_heads = 2
model = SimpleGAT(in_channels=input_dim, out_channels=output_dim, heads=num_heads)
print("\nGAT Model:\n", model)

# 4. Perform a Forward Pass
with torch.no_grad():  # We don't need gradients for a simple forward pass
    output = model(data)

print("\nOutput Node Embeddings after one GAT layer:\n", output)
print("\nShape of Output Embeddings:", output.shape)