In [15]:
import pandas as pd 
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv  # You can also use GATConv, GraphSAGE, etc.
from torch_geometric.data import Data, DataLoader



In [26]:
descriptors = pd.read_csv('example/descriptors.csv')

descriptors.head()

node_labels = descriptors.edge_source.unique()
node_labels_mapping = {label: i for i, label in enumerate(node_labels)}

descriptors['edge_source'] = descriptors['edge_source'].map(node_labels_mapping)
descriptors['edge_dest'] = descriptors['edge_dest'].map(node_labels_mapping)

to_drop = [
    "function_id",
    "graph_id",
    "edge_source",
    "edge_dest",
    "is_causal",
]

features = descriptors.columns.difference(to_drop)

features = ['coeff_cause', 'kurtosis_ef', 'kurtosis_ca', 'HOC_1_3', 'coeff_eff',
       'HOC_3_1', 'eff_m_cau_q0', 'eff_m_cau_q1', 'eff_m_cau_q5',
       'skewness_ef', 'eff_m_cau_q6', 'm_eff_q5', 'eff_m_cau_q3',
       'skewness_ca', 'eff_m_cau_q2', 'eff_m_cau_q4', 'HOC_2_1', 'm_eff_q6',
       'HOC_1_2', 'm_eff_q0']

  descriptors = pd.read_csv('example/descriptors.csv')


In [27]:
graph_data_list = []
for graph_id in descriptors.graph_id.unique(): 
    single_graph = descriptors.loc[descriptors.graph_id == graph_id]

    stacked_edges = np.vstack([single_graph.edge_source.values,
         single_graph.edge_dest.values])
    
    data = Data(
        x=torch.ones((len(single_graph.edge_source.unique()), 1)),
        edge_index=torch.tensor(stacked_edges, dtype=torch.long),
        edge_attr=torch.tensor(single_graph[features].values, dtype=torch.float),
        y=torch.tensor(single_graph.is_causal.values, dtype=torch.float).unsqueeze(1)
    )
    graph_data_list.append(data)

In [28]:
num_nodes = len(node_labels)
node_features = 1
edge_features = len(features)
num_edges = single_graph.shape[0]

In [29]:
from torch_geometric.data import Dataset

class CustomGraphDataset(Dataset):
    def __init__(self, data_list):
        super(CustomGraphDataset, self).__init__()
        self.data_list = data_list

    def len(self):
        return len(self.data_list)

    def get(self, idx):
        return self.data_list[idx]

# Create the dataset from your list of Data objects
dataset = CustomGraphDataset(graph_data_list)


In [34]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

class EdgeClassifier(torch.nn.Module):
    def __init__(self, node_feature_dim, edge_feature_dim, hidden_dim):
        super(EdgeClassifier, self).__init__()
        self.conv1 = GCNConv(node_feature_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.edge_mlp = torch.nn.Sequential(
            torch.nn.Linear(2 * hidden_dim + edge_feature_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, 1)  # Assuming binary classification
        )

    def forward(self, x, edge_index, batch, edge_attr):
        # Node feature learning
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)

        # Edge feature concatenation (node features of source and target + original edge features)
        row, col = edge_index
        edge_feat = torch.cat([x[row], x[col], edge_attr], dim=1)

        # Use batch information if necessary, for example to aggregate graph-level features
        x =  global_mean_pool(x, batch)  # Uncomment if node-level to graph-level is needed

        # Edge classification
        return torch.sigmoid(self.edge_mlp(edge_feat))


In [35]:
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split

batch_size = 16

# Split the dataset
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [36]:
# Assume some dimensions for the example
hidden_dim = 16

model = EdgeClassifier(node_features, edge_features, hidden_dim)


In [37]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCELoss()

def train_epoch(train_loader):
    model.train()
    total_loss = 0
    for data in train_loader:
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch, data.edge_attr)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(train_loader.dataset)

def validate(val_loader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for data in val_loader:
            out = model(data.x, data.edge_index, data.batch, data.edge_attr)
            loss = criterion(out, data.y)
            total_loss += loss.item() * data.num_graphs
    return total_loss / len(val_loader.dataset)

def test():
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for data in test_loader:
            out = model(data.x, data.edge_index, data.batch, data.edge_attr)
            loss = criterion(out, data.y)
            total_loss += loss.item() * data.num_graphs
    return total_loss / len(test_loader.dataset)


num_epochs = 200
batch_size = 16
for epoch in range(num_epochs):
    # Reshuffle and split the dataset each epoch
    train_size = int(0.85 * len(dataset))
    train_dataset, val_dataset = random_split(dataset, [train_size, len(dataset) - train_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    train_loss = train_epoch(train_loader)
    val_loss = validate(val_loader)

    print(f'Epoch {epoch+1}: Train Loss: {train_loss}, Validation Loss: {val_loss}')

Epoch 1: Train Loss: 1.9089626974218032, Validation Loss: 1.6971874175752912
Epoch 2: Train Loss: 1.8008706800476844, Validation Loss: 1.740078119096302
Epoch 3: Train Loss: 1.9552711681959007, Validation Loss: 1.5376230993725004
Epoch 4: Train Loss: 1.679688756125314, Validation Loss: 1.5689415339061192
Epoch 5: Train Loss: 1.676900782705355, Validation Loss: 1.5695895108722504
Epoch 6: Train Loss: 1.7014341125167718, Validation Loss: 1.7155283298946562
Epoch 7: Train Loss: 1.7577015348642815, Validation Loss: 1.6734077803293863
Epoch 8: Train Loss: 1.6856637951105582, Validation Loss: 1.512958423183078
Epoch 9: Train Loss: 1.739669939890629, Validation Loss: 1.5912588914235433
Epoch 10: Train Loss: 1.6420836867805289, Validation Loss: 1.6254246125902447
Epoch 11: Train Loss: 1.7560079332159346, Validation Loss: 1.707205636614845
Epoch 12: Train Loss: 1.8575716170543382, Validation Loss: 1.6877021076565697
Epoch 13: Train Loss: 1.6730247097255804, Validation Loss: 1.5403233687082927
E

KeyboardInterrupt: 