In [18]:
import mlflow.pytorch
import torch
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn.models import DeepGraphInfomax
from torch.utils.data import ConcatDataset

from tqdm import tqdm
import pandas as pd
import graph_reinforcement_learning_using_blockchain_data as grl
from graph_reinforcement_learning_using_blockchain_data import config

config.load_dotenv()

True

## Load Data

In [19]:
dataset = torch.load(
    config.FLASHBOTS_Q2_DATA_DIR / "state_graphs_train_test_split.pt", weights_only=False
)

In [20]:
max_feats_len = 0
for vals in dataset.values():
    for graphs in vals:
        if graphs.x.shape[1] > max_feats_len:
            max_feats_len = graphs.x.shape[1]

In [21]:
for vals in dataset.values():
    for graphs in vals:
        graphs.x = grl.pad_features(graphs.x, max_feats_len)

In [22]:
train_dataset = dataset["train_graphs"]
test_dataset = dataset["test_graphs"]
combined_dataset = ConcatDataset([train_dataset, test_dataset])

In [24]:
train_loader = DataLoader(
    train_dataset, batch_size=512, shuffle=True, exclude_keys=["account_mapping"], drop_last=False
)
test_loader = DataLoader(
    test_dataset, batch_size=128, shuffle=False, exclude_keys=["account_mapping"], drop_last=False
)
data_loader = DataLoader(
    combined_dataset,
    batch_size=128,
    shuffle=False,
    exclude_keys=["account_mapping"],
    drop_last=False,
)

## Train DGI

In [None]:
summary = lambda z, *_, batch=None, **__: global_mean_pool(z, batch)


def corruption(data):
    """
    Corrupts the graph data by shuffling node features.
    Accepts a Data/Batch object and returns a corrupted Data/Batch object.
    """
    corrupted_data = data.clone()
    num_nodes_total = data.x.size(0)
    perm = torch.randperm(num_nodes_total, device=data.x.device)
    corrupted_data.x = data.x[perm]
    return corrupted_data

In [None]:
num_node_features = max_feats_len
dim_global_features = 0
hidden_channels = 128

device = torch.device("mps")
enc = grl.SAGEEncoder(num_node_features, hidden_channels).to(device)
dgi = DeepGraphInfomax(
    hidden_channels=hidden_channels, encoder=enc, summary=summary, corruption=corruption
).to(device)

In [None]:
enc = grl.pretrain_dgi(train_loader, dgi, device, epochs=20)

## Train GraphSAGE Classifier using frozen DGI embeddings

In [14]:
model_uri = "mlflow-artifacts:/330930495026013213/7cd9300ef7834a76980e2ec75347bbae/artifacts/model"
model = mlflow.pytorch.load_model(model_uri)
enc = model.encoder

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

In [15]:
for p in enc.parameters():
    p.requires_grad = False

device = torch.device("mps")
model = grl.GraphSAGEClassifier(enc, hidden=128, num_classes=2).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

In [None]:
model, embeddings = grl.run_experiment(
    "DGI GraphSAGE",
    20,
    model,
    train_loader,
    test_loader,
    optimizer,
    criterion,
    device,
    return_embeddings=True,
)

## Create Embeddings from pre-trained GNN for downstream tasks

In [25]:
model_uri = "mlflow-artifacts:/132032870842317128/7559d28e50674e629ce8042ea64902de/artifacts/model"
model = mlflow.pytorch.load_model(model_uri)
device = torch.device("mps")

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

In [26]:
criterion = torch.nn.CrossEntropyLoss()
_, _, embeddings = grl.test(model, data_loader, criterion, device, return_embeddings=True)

In [27]:
emb = {trx_id: emb.cpu().detach().numpy().tolist() for trx_id, emb in embeddings.items()}

In [28]:
df_embeddings = pd.DataFrame({"transactionHash": emb.keys(), "embeddings": emb.values()})

In [29]:
df_embeddings.to_csv(
    config.FLASHBOTS_Q2_DATA_DIR / "state_embeddings_pre_trained_128.csv", index=False
)

## Create Embeddings from DGI encoder for downstream tasks

In [30]:
model_uri = "mlflow-artifacts:/330930495026013213/7cd9300ef7834a76980e2ec75347bbae/artifacts/model"
model = mlflow.pytorch.load_model(model_uri)

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

In [31]:
model.eval()
unsup_list, y_list = [], []
device = torch.device("mps")
embeddings = {}
with torch.no_grad():
    for data in tqdm(data_loader):
        data = data.to(device)

        z_nodes = model.encoder(data)
        unsup_emb = global_mean_pool(z_nodes, data.batch)

        mapping = {trx_id: emb for trx_id, emb in zip(data.trx_id, unsup_emb)}
        embeddings.update(mapping)

        unsup_list.append(unsup_emb.cpu())
        y_list.append(data.y.cpu())

unsup_X = torch.cat(unsup_list).numpy()
y = torch.cat(y_list).numpy()

100%|██████████| 1166/1166 [00:40<00:00, 28.65it/s]


In [32]:
emb = {trx_id: emb.cpu().detach().numpy().tolist() for trx_id, emb in embeddings.items()}
df_embeddings = pd.DataFrame({"transactionHash": emb.keys(), "embeddings": emb.values()})

df_embeddings.to_csv(config.FLASHBOTS_Q2_DATA_DIR / "state_embeddings_dgi_128.csv", index=False)