# AIRL with TransactionsGraphEnvironment v2

In [1]:
import ast

import gymnasium as gym
import mlflow
import numpy as np
import pandas as pd
import torch
from imitation.algorithms.adversarial.airl import AIRL
from imitation.data.rollout import flatten_trajectories
from imitation.data.types import Trajectory
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.rewards.reward_nets import BasicShapedRewardNet
from imitation.util import logger
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import VecCheckNan
from stable_baselines3.ppo import MlpPolicy

import graph_reinforcement_learning_using_blockchain_data as grl
from graph_reinforcement_learning_using_blockchain_data import config

config.load_dotenv()
mlflow.set_tracking_uri(uri=config.MLFLOW_TRACKING_URI)

[32m2025-05-04 10:38:57.765[0m | [1mINFO    [0m | [36mgraph_reinforcement_learning_using_blockchain_data.config[0m:[36m<module>[0m:[36m12[0m - [1mPROJ_ROOT path is: /Users/liamtessendorf/Programming/Uni/2_Master/4_FS25_Programming/graph-reinforcement-learning-using-blockchain-data[0m


In [2]:
RNG = np.random.default_rng(seed=42)

## Dataset 

In [26]:
df_emb = pd.read_csv(config.FLASHBOTS_Q2_DATA_DIR / "state_embeddings_dgi.csv")
df_class0 = pd.read_csv(config.RAW_DATA_DIR / "receipts_class0.csv")
df_class1 = pd.read_csv(config.RAW_DATA_DIR / "receipts_class1.csv")
df_eth_balances_class1 = pd.read_csv(config.RAW_DATA_DIR / "eth_balances_class1.csv")
df_eth_balances_class0 = pd.read_csv(config.RAW_DATA_DIR / "eth_balances_class0.csv")

In [4]:
print(df_class0.columns)
print(df_eth_balances_class0.columns)

Index(['block_number', 'transaction_hash', 'blockHash', 'blockNumber',
       'logsBloom', 'gasUsed', 'contractAddress', 'cumulativeGasUsed',
       'transactionIndex', 'from', 'to', 'type', 'effectiveGasPrice', 'logs',
       'status'],
      dtype='object')
Index(['account', 'block_number', 'balance'], dtype='object')


In [5]:
df_class0_with_eth_balances = df_class0.merge(
    df_eth_balances_class0,
    left_on=["from", "blockNumber"],
    right_on=["account", "block_number"],
    how="inner",
)
df_class1_with_eth_balances = df_class1.merge(
    df_eth_balances_class1,
    left_on=["from", "blockNumber"],
    right_on=["account", "block_number"],
    how="inner",
)

In [6]:
df_class0_multi_occ = df_class0_with_eth_balances[
    df_class0_with_eth_balances["from"].duplicated(keep=False)
]

In [7]:
df_emb["embeddings"] = df_emb["embeddings"].apply(
    lambda x: np.array(ast.literal_eval(x), dtype=np.float32)
)

In [8]:
df_class0_with_eth_balances["label"] = 0
df_class1_with_eth_balances["label"] = 1

In [9]:
df_receipts = pd.concat(
    [df_class0_with_eth_balances, df_class1_with_eth_balances], ignore_index=True
)
df_receipts.drop_duplicates("transaction_hash", inplace=True)
df = df_receipts.merge(df_emb, how="right", left_on="transaction_hash", right_on="transactionHash")

In [10]:
df.head()

Unnamed: 0,block_number_x,transaction_hash,blockHash,blockNumber,logsBloom,gasUsed,contractAddress,cumulativeGasUsed,transactionIndex,from,...,created_at,account_address,profit_token_address,start_amount,end_amount,profit_amount,error,protocols,transactionHash,embeddings
0,16969850,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,0x241b2ebd536a6f546ce2214bcf2146d359ae4e0cc4e3...,16969850,0x00000000000000000000000000000000000000000000...,21000,,1315400,19,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,,,,,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,"[0.16500826, -0.09403816, 0.21792151, 0.012831..."
1,16975162,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,0x577d4ef683ebd84c73cdbc3635c7177f6ec137eef3bf...,16975162,0x00000000000000000000000080000000000000000000...,116880,,396290,3,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,,,,,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,"[-0.09382023, -0.012681959, 0.14728837, 0.0354..."
2,16983327,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,0xe79daae4074cb858a3e79b4d95064845295eaddc134c...,16983327,0x00000000000000000000000000000000000000000000...,103421,,3934866,21,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,,,,,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,"[-0.10671734, 0.0006063916, 0.160602, 0.039328..."
3,16990575,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,0x9c3fa88c63603c9c5696174085177f9559e20b331a64...,16990575,0x00200000000000000000000080004000000000000000...,105953,,4683545,39,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,,,,,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,"[-0.11480839, 0.027570017, 0.22743803, 0.04695..."
4,16994491,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,0xae6c66f64f13ba5404bbe4cebd33fdd666c7bc8fe601...,16994491,0x00200000000011000000000080000000000000000000...,223722,,3972377,34,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,,,,,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,"[-0.067668885, -0.094198264, 0.12297824, 0.158..."


In [11]:
df_median_gas_prices = pd.DataFrame(
    {
        "median_gas_price": df.groupby(["blockNumber"])["effectiveGasPrice"].median(),
        "std_gas_price": df.groupby(["blockNumber"])["effectiveGasPrice"].std(),
        "max_gas_price": df.groupby(["blockNumber"])["effectiveGasPrice"].max(),
        "min_gas_price": df.groupby(["blockNumber"])["effectiveGasPrice"].min(),
    }
)

df_with_median_gas_prices = df.merge(df_median_gas_prices, how="left", on="blockNumber")
df_with_median_gas_prices.head()

Unnamed: 0,block_number_x,transaction_hash,blockHash,blockNumber,logsBloom,gasUsed,contractAddress,cumulativeGasUsed,transactionIndex,from,...,end_amount,profit_amount,error,protocols,transactionHash,embeddings,median_gas_price,std_gas_price,max_gas_price,min_gas_price
0,16969850,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,0x241b2ebd536a6f546ce2214bcf2146d359ae4e0cc4e3...,16969850,0x00000000000000000000000000000000000000000000...,21000,,1315400,19,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,"[0.16500826, -0.09403816, 0.21792151, 0.012831...",28412780000.0,5338623000.0,40046142239,27985774295
1,16975162,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,0x577d4ef683ebd84c73cdbc3635c7177f6ec137eef3bf...,16975162,0x00000000000000000000000080000000000000000000...,116880,,396290,3,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,"[-0.09382023, -0.012681959, 0.14728837, 0.0354...",31324250000.0,13106710000.0,62912040686,29651658352
2,16983327,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,0xe79daae4074cb858a3e79b4d95064845295eaddc134c...,16983327,0x00000000000000000000000000000000000000000000...,103421,,3934866,21,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,"[-0.10671734, 0.0006063916, 0.160602, 0.039328...",47015990000.0,94835380000.0,298140379626,44115991364
3,16990575,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,0x9c3fa88c63603c9c5696174085177f9559e20b331a64...,16990575,0x00200000000000000000000080004000000000000000...,105953,,4683545,39,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,"[-0.11480839, 0.027570017, 0.22743803, 0.04695...",39976070000.0,22168940000.0,72384430845,29976074199
4,16994491,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,0xae6c66f64f13ba5404bbe4cebd33fdd666c7bc8fe601...,16994491,0x00200000000011000000000080000000000000000000...,223722,,3972377,34,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,"[-0.067668885, -0.094198264, 0.12297824, 0.158...",26952080000.0,3320303000.0,30915355327,22915355327


In [12]:
df_with_actions = df_with_median_gas_prices.copy()
df_with_actions["action"] = df_with_median_gas_prices.apply(
    lambda r: 1 if r["effectiveGasPrice"] > r["median_gas_price"] else 0, axis=1
)

In [13]:
df_with_actions["action"].mean()

0.24449359876667337

In [14]:
df_with_actions.head()

Unnamed: 0,block_number_x,transaction_hash,blockHash,blockNumber,logsBloom,gasUsed,contractAddress,cumulativeGasUsed,transactionIndex,from,...,profit_amount,error,protocols,transactionHash,embeddings,median_gas_price,std_gas_price,max_gas_price,min_gas_price,action
0,16969850,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,0x241b2ebd536a6f546ce2214bcf2146d359ae4e0cc4e3...,16969850,0x00000000000000000000000000000000000000000000...,21000,,1315400,19,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,"[0.16500826, -0.09403816, 0.21792151, 0.012831...",28412780000.0,5338623000.0,40046142239,27985774295,1
1,16975162,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,0x577d4ef683ebd84c73cdbc3635c7177f6ec137eef3bf...,16975162,0x00000000000000000000000080000000000000000000...,116880,,396290,3,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,"[-0.09382023, -0.012681959, 0.14728837, 0.0354...",31324250000.0,13106710000.0,62912040686,29651658352,1
2,16983327,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,0xe79daae4074cb858a3e79b4d95064845295eaddc134c...,16983327,0x00000000000000000000000000000000000000000000...,103421,,3934866,21,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,"[-0.10671734, 0.0006063916, 0.160602, 0.039328...",47015990000.0,94835380000.0,298140379626,44115991364,1
3,16990575,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,0x9c3fa88c63603c9c5696174085177f9559e20b331a64...,16990575,0x00200000000000000000000080004000000000000000...,105953,,4683545,39,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,"[-0.11480839, 0.027570017, 0.22743803, 0.04695...",39976070000.0,22168940000.0,72384430845,29976074199,0
4,16994491,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,0xae6c66f64f13ba5404bbe4cebd33fdd666c7bc8fe601...,16994491,0x00200000000011000000000080000000000000000000...,223722,,3972377,34,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,"[-0.067668885, -0.094198264, 0.12297824, 0.158...",26952080000.0,3320303000.0,30915355327,22915355327,1


In [15]:
df_with_actions.rename(columns={"balance": "eth_balance"}, inplace=True)
df_with_actions["eth_balance"] = df_with_actions["eth_balance"].astype("float64")
df_with_actions["median_gas_price"] = df_with_actions["median_gas_price"].astype("float64")
df_with_actions["std_gas_price"] = df_with_actions["std_gas_price"].astype("float64")
df_with_actions["from"] = df_with_actions["from"].astype("string")
df_with_actions["to"] = df_with_actions["to"].astype("string")
df_with_actions.head()

Unnamed: 0,block_number_x,transaction_hash,blockHash,blockNumber,logsBloom,gasUsed,contractAddress,cumulativeGasUsed,transactionIndex,from,...,profit_amount,error,protocols,transactionHash,embeddings,median_gas_price,std_gas_price,max_gas_price,min_gas_price,action
0,16969850,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,0x241b2ebd536a6f546ce2214bcf2146d359ae4e0cc4e3...,16969850,0x00000000000000000000000000000000000000000000...,21000,,1315400,19,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,"[0.16500826, -0.09403816, 0.21792151, 0.012831...",28412780000.0,5338623000.0,40046142239,27985774295,1
1,16975162,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,0x577d4ef683ebd84c73cdbc3635c7177f6ec137eef3bf...,16975162,0x00000000000000000000000080000000000000000000...,116880,,396290,3,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,"[-0.09382023, -0.012681959, 0.14728837, 0.0354...",31324250000.0,13106710000.0,62912040686,29651658352,1
2,16983327,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,0xe79daae4074cb858a3e79b4d95064845295eaddc134c...,16983327,0x00000000000000000000000000000000000000000000...,103421,,3934866,21,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,"[-0.10671734, 0.0006063916, 0.160602, 0.039328...",47015990000.0,94835380000.0,298140379626,44115991364,1
3,16990575,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,0x9c3fa88c63603c9c5696174085177f9559e20b331a64...,16990575,0x00200000000000000000000080004000000000000000...,105953,,4683545,39,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,"[-0.11480839, 0.027570017, 0.22743803, 0.04695...",39976070000.0,22168940000.0,72384430845,29976074199,0
4,16994491,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,0xae6c66f64f13ba5404bbe4cebd33fdd666c7bc8fe601...,16994491,0x00200000000011000000000080000000000000000000...,223722,,3972377,34,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,"[-0.067668885, -0.094198264, 0.12297824, 0.158...",26952080000.0,3320303000.0,30915355327,22915355327,1


In [16]:
df_with_actions[df_with_actions["std_gas_price"].isna() == True]

Unnamed: 0,block_number_x,transaction_hash,blockHash,blockNumber,logsBloom,gasUsed,contractAddress,cumulativeGasUsed,transactionIndex,from,...,profit_amount,error,protocols,transactionHash,embeddings,median_gas_price,std_gas_price,max_gas_price,min_gas_price,action
63,16975107,0xd05a59ef18204af79ae9bf2a7ba722bca892055819c7...,0x8d8cd7dfa64a3867f12d472be895ee2a1b163d854a02...,16975107,0x00200000000000000000000080000000000008000000...,196537,,685900,5,0x00000006e42915A2B6907f8b3fAF311B68862f60,...,6855415086744551,,"[""uniswap_v2""]",0xd05a59ef18204af79ae9bf2a7ba722bca892055819c7...,"[0.33906102, 0.11006685, -0.11049971, -0.11027...",3.442311e+10,,34423111170,34423111170,0
65,16976564,0x5530313d0b0271506691e3732c517172d5bfa1b2ba3d...,0x458b66a35808bf44a7e332b9f9b326ca8660952ab40a...,16976564,0x00200000000000000000000080000000200000000000...,190691,,1322470,7,0x00000006e42915A2B6907f8b3fAF311B68862f60,...,7796163749018364,,"[""uniswap_v2""]",0x5530313d0b0271506691e3732c517172d5bfa1b2ba3d...,"[0.41424438, 0.1412386, -0.15269175, -0.147597...",4.041179e+10,,40411786608,40411786608,0
76,16979829,0xe47601937f0538ecc2a67c0a1b2481a1d339b52b2ef0...,0x4d1ebc8a72732a87fb083679233a06cf0a33ec984b35...,16979829,0x00200000000000000000000084000000200000000000...,158198,,158198,0,0x00000006e42915A2B6907f8b3fAF311B68862f60,...,24408839782523421,,"[""uniswap_v2"",""uniswap_v3""]",0xe47601937f0538ecc2a67c0a1b2481a1d339b52b2ef0...,"[0.4598713, 0.15552238, -0.16139461, -0.179502...",1.532201e+11,,153220129205,153220129205,0
98,16992014,0x23e05562df7784836aaf6c8235d2aca5501621e6aab2...,0x8794222099471416ba2721137ee0e0d60a149b2f2029...,16992014,0x00200000000000000000000080000200000000000000...,219788,,219788,0,0x00000006e42915A2B6907f8b3fAF311B68862f60,...,35168219119161781,,"[""uniswap_v2"",""uniswap_v3""]",0x23e05562df7784836aaf6c8235d2aca5501621e6aab2...,"[0.46931684, 0.12999484, -0.15137316, -0.15249...",1.595249e+11,,159524873969,159524873969,0
167,17240935,0x74e6628155b2f61c067a568235c952e3c1fa4aa22d76...,0xa658109201d2dcd46dcc9f5ad1c71d29d6942c44b1c4...,17240935,0x00000000040000000000000000000000000000000000...,210289,,2954344,23,0x00000006e42915A2B6907f8b3fAF311B68862f60,...,13097045417966494,,"[""uniswap_v3""]",0x74e6628155b2f61c067a568235c952e3c1fa4aa22d76...,"[0.31971875, 0.08582469, -0.06865533, -0.08339...",6.190800e+10,,61907998723,61907998723,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
149181,17296632,0x2f06cc7f885dd98fbf765ad57706fe505c5b63478928...,0x64522455567aa9a4209551fd938a1e31b448487ac157...,17296632,0x00200000000000000010000080000000000000000000...,178376,,440031,1,0xffFf14106945bCB267B34711c416AA3085B8865F,...,22902832502491512,,"[""uniswap_v2""]",0x2f06cc7f885dd98fbf765ad57706fe505c5b63478928...,"[0.6747055, 0.22695717, -0.25271302, -0.264935...",3.450291e+10,,34502909809,34502909809,0
149182,17297932,0xd477617aaa93aad5aa5f5a2f880f3e934b70a91cd523...,0x493341206c063a7ba6c85137473b17f697ed72701c07...,17297932,0x00200000000000000010000080000000000000000000...,178388,,440065,1,0xffFf14106945bCB267B34711c416AA3085B8865F,...,31467600199400760,,"[""uniswap_v2""]",0xd477617aaa93aad5aa5f5a2f880f3e934b70a91cd523...,"[0.73240715, 0.24466729, -0.27648795, -0.27974...",2.928406e+10,,29284056698,29284056698,0
149183,17298360,0xda0b7b1156a57ff85ef5c2e80f47d6fc0c89c6253fca...,0x21a1c5ae76e2cbff779f308288e1cf2273d999f978e1...,17298360,0x00200000000000000010000080000000000000000000...,178352,,439975,1,0xffFf14106945bCB267B34711c416AA3085B8865F,...,38228015952142520,,"[""uniswap_v2""]",0xda0b7b1156a57ff85ef5c2e80f47d6fc0c89c6253fca...,"[0.7970462, 0.2638063, -0.30356315, -0.2968765...",2.959543e+10,,29595427147,29595427147,0
149184,17298379,0x7755b8553cc9f2479e5cc48ebf18dd6f25fe038a668f...,0xa0171ad96bcd73eb6aa92c8da1e3a8629c1c067ea811...,17298379,0x00200000000000000010000080000000000000000000...,178364,,1162669,7,0xffFf14106945bCB267B34711c416AA3085B8865F,...,28570015952142520,,"[""uniswap_v2""]",0x7755b8553cc9f2479e5cc48ebf18dd6f25fe038a668f...,"[0.84662515, 0.26852706, -0.4143204, -0.313244...",2.783189e+10,,27831892530,27831892530,0


In [17]:
df_with_actions = df_with_actions.fillna({"std_gas_price": 0})

In [18]:
df_with_actions_0 = df_with_actions[df_with_actions["label"] == 0]
df_with_actions_1 = df_with_actions[df_with_actions["label"] == 1]

In [19]:
unique_accs_0 = df_with_actions_0["from"].unique()
accs_train_0 = unique_accs_0[: int(0.8 * len(unique_accs_0))]
accs_val_0 = unique_accs_0[int(0.8 * len(unique_accs_0)) : int(0.9 * len(unique_accs_0))]
accs_test_0 = unique_accs_0[int(0.9 * len(unique_accs_0)) :]
df_with_actions_0_train = df_with_actions_0[df_with_actions_0["from"].isin(accs_train_0)]
df_with_actions_0_val = df_with_actions_0[df_with_actions_0["from"].isin(accs_val_0)]
df_with_actions_0_test = df_with_actions_0[df_with_actions_0["from"].isin(accs_test_0)]

In [20]:
unique_accs_1 = df_with_actions_1["from"].unique()
accs_train_1 = unique_accs_1[: int(0.8 * len(unique_accs_1))]
accs_val_1 = unique_accs_1[int(0.8 * len(unique_accs_1)) : int(0.9 * len(unique_accs_1))]
accs_test_1 = unique_accs_1[int(0.9 * len(unique_accs_1)) :]
df_with_actions_1_train = df_with_actions_1[df_with_actions_1["from"].isin(accs_train_1)]
df_with_actions_1_val = df_with_actions_1[df_with_actions_1["from"].isin(accs_val_1)]
df_with_actions_1_test = df_with_actions_1[df_with_actions_1["from"].isin(accs_test_1)]

In [21]:
df_val = pd.concat([df_with_actions_0_val, df_with_actions_1_val])

In [22]:
df_with_actions_1_train[
    df_with_actions_1_train["from"] == "0x1e6c1c4669f612112a7caCa5596BfE6629e669aA"
]

Unnamed: 0,block_number_x,transaction_hash,blockHash,blockNumber,logsBloom,gasUsed,contractAddress,cumulativeGasUsed,transactionIndex,from,...,profit_amount,error,protocols,transactionHash,embeddings,median_gas_price,std_gas_price,max_gas_price,min_gas_price,action


## Creating trajectories

In [25]:
def extract_trajectories(df: pd.DataFrame):
    trajectories = []
    for account, group in df.groupby("from"):
        group = group.sort_values("blockNumber")
        obs_list = group["embeddings"].tolist() + [np.zeros(256, dtype=np.float32)]
        print(obs_list)
        traj = {
            "obs": np.stack(obs_list),  # Convert list of arrays to a single numpy array
            "acts": np.array(group["action"].tolist()),
            "label": group["label"].iloc[0],
        }
        trajectories.append(traj)
    return trajectories


trajectories_1_train = extract_trajectories(df_with_actions_1_train)
trajectories_0_train = extract_trajectories(df_with_actions_0_train)
trajectories_1_test = extract_trajectories(df_with_actions_1_test)
trajectories_0_test = extract_trajectories(df_with_actions_0_test)
trajectories_val = extract_trajectories(df_val)

KeyboardInterrupt: 

In [71]:
trajectories_1 = [
    Trajectory(obs=traj["obs"], acts=traj["acts"], infos=None, terminal=True)
    for traj in trajectories_1_train
]
trajectories_0 = [
    Trajectory(obs=traj["obs"], acts=traj["acts"], infos=None, terminal=True)
    for traj in trajectories_0_train
]
trajectories_val = [
    Trajectory(obs=traj["obs"], acts=traj["acts"], infos=None, terminal=True)
    for traj in trajectories_val
]

trajectories_1 = flatten_trajectories(trajectories_1)
trajectories_0 = flatten_trajectories(trajectories_0)
trajectories_val = flatten_trajectories(trajectories_val)

In [72]:
trajectories_1_test = [
    Trajectory(obs=traj["obs"], acts=traj["acts"], infos=None, terminal=True)
    for traj in trajectories_1_test
]
trajectories_0_test = [
    Trajectory(obs=traj["obs"], acts=traj["acts"], infos=None, terminal=True)
    for traj in trajectories_0_test
]

trajectories_1_test = flatten_trajectories(trajectories_1_test)
trajectories_0_test = flatten_trajectories(trajectories_0_test)

## Setting up environments

In [73]:
ID0 = "gymnasium_env/TransactionGraphEnv0-v2"
gym.envs.register(
    id=ID0,
    entry_point=grl.TransactionGraphEnvV2,
    kwargs={
        "df": df_with_actions_0,
        "alpha": 0.9,
        "device": torch.device("mps"),
        "label": 0,
        "model_uri": "mlflow-artifacts:/330930495026013213/1f3f3d97f6a14569b9f2f285c5a98bb4/artifacts/model",
        "observation_space_dim": 256,
        "case": "unsupervised",
    },
    max_episode_steps=300,
)

ID1 = "gymnasium_env/TransactionGraphEnv1-v2"
gym.envs.register(
    id=ID1,
    entry_point=grl.TransactionGraphEnvV2,
    kwargs={
        "df": df_with_actions_1,
        "alpha": 0.9,
        "device": torch.device("mps"),
        "label": 1,
        "model_uri": "mlflow-artifacts:/330930495026013213/1f3f3d97f6a14569b9f2f285c5a98bb4/artifacts/model",
        "observation_space_dim": 256,
        "case": "unsupervised",
    },
    max_episode_steps=300,
)

  logger.warn(f"Overriding environment {new_spec.id} already in registry.")
  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


In [74]:
gym.pprint_registry()

===== classic_control =====
Acrobot-v1             CartPole-v0            CartPole-v1
MountainCar-v0         MountainCarContinuous-v0 Pendulum-v1
===== phys2d =====
phys2d/CartPole-v0     phys2d/CartPole-v1     phys2d/Pendulum-v0
===== box2d =====
BipedalWalker-v3       BipedalWalkerHardcore-v3 CarRacing-v2
LunarLander-v2         LunarLanderContinuous-v2
===== toy_text =====
Blackjack-v1           CliffWalking-v0        FrozenLake-v1
FrozenLake8x8-v1       Taxi-v3
===== tabular =====
tabular/Blackjack-v0   tabular/CliffWalking-v0
===== mujoco =====
Ant-v2                 Ant-v3                 Ant-v4
HalfCheetah-v2         HalfCheetah-v3         HalfCheetah-v4
Hopper-v2              Hopper-v3              Hopper-v4
Humanoid-v2            Humanoid-v3            Humanoid-v4
HumanoidStandup-v2     HumanoidStandup-v4     InvertedDoublePendulum-v2
InvertedDoublePendulum-v4 InvertedPendulum-v2    InvertedPendulum-v4
Pusher-v2              Pusher-v4              Reacher-v2
Reacher-v4         

In [75]:
env0 = Monitor(gym.make(ID0))

venv0 = make_vec_env(
    ID0,
    rng=RNG,
    n_envs=1,
    post_wrappers=[lambda env0, _: RolloutInfoWrapper(env0)],
    parallel=False,
)

venv0 = VecCheckNan(venv0, raise_exception=True)  # Check for NaN observations
venv0.reset()

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

array([[0.        , 0.19274484, 0.        , 0.        , 0.        ,
        0.00796854, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.03639701, 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.08493737, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.01121702, 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.00877421, 0.1003125 ,
        0.        , 0.        , 0.        , 0.14473669, 0.        ,
        0.        , 0.        , 0.        , 0.01462195, 0.        ,
        0.        , 0.        , 0.        , 0.06284675, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.01207574,
        0.        , 0.        , 0.        , 0.1652727 , 0.        ,
        0.        , 0.        , 0.        , 0.07557897, 0.24741852,
        0.        , 0.10367643, 0.        , 0.        , 0.40269884,
        0.        , 0.        , 0.        , 0.  

In [76]:
env1 = Monitor(gym.make(ID1))

venv1 = make_vec_env(
    ID1,
    rng=RNG,
    n_envs=1,
    post_wrappers=[lambda env1, _: RolloutInfoWrapper(env1)],
    parallel=False,
)

venv1 = VecCheckNan(venv1, raise_exception=True)  # Check for NaN observations
venv1.reset()

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

array([[0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.05173976, 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.07948201, 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.00884642, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.17551276, 0.01525478, 0.        ,
        0.        , 0.        , 0.        , 0.04692906, 1.6734402 ,
        0.        , 0.        , 0.        , 0.  

## AIRL setup

In [77]:
# Set parameters for the PPO algorithm (generator)
learning_rate = 0.001  # Learning rate, can be a function of progress
batch_size = 60  # Mini batch size for each gradient update
n_epochs = 15  # N of epochs when optimizing the surrogate loss

gamma = 0.5  # Discount factor, focus on the recent rewards
gae_lambda = 0  # Generalized advantage estimation
clip_range = 0.1  # Clipping parameter
ent_coef = 0.01  # Entropy coefficient for the loss calculation
vf_coef = 0.5  # Value function coef. for the loss calculation
max_grad_norm = 0.5  # The maximum value for the gradient clipping

verbose = 0  # Verbosity level: 0 no output, 1 info, 2 debug
normalize_advantage = True  # Whether to normalize or not the advantage

clip_range_vf = None  # Clip for the value function
use_sde = False  # Use State Dependent Exploration
sde_sample_freq = -1  # SDE - noise matrix frequency (-1 = disable)

# Set parameters for the AIRL trainer
gen_replay_buffer_capacity = None
allow_variable_horizon = True

disc_opt_kwargs = {
    "lr": 0.001,
}
policy_kwargs = {"use_expln": True}  # Fixing an issue with NaNs

In [78]:
# Set the number of timesteps, batch size and number of disc updates

# Total number of timesteps in the whole training
total_timesteps = 3000 * 100

# Generator
gen_train_timesteps = 3000  # N steps in the environment per one round
n_steps = gen_train_timesteps

# Discriminator batches
demo_minibatch_size = 60  # N samples in minibatch for one discrim. update
demo_batch_size = 300 * 10  # N samples in the batch of expert data (batch)
n_disc_updates_per_round = 4  # N discriminator updates per one round

In [79]:
hier_logger = logger.configure()
hier_logger.default_logger.output_formats.append(grl.MLflowOutputFormat())

In [80]:
# Initialize the learner PPO policy (generator)
learner0 = PPO(
    env=venv0,
    policy=MlpPolicy,
    policy_kwargs=policy_kwargs,
    learning_rate=learning_rate,
    n_steps=n_steps,
    batch_size=batch_size,
    n_epochs=n_epochs,
    gamma=gamma,
    gae_lambda=gae_lambda,
    clip_range=clip_range,
    clip_range_vf=clip_range_vf,
    normalize_advantage=normalize_advantage,
    ent_coef=ent_coef,
    vf_coef=vf_coef,
    max_grad_norm=max_grad_norm,
    use_sde=use_sde,
    sde_sample_freq=sde_sample_freq,
    verbose=verbose,
    seed=42,
    device="mps",
)

reward_net0 = BasicShapedRewardNet(
    observation_space=venv0.observation_space,
    action_space=venv0.action_space,
    normalize_input_layer=RunningNorm,
)

# Initialize the AIRL trainer
airl_trainer0 = AIRL(
    demonstrations=trajectories_0,
    demo_batch_size=demo_batch_size,
    demo_minibatch_size=demo_minibatch_size,
    n_disc_updates_per_round=n_disc_updates_per_round,
    gen_train_timesteps=gen_train_timesteps,
    gen_replay_buffer_capacity=gen_replay_buffer_capacity,
    venv=venv0,
    gen_algo=learner0,
    reward_net=reward_net0,
    allow_variable_horizon=allow_variable_horizon,
    disc_opt_kwargs=disc_opt_kwargs,
    custom_logger=hier_logger,
)

Running with `allow_variable_horizon` set to True. Some algorithms are biased towards shorter or longer episodes, which may significantly confound results. Additionally, even unbiased algorithms can exploit the information leak from the termination condition, producing spuriously high performance. See https://imitation.readthedocs.io/en/latest/getting-started/variable-horizon.html for more information.


In [81]:
# Initialize the learner PPO policy (generator)
learner1 = PPO(
    env=venv1,
    policy=MlpPolicy,
    policy_kwargs=policy_kwargs,
    learning_rate=learning_rate,
    n_steps=n_steps,
    batch_size=batch_size,
    n_epochs=n_epochs,
    gamma=gamma,
    gae_lambda=gae_lambda,
    clip_range=clip_range,
    clip_range_vf=clip_range_vf,
    normalize_advantage=normalize_advantage,
    ent_coef=ent_coef,
    vf_coef=vf_coef,
    max_grad_norm=max_grad_norm,
    use_sde=use_sde,
    sde_sample_freq=sde_sample_freq,
    verbose=verbose,
    seed=42,
    device="mps",
)

reward_net1 = BasicShapedRewardNet(
    observation_space=venv1.observation_space,
    action_space=venv1.action_space,
    normalize_input_layer=RunningNorm,
)

# Initialize the AIRL trainer
airl_trainer1 = AIRL(
    demonstrations=trajectories_1,
    demo_batch_size=demo_batch_size,
    demo_minibatch_size=demo_minibatch_size,
    n_disc_updates_per_round=n_disc_updates_per_round,
    gen_train_timesteps=gen_train_timesteps,
    gen_replay_buffer_capacity=gen_replay_buffer_capacity,
    venv=venv1,
    gen_algo=learner1,
    reward_net=reward_net1,
    allow_variable_horizon=allow_variable_horizon,
    disc_opt_kwargs=disc_opt_kwargs,
    custom_logger=hier_logger,
)

Running with `allow_variable_horizon` set to True. Some algorithms are biased towards shorter or longer episodes, which may significantly confound results. Additionally, even unbiased algorithms can exploit the information leak from the termination condition, producing spuriously high performance. See https://imitation.readthedocs.io/en/latest/getting-started/variable-horizon.html for more information.


## Training AIRL discriminator and generator, stats are saved with mlflow

We need to train 2 distinct airl trainers, one for arbitrage transactions and the other for class 0 transactions. The goal is to use the resulting reward functions to do classification.

In [None]:
mlflow.set_experiment("AIRLv2 DGI pre-trained embeddings")
with mlflow.start_run():
    mlflow.log_param("n_steps", n_steps)
    mlflow.log_param("batch_size", batch_size)
    mlflow.log_param("total_timesteps", total_timesteps)

    airl_trainer0.train(total_timesteps=total_timesteps)

    learner0.save(config.MODELS_DIR / "learner0v2_dgi")
    torch.save(reward_net0, config.MODELS_DIR / "reward_net0v2_dgi")

    mlflow.log_artifact(config.MODELS_DIR / "learner0v2_dgi.zip")
    mlflow.log_artifact(config.MODELS_DIR / "reward_net0v2_dgi")
    mlflow.end_run()

In [None]:
mlflow.set_experiment("AIRLv2 DGI pre-trained embeddings")

with mlflow.start_run():
    mlflow.log_param("n_steps", n_steps)
    mlflow.log_param("batch_size", batch_size)
    mlflow.log_param("total_timesteps", total_timesteps)

    airl_trainer1.train(total_timesteps=total_timesteps)

    learner1.save(config.MODELS_DIR / "learner1v2_dgi")
    torch.save(reward_net1, config.MODELS_DIR / "reward_net1v2_dgi")

    mlflow.log_artifact(config.MODELS_DIR / "learner1v2_dgi.zip")
    mlflow.log_artifact(config.MODELS_DIR / "reward_net1v2_dgi")
    mlflow.end_run()

## Metrics

Download the reward network from mlflow.

In [82]:
local_path = mlflow.artifacts.download_artifacts(
    artifact_uri="mlflow-artifacts:/727687587886726594/3025fa0c446443cab65b5b8f4628e59f/artifacts/reward_net0v2_dgi"
)

reward_net0 = torch.load(local_path, weights_only=False)

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

In [87]:
local_path = mlflow.artifacts.download_artifacts(
    artifact_uri="mlflow-artifacts:/727687587886726594/661cd09c18864cd39c2b864303821029/artifacts/reward_net1v2_dgi"
)

reward_net1 = torch.load(local_path, weights_only=False)

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
reward_net0 = torch.load(config.MODELS_DIR / "reward_net0v2_dgi", weights_only=False)
reward_net1 = torch.load(config.MODELS_DIR / "reward_net1v2_dgi", weights_only=False)

### Comparing **uncalibrated** mean rewards

In [88]:
def compare_model_weights(model1, model2):
    """Compare weights between two PyTorch models."""
    are_identical = True
    differences = []

    for (name1, param1), (name2, param2) in zip(
        model1.named_parameters(), model2.named_parameters()
    ):
        if not torch.allclose(param1, param2):
            are_identical = False
            differences.append(f"Difference in {name1}: {torch.abs(param1 - param2).max().item()}")

    return are_identical, differences


are_same, diffs = compare_model_weights(reward_net0, reward_net1)
print(f"Models are identical: {are_same}")
if not are_same:
    print("Differences found:")
    for diff in diffs:
        print(f"  {diff}")

Models are identical: False
Differences found:
  Difference in _base.mlp.dense0.weight: 0.5747567415237427
  Difference in _base.mlp.dense0.bias: 0.18085934221744537
  Difference in _base.mlp.dense_final.weight: 0.21682842075824738
  Difference in _base.mlp.dense_final.bias: 0.02910691499710083
  Difference in potential._potential_net.dense0.weight: 0.42624878883361816
  Difference in potential._potential_net.dense0.bias: 0.19475007057189941
  Difference in potential._potential_net.dense1.weight: 0.21872328221797943
  Difference in potential._potential_net.dense1.bias: 0.08184748888015747
  Difference in potential._potential_net.dense_final.weight: 0.20185664296150208
  Difference in potential._potential_net.dense_final.bias: 0.03617658466100693


In [89]:
states0, obs0, next_states0, dones0 = (
    trajectories_0_test.obs,
    trajectories_0_test.acts,
    trajectories_0_test.next_obs,
    trajectories_0_test.dones,
)
states1, obs1, next_states1, dones1 = (
    trajectories_1_test.obs,
    trajectories_1_test.acts,
    trajectories_1_test.next_obs,
    trajectories_1_test.dones,
)

In [90]:
# For reward_net1
rewards1 = reward_net1.predict(states0, obs0, next_states0, dones0)
norm_rewards1 = (rewards1 - rewards1.mean()) / rewards1.std()
print("Reward network 1 with traj0: ", norm_rewards1.mean())

# For reward_net0
rewards0 = reward_net0.predict(states0, obs0, next_states0, dones0)
norm_rewards0 = (rewards0 - rewards0.mean()) / rewards0.std()
print("Reward network 0 with traj0: ", norm_rewards0.mean())

Reward network 1 with traj0:  -1.0606202e-08
Reward network 0 with traj0:  -5.303101e-08


In [91]:
# For reward_net1
rewards1 = reward_net1.predict(states1, obs1, next_states1, dones1)
norm_rewards1 = (rewards1 - rewards1.mean()) / rewards1.std()
print("Reward network 1 with traj1: ", norm_rewards1.mean())

# For reward_net0
rewards0 = reward_net0.predict(states1, obs1, next_states1, dones1)
norm_rewards0 = (rewards0 - rewards0.mean()) / rewards0.std()
print("Reward network 0 with traj1: ", norm_rewards0.mean())

Reward network 1 with traj1:  1.7598559e-08
Reward network 0 with traj1:  3.3521066e-09


### Calibration using affine transformations

We use a validation set to calculate alpha and beta parameters for our affine transformation, which is simply 
$f(x) = \alpha x + \beta$, where $\alpha = {\sigma_{target}\over\sigma_{x}}$ and $\beta = \mu_{target} - \alpha_x * \mu_x$

In [92]:
states, obs, next_states, dones = (
    trajectories_val.obs,
    trajectories_val.acts,
    trajectories_val.next_obs,
    trajectories_val.dones,
)

In [93]:
outputs_arb = reward_net1.predict(states, obs, next_states, dones)
outputs_nonarb = reward_net0.predict(states, obs, next_states, dones)

# Compute empirical mean and std for each network
mean_arb, std_arb = outputs_arb.mean(), outputs_arb.std()
mean_nonarb, std_nonarb = outputs_nonarb.mean(), outputs_nonarb.std()

# Define target calibration values (e.g., mean=0, std=1)
target_mean, target_std = 0.0, 1.0

# Calculate affine transformation parameters
alpha_arb = target_std / std_arb
beta_arb = target_mean - alpha_arb * mean_arb

alpha_nonarb = target_std / std_nonarb
beta_nonarb = target_mean - alpha_nonarb * mean_nonarb

# Calibrate the outputs
calibrated_arb = outputs_arb * alpha_arb + beta_arb
calibrated_nonarb = outputs_nonarb * alpha_nonarb + beta_nonarb

### Accuracy, Recall and F1-Score for both calibrated reward networks

In [94]:
# For reward_net1
rewards1 = reward_net1.predict(states0, obs0, next_states0, dones0)
calibrated_rewards1 = alpha_arb * rewards1 + beta_arb
print("Reward network 1 with traj0: ", calibrated_rewards1.mean())

# For reward_net0
rewards0 = reward_net0.predict(states0, obs0, next_states0, dones0)
calibrated_rewards0 = alpha_nonarb * rewards0 + beta_nonarb
print("Reward network 0 with traj0: ", calibrated_rewards0.mean())

Reward network 1 with traj0:  -0.24265695
Reward network 0 with traj0:  0.12527163


In [95]:
predictions = np.array(
    [1 if r0 > r1 else 0 for r0, r1 in zip(calibrated_rewards0, calibrated_rewards1)]
)
true_labels = np.array([1] * len(predictions))

accuracy = accuracy_score(true_labels, predictions)
precision = precision_score(true_labels, predictions)
recall = recall_score(true_labels, predictions)
f1 = f1_score(true_labels, predictions)

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)

Accuracy: 0.7159406858202039
Precision: 1.0
Recall: 0.7159406858202039
F1 Score: 0.8344585471239535


In [96]:
# For reward_net1
rewards1 = reward_net1.predict(states1, obs1, next_states1, dones1)
calibrated_rewards1 = alpha_arb * rewards1 + beta_arb
print("Reward network 1 with traj1: ", calibrated_rewards1.mean())

# For reward_net0
rewards0 = reward_net0.predict(states1, obs1, next_states1, dones1)
calibrated_rewards0 = alpha_nonarb * rewards0 + beta_nonarb
print("Reward network 0 with traj1: ", calibrated_rewards0.mean())

Reward network 1 with traj1:  0.4241865
Reward network 0 with traj1:  0.40248522


In [97]:
predictions = np.array(
    [1 if r0 < r1 else 0 for r0, r1 in zip(calibrated_rewards0, calibrated_rewards1)]
)
true_labels = np.array([1] * len(predictions))

accuracy = accuracy_score(true_labels, predictions)
precision = precision_score(true_labels, predictions)
recall = recall_score(true_labels, predictions)
f1 = f1_score(true_labels, predictions)

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)

Accuracy: 0.5566783831282952
Precision: 1.0
Recall: 0.5566783831282952
F1 Score: 0.7152130962461191
