In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.data import Data
from sklearn.metrics import accuracy_score

# Define a simple Graph Neural Network (GNN) model for edge direction prediction
class EdgeDirectionPredictionModel(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super(EdgeDirectionPredictionModel, self).__init__()
        
        # GraphSAGE layers
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        
        # Output layer for binary classification
        self.fc = nn.Linear(hidden_channels * 2, 1)  # Adjust output size for edge direction prediction

    def forward(self, x, edge_index):
        # Apply GraphSAGE layers
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        
        # Pairwise combination of node embeddings for edge prediction
        edge_pairs = torch.cat((x[edge_index[0]], x[edge_index[1]]), dim=1)
        
        # Fully connected layer for binary classification (edge direction)
        x = self.fc(edge_pairs)
        
        return x

# Create a synthetic directed graph
edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 0]], dtype=torch.long)  # Directed edges
x = torch.randn(4, 16)  # Node features

# Create labels for edge directions (1 for A to B, 0 for B to A)
labels = torch.tensor([1, 1, 0, 0], dtype=torch.float)

# Create a PyTorch Geometric Data object
data = Data(x=x, edge_index=edge_index, y=labels)

# Define the GNN model
model = EdgeDirectionPredictionModel(in_channels=16, hidden_channels=32)

# Define a binary classification loss function (BCEWithLogitsLoss)
criterion = nn.BCEWithLogitsLoss()

# Define an optimizer (e.g., Adam)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Training loop
for epoch in range(100):
    optimizer.zero_grad()
    
    # Predict edge directions
    logits = model(data.x, data.edge_index)
    print(logits)
    # Calculate the binary classification loss
    loss = criterion(logits.squeeze(), data.y)
    print(data.y)
    # Backpropagation
    loss.backward()
    optimizer.step()
    
    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}')

# Evaluate the model (e.g., on a validation set)
# You can use metrics like accuracy, precision, recall, etc.

# Predict edge directions for the same graph (you can replace this with your own test data)
test_logits = model(data.x, data.edge_index)
test_pred_labels = (test_logits.squeeze() > 0).float()

# Calculate accuracy on the test data
test_accuracy = accuracy_score(data.y.cpu().numpy(), test_pred_labels.cpu().numpy())
print(f'Test Accuracy: {test_accuracy * 100:.2f}%')


tensor([[-0.0917],
        [-0.6163],
        [-0.0307],
        [-0.1698]], grad_fn=<AddmmBackward0>)
tensor([1., 1., 0., 0.])
Epoch 0, Loss: 0.769457995891571
tensor([[ 0.3394],
        [ 0.2948],
        [-0.6116],
        [-0.3233]], grad_fn=<AddmmBackward0>)
tensor([1., 1., 0., 0.])
tensor([[ 0.8175],
        [ 1.1313],
        [-1.1675],
        [-0.5969]], grad_fn=<AddmmBackward0>)
tensor([1., 1., 0., 0.])
tensor([[ 1.4031],
        [ 2.0328],
        [-1.8774],
        [-1.1213]], grad_fn=<AddmmBackward0>)
tensor([1., 1., 0., 0.])
tensor([[ 2.1752],
        [ 3.0725],
        [-2.7895],
        [-1.9682]], grad_fn=<AddmmBackward0>)
tensor([1., 1., 0., 0.])
tensor([[ 3.1572],
        [ 4.2619],
        [-3.9060],
        [-3.1141]], grad_fn=<AddmmBackward0>)
tensor([1., 1., 0., 0.])
tensor([[ 4.3012],
        [ 5.5596],
        [-5.1542],
        [-4.4263]], grad_fn=<AddmmBackward0>)
tensor([1., 1., 0., 0.])
tensor([[ 5.5489],
        [ 6.9038],
        [-6.4795],
        [-5.84

In [2]:
x

tensor([[-0.3780, -0.0895, -0.7423, -3.1009, -0.0369,  0.9758, -0.5607,  0.5412,
         -0.0240,  1.4392,  0.4660, -1.3272, -1.5561, -1.9406,  0.8731,  0.5163],
        [-0.8288,  1.5572, -0.1635,  0.4426,  0.6115, -0.7658,  0.3615, -0.0908,
          1.0235, -1.7073,  0.2634, -0.1714,  1.0694,  1.2531,  0.7995, -1.8892],
        [-0.3652,  1.7381,  0.5900,  0.9933, -0.2686, -1.1239, -0.5722, -0.2785,
         -0.9907, -0.1474, -0.1222, -0.5758, -1.2241, -0.4331,  1.2500,  0.6371],
        [-0.6630,  1.8745,  1.6405,  0.3374,  0.4381, -0.8106,  0.1468, -1.0950,
         -0.0608, -0.2938, -0.0285, -0.8962, -0.8278, -1.0560,  0.6829,  0.9871]])

In [4]:
edge_index[0]

tensor([0, 1, 2, 3])

In [10]:
x[edge_index[0]]

tensor([[-0.3780, -0.0895, -0.7423, -3.1009, -0.0369,  0.9758, -0.5607,  0.5412,
         -0.0240,  1.4392,  0.4660, -1.3272, -1.5561, -1.9406,  0.8731,  0.5163],
        [-0.8288,  1.5572, -0.1635,  0.4426,  0.6115, -0.7658,  0.3615, -0.0908,
          1.0235, -1.7073,  0.2634, -0.1714,  1.0694,  1.2531,  0.7995, -1.8892],
        [-0.3652,  1.7381,  0.5900,  0.9933, -0.2686, -1.1239, -0.5722, -0.2785,
         -0.9907, -0.1474, -0.1222, -0.5758, -1.2241, -0.4331,  1.2500,  0.6371],
        [-0.6630,  1.8745,  1.6405,  0.3374,  0.4381, -0.8106,  0.1468, -1.0950,
         -0.0608, -0.2938, -0.0285, -0.8962, -0.8278, -1.0560,  0.6829,  0.9871]])

In [11]:
edge_index[1]

tensor([1, 2, 3, 0])

In [9]:
x[edge_index[1]]

tensor([[-0.8288,  1.5572, -0.1635,  0.4426,  0.6115, -0.7658,  0.3615, -0.0908,
          1.0235, -1.7073,  0.2634, -0.1714,  1.0694,  1.2531,  0.7995, -1.8892],
        [-0.3652,  1.7381,  0.5900,  0.9933, -0.2686, -1.1239, -0.5722, -0.2785,
         -0.9907, -0.1474, -0.1222, -0.5758, -1.2241, -0.4331,  1.2500,  0.6371],
        [-0.6630,  1.8745,  1.6405,  0.3374,  0.4381, -0.8106,  0.1468, -1.0950,
         -0.0608, -0.2938, -0.0285, -0.8962, -0.8278, -1.0560,  0.6829,  0.9871],
        [-0.3780, -0.0895, -0.7423, -3.1009, -0.0369,  0.9758, -0.5607,  0.5412,
         -0.0240,  1.4392,  0.4660, -1.3272, -1.5561, -1.9406,  0.8731,  0.5163]])

In [25]:
single_edge_index = torch.tensor([[0], [1]], dtype=torch.long)  # Edge from 0 to 1
#single_edge_index = torch.tensor([[0, 1],[1, 0]], dtype=torch.long)  # Edge from 0 to 1
single_edge_logits = model(data.x, single_edge_index)

test_logits = model(data.x, single_edge_index)
test_pred_labels = (test_logits.squeeze() > 0).float()

print(edge_index)
print(single_edge_index)
print(test_logits)
# Calculate accuracy on the test data
#test_accuracy = accuracy_score(data.y.cpu().numpy(), test_pred_labels)
#print(f'Test Accuracy: {test_accuracy * 100:.2f}%')

tensor([[0, 1, 2, 3],
        [1, 2, 3, 0]])
tensor([[0],
        [1]])
tensor([[15.6966]], grad_fn=<AddmmBackward0>)
