In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch_geometric.loader import DataLoader
from torch_geometric.nn import SAGEConv
from torch_geometric.nn import global_mean_pool
from tqdm import tqdm

from graph_reinforcement_learning_using_blockchain_data import config

config.load_dotenv()

[32m2025-03-06 15:18:38.735[0m | [1mINFO    [0m | [36mgraph_reinforcement_learning_using_blockchain_data.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /Users/liamtessendorf/Programming/Uni/2_Master/4_FS25_Programming/graph-reinforcement-learning-using-blockchain-data[0m


True

In [2]:
dataset = torch.load(config.FLASHBOTS_Q2_DATA_DIR / "trx_graphs.pt", weights_only=False)
train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=42)

In [3]:
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, exclude_keys=["account_mapping"],
                          drop_last=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, exclude_keys=["account_mapping"], drop_last=False)

In [4]:
# for step, data in enumerate(train_loader):
#     print(f'Step {step + 1}:')
#     print('=======')
#     print(f'Number of graphs in the current batch: {data.num_graphs}')
#     print(data)
#     print()

In [7]:
class GraphSAGEModel(nn.Module):
    def __init__(self, num_node_features, hidden_channels, num_classes):
        super(GraphSAGEModel, self).__init__()
        self.conv1 = SAGEConv(num_node_features, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.conv3 = SAGEConv(hidden_channels, hidden_channels)
        self.conv4 = SAGEConv(hidden_channels, hidden_channels)
        self.fc1 = nn.Linear(hidden_channels, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, data, return_embeddings=False):
        # 1. embeddings
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = F.relu(self.conv4(x, edge_index))

        # 2. readout
        embeddings = global_mean_pool(x, data.batch, size=data.num_graphs)
        
        # 3. final classifier
        x = self.fc1(embeddings)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)
        if return_embeddings:
            return x, embeddings
        else: 
            return x, None

In [51]:
def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for data in loader:
        # data.x = torch.cat([data.x, data.global_features[data.batch].unsqueeze(1)], dim=-1)
        data = data.to(device)
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item() * data.num_graphs
        
    return total_loss / len(loader.dataset)


def test(model, loader, criterion, device, return_embeddings=False):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    embeddings = {}

    with torch.no_grad():
        for data in loader:
            # data.x = torch.cat([data.x, data.global_features[data.batch].unsqueeze(1)], dim=-1)
            data = data.to(device)
            out, emb = model(data)
            mapping = {trx_id: emb for trx_id, emb in zip(data.trx_id, emb)}
            embeddings.update(mapping)     
            loss = criterion(out, data.y)
            total_loss += loss.item() * data.num_graphs
            pred = out.argmax(dim=1)
            correct += (pred == data.y).sum().item()
            total += data.num_graphs
    return total_loss / len(loader.dataset), correct / total, embeddings

In [52]:
num_node_features = 1
dim_global_features = 0
hidden_channels = 512  # adjust as needed
num_classes = 2  # binary classification

model = GraphSAGEModel(num_node_features + dim_global_features, hidden_channels, num_classes)
print(model)

GraphSAGEModel(
  (conv1): SAGEConv(1, 512, aggr=mean)
  (conv2): SAGEConv(512, 512, aggr=mean)
  (conv3): SAGEConv(512, 512, aggr=mean)
  (conv4): SAGEConv(512, 512, aggr=mean)
  (fc1): Linear(in_features=512, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=2, bias=True)
)


In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
device = torch.device("mps")
model.to(device)

for epoch in tqdm(range(1, 2)):
    print(f"Epoch {epoch} starts")
    loss = train(model, train_loader, optimizer, criterion, device)
    loss, acc, embeddings = test(model, test_loader, criterion, device, return_embeddings=True)
    print(f'Epoch: {epoch:03d}, Train Loss: {loss:.4f}, Test Acc: {acc:.4f}')

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch 1 starts


In [44]:
len(embeddings)

147845