In [12]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data


In [13]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(2, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, 2) # 2 classes (causal, non-causal)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x


In [14]:
# Example data (difference of mean and difference of std for edges)
edge_features = torch.tensor([
    [0.2, 0.1], # edge 0 -> 1
    [-0.3, 0.2], # edge 1 -> 2
    [0.1, -0.2], # edge 2 -> 0
], dtype=torch.float)

# Edge index (source -> target)
edge_index = torch.tensor([
    [0, 1, 2],
    [1, 2, 0],
], dtype=torch.long)

# Define Graph
data = Data(x=edge_features, edge_index=edge_index)

# Labels for the edges (1 = causal, 0 = non-causal)
edge_labels = torch.tensor([1, 0, 1], dtype=torch.long)


In [15]:
# Define model and optimizer
model = GNN(hidden_channels=16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Training loop
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.edge_index[0]], edge_labels)
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss {loss.item()}')


Epoch 0, Loss 0.6931743025779724
Epoch 10, Loss 0.6460168361663818
Epoch 20, Loss 0.6297574043273926
Epoch 30, Loss 0.5344840884208679
Epoch 40, Loss 0.509902834892273
Epoch 50, Loss 0.4772721827030182
Epoch 60, Loss 0.5411622524261475
Epoch 70, Loss 0.5012961030006409
Epoch 80, Loss 0.578216016292572
Epoch 90, Loss 0.5210465788841248
Epoch 100, Loss 0.3169881999492645
Epoch 110, Loss 0.4279502332210541
Epoch 120, Loss 0.5111482739448547
Epoch 130, Loss 0.345454603433609
Epoch 140, Loss 0.2688114047050476
Epoch 150, Loss 0.5245887637138367
Epoch 160, Loss 0.2210673838853836
Epoch 170, Loss 0.30617836117744446
Epoch 180, Loss 0.3048185408115387
Epoch 190, Loss 0.4811214208602905


In [16]:
# New edge features for evaluation
eval_edge_features = torch.tensor([
    [0.3, 0.1],  # edge 0 -> 1
    [-0.4, 0.2], # edge 1 -> 2
    [0.15, -0.2],# edge 2 -> 0
], dtype=torch.float)

# Edge index for evaluation (same as training, but can be different)
eval_edge_index = torch.tensor([
    [0, 1, 2],
    [1, 2, 0],
], dtype=torch.long)

# Define Graph for Evaluation
eval_data = Data(x=eval_edge_features, edge_index=eval_edge_index)


In [17]:
# Set the model to evaluation mode
model.eval()

# Make predictions using the evaluation data
with torch.no_grad():
    predictions = model(eval_data.x, eval_data.edge_index)

# Convert the predictions to probabilities using the softmax function
probs = F.softmax(predictions, dim=1)

# Get the class with the highest probability
predicted_labels = torch.argmax(probs, dim=1)

# Print the Predicted Labels
print(predicted_labels)


tensor([1, 0, 1])
