In [1]:
import mlflow.pytorch
import pandas as pd
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from torch.nn import Linear
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATv2Conv
from torch_geometric.nn import GINConv
from torch_geometric.nn import GINEConv
from torch_geometric.nn import global_mean_pool
from tqdm import tqdm

from graph_reinforcement_learning_using_blockchain_data import config
from graph_reinforcement_learning_using_blockchain_data.modeling import gnn

config.load_dotenv()

[32m2025-03-14 15:56:24.336[0m | [1mINFO    [0m | [36mgraph_reinforcement_learning_using_blockchain_data.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /Users/liamtessendorf/Programming/Uni/2_Master/4_FS25_Programming/graph-reinforcement-learning-using-blockchain-data[0m


True

In [2]:
dataset = torch.load(config.FLASHBOTS_Q2_DATA_DIR / "trx_graphs.pt", weights_only=False)
train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=42)

In [13]:
train_loader = DataLoader(
    train_dataset, batch_size=256, shuffle=True, exclude_keys=["account_mapping"], drop_last=False
)
test_loader = DataLoader(
    test_dataset, batch_size=128, shuffle=False, exclude_keys=["account_mapping"], drop_last=False
)
data_loader = DataLoader(
    dataset, batch_size=256, shuffle=False, exclude_keys=["account_mapping"], drop_last=False
)

In [9]:
for step, data in enumerate(test_loader):
    print(f"Step {step + 1}:")
    print("=======")
    print(f"Number of graphs in the current batch: {data.num_graphs}")
    print(data.edge_index.shape)
    print()

Step 1:
Number of graphs in the current batch: 1
torch.Size([0])

Step 2:
Number of graphs in the current batch: 1
torch.Size([2, 2])

Step 3:
Number of graphs in the current batch: 1
torch.Size([2, 4])

Step 4:
Number of graphs in the current batch: 1
torch.Size([2, 1])

Step 5:
Number of graphs in the current batch: 1
torch.Size([2, 3])

Step 6:
Number of graphs in the current batch: 1
torch.Size([2, 1])

Step 7:
Number of graphs in the current batch: 1
torch.Size([2, 4])

Step 8:
Number of graphs in the current batch: 1
torch.Size([2, 5])

Step 9:
Number of graphs in the current batch: 1
torch.Size([2, 3])

Step 10:
Number of graphs in the current batch: 1
torch.Size([2, 2])

Step 11:
Number of graphs in the current batch: 1
torch.Size([2, 1])

Step 12:
Number of graphs in the current batch: 1
torch.Size([2, 1])

Step 13:
Number of graphs in the current batch: 1
torch.Size([2, 1])

Step 14:
Number of graphs in the current batch: 1
torch.Size([0])

Step 15:
Number of graphs in the cu

## Training a Graph Neural Network (GNN)

Training a GNN for graph classification usually follows a simple recipe:

1. Embed each node by performing multiple rounds of message passing
2. Aggregate node embeddings into a unified graph embedding (**readout layer**)
3. Train a final classifier on the graph embedding

There exists multiple **readout layers** in literature, but the most common one is to simply take the average of node embeddings:

$$
\mathbf{x}_{\mathcal{G}} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \mathcal{x}^{(L)}_v
$$

PyTorch Geometric provides this functionality via [`torch_geometric.nn.global_mean_pool`](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.global_mean_pool), which takes in the node embeddings of all nodes in the mini-batch and the assignment vector `batch` to compute a graph embedding of size `[batch_size, hidden_channels]` for each graph in the batch.

The final architecture for applying GNNs to the task of graph classification then looks as follows and allows for complete end-to-end training:

In [4]:
# def train(model, loader, optimizer, criterion, device):
#     model.train()
#     total_loss = 0
#     for data in loader:
#         # data.x = torch.cat([data.x, data.global_features[data.batch].unsqueeze(1)], dim=-1)
#         data = data.to(device)
#         out = model(data)
#         loss = criterion(out, data.y)
#         loss.backward()
#         optimizer.step()
#         optimizer.zero_grad()
#         total_loss += loss.item() * data.num_graphs
#
#     return total_loss / len(loader.dataset)
#
#
# def test(model, loader, criterion, device, return_embeddings=False):
#     model.eval()
#     total_loss = 0
#     correct = 0
#     total = 0
#
#     embeddings = {}
#
#     with torch.no_grad():
#         for data in loader:
#             # data.x = torch.cat([data.x, data.global_features[data.batch].unsqueeze(1)], dim=-1)
#             data = data.to(device)
#             if return_embeddings:
#                 out, emb = model(data, return_embeddings)
#                 mapping = {trx_id: emb for trx_id, emb in zip(data.trx_id, emb)}
#                 embeddings.update(mapping)
#             else:
#                 out = model(data)
#             loss = criterion(out, data.y)
#             total_loss += loss.item() * data.num_graphs
#             pred = out.argmax(dim=1)
#             correct += (pred == data.y).sum().item()
#             total += data.num_graphs
#     return total_loss / len(loader.dataset), correct / total, embeddings

In [5]:
num_node_features = 1
dim_global_features = 0
hidden_channels = 256  # adjust as needed
num_classes = 2  # binary classification

In [6]:
model_GNNSAGE = gnn.GraphSAGE(
    num_node_features + dim_global_features, hidden_channels, num_classes
)
print(model_GNNSAGE)

GraphSAGE(
  (conv1): SAGEConv(1, 256, aggr=mean)
  (conv2): SAGEConv(256, 256, aggr=mean)
  (conv3): SAGEConv(256, 256, aggr=mean)
  (conv4): SAGEConv(256, 256, aggr=mean)
  (fc1): Linear(in_features=256, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=2, bias=True)
)


In [13]:
optimizer = torch.optim.Adam(model_GNNSAGE.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
device = torch.device("mps")
model_GNNSAGE.to(device)

GraphSAGE(
  (conv1): SAGEConv(1, 256, aggr=mean)
  (conv2): SAGEConv(256, 256, aggr=mean)
  (conv3): SAGEConv(256, 256, aggr=mean)
  (conv4): SAGEConv(256, 256, aggr=mean)
  (fc1): Linear(in_features=256, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=2, bias=True)
)

In [12]:
gnn.run_experiment(
    "Graph SAGE", 20, model_GNNSAGE, train_loader, test_loader, optimizer, criterion, device
)

# for epoch in tqdm(range(1, 2)):
#     print(f"Epoch {epoch} starts")
#     loss = train(model_GNNSAGE, train_loader, optimizer, criterion, device)
#     loss, acc, embeddings = test(model_GNNSAGE, test_loader, criterion, device)
#     print(f'Epoch: {epoch:03d}, Train Loss: {loss:.4f}, Test Acc: {acc:.4f}')

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1 starts


  5%|▌         | 1/20 [01:32<29:10, 92.15s/it]

Epoch: 001, Train Loss: 0.1579, Test Acc: 0.9612
Epoch 2 starts


 10%|█         | 2/20 [03:03<27:27, 91.52s/it]

Epoch: 002, Train Loss: 0.1565, Test Acc: 0.9607
Epoch 3 starts


 15%|█▌        | 3/20 [04:47<27:32, 97.22s/it]

Epoch: 003, Train Loss: 0.1563, Test Acc: 0.9611
Epoch 4 starts


 20%|██        | 4/20 [06:35<27:05, 101.62s/it]

Epoch: 004, Train Loss: 0.1547, Test Acc: 0.9608
Epoch 5 starts


 25%|██▌       | 5/20 [08:19<25:36, 102.41s/it]

Epoch: 005, Train Loss: 0.1547, Test Acc: 0.9611
Epoch 6 starts


 30%|███       | 6/20 [10:07<24:22, 104.49s/it]

Epoch: 006, Train Loss: 0.1547, Test Acc: 0.9614
Epoch 7 starts


 35%|███▌      | 7/20 [12:12<24:04, 111.11s/it]

Epoch: 007, Train Loss: 0.1538, Test Acc: 0.9613
Epoch 8 starts


 40%|████      | 8/20 [14:31<24:00, 120.01s/it]

Epoch: 008, Train Loss: 0.1534, Test Acc: 0.9605
Epoch 9 starts


 45%|████▌     | 9/20 [16:46<22:49, 124.50s/it]

Epoch: 009, Train Loss: 0.1532, Test Acc: 0.9609
Epoch 10 starts


 50%|█████     | 10/20 [18:54<20:56, 125.63s/it]

Epoch: 010, Train Loss: 0.1531, Test Acc: 0.9612
Epoch 11 starts


 55%|█████▌    | 11/20 [21:14<19:32, 130.24s/it]

Epoch: 011, Train Loss: 0.1525, Test Acc: 0.9612
Epoch 12 starts


 60%|██████    | 12/20 [23:32<17:38, 132.31s/it]

Epoch: 012, Train Loss: 0.1521, Test Acc: 0.9605
Epoch 13 starts


 65%|██████▌   | 13/20 [25:45<15:28, 132.59s/it]

Epoch: 013, Train Loss: 0.1523, Test Acc: 0.9613
Epoch 14 starts


 70%|███████   | 14/20 [27:54<13:09, 131.56s/it]

Epoch: 014, Train Loss: 0.1519, Test Acc: 0.9613
Epoch 15 starts


 75%|███████▌  | 15/20 [30:05<10:57, 131.51s/it]

Epoch: 015, Train Loss: 0.1516, Test Acc: 0.9617
Epoch 16 starts


 80%|████████  | 16/20 [32:18<08:47, 131.94s/it]

Epoch: 016, Train Loss: 0.1514, Test Acc: 0.9605
Epoch 17 starts


 85%|████████▌ | 17/20 [34:28<06:34, 131.37s/it]

Epoch: 017, Train Loss: 0.1514, Test Acc: 0.9613
Epoch 18 starts


 90%|█████████ | 18/20 [36:55<04:31, 135.98s/it]

Epoch: 018, Train Loss: 0.1507, Test Acc: 0.9610
Epoch 19 starts


 95%|█████████▌| 19/20 [38:55<02:11, 131.24s/it]

Epoch: 019, Train Loss: 0.1507, Test Acc: 0.9613
Epoch 20 starts


100%|██████████| 20/20 [40:48<00:00, 122.43s/it]

Epoch: 020, Train Loss: 0.1505, Test Acc: 0.9611





🏃 View run aged-quail-253 at: http://127.0.0.1:8080/#/experiments/145054897104438872/runs/3ad1aa83e33c41f2b6d28b69d9b5f2bc
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/145054897104438872


In [5]:
model_uri = "runs:/3ad1aa83e33c41f2b6d28b69d9b5f2bc/model"
model_GNNSAGE = mlflow.pytorch.load_model(model_uri)

  from .autonotebook import tqdm as notebook_tqdm
Downloading artifacts: 100%|██████████| 6/6 [00:00<00:00, 209.74it/s] 


In [14]:
model_GNNSAGE.eval()
device = torch.device("mps")
model_GNNSAGE.to(device)

all_preds = []
all_labels = []

with torch.no_grad():
    for data in test_loader:
        data = data.to(device)
        out = model_GNNSAGE(data)
        preds = out.argmax(dim=1).cpu().numpy()
        labels = data.y.cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels)

#         for data in loader:
#             # data.x = torch.cat([data.x, data.global_features[data.batch].unsqueeze(1)], dim=-1)
#             data = data.to(device)
#             if return_embeddings:
#                 out, emb = model(data, return_embeddings)
#                 mapping = {trx_id: emb for trx_id, emb in zip(data.trx_id, emb)}
#                 embeddings.update(mapping)
#             else:
#                 out = model(data)
#             loss = criterion(out, data.y)
#             total_loss += loss.item() * data.num_graphs
#             pred = out.argmax(dim=1)
#             correct += (pred == data.y).sum().item()
#             total += data.num_graphs

accuracy = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds, average="weighted")
recall = recall_score(all_labels, all_preds, average="weighted")
f1 = f1_score(all_labels, all_preds, average="weighted")

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)

Accuracy: 0.9611338398597311
Precision: 0.9618294597648908
Recall: 0.9611338398597311
F1 Score: 0.9609082623446633


In [23]:
with torch.no_grad():
    for data in data_loader:
        data = data.to(device)

        if data.num_graphs != len(data.trx_id):
            print("Batch num_graphs:", data.num_graphs)
            print("Number of trx_ids:", len(data.trx_id))

In [15]:
loss, acc, embeddings = gnn.test(
    model_GNNSAGE, data_loader, criterion, device, return_embeddings=True
)

In [16]:
len(embeddings)

184132

In [17]:
emb = {trx_id: emb.cpu().detach().numpy() for trx_id, emb in embeddings.items()}

In [18]:
pd_embeddings = pd.DataFrame({"transactionHash": emb.keys(), "embeddings": emb.values()})

In [19]:
pd_embeddings.to_csv(config.FLASHBOTS_Q2_DATA_DIR / "embeddings_128.csv", index=False)

In [4]:
class GAT(torch.nn.Module):
    def __init__(self, input_features, hidden_channels, num_classes, edge_attr_dim):
        super(GAT, self).__init__()
        torch.manual_seed(42)
        self.conv1 = GATv2Conv(input_features, hidden_channels, edge_dim=edge_attr_dim)
        self.conv2 = GATv2Conv(hidden_channels, hidden_channels, edge_dim=edge_attr_dim)
        self.conv3 = GATv2Conv(hidden_channels, hidden_channels, edge_dim=edge_attr_dim)
        self.conv4 = GATv2Conv(hidden_channels, hidden_channels, edge_dim=edge_attr_dim)
        self.lin = Linear(hidden_channels, 256)
        self.lin2 = Linear(256, num_classes)
        self.batchnorm = nn.BatchNorm1d(256)

    def forward(self, data):
        # 1. Obtain node embeddings
        x = self.conv1(data.x, data.edge_index, data.edge_attr)
        x = x.relu()
        x = self.conv2(x, data.edge_index, data.edge_attr)
        x = x.relu()
        x = self.conv3(x, data.edge_index, data.edge_attr)

        # 2. Readout layer
        x = global_mean_pool(x, data.batch, size=data.num_graphs)

        # 3. Apply a final classifier
        x = self.lin(x)

        # x = self.batchnorm(x)

        x = x.relu()
        x = self.lin2(x)
        return x

In [9]:
edge_attr_dim = 2
model_GAT = GAT(num_node_features, hidden_channels, num_classes, edge_attr_dim)
optimizer = torch.optim.Adam(model_GAT.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
device = torch.device("mps")
model_GAT.to(device)

for epoch in tqdm(range(1, 10)):
    print(f"Epoch {epoch} starts")
    loss = train(model_GAT, train_loader, optimizer, criterion, device)
    loss, acc, embeddings = test(model_GAT, test_loader, criterion, device)
    print(f"Epoch: {epoch:03d}, Train Loss: {loss:.4f}, Test Acc: {acc:.4f}")

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 1 starts


 11%|█         | 1/9 [03:31<28:08, 211.09s/it]

Epoch: 001, Train Loss: 0.6717, Test Acc: 0.6042
Epoch 2 starts


 11%|█         | 1/9 [05:06<40:53, 306.65s/it]


KeyboardInterrupt: 

In [21]:
class GINE(torch.nn.Module):
    def __init__(self, input_features, hidden_channels, num_classes, edge_attr_dim):
        super(GINE, self).__init__()
        torch.manual_seed(42)
        mlp1 = nn.Sequential(
            nn.Linear(input_features, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels),
        )
        self.conv1 = GINEConv(mlp1, edge_dim=edge_attr_dim)

        mlp2 = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels),
        )
        self.conv2 = GINEConv(mlp2, edge_dim=edge_attr_dim)

        mlp3 = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels),
        )
        self.conv3 = GINEConv(mlp3, edge_dim=edge_attr_dim)

        self.lin = Linear(hidden_channels, 256)
        self.lin2 = Linear(256, num_classes)
        self.batchnorm = nn.BatchNorm1d(256)

    def forward(self, data):
        # 1. Obtain node embeddings
        # edge_attr = data.edge_attr.unsqueeze(-1)  # Now shape: [num_edges, 1]

        x = self.conv1(data.x, data.edge_index, data.edge_attr)
        x = x.relu()
        x = self.conv2(x, data.edge_index, data.edge_attr)
        x = x.relu()
        x = self.conv3(x, data.edge_index, data.edge_attr)

        # 2. Readout layer
        x = global_mean_pool(x, data.batch, size=data.num_graphs)

        # 3. Apply a final classifier
        x = self.lin(x)

        # x = self.batchnorm(x)

        x = x.relu()
        x = self.lin2(x)
        return x

In [23]:
edge_attr_dim = 2
model_GINE = GINE(num_node_features, hidden_channels, num_classes, edge_attr_dim)
optimizer = torch.optim.Adam(model_GINE.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
device = torch.device("mps")
model_GINE.to(device)

for epoch in tqdm(range(1, 10)):
    print(f"Epoch {epoch} starts")
    loss = train(model_GINE, train_loader, optimizer, criterion, device)
    loss, acc, embeddings = test(model_GINE, test_loader, criterion, device)
    print(f"Epoch: {epoch:03d}, Train Loss: {loss:.4f}, Test Acc: {acc:.4f}")

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 1 starts


 11%|█         | 1/9 [00:56<07:30, 56.37s/it]

Epoch: 001, Train Loss: 0.5691, Test Acc: 0.7184
Epoch 2 starts


 22%|██▏       | 2/9 [03:05<11:34, 99.24s/it]

Epoch: 002, Train Loss: 0.5684, Test Acc: 0.7184
Epoch 3 starts


 22%|██▏       | 2/9 [03:11<11:11, 95.88s/it]


KeyboardInterrupt: 

In [18]:
class GINC(torch.nn.Module):
    def __init__(self, input_features, hidden_channels, num_classes):
        super(GINC, self).__init__()
        torch.manual_seed(42)
        mlp1 = nn.Sequential(
            nn.Linear(input_features, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels),
        )
        self.conv1 = GINConv(mlp1)

        mlp2 = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels),
        )
        self.conv2 = GINConv(mlp2)

        mlp3 = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels),
        )
        self.conv3 = GINConv(mlp3)

        self.lin = Linear(hidden_channels, 256)
        self.lin2 = Linear(256, num_classes)
        self.batchnorm = nn.BatchNorm1d(256)

    def forward(self, data):
        # 1. Obtain node embeddings
        x = self.conv1(data.x, data.edge_index)
        x = x.relu()
        x = self.conv2(x, data.edge_index)
        x = x.relu()
        x = self.conv3(x, data.edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, data.batch, size=data.num_graphs)

        # 3. Apply a final classifier
        x = self.lin(x)

        # x = self.batchnorm(x)

        x = x.relu()
        x = self.lin2(x)
        return x

In [63]:
model_GINC = GINC(num_node_features, hidden_channels, num_classes)
optimizer = torch.optim.Adam(model_GINC.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
device = torch.device("mps")
model_GINC.to(device)

for epoch in tqdm(range(0, 10)):
    print(f"Epoch {epoch} starts")
    loss = train(model_GINC, train_loader, optimizer, criterion, device)
    loss, acc, embeddings = test(model_GINC, test_loader, criterion, device)
    print(f"Epoch: {epoch:03d}, Train Loss: {loss:.4f}, Test Acc: {acc:.4f}")

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 0 starts


 10%|█         | 1/10 [00:16<02:31, 16.80s/it]

Epoch: 000, Train Loss: 0.2053, Test Acc: 0.9325
Epoch 1 starts


 20%|██        | 2/10 [00:32<02:07, 15.90s/it]

Epoch: 001, Train Loss: 0.2182, Test Acc: 0.9325
Epoch 2 starts


 20%|██        | 2/10 [00:38<02:33, 19.21s/it]


KeyboardInterrupt: 