# AIRL with TransactionsGraphEnvironment v2

In [1]:
import ast

import gymnasium as gym
import mlflow
import numpy as np
import pandas as pd
import torch
from imitation.algorithms.adversarial.airl import AIRL
from imitation.data.rollout import flatten_trajectories
from imitation.data.types import Trajectory
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.rewards.reward_nets import BasicShapedRewardNet
from imitation.util import logger
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import VecCheckNan
from stable_baselines3.ppo import MlpPolicy

import graph_reinforcement_learning_using_blockchain_data as grl
from graph_reinforcement_learning_using_blockchain_data import config

config.load_dotenv()
mlflow.set_tracking_uri(uri=config.MLFLOW_TRACKING_URI)

[32m2025-05-03 15:26:51.772[0m | [1mINFO    [0m | [36mgraph_reinforcement_learning_using_blockchain_data.config[0m:[36m<module>[0m:[36m12[0m - [1mPROJ_ROOT path is: /Users/liamtessendorf/Programming/Uni/2_Master/4_FS25_Programming/graph-reinforcement-learning-using-blockchain-data[0m


In [2]:
RNG = np.random.default_rng(seed=42)

## Dataset 

In [3]:
df_emb = pd.read_csv(config.FLASHBOTS_Q2_DATA_DIR / "state_embeddings_dgi_128.csv")
df_class0 = pd.read_csv(config.RAW_DATA_DIR / "receipts_class0.csv")
df_class1 = pd.read_csv(config.RAW_DATA_DIR / "receipts_class1.csv")
df_eth_balances_class1 = pd.read_csv(config.RAW_DATA_DIR / "eth_balances_class1.csv")
df_eth_balances_class0 = pd.read_csv(config.RAW_DATA_DIR / "eth_balances_class0.csv")

In [4]:
print(df_class0.columns)
print(df_eth_balances_class0.columns)

Index(['block_number', 'transaction_hash', 'blockHash', 'blockNumber',
       'logsBloom', 'gasUsed', 'contractAddress', 'cumulativeGasUsed',
       'transactionIndex', 'from', 'to', 'type', 'effectiveGasPrice', 'logs',
       'status'],
      dtype='object')
Index(['account', 'block_number', 'balance'], dtype='object')


In [5]:
df_class0_with_eth_balances = df_class0.merge(
    df_eth_balances_class0,
    left_on=["from", "blockNumber"],
    right_on=["account", "block_number"],
    how="inner",
)
df_class1_with_eth_balances = df_class1.merge(
    df_eth_balances_class1,
    left_on=["from", "blockNumber"],
    right_on=["account", "block_number"],
    how="inner",
)

In [6]:
df_class0_multi_occ = df_class0_with_eth_balances[
    df_class0_with_eth_balances["from"].duplicated(keep=False)
]

In [7]:
df_emb["embeddings"] = df_emb["embeddings"].apply(
    lambda x: np.array(ast.literal_eval(x), dtype=np.float32)
)

In [8]:
df_class0_with_eth_balances["label"] = 0
df_class1_with_eth_balances["label"] = 1

In [9]:
df_receipts = pd.concat(
    [df_class0_with_eth_balances, df_class1_with_eth_balances], ignore_index=True
)
df_receipts.drop_duplicates("transaction_hash", inplace=True)
df = df_receipts.merge(df_emb, how="right", left_on="transaction_hash", right_on="transactionHash")

In [10]:
df.head()

Unnamed: 0,block_number_x,transaction_hash,blockHash,blockNumber,logsBloom,gasUsed,contractAddress,cumulativeGasUsed,transactionIndex,from,...,created_at,account_address,profit_token_address,start_amount,end_amount,profit_amount,error,protocols,transactionHash,embeddings
0,16969850,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,0x241b2ebd536a6f546ce2214bcf2146d359ae4e0cc4e3...,16969850,0x00000000000000000000000000000000000000000000...,21000,,1315400,19,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,,,,,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
1,16975162,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,0x577d4ef683ebd84c73cdbc3635c7177f6ec137eef3bf...,16975162,0x00000000000000000000000080000000000000000000...,116880,,396290,3,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,,,,,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
2,16983327,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,0xe79daae4074cb858a3e79b4d95064845295eaddc134c...,16983327,0x00000000000000000000000000000000000000000000...,103421,,3934866,21,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,,,,,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
3,16990575,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,0x9c3fa88c63603c9c5696174085177f9559e20b331a64...,16990575,0x00200000000000000000000080004000000000000000...,105953,,4683545,39,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,,,,,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0071866573, 0.0, 0..."
4,16994491,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,0xae6c66f64f13ba5404bbe4cebd33fdd666c7bc8fe601...,16994491,0x00200000000011000000000080000000000000000000...,223722,,3972377,34,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,,,,,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.007031055, 0.0, 0...."


In [11]:
df_median_gas_prices = pd.DataFrame(
    {
        "median_gas_price": df.groupby(["blockNumber"])["effectiveGasPrice"].median(),
        "std_gas_price": df.groupby(["blockNumber"])["effectiveGasPrice"].std(),
        "max_gas_price": df.groupby(["blockNumber"])["effectiveGasPrice"].max(),
        "min_gas_price": df.groupby(["blockNumber"])["effectiveGasPrice"].min(),
    }
)

df_with_median_gas_prices = df.merge(df_median_gas_prices, how="left", on="blockNumber")
df_with_median_gas_prices.head()

Unnamed: 0,block_number_x,transaction_hash,blockHash,blockNumber,logsBloom,gasUsed,contractAddress,cumulativeGasUsed,transactionIndex,from,...,end_amount,profit_amount,error,protocols,transactionHash,embeddings,median_gas_price,std_gas_price,max_gas_price,min_gas_price
0,16969850,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,0x241b2ebd536a6f546ce2214bcf2146d359ae4e0cc4e3...,16969850,0x00000000000000000000000000000000000000000000...,21000,,1315400,19,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",28412780000.0,5338623000.0,40046142239,27985774295
1,16975162,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,0x577d4ef683ebd84c73cdbc3635c7177f6ec137eef3bf...,16975162,0x00000000000000000000000080000000000000000000...,116880,,396290,3,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",31324250000.0,13106710000.0,62912040686,29651658352
2,16983327,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,0xe79daae4074cb858a3e79b4d95064845295eaddc134c...,16983327,0x00000000000000000000000000000000000000000000...,103421,,3934866,21,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",47015990000.0,94835380000.0,298140379626,44115991364
3,16990575,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,0x9c3fa88c63603c9c5696174085177f9559e20b331a64...,16990575,0x00200000000000000000000080004000000000000000...,105953,,4683545,39,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0071866573, 0.0, 0...",39976070000.0,22168940000.0,72384430845,29976074199
4,16994491,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,0xae6c66f64f13ba5404bbe4cebd33fdd666c7bc8fe601...,16994491,0x00200000000011000000000080000000000000000000...,223722,,3972377,34,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.007031055, 0.0, 0....",26952080000.0,3320303000.0,30915355327,22915355327


In [12]:
df_with_actions = df_with_median_gas_prices.copy()
df_with_actions["action"] = df_with_median_gas_prices.apply(
    lambda r: 1 if r["effectiveGasPrice"] > r["median_gas_price"] else 0, axis=1
)

In [13]:
df_with_actions["action"].mean()

0.24449359876667337

In [14]:
df_with_actions.head()

Unnamed: 0,block_number_x,transaction_hash,blockHash,blockNumber,logsBloom,gasUsed,contractAddress,cumulativeGasUsed,transactionIndex,from,...,profit_amount,error,protocols,transactionHash,embeddings,median_gas_price,std_gas_price,max_gas_price,min_gas_price,action
0,16969850,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,0x241b2ebd536a6f546ce2214bcf2146d359ae4e0cc4e3...,16969850,0x00000000000000000000000000000000000000000000...,21000,,1315400,19,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",28412780000.0,5338623000.0,40046142239,27985774295,1
1,16975162,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,0x577d4ef683ebd84c73cdbc3635c7177f6ec137eef3bf...,16975162,0x00000000000000000000000080000000000000000000...,116880,,396290,3,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",31324250000.0,13106710000.0,62912040686,29651658352,1
2,16983327,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,0xe79daae4074cb858a3e79b4d95064845295eaddc134c...,16983327,0x00000000000000000000000000000000000000000000...,103421,,3934866,21,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",47015990000.0,94835380000.0,298140379626,44115991364,1
3,16990575,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,0x9c3fa88c63603c9c5696174085177f9559e20b331a64...,16990575,0x00200000000000000000000080004000000000000000...,105953,,4683545,39,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0071866573, 0.0, 0...",39976070000.0,22168940000.0,72384430845,29976074199,0
4,16994491,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,0xae6c66f64f13ba5404bbe4cebd33fdd666c7bc8fe601...,16994491,0x00200000000011000000000080000000000000000000...,223722,,3972377,34,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.007031055, 0.0, 0....",26952080000.0,3320303000.0,30915355327,22915355327,1


In [15]:
df_with_actions.rename(columns={"balance": "eth_balance"}, inplace=True)
df_with_actions["eth_balance"] = df_with_actions["eth_balance"].astype("float64")
df_with_actions["median_gas_price"] = df_with_actions["median_gas_price"].astype("float64")
df_with_actions["std_gas_price"] = df_with_actions["std_gas_price"].astype("float64")
df_with_actions["from"] = df_with_actions["from"].astype("string")
df_with_actions["to"] = df_with_actions["to"].astype("string")
df_with_actions.head()

Unnamed: 0,block_number_x,transaction_hash,blockHash,blockNumber,logsBloom,gasUsed,contractAddress,cumulativeGasUsed,transactionIndex,from,...,profit_amount,error,protocols,transactionHash,embeddings,median_gas_price,std_gas_price,max_gas_price,min_gas_price,action
0,16969850,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,0x241b2ebd536a6f546ce2214bcf2146d359ae4e0cc4e3...,16969850,0x00000000000000000000000000000000000000000000...,21000,,1315400,19,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",28412780000.0,5338623000.0,40046142239,27985774295,1
1,16975162,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,0x577d4ef683ebd84c73cdbc3635c7177f6ec137eef3bf...,16975162,0x00000000000000000000000080000000000000000000...,116880,,396290,3,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",31324250000.0,13106710000.0,62912040686,29651658352,1
2,16983327,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,0xe79daae4074cb858a3e79b4d95064845295eaddc134c...,16983327,0x00000000000000000000000000000000000000000000...,103421,,3934866,21,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",47015990000.0,94835380000.0,298140379626,44115991364,1
3,16990575,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,0x9c3fa88c63603c9c5696174085177f9559e20b331a64...,16990575,0x00200000000000000000000080004000000000000000...,105953,,4683545,39,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0071866573, 0.0, 0...",39976070000.0,22168940000.0,72384430845,29976074199,0
4,16994491,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,0xae6c66f64f13ba5404bbe4cebd33fdd666c7bc8fe601...,16994491,0x00200000000011000000000080000000000000000000...,223722,,3972377,34,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.007031055, 0.0, 0....",26952080000.0,3320303000.0,30915355327,22915355327,1


In [16]:
df_with_actions[df_with_actions["std_gas_price"].isna() == True]

Unnamed: 0,block_number_x,transaction_hash,blockHash,blockNumber,logsBloom,gasUsed,contractAddress,cumulativeGasUsed,transactionIndex,from,...,profit_amount,error,protocols,transactionHash,embeddings,median_gas_price,std_gas_price,max_gas_price,min_gas_price,action
63,16975107,0xd05a59ef18204af79ae9bf2a7ba722bca892055819c7...,0x8d8cd7dfa64a3867f12d472be895ee2a1b163d854a02...,16975107,0x00200000000000000000000080000000000008000000...,196537,,685900,5,0x00000006e42915A2B6907f8b3fAF311B68862f60,...,6855415086744551,,"[""uniswap_v2""]",0xd05a59ef18204af79ae9bf2a7ba722bca892055819c7...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",3.442311e+10,,34423111170,34423111170,0
65,16976564,0x5530313d0b0271506691e3732c517172d5bfa1b2ba3d...,0x458b66a35808bf44a7e332b9f9b326ca8660952ab40a...,16976564,0x00200000000000000000000080000000200000000000...,190691,,1322470,7,0x00000006e42915A2B6907f8b3fAF311B68862f60,...,7796163749018364,,"[""uniswap_v2""]",0x5530313d0b0271506691e3732c517172d5bfa1b2ba3d...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",4.041179e+10,,40411786608,40411786608,0
76,16979829,0xe47601937f0538ecc2a67c0a1b2481a1d339b52b2ef0...,0x4d1ebc8a72732a87fb083679233a06cf0a33ec984b35...,16979829,0x00200000000000000000000084000000200000000000...,158198,,158198,0,0x00000006e42915A2B6907f8b3fAF311B68862f60,...,24408839782523421,,"[""uniswap_v2"",""uniswap_v3""]",0xe47601937f0538ecc2a67c0a1b2481a1d339b52b2ef0...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1.532201e+11,,153220129205,153220129205,0
98,16992014,0x23e05562df7784836aaf6c8235d2aca5501621e6aab2...,0x8794222099471416ba2721137ee0e0d60a149b2f2029...,16992014,0x00200000000000000000000080000200000000000000...,219788,,219788,0,0x00000006e42915A2B6907f8b3fAF311B68862f60,...,35168219119161781,,"[""uniswap_v2"",""uniswap_v3""]",0x23e05562df7784836aaf6c8235d2aca5501621e6aab2...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1.595249e+11,,159524873969,159524873969,0
167,17240935,0x74e6628155b2f61c067a568235c952e3c1fa4aa22d76...,0xa658109201d2dcd46dcc9f5ad1c71d29d6942c44b1c4...,17240935,0x00000000040000000000000000000000000000000000...,210289,,2954344,23,0x00000006e42915A2B6907f8b3fAF311B68862f60,...,13097045417966494,,"[""uniswap_v3""]",0x74e6628155b2f61c067a568235c952e3c1fa4aa22d76...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",6.190800e+10,,61907998723,61907998723,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
149181,17296632,0x2f06cc7f885dd98fbf765ad57706fe505c5b63478928...,0x64522455567aa9a4209551fd938a1e31b448487ac157...,17296632,0x00200000000000000010000080000000000000000000...,178376,,440031,1,0xffFf14106945bCB267B34711c416AA3085B8865F,...,22902832502491512,,"[""uniswap_v2""]",0x2f06cc7f885dd98fbf765ad57706fe505c5b63478928...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",3.450291e+10,,34502909809,34502909809,0
149182,17297932,0xd477617aaa93aad5aa5f5a2f880f3e934b70a91cd523...,0x493341206c063a7ba6c85137473b17f697ed72701c07...,17297932,0x00200000000000000010000080000000000000000000...,178388,,440065,1,0xffFf14106945bCB267B34711c416AA3085B8865F,...,31467600199400760,,"[""uniswap_v2""]",0xd477617aaa93aad5aa5f5a2f880f3e934b70a91cd523...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",2.928406e+10,,29284056698,29284056698,0
149183,17298360,0xda0b7b1156a57ff85ef5c2e80f47d6fc0c89c6253fca...,0x21a1c5ae76e2cbff779f308288e1cf2273d999f978e1...,17298360,0x00200000000000000010000080000000000000000000...,178352,,439975,1,0xffFf14106945bCB267B34711c416AA3085B8865F,...,38228015952142520,,"[""uniswap_v2""]",0xda0b7b1156a57ff85ef5c2e80f47d6fc0c89c6253fca...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",2.959543e+10,,29595427147,29595427147,0
149184,17298379,0x7755b8553cc9f2479e5cc48ebf18dd6f25fe038a668f...,0xa0171ad96bcd73eb6aa92c8da1e3a8629c1c067ea811...,17298379,0x00200000000000000010000080000000000000000000...,178364,,1162669,7,0xffFf14106945bCB267B34711c416AA3085B8865F,...,28570015952142520,,"[""uniswap_v2""]",0x7755b8553cc9f2479e5cc48ebf18dd6f25fe038a668f...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",2.783189e+10,,27831892530,27831892530,0


In [17]:
df_with_actions = df_with_actions.fillna({"std_gas_price": 0})

In [18]:
df_with_actions_0 = df_with_actions[df_with_actions["label"] == 0]
df_with_actions_1 = df_with_actions[df_with_actions["label"] == 1]

In [19]:
unique_accs_0 = df_with_actions_0["from"].unique()
accs_train_0 = unique_accs_0[: int(0.8 * len(unique_accs_0))]
accs_val_0 = unique_accs_0[int(0.8 * len(unique_accs_0)) : int(0.9 * len(unique_accs_0))]
accs_test_0 = unique_accs_0[int(0.9 * len(unique_accs_0)) :]
df_with_actions_0_train = df_with_actions_0[df_with_actions_0["from"].isin(accs_train_0)]
df_with_actions_0_val = df_with_actions_0[df_with_actions_0["from"].isin(accs_val_0)]
df_with_actions_0_test = df_with_actions_0[df_with_actions_0["from"].isin(accs_test_0)]

In [20]:
unique_accs_1 = df_with_actions_1["from"].unique()
accs_train_1 = unique_accs_1[: int(0.8 * len(unique_accs_1))]
accs_val_1 = unique_accs_1[int(0.8 * len(unique_accs_1)) : int(0.9 * len(unique_accs_1))]
accs_test_1 = unique_accs_1[int(0.9 * len(unique_accs_1)) :]
df_with_actions_1_train = df_with_actions_1[df_with_actions_1["from"].isin(accs_train_1)]
df_with_actions_1_val = df_with_actions_1[df_with_actions_1["from"].isin(accs_val_1)]
df_with_actions_1_test = df_with_actions_1[df_with_actions_1["from"].isin(accs_test_1)]

In [21]:
df_val = pd.concat([df_with_actions_0_val, df_with_actions_1_val])

In [22]:
df_with_actions_1_train[
    df_with_actions_1_train["from"] == "0x1e6c1c4669f612112a7caCa5596BfE6629e669aA"
]

Unnamed: 0,block_number_x,transaction_hash,blockHash,blockNumber,logsBloom,gasUsed,contractAddress,cumulativeGasUsed,transactionIndex,from,...,profit_amount,error,protocols,transactionHash,embeddings,median_gas_price,std_gas_price,max_gas_price,min_gas_price,action


## Creating trajectories

In [23]:
def extract_trajectories(df: pd.DataFrame):
    trajectories = []
    for account, group in df.groupby("from"):
        group = group.sort_values("blockNumber")
        obs_list = group["embeddings"].tolist() + [np.zeros(128, dtype=np.float32)]
        traj = {
            "obs": np.stack(obs_list),  # Convert list of arrays to a single numpy array
            "acts": np.array(group["action"].tolist()),
            "label": group["label"].iloc[0],
        }
        trajectories.append(traj)
    return trajectories


trajectories_1_train = extract_trajectories(df_with_actions_1_train)
trajectories_0_train = extract_trajectories(df_with_actions_0_train)
trajectories_1_test = extract_trajectories(df_with_actions_1_test)
trajectories_0_test = extract_trajectories(df_with_actions_0_test)
trajectories_val = extract_trajectories(df_val)

In [24]:
trajectories_1 = [
    Trajectory(obs=traj["obs"], acts=traj["acts"], infos=None, terminal=True)
    for traj in trajectories_1_train
]
trajectories_0 = [
    Trajectory(obs=traj["obs"], acts=traj["acts"], infos=None, terminal=True)
    for traj in trajectories_0_train
]
trajectories_val = [
    Trajectory(obs=traj["obs"], acts=traj["acts"], infos=None, terminal=True)
    for traj in trajectories_val
]

trajectories_1 = flatten_trajectories(trajectories_1)
trajectories_0 = flatten_trajectories(trajectories_0)
trajectories_val = flatten_trajectories(trajectories_val)

In [25]:
trajectories_1_test = [
    Trajectory(obs=traj["obs"], acts=traj["acts"], infos=None, terminal=True)
    for traj in trajectories_1_test
]
trajectories_0_test = [
    Trajectory(obs=traj["obs"], acts=traj["acts"], infos=None, terminal=True)
    for traj in trajectories_0_test
]

trajectories_1_test = flatten_trajectories(trajectories_1_test)
trajectories_0_test = flatten_trajectories(trajectories_0_test)

## Setting up environments

In [26]:
ID0 = "gymnasium_env/TransactionGraphEnv0-v2"
gym.envs.register(
    id=ID0,
    entry_point=grl.TransactionGraphEnvV2,
    kwargs={"df": df_with_actions_0, "alpha": 0.9, "device": torch.device("mps"), "label": 0},
    max_episode_steps=300,
)

ID1 = "gymnasium_env/TransactionGraphEnv1-v2"
gym.envs.register(
    id=ID1,
    entry_point=grl.TransactionGraphEnvV2,
    kwargs={"df": df_with_actions_1, "alpha": 0.9, "device": torch.device("mps"), "label": 1},
    max_episode_steps=300,
)

In [27]:
gym.pprint_registry()

===== classic_control =====
Acrobot-v1             CartPole-v0            CartPole-v1
MountainCar-v0         MountainCarContinuous-v0 Pendulum-v1
===== phys2d =====
phys2d/CartPole-v0     phys2d/CartPole-v1     phys2d/Pendulum-v0
===== box2d =====
BipedalWalker-v3       BipedalWalkerHardcore-v3 CarRacing-v2
LunarLander-v2         LunarLanderContinuous-v2
===== toy_text =====
Blackjack-v1           CliffWalking-v0        FrozenLake-v1
FrozenLake8x8-v1       Taxi-v3
===== tabular =====
tabular/Blackjack-v0   tabular/CliffWalking-v0
===== mujoco =====
Ant-v2                 Ant-v3                 Ant-v4
HalfCheetah-v2         HalfCheetah-v3         HalfCheetah-v4
Hopper-v2              Hopper-v3              Hopper-v4
Humanoid-v2            Humanoid-v3            Humanoid-v4
HumanoidStandup-v2     HumanoidStandup-v4     InvertedDoublePendulum-v2
InvertedDoublePendulum-v4 InvertedPendulum-v2    InvertedPendulum-v4
Pusher-v2              Pusher-v4              Reacher-v2
Reacher-v4         

In [28]:
env0 = Monitor(gym.make(ID0))

venv0 = make_vec_env(
    ID0,
    rng=RNG,
    n_envs=1,
    post_wrappers=[lambda env0, _: RolloutInfoWrapper(env0)],
    parallel=False,
)

venv0 = VecCheckNan(venv0, raise_exception=True)  # Check for NaN observations
venv0.reset()

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

array([[0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.25352484,
        0.        , 0.        , 0.        , 0.24583885, 0.        ,
        0.        , 0.05596162, 0.        , 0.73574173, 0.        ,
        0.34176326, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.6663085 , 0.        , 0.        ,
        0.        , 0.28583455, 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.4065075 , 0.        , 0.        , 1.2506028 ,
        0.        , 0.25370663, 0.08524673, 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.5160529 , 0.        , 0.        , 0.        , 0.        ,
        0.8520548 , 0.        , 0.23360552, 0.        , 0.14838207,
        0.        , 0.06213059, 0.        , 0.        , 0.        ,
        0.6388226 , 0.26219702, 0.        , 0.  

In [29]:
env1 = Monitor(gym.make(ID1))

venv1 = make_vec_env(
    ID1,
    rng=RNG,
    n_envs=1,
    post_wrappers=[lambda env1, _: RolloutInfoWrapper(env1)],
    parallel=False,
)

venv1 = VecCheckNan(venv1, raise_exception=True)  # Check for NaN observations
venv1.reset()

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

array([[0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.02219819,
        0.312723  , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.16475148, 0.        , 0.4680296 , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.05625061, 0.        , 0.00204224, 0.        , 0.09073605,
        0.        , 0.07746468, 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.24095674, 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.89360684,
        0.        , 1.486804  , 0.22280207, 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.01592473, 0.7889279 ,
        0.56510556, 0.        , 0.        , 0.09371138, 0.        ,
        1.6673076 , 0.        , 0.604187  , 0.        , 0.13083528,
        0.        , 0.        , 0.        , 0.11422996, 0.        ,
        0.        , 0.        , 0.        , 0.  

## AIRL setup

In [30]:
# Set parameters for the PPO algorithm (generator)
learning_rate = 0.001  # Learning rate, can be a function of progress
batch_size = 60  # Mini batch size for each gradient update
n_epochs = 15  # N of epochs when optimizing the surrogate loss

gamma = 0.5  # Discount factor, focus on the recent rewards
gae_lambda = 0  # Generalized advantage estimation
clip_range = 0.1  # Clipping parameter
ent_coef = 0.01  # Entropy coefficient for the loss calculation
vf_coef = 0.5  # Value function coef. for the loss calculation
max_grad_norm = 0.5  # The maximum value for the gradient clipping

verbose = 0  # Verbosity level: 0 no output, 1 info, 2 debug
normalize_advantage = True  # Whether to normalize or not the advantage

clip_range_vf = None  # Clip for the value function
use_sde = False  # Use State Dependent Exploration
sde_sample_freq = -1  # SDE - noise matrix frequency (-1 = disable)

# Set parameters for the AIRL trainer
gen_replay_buffer_capacity = None
allow_variable_horizon = True

disc_opt_kwargs = {
    "lr": 0.001,
}
policy_kwargs = {"use_expln": True}  # Fixing an issue with NaNs

In [31]:
# Set the number of timesteps, batch size and number of disc updates

# Total number of timesteps in the whole training
total_timesteps = 3000 * 25

# Generator
gen_train_timesteps = 3000  # N steps in the environment per one round
n_steps = gen_train_timesteps

# Discriminator batches
demo_minibatch_size = 60  # N samples in minibatch for one discrim. update
demo_batch_size = 300 * 10  # N samples in the batch of expert data (batch)
n_disc_updates_per_round = 4  # N discriminator updates per one round

In [32]:
hier_logger = logger.configure()
hier_logger.default_logger.output_formats.append(grl.MLflowOutputFormat())

In [33]:
# Initialize the learner PPO policy (generator)
learner0 = PPO(
    env=venv0,
    policy=MlpPolicy,
    policy_kwargs=policy_kwargs,
    learning_rate=learning_rate,
    n_steps=n_steps,
    batch_size=batch_size,
    n_epochs=n_epochs,
    gamma=gamma,
    gae_lambda=gae_lambda,
    clip_range=clip_range,
    clip_range_vf=clip_range_vf,
    normalize_advantage=normalize_advantage,
    ent_coef=ent_coef,
    vf_coef=vf_coef,
    max_grad_norm=max_grad_norm,
    use_sde=use_sde,
    sde_sample_freq=sde_sample_freq,
    verbose=verbose,
    seed=42,
    device="mps",
)

reward_net0 = BasicShapedRewardNet(
    observation_space=venv0.observation_space,
    action_space=venv0.action_space,
    normalize_input_layer=RunningNorm,
)

# Initialize the AIRL trainer
airl_trainer0 = AIRL(
    demonstrations=trajectories_0,
    demo_batch_size=demo_batch_size,
    demo_minibatch_size=demo_minibatch_size,
    n_disc_updates_per_round=n_disc_updates_per_round,
    gen_train_timesteps=gen_train_timesteps,
    gen_replay_buffer_capacity=gen_replay_buffer_capacity,
    venv=venv0,
    gen_algo=learner0,
    reward_net=reward_net0,
    allow_variable_horizon=allow_variable_horizon,
    disc_opt_kwargs=disc_opt_kwargs,
    custom_logger=hier_logger,
)

Running with `allow_variable_horizon` set to True. Some algorithms are biased towards shorter or longer episodes, which may significantly confound results. Additionally, even unbiased algorithms can exploit the information leak from the termination condition, producing spuriously high performance. See https://imitation.readthedocs.io/en/latest/getting-started/variable-horizon.html for more information.


In [34]:
# Initialize the learner PPO policy (generator)
learner1 = PPO(
    env=venv1,
    policy=MlpPolicy,
    policy_kwargs=policy_kwargs,
    learning_rate=learning_rate,
    n_steps=n_steps,
    batch_size=batch_size,
    n_epochs=n_epochs,
    gamma=gamma,
    gae_lambda=gae_lambda,
    clip_range=clip_range,
    clip_range_vf=clip_range_vf,
    normalize_advantage=normalize_advantage,
    ent_coef=ent_coef,
    vf_coef=vf_coef,
    max_grad_norm=max_grad_norm,
    use_sde=use_sde,
    sde_sample_freq=sde_sample_freq,
    verbose=verbose,
    seed=42,
    device="mps",
)

reward_net1 = BasicShapedRewardNet(
    observation_space=venv1.observation_space,
    action_space=venv1.action_space,
    normalize_input_layer=RunningNorm,
)

# Initialize the AIRL trainer
airl_trainer1 = AIRL(
    demonstrations=trajectories_1,
    demo_batch_size=demo_batch_size,
    demo_minibatch_size=demo_minibatch_size,
    n_disc_updates_per_round=n_disc_updates_per_round,
    gen_train_timesteps=gen_train_timesteps,
    gen_replay_buffer_capacity=gen_replay_buffer_capacity,
    venv=venv1,
    gen_algo=learner1,
    reward_net=reward_net1,
    allow_variable_horizon=allow_variable_horizon,
    disc_opt_kwargs=disc_opt_kwargs,
    custom_logger=hier_logger,
)

Running with `allow_variable_horizon` set to True. Some algorithms are biased towards shorter or longer episodes, which may significantly confound results. Additionally, even unbiased algorithms can exploit the information leak from the termination condition, producing spuriously high performance. See https://imitation.readthedocs.io/en/latest/getting-started/variable-horizon.html for more information.


## Training AIRL discriminator and generator, stats are saved with mlflow

We need to train 2 distinct airl trainers, one for arbitrage transactions and the other for class 0 transactions. The goal is to use the resulting reward functions to do classification.

In [35]:
mlflow.set_experiment("AIRLv2")

with mlflow.start_run():
    mlflow.log_param("n_steps", n_steps)
    mlflow.log_param("batch_size", batch_size)
    mlflow.log_param("total_timesteps", total_timesteps)

    airl_trainer1.train(total_timesteps=total_timesteps)

    learner1.save(config.MODELS_DIR / "learner1v2")
    torch.save(reward_net1, config.MODELS_DIR / "reward_net1v2")

    mlflow.log_artifact(config.MODELS_DIR / "learner1v2.zip")
    mlflow.log_artifact(config.MODELS_DIR / "reward_net1v2")
    mlflow.end_run()

round:   0%|          | 0/25 [00:00<?, ?it/s]

------------------------------------------
| raw/                        |          |
|    gen/rollout/ep_len_mean  | 22.7     |
|    gen/rollout/ep_rew_mean  | 10.8     |
|    gen/time/fps             | 27       |
|    gen/time/iterations      | 1        |
|    gen/time/time_elapsed    | 109      |
|    gen/time/total_timesteps | 3000     |
------------------------------------------
--------------------------------------------------
| raw/                                |          |
|    disc/disc_acc                    | 0.5      |
|    disc/disc_acc_expert             | 1        |
|    disc/disc_acc_gen                | 0        |
|    disc/disc_entropy                | 0.59     |
|    disc/disc_loss                   | 0.0153   |
|    disc/disc_proportion_expert_pred | 1        |
|    disc/disc_proportion_expert_true | 0.5      |
|    disc/global_step                 | 1        |
|    disc/n_expert                    | 60       |
|    disc/n_generated                 | 60       |
-

round:   4%|▍         | 1/25 [01:57<46:49, 117.08s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 24.9         |
|    gen/rollout/ep_rew_mean         | 11.6         |
|    gen/rollout/ep_rew_wrapped_mean | 19           |
|    gen/time/fps                    | 31           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 95           |
|    gen/time/total_timesteps        | 6000         |
|    gen/train/approx_kl             | 0.0011770331 |
|    gen/train/clip_fraction         | 0.086        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.691       |
|    gen/train/explained_variance    | -0.00454     |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.0169       |
|    gen/train/n_updates             | 15           |
|    gen/train/policy_gradient_loss  | -0.00134     |
|    gen/train/value_loss   

round:   8%|▊         | 2/25 [03:39<41:40, 108.71s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 26.7         |
|    gen/rollout/ep_rew_mean         | 13           |
|    gen/rollout/ep_rew_wrapped_mean | -5.7         |
|    gen/time/fps                    | 32           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 92           |
|    gen/time/total_timesteps        | 9000         |
|    gen/train/approx_kl             | 0.0015401584 |
|    gen/train/clip_fraction         | 0.0884       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.684       |
|    gen/train/explained_variance    | -0.883       |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.0256       |
|    gen/train/n_updates             | 30           |
|    gen/train/policy_gradient_loss  | -0.000356    |
|    gen/train/value_loss   

round:  12%|█▏        | 3/25 [05:19<38:17, 104.43s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 25.2         |
|    gen/rollout/ep_rew_mean         | 12.4         |
|    gen/rollout/ep_rew_wrapped_mean | -14.1        |
|    gen/time/fps                    | 34           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 87           |
|    gen/time/total_timesteps        | 12000        |
|    gen/train/approx_kl             | 0.0041269707 |
|    gen/train/clip_fraction         | 0.186        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.689       |
|    gen/train/explained_variance    | 0.509        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.0218       |
|    gen/train/n_updates             | 45           |
|    gen/train/policy_gradient_loss  | -0.00348     |
|    gen/train/value_loss   

round:  16%|█▌        | 4/25 [06:54<35:13, 100.66s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 29.5         |
|    gen/rollout/ep_rew_mean         | 14.1         |
|    gen/rollout/ep_rew_wrapped_mean | -23.5        |
|    gen/time/fps                    | 34           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 86           |
|    gen/time/total_timesteps        | 15000        |
|    gen/train/approx_kl             | 0.0036146024 |
|    gen/train/clip_fraction         | 0.228        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.692       |
|    gen/train/explained_variance    | 0.715        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.041        |
|    gen/train/n_updates             | 60           |
|    gen/train/policy_gradient_loss  | -0.00646     |
|    gen/train/value_loss   

round:  20%|██        | 5/25 [08:28<32:45, 98.30s/it] 

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 29.8         |
|    gen/rollout/ep_rew_mean         | 14.5         |
|    gen/rollout/ep_rew_wrapped_mean | -34.6        |
|    gen/time/fps                    | 31           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 95           |
|    gen/time/total_timesteps        | 18000        |
|    gen/train/approx_kl             | 0.0031699846 |
|    gen/train/clip_fraction         | 0.186        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.687       |
|    gen/train/explained_variance    | 0.867        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.191        |
|    gen/train/n_updates             | 75           |
|    gen/train/policy_gradient_loss  | -0.00541     |
|    gen/train/value_loss   

round:  24%|██▍       | 6/25 [10:11<31:38, 99.93s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 28.9         |
|    gen/rollout/ep_rew_mean         | 13.7         |
|    gen/rollout/ep_rew_wrapped_mean | -47.4        |
|    gen/time/fps                    | 31           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 94           |
|    gen/time/total_timesteps        | 21000        |
|    gen/train/approx_kl             | 0.0039541167 |
|    gen/train/clip_fraction         | 0.137        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.687       |
|    gen/train/explained_variance    | 0.869        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.237        |
|    gen/train/n_updates             | 90           |
|    gen/train/policy_gradient_loss  | -0.00466     |
|    gen/train/value_loss   

round:  28%|██▊       | 7/25 [11:52<30:07, 100.41s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 35.1         |
|    gen/rollout/ep_rew_mean         | 17           |
|    gen/rollout/ep_rew_wrapped_mean | -85.6        |
|    gen/time/fps                    | 37           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 80           |
|    gen/time/total_timesteps        | 24000        |
|    gen/train/approx_kl             | 0.0036134466 |
|    gen/train/clip_fraction         | 0.0906       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.681       |
|    gen/train/explained_variance    | 0.757        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 1.33         |
|    gen/train/n_updates             | 105          |
|    gen/train/policy_gradient_loss  | -0.00319     |
|    gen/train/value_loss   

round:  32%|███▏      | 8/25 [13:20<27:17, 96.35s/it] 

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 15.2         |
|    gen/rollout/ep_rew_mean         | 7.56         |
|    gen/rollout/ep_rew_wrapped_mean | -97.8        |
|    gen/time/fps                    | 34           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 86           |
|    gen/time/total_timesteps        | 27000        |
|    gen/train/approx_kl             | 0.0033162183 |
|    gen/train/clip_fraction         | 0.108        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.662       |
|    gen/train/explained_variance    | 0.758        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 2.51         |
|    gen/train/n_updates             | 120          |
|    gen/train/policy_gradient_loss  | -0.00349     |
|    gen/train/value_loss   

round:  36%|███▌      | 9/25 [14:53<25:27, 95.44s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 29.8        |
|    gen/rollout/ep_rew_mean         | 15.5        |
|    gen/rollout/ep_rew_wrapped_mean | -44.2       |
|    gen/time/fps                    | 34          |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 87          |
|    gen/time/total_timesteps        | 30000       |
|    gen/train/approx_kl             | 0.003690639 |
|    gen/train/clip_fraction         | 0.203       |
|    gen/train/clip_range            | 0.1         |
|    gen/train/entropy_loss          | -0.65       |
|    gen/train/explained_variance    | 0.802       |
|    gen/train/learning_rate         | 0.001       |
|    gen/train/loss                  | 1.78        |
|    gen/train/n_updates             | 135         |
|    gen/train/policy_gradient_loss  | -0.00398    |
|    gen/train/value_loss            | 6.72   

round:  40%|████      | 10/25 [16:28<23:48, 95.23s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 18           |
|    gen/rollout/ep_rew_mean         | 9.97         |
|    gen/rollout/ep_rew_wrapped_mean | -87.8        |
|    gen/time/fps                    | 39           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 76           |
|    gen/time/total_timesteps        | 33000        |
|    gen/train/approx_kl             | 0.0024977394 |
|    gen/train/clip_fraction         | 0.0886       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.628       |
|    gen/train/explained_variance    | 0.726        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 2.09         |
|    gen/train/n_updates             | 150          |
|    gen/train/policy_gradient_loss  | -0.00345     |
|    gen/train/value_loss   

round:  44%|████▍     | 11/25 [17:52<21:22, 91.61s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 33           |
|    gen/rollout/ep_rew_mean         | 19           |
|    gen/rollout/ep_rew_wrapped_mean | -54.1        |
|    gen/time/fps                    | 34           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 86           |
|    gen/time/total_timesteps        | 36000        |
|    gen/train/approx_kl             | 0.0020805723 |
|    gen/train/clip_fraction         | 0.116        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.603       |
|    gen/train/explained_variance    | 0.825        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 1.29         |
|    gen/train/n_updates             | 165          |
|    gen/train/policy_gradient_loss  | -0.00321     |
|    gen/train/value_loss   

round:  48%|████▊     | 12/25 [19:25<19:58, 92.17s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 22.7         |
|    gen/rollout/ep_rew_mean         | 13.2         |
|    gen/rollout/ep_rew_wrapped_mean | -148         |
|    gen/time/fps                    | 34           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 86           |
|    gen/time/total_timesteps        | 39000        |
|    gen/train/approx_kl             | 0.0015283431 |
|    gen/train/clip_fraction         | 0.0782       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.583       |
|    gen/train/explained_variance    | 0.772        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 4.34         |
|    gen/train/n_updates             | 180          |
|    gen/train/policy_gradient_loss  | -0.00167     |
|    gen/train/value_loss   

round:  52%|█████▏    | 13/25 [20:58<18:30, 92.52s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 31.6         |
|    gen/rollout/ep_rew_mean         | 19.3         |
|    gen/rollout/ep_rew_wrapped_mean | -87.5        |
|    gen/time/fps                    | 35           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 83           |
|    gen/time/total_timesteps        | 42000        |
|    gen/train/approx_kl             | 0.0037573203 |
|    gen/train/clip_fraction         | 0.0912       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.563       |
|    gen/train/explained_variance    | 0.848        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 3.6          |
|    gen/train/n_updates             | 195          |
|    gen/train/policy_gradient_loss  | -0.00204     |
|    gen/train/value_loss   

round:  56%|█████▌    | 14/25 [22:29<16:52, 92.02s/it]

------------------------------------------------------
| raw/                               |               |
|    gen/rollout/ep_len_mean         | 37.4          |
|    gen/rollout/ep_rew_mean         | 23.4          |
|    gen/rollout/ep_rew_wrapped_mean | -140          |
|    gen/time/fps                    | 37            |
|    gen/time/iterations             | 1             |
|    gen/time/time_elapsed           | 79            |
|    gen/time/total_timesteps        | 45000         |
|    gen/train/approx_kl             | 0.00079923763 |
|    gen/train/clip_fraction         | 0.0672        |
|    gen/train/clip_range            | 0.1           |
|    gen/train/entropy_loss          | -0.548        |
|    gen/train/explained_variance    | 0.862         |
|    gen/train/learning_rate         | 0.001         |
|    gen/train/loss                  | 2.78          |
|    gen/train/n_updates             | 210           |
|    gen/train/policy_gradient_loss  | -0.000797     |
|    gen/t

round:  60%|██████    | 15/25 [23:56<15:04, 90.43s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 31.3         |
|    gen/rollout/ep_rew_mean         | 19.9         |
|    gen/rollout/ep_rew_wrapped_mean | -182         |
|    gen/time/fps                    | 36           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 81           |
|    gen/time/total_timesteps        | 48000        |
|    gen/train/approx_kl             | 0.0017615955 |
|    gen/train/clip_fraction         | 0.0932       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.531       |
|    gen/train/explained_variance    | 0.798        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 1.61         |
|    gen/train/n_updates             | 225          |
|    gen/train/policy_gradient_loss  | -0.000459    |
|    gen/train/value_loss   

round:  64%|██████▍   | 16/25 [25:25<13:29, 89.96s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 23.8         |
|    gen/rollout/ep_rew_mean         | 14.7         |
|    gen/rollout/ep_rew_wrapped_mean | -163         |
|    gen/time/fps                    | 35           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 83           |
|    gen/time/total_timesteps        | 51000        |
|    gen/train/approx_kl             | 0.0011363286 |
|    gen/train/clip_fraction         | 0.0424       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.527       |
|    gen/train/explained_variance    | 0.887        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 5.59         |
|    gen/train/n_updates             | 240          |
|    gen/train/policy_gradient_loss  | -0.00101     |
|    gen/train/value_loss   

round:  68%|██████▊   | 17/25 [26:55<12:01, 90.14s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 28.5         |
|    gen/rollout/ep_rew_mean         | 17.6         |
|    gen/rollout/ep_rew_wrapped_mean | -98.5        |
|    gen/time/fps                    | 36           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 81           |
|    gen/time/total_timesteps        | 54000        |
|    gen/train/approx_kl             | 0.0021148617 |
|    gen/train/clip_fraction         | 0.104        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.517       |
|    gen/train/explained_variance    | 0.842        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 3.22         |
|    gen/train/n_updates             | 255          |
|    gen/train/policy_gradient_loss  | -0.00334     |
|    gen/train/value_loss   

round:  72%|███████▏  | 18/25 [28:24<10:28, 89.75s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 37.1         |
|    gen/rollout/ep_rew_mean         | 22.9         |
|    gen/rollout/ep_rew_wrapped_mean | -188         |
|    gen/time/fps                    | 32           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 91           |
|    gen/time/total_timesteps        | 57000        |
|    gen/train/approx_kl             | 0.0008773387 |
|    gen/train/clip_fraction         | 0.0839       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.509       |
|    gen/train/explained_variance    | 0.815        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 5.78         |
|    gen/train/n_updates             | 270          |
|    gen/train/policy_gradient_loss  | -0.00217     |
|    gen/train/value_loss   

round:  76%|███████▌  | 19/25 [30:02<09:13, 92.24s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 26.3         |
|    gen/rollout/ep_rew_mean         | 17           |
|    gen/rollout/ep_rew_wrapped_mean | -213         |
|    gen/time/fps                    | 36           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 81           |
|    gen/time/total_timesteps        | 60000        |
|    gen/train/approx_kl             | 0.0013346612 |
|    gen/train/clip_fraction         | 0.0413       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.521       |
|    gen/train/explained_variance    | 0.803        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 6.09         |
|    gen/train/n_updates             | 285          |
|    gen/train/policy_gradient_loss  | -0.000332    |
|    gen/train/value_loss   

round:  80%|████████  | 20/25 [31:31<07:35, 91.15s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 28          |
|    gen/rollout/ep_rew_mean         | 18          |
|    gen/rollout/ep_rew_wrapped_mean | -138        |
|    gen/time/fps                    | 36          |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 82          |
|    gen/time/total_timesteps        | 63000       |
|    gen/train/approx_kl             | 0.003080748 |
|    gen/train/clip_fraction         | 0.0947      |
|    gen/train/clip_range            | 0.1         |
|    gen/train/entropy_loss          | -0.514      |
|    gen/train/explained_variance    | 0.824       |
|    gen/train/learning_rate         | 0.001       |
|    gen/train/loss                  | 4.03        |
|    gen/train/n_updates             | 300         |
|    gen/train/policy_gradient_loss  | -0.00108    |
|    gen/train/value_loss            | 12.4   

round:  84%|████████▍ | 21/25 [33:01<06:03, 90.76s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 29.9         |
|    gen/rollout/ep_rew_mean         | 20           |
|    gen/rollout/ep_rew_wrapped_mean | -151         |
|    gen/time/fps                    | 33           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 90           |
|    gen/time/total_timesteps        | 66000        |
|    gen/train/approx_kl             | 0.0024789649 |
|    gen/train/clip_fraction         | 0.0993       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.485       |
|    gen/train/explained_variance    | 0.815        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 6.03         |
|    gen/train/n_updates             | 315          |
|    gen/train/policy_gradient_loss  | -0.00123     |
|    gen/train/value_loss   

round:  88%|████████▊ | 22/25 [34:38<04:38, 92.69s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 32.1        |
|    gen/rollout/ep_rew_mean         | 22.1        |
|    gen/rollout/ep_rew_wrapped_mean | -185        |
|    gen/time/fps                    | 34          |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 86          |
|    gen/time/total_timesteps        | 69000       |
|    gen/train/approx_kl             | 0.006878122 |
|    gen/train/clip_fraction         | 0.0896      |
|    gen/train/clip_range            | 0.1         |
|    gen/train/entropy_loss          | -0.461      |
|    gen/train/explained_variance    | 0.823       |
|    gen/train/learning_rate         | 0.001       |
|    gen/train/loss                  | 6.98        |
|    gen/train/n_updates             | 330         |
|    gen/train/policy_gradient_loss  | 0.00025     |
|    gen/train/value_loss            | 15.6   

round:  92%|█████████▏| 23/25 [36:11<03:05, 92.98s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 38.8         |
|    gen/rollout/ep_rew_mean         | 24.3         |
|    gen/rollout/ep_rew_wrapped_mean | -233         |
|    gen/time/fps                    | 37           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 79           |
|    gen/time/total_timesteps        | 72000        |
|    gen/train/approx_kl             | 0.0020466365 |
|    gen/train/clip_fraction         | 0.0801       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.462       |
|    gen/train/explained_variance    | 0.817        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 9.96         |
|    gen/train/n_updates             | 345          |
|    gen/train/policy_gradient_loss  | -0.000467    |
|    gen/train/value_loss   

round:  96%|█████████▌| 24/25 [37:38<01:31, 91.12s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 20.6         |
|    gen/rollout/ep_rew_mean         | 12           |
|    gen/rollout/ep_rew_wrapped_mean | -226         |
|    gen/time/fps                    | 37           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 80           |
|    gen/time/total_timesteps        | 75000        |
|    gen/train/approx_kl             | 0.0025121227 |
|    gen/train/clip_fraction         | 0.114        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.554       |
|    gen/train/explained_variance    | 0.846        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 6.21         |
|    gen/train/n_updates             | 360          |
|    gen/train/policy_gradient_loss  | -0.000963    |
|    gen/train/value_loss   

round: 100%|██████████| 25/25 [39:06<00:00, 93.85s/it]

🏃 View run unequaled-skunk-7 at: http://127.0.0.1:8080/#/experiments/282678262450638424/runs/7006d8322b42432aa5a49b35983a938a
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/282678262450638424





In [36]:
mlflow.set_experiment("AIRLv2")
with mlflow.start_run():
    mlflow.log_param("n_steps", n_steps)
    mlflow.log_param("batch_size", batch_size)
    mlflow.log_param("total_timesteps", total_timesteps)

    airl_trainer0.train(total_timesteps=total_timesteps)

    learner0.save(config.MODELS_DIR / "learner0v2")
    torch.save(reward_net0, config.MODELS_DIR / "reward_net0v2")

    mlflow.log_artifact(config.MODELS_DIR / "learner0v2.zip")
    mlflow.log_artifact(config.MODELS_DIR / "reward_net0v2")
    mlflow.end_run()

round:   0%|          | 0/25 [00:00<?, ?it/s]

----------------------------------------------------
| raw/                              |              |
|    gen/rollout/ep_len_mean        | 6.75         |
|    gen/rollout/ep_rew_mean        | 2.7          |
|    gen/time/fps                   | 59           |
|    gen/time/iterations            | 1            |
|    gen/time/time_elapsed          | 50           |
|    gen/time/total_timesteps       | 3000         |
|    gen/train/approx_kl            | 0.0018359831 |
|    gen/train/clip_fraction        | 0.12         |
|    gen/train/clip_range           | 0.1          |
|    gen/train/entropy_loss         | -0.551       |
|    gen/train/explained_variance   | 0.81         |
|    gen/train/learning_rate        | 0.001        |
|    gen/train/loss                 | 18.8         |
|    gen/train/n_updates            | 375          |
|    gen/train/policy_gradient_loss | -0.00136     |
|    gen/train/value_loss           | 31.9         |
----------------------------------------------

round:   4%|▍         | 1/25 [00:57<22:56, 57.36s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 5.63         |
|    gen/rollout/ep_rew_mean         | 2.34         |
|    gen/rollout/ep_rew_wrapped_mean | 9.65         |
|    gen/time/fps                    | 35           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 84           |
|    gen/time/total_timesteps        | 6000         |
|    gen/train/approx_kl             | 0.0017490982 |
|    gen/train/clip_fraction         | 0.112        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.691       |
|    gen/train/explained_variance    | -0.0421      |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.241        |
|    gen/train/n_updates             | 15           |
|    gen/train/policy_gradient_loss  | -0.00396     |
|    gen/train/value_loss   

round:   8%|▊         | 2/25 [02:28<29:35, 77.20s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 4.78         |
|    gen/rollout/ep_rew_mean         | 1.77         |
|    gen/rollout/ep_rew_wrapped_mean | -0.879       |
|    gen/time/fps                    | 60           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 49           |
|    gen/time/total_timesteps        | 9000         |
|    gen/train/approx_kl             | 0.0025618884 |
|    gen/train/clip_fraction         | 0.126        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.689       |
|    gen/train/explained_variance    | -3.94        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.27         |
|    gen/train/n_updates             | 30           |
|    gen/train/policy_gradient_loss  | -0.00153     |
|    gen/train/value_loss   

round:  12%|█▏        | 3/25 [03:24<24:48, 67.64s/it]

---------------------------------------------------
| raw/                               |            |
|    gen/rollout/ep_len_mean         | 2.89       |
|    gen/rollout/ep_rew_mean         | 0.92       |
|    gen/rollout/ep_rew_wrapped_mean | -1.68      |
|    gen/time/fps                    | 38         |
|    gen/time/iterations             | 1          |
|    gen/time/time_elapsed           | 77         |
|    gen/time/total_timesteps        | 12000      |
|    gen/train/approx_kl             | 0.00100704 |
|    gen/train/clip_fraction         | 0.0781     |
|    gen/train/clip_range            | 0.1        |
|    gen/train/entropy_loss          | -0.691     |
|    gen/train/explained_variance    | 0.144      |
|    gen/train/learning_rate         | 0.001      |
|    gen/train/loss                  | 0.0594     |
|    gen/train/n_updates             | 45         |
|    gen/train/policy_gradient_loss  | -0.000651  |
|    gen/train/value_loss            | 1.28       |
------------

round:  16%|█▌        | 4/25 [04:49<26:00, 74.30s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 5.76         |
|    gen/rollout/ep_rew_mean         | 2.47         |
|    gen/rollout/ep_rew_wrapped_mean | -0.44        |
|    gen/time/fps                    | 56           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 53           |
|    gen/time/total_timesteps        | 15000        |
|    gen/train/approx_kl             | 0.0009570085 |
|    gen/train/clip_fraction         | 0.025        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.691       |
|    gen/train/explained_variance    | 0.192        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 4.51         |
|    gen/train/n_updates             | 60           |
|    gen/train/policy_gradient_loss  | 7.25e-05     |
|    gen/train/value_loss   

round:  20%|██        | 5/25 [05:49<23:03, 69.20s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 3.19         |
|    gen/rollout/ep_rew_mean         | 1.13         |
|    gen/rollout/ep_rew_wrapped_mean | -5.17        |
|    gen/time/fps                    | 61           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 48           |
|    gen/time/total_timesteps        | 18000        |
|    gen/train/approx_kl             | 0.0024029135 |
|    gen/train/clip_fraction         | 0.12         |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.689       |
|    gen/train/explained_variance    | 0.621        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.593        |
|    gen/train/n_updates             | 75           |
|    gen/train/policy_gradient_loss  | -0.00171     |
|    gen/train/value_loss   

round:  24%|██▍       | 6/25 [06:44<20:25, 64.51s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 5.92         |
|    gen/rollout/ep_rew_mean         | 2.46         |
|    gen/rollout/ep_rew_wrapped_mean | -2.18        |
|    gen/time/fps                    | 63           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 46           |
|    gen/time/total_timesteps        | 21000        |
|    gen/train/approx_kl             | 0.0015087682 |
|    gen/train/clip_fraction         | 0.045        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.687       |
|    gen/train/explained_variance    | 0.393        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 1.1          |
|    gen/train/n_updates             | 90           |
|    gen/train/policy_gradient_loss  | -0.000224    |
|    gen/train/value_loss   

round:  28%|██▊       | 7/25 [07:38<18:17, 60.98s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 4.03        |
|    gen/rollout/ep_rew_mean         | 1.34        |
|    gen/rollout/ep_rew_wrapped_mean | -4.78       |
|    gen/time/fps                    | 57          |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 51          |
|    gen/time/total_timesteps        | 24000       |
|    gen/train/approx_kl             | 0.002944558 |
|    gen/train/clip_fraction         | 0.0263      |
|    gen/train/clip_range            | 0.1         |
|    gen/train/entropy_loss          | -0.691      |
|    gen/train/explained_variance    | 0.286       |
|    gen/train/learning_rate         | 0.001       |
|    gen/train/loss                  | 8.46        |
|    gen/train/n_updates             | 105         |
|    gen/train/policy_gradient_loss  | -0.00101    |
|    gen/train/value_loss            | 13.1   

round:  32%|███▏      | 8/25 [08:37<17:04, 60.28s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 4.28         |
|    gen/rollout/ep_rew_mean         | 1.71         |
|    gen/rollout/ep_rew_wrapped_mean | -5.16        |
|    gen/time/fps                    | 66           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 45           |
|    gen/time/total_timesteps        | 27000        |
|    gen/train/approx_kl             | 0.0023134362 |
|    gen/train/clip_fraction         | 0.0642       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.683       |
|    gen/train/explained_variance    | 0.475        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 1.95         |
|    gen/train/n_updates             | 120          |
|    gen/train/policy_gradient_loss  | -0.00129     |
|    gen/train/value_loss   

round:  36%|███▌      | 9/25 [09:29<15:22, 57.69s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 4.48         |
|    gen/rollout/ep_rew_mean         | 1.71         |
|    gen/rollout/ep_rew_wrapped_mean | -2.06        |
|    gen/time/fps                    | 65           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 45           |
|    gen/time/total_timesteps        | 30000        |
|    gen/train/approx_kl             | 0.0018557204 |
|    gen/train/clip_fraction         | 0.0563       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.674       |
|    gen/train/explained_variance    | 0.416        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 8.88         |
|    gen/train/n_updates             | 135          |
|    gen/train/policy_gradient_loss  | -0.00109     |
|    gen/train/value_loss   

round:  40%|████      | 10/25 [10:21<14:02, 56.15s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 5.69         |
|    gen/rollout/ep_rew_mean         | 2.73         |
|    gen/rollout/ep_rew_wrapped_mean | -11.5        |
|    gen/time/fps                    | 60           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 49           |
|    gen/time/total_timesteps        | 33000        |
|    gen/train/approx_kl             | 0.0016394929 |
|    gen/train/clip_fraction         | 0.0352       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.676       |
|    gen/train/explained_variance    | 0.385        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 10.1         |
|    gen/train/n_updates             | 150          |
|    gen/train/policy_gradient_loss  | -0.000632    |
|    gen/train/value_loss   

round:  44%|████▍     | 11/25 [11:18<13:08, 56.34s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 4.4          |
|    gen/rollout/ep_rew_mean         | 1.91         |
|    gen/rollout/ep_rew_wrapped_mean | -10.6        |
|    gen/time/fps                    | 64           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 46           |
|    gen/time/total_timesteps        | 36000        |
|    gen/train/approx_kl             | 0.0014705589 |
|    gen/train/clip_fraction         | 0.104        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.67        |
|    gen/train/explained_variance    | 0.704        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 6.24         |
|    gen/train/n_updates             | 165          |
|    gen/train/policy_gradient_loss  | -0.00123     |
|    gen/train/value_loss   

round:  48%|████▊     | 12/25 [12:11<11:58, 55.29s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 5.4         |
|    gen/rollout/ep_rew_mean         | 2.44        |
|    gen/rollout/ep_rew_wrapped_mean | -13.9       |
|    gen/time/fps                    | 56          |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 53          |
|    gen/time/total_timesteps        | 39000       |
|    gen/train/approx_kl             | 0.002051407 |
|    gen/train/clip_fraction         | 0.107       |
|    gen/train/clip_range            | 0.1         |
|    gen/train/entropy_loss          | -0.659      |
|    gen/train/explained_variance    | 0.544       |
|    gen/train/learning_rate         | 0.001       |
|    gen/train/loss                  | 4.65        |
|    gen/train/n_updates             | 180         |
|    gen/train/policy_gradient_loss  | -0.00163    |
|    gen/train/value_loss            | 14.4   

round:  52%|█████▏    | 13/25 [13:11<11:20, 56.71s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 3.39         |
|    gen/rollout/ep_rew_mean         | 1.18         |
|    gen/rollout/ep_rew_wrapped_mean | -14.2        |
|    gen/time/fps                    | 61           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 49           |
|    gen/time/total_timesteps        | 42000        |
|    gen/train/approx_kl             | 0.0020605237 |
|    gen/train/clip_fraction         | 0.11         |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.656       |
|    gen/train/explained_variance    | 0.467        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 41.5         |
|    gen/train/n_updates             | 195          |
|    gen/train/policy_gradient_loss  | 0.00132      |
|    gen/train/value_loss   

round:  56%|█████▌    | 14/25 [14:07<10:21, 56.47s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 8.01        |
|    gen/rollout/ep_rew_mean         | 3.92        |
|    gen/rollout/ep_rew_wrapped_mean | -10.1       |
|    gen/time/fps                    | 59          |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 50          |
|    gen/time/total_timesteps        | 45000       |
|    gen/train/approx_kl             | 0.003232118 |
|    gen/train/clip_fraction         | 0.13        |
|    gen/train/clip_range            | 0.1         |
|    gen/train/entropy_loss          | -0.633      |
|    gen/train/explained_variance    | 0.472       |
|    gen/train/learning_rate         | 0.001       |
|    gen/train/loss                  | 12.8        |
|    gen/train/n_updates             | 210         |
|    gen/train/policy_gradient_loss  | -0.00241    |
|    gen/train/value_loss            | 36     

round:  60%|██████    | 15/25 [15:04<09:27, 56.74s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 5.15         |
|    gen/rollout/ep_rew_mean         | 2.14         |
|    gen/rollout/ep_rew_wrapped_mean | -10.2        |
|    gen/time/fps                    | 63           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 46           |
|    gen/time/total_timesteps        | 48000        |
|    gen/train/approx_kl             | 0.0026796258 |
|    gen/train/clip_fraction         | 0.0853       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.636       |
|    gen/train/explained_variance    | 0.58         |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 3.25         |
|    gen/train/n_updates             | 225          |
|    gen/train/policy_gradient_loss  | -0.000609    |
|    gen/train/value_loss   

round:  64%|██████▍   | 16/25 [15:58<08:22, 55.83s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 5.99         |
|    gen/rollout/ep_rew_mean         | 2.82         |
|    gen/rollout/ep_rew_wrapped_mean | -8.34        |
|    gen/time/fps                    | 65           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 45           |
|    gen/time/total_timesteps        | 51000        |
|    gen/train/approx_kl             | 0.0028806692 |
|    gen/train/clip_fraction         | 0.104        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.62        |
|    gen/train/explained_variance    | 0.453        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 10.9         |
|    gen/train/n_updates             | 240          |
|    gen/train/policy_gradient_loss  | -0.00267     |
|    gen/train/value_loss   

round:  68%|██████▊   | 17/25 [16:50<07:18, 54.79s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 8.42        |
|    gen/rollout/ep_rew_mean         | 4.37        |
|    gen/rollout/ep_rew_wrapped_mean | -26.1       |
|    gen/time/fps                    | 66          |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 44          |
|    gen/time/total_timesteps        | 54000       |
|    gen/train/approx_kl             | 0.002071661 |
|    gen/train/clip_fraction         | 0.0961      |
|    gen/train/clip_range            | 0.1         |
|    gen/train/entropy_loss          | -0.607      |
|    gen/train/explained_variance    | 0.48        |
|    gen/train/learning_rate         | 0.001       |
|    gen/train/loss                  | 3.46        |
|    gen/train/n_updates             | 255         |
|    gen/train/policy_gradient_loss  | -0.00169    |
|    gen/train/value_loss            | 28.1   

round:  72%|███████▏  | 18/25 [17:42<06:17, 53.89s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 6.44        |
|    gen/rollout/ep_rew_mean         | 3.46        |
|    gen/rollout/ep_rew_wrapped_mean | -16.1       |
|    gen/time/fps                    | 61          |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 48          |
|    gen/time/total_timesteps        | 57000       |
|    gen/train/approx_kl             | 0.004281704 |
|    gen/train/clip_fraction         | 0.185       |
|    gen/train/clip_range            | 0.1         |
|    gen/train/entropy_loss          | -0.552      |
|    gen/train/explained_variance    | 0.501       |
|    gen/train/learning_rate         | 0.001       |
|    gen/train/loss                  | 5.69        |
|    gen/train/n_updates             | 270         |
|    gen/train/policy_gradient_loss  | -0.00292    |
|    gen/train/value_loss            | 31.7   

round:  76%|███████▌  | 19/25 [18:38<05:25, 54.30s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 5.94        |
|    gen/rollout/ep_rew_mean         | 3.35        |
|    gen/rollout/ep_rew_wrapped_mean | -25         |
|    gen/time/fps                    | 65          |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 45          |
|    gen/time/total_timesteps        | 60000       |
|    gen/train/approx_kl             | 0.005299911 |
|    gen/train/clip_fraction         | 0.077       |
|    gen/train/clip_range            | 0.1         |
|    gen/train/entropy_loss          | -0.562      |
|    gen/train/explained_variance    | 0.321       |
|    gen/train/learning_rate         | 0.001       |
|    gen/train/loss                  | 15.9        |
|    gen/train/n_updates             | 285         |
|    gen/train/policy_gradient_loss  | -0.000251   |
|    gen/train/value_loss            | 88     

round:  80%|████████  | 20/25 [19:30<04:28, 53.70s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 4.42        |
|    gen/rollout/ep_rew_mean         | 2.24        |
|    gen/rollout/ep_rew_wrapped_mean | -20.5       |
|    gen/time/fps                    | 62          |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 47          |
|    gen/time/total_timesteps        | 63000       |
|    gen/train/approx_kl             | 0.002830202 |
|    gen/train/clip_fraction         | 0.102       |
|    gen/train/clip_range            | 0.1         |
|    gen/train/entropy_loss          | -0.514      |
|    gen/train/explained_variance    | 0.463       |
|    gen/train/learning_rate         | 0.001       |
|    gen/train/loss                  | 31.3        |
|    gen/train/n_updates             | 300         |
|    gen/train/policy_gradient_loss  | -0.000449   |
|    gen/train/value_loss            | 53.4   

round:  84%|████████▍ | 21/25 [20:24<03:35, 53.99s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 5.83         |
|    gen/rollout/ep_rew_mean         | 3.49         |
|    gen/rollout/ep_rew_wrapped_mean | -22.5        |
|    gen/time/fps                    | 58           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 51           |
|    gen/time/total_timesteps        | 66000        |
|    gen/train/approx_kl             | 0.0006946413 |
|    gen/train/clip_fraction         | 0.0473       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.43        |
|    gen/train/explained_variance    | 0.49         |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 59.9         |
|    gen/train/n_updates             | 315          |
|    gen/train/policy_gradient_loss  | 0.000193     |
|    gen/train/value_loss   

round:  88%|████████▊ | 22/25 [21:22<02:45, 55.15s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 4.97         |
|    gen/rollout/ep_rew_mean         | 2.75         |
|    gen/rollout/ep_rew_wrapped_mean | -23.2        |
|    gen/time/fps                    | 64           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 46           |
|    gen/time/total_timesteps        | 69000        |
|    gen/train/approx_kl             | 0.0010696149 |
|    gen/train/clip_fraction         | 0.0754       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.422       |
|    gen/train/explained_variance    | 0.625        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 12.1         |
|    gen/train/n_updates             | 330          |
|    gen/train/policy_gradient_loss  | 0.000857     |
|    gen/train/value_loss   

round:  92%|█████████▏| 23/25 [22:16<01:49, 54.59s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 5.94         |
|    gen/rollout/ep_rew_mean         | 3.6          |
|    gen/rollout/ep_rew_wrapped_mean | -11.9        |
|    gen/time/fps                    | 59           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 50           |
|    gen/time/total_timesteps        | 72000        |
|    gen/train/approx_kl             | 0.0018942554 |
|    gen/train/clip_fraction         | 0.0753       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.426       |
|    gen/train/explained_variance    | 0.676        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 50.4         |
|    gen/train/n_updates             | 345          |
|    gen/train/policy_gradient_loss  | 0.00192      |
|    gen/train/value_loss   

round:  96%|█████████▌| 24/25 [23:13<00:55, 55.49s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 7.67         |
|    gen/rollout/ep_rew_mean         | 4.79         |
|    gen/rollout/ep_rew_wrapped_mean | -33.9        |
|    gen/time/fps                    | 63           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 47           |
|    gen/time/total_timesteps        | 75000        |
|    gen/train/approx_kl             | 0.0015538778 |
|    gen/train/clip_fraction         | 0.048        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.413       |
|    gen/train/explained_variance    | 0.263        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 12.9         |
|    gen/train/n_updates             | 360          |
|    gen/train/policy_gradient_loss  | 0.00258      |
|    gen/train/value_loss   

round: 100%|██████████| 25/25 [24:07<00:00, 57.91s/it]

🏃 View run funny-mole-185 at: http://127.0.0.1:8080/#/experiments/282678262450638424/runs/cc1c5aa1802e4ac8be9cdc69ef0af740
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/282678262450638424





## Metrics

Download the reward network from mlflow.

In [37]:
local_path = mlflow.artifacts.download_artifacts(
    artifact_uri="mlflow-artifacts:/282678262450638424/cc1c5aa1802e4ac8be9cdc69ef0af740/artifacts/reward_net0v2"
)

reward_net0 = torch.load(local_path, weights_only=False)

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

In [38]:
local_path = mlflow.artifacts.download_artifacts(
    artifact_uri="mlflow-artifacts:/282678262450638424/7006d8322b42432aa5a49b35983a938a/artifacts/reward_net1v2"
)

reward_net1 = torch.load(local_path, weights_only=False)

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

### Comparing **uncalibrated** mean rewards

In [39]:
states0, obs0, next_states0, dones0 = (
    trajectories_0_test.obs,
    trajectories_0_test.acts,
    trajectories_0_test.next_obs,
    trajectories_0_test.dones,
)
states1, obs1, next_states1, dones1 = (
    trajectories_1_test.obs,
    trajectories_1_test.acts,
    trajectories_1_test.next_obs,
    trajectories_1_test.dones,
)

In [40]:
# For reward_net1
rewards1 = reward_net1.predict(states0, obs0, next_states0, dones0)
norm_rewards1 = (rewards1 - rewards1.mean()) / rewards1.std()
print("Reward network 1 with traj0: ", norm_rewards1.mean())

# For reward_net0
rewards0 = reward_net0.predict(states0, obs0, next_states0, dones0)
norm_rewards0 = (rewards0 - rewards0.mean()) / rewards0.std()
print("Reward network 0 with traj0: ", norm_rewards0.mean())

Reward network 1 with traj0:  3.8889407e-08
Reward network 0 with traj0:  3.5354006e-08


In [41]:
# For reward_net1
rewards1 = reward_net1.predict(states1, obs1, next_states1, dones1)
norm_rewards1 = (rewards1 - rewards1.mean()) / rewards1.std()
print("Reward network 1 with traj1: ", norm_rewards1.mean())

# For reward_net0
rewards0 = reward_net0.predict(states1, obs1, next_states1, dones1)
norm_rewards0 = (rewards0 - rewards0.mean()) / rewards0.std()
print("Reward network 0 with traj1: ", norm_rewards0.mean())

Reward network 1 with traj1:  1.2738005e-07
Reward network 0 with traj1:  -2.0280245e-07


### Calibration using affine transformations

We use a validation set to calculate alpha and beta parameters for our affine transformation, which is simply
$f(x) = \alpha x + \beta$, where $\alpha = {\sigma_{target}\over\sigma_{x}}$ and $\beta = \mu_{target} - \alpha_x * \mu_x$

In [42]:
states, obs, next_states, dones = (
    trajectories_val.obs,
    trajectories_val.acts,
    trajectories_val.next_obs,
    trajectories_val.dones,
)

In [43]:
outputs_arb = reward_net1.predict(states, obs, next_states, dones)
outputs_nonarb = reward_net0.predict(states, obs, next_states, dones)

# Compute empirical mean and std for each network
mean_arb, std_arb = outputs_arb.mean(), outputs_arb.std()
mean_nonarb, std_nonarb = outputs_nonarb.mean(), outputs_nonarb.std()

# Define target calibration values (e.g., mean=0, std=1)
target_mean, target_std = 0.0, 1.0

# Calculate affine transformation parameters
alpha_arb = target_std / std_arb
beta_arb = target_mean - alpha_arb * mean_arb

alpha_nonarb = target_std / std_nonarb
beta_nonarb = target_mean - alpha_nonarb * mean_nonarb

# Calibrate the outputs
calibrated_arb = outputs_arb * alpha_arb + beta_arb
calibrated_nonarb = outputs_nonarb * alpha_nonarb + beta_nonarb

### Accuracy, Recall and F1-Score for both calibrated reward networks

In [44]:
# For reward_net1
rewards1 = reward_net1.predict(states0, obs0, next_states0, dones0)
calibrated_rewards1 = alpha_arb * rewards1 + beta_arb
print("Reward network 1 with traj0: ", calibrated_rewards1.mean())

# For reward_net0
rewards0 = reward_net0.predict(states0, obs0, next_states0, dones0)
calibrated_rewards0 = alpha_nonarb * rewards0 + beta_nonarb
print("Reward network 0 with traj0: ", calibrated_rewards0.mean())

Reward network 1 with traj0:  -0.15752733
Reward network 0 with traj0:  -0.17344248


In [45]:
predictions = np.array(
    [1 if r0 > r1 else 0 for r0, r1 in zip(calibrated_rewards0, calibrated_rewards1)]
)
true_labels = np.array([1] * len(predictions))

accuracy = accuracy_score(true_labels, predictions)
precision = precision_score(true_labels, predictions)
recall = recall_score(true_labels, predictions)
f1 = f1_score(true_labels, predictions)

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)

Accuracy: 0.6362372567191844
Precision: 1.0
Recall: 0.6362372567191844
F1 Score: 0.777683375814217


In [46]:
# For reward_net1
rewards1 = reward_net1.predict(states1, obs1, next_states1, dones1)
calibrated_rewards1 = alpha_arb * rewards1 + beta_arb
print("Reward network 1 with traj1: ", calibrated_rewards1.mean())

# For reward_net0
rewards0 = reward_net0.predict(states1, obs1, next_states1, dones1)
calibrated_rewards0 = alpha_nonarb * rewards0 + beta_nonarb
print("Reward network 0 with traj1: ", calibrated_rewards0.mean())

Reward network 1 with traj1:  0.3318534
Reward network 0 with traj1:  0.37683234


In [47]:
predictions = np.array(
    [1 if r0 < r1 else 0 for r0, r1 in zip(calibrated_rewards0, calibrated_rewards1)]
)
true_labels = np.array([1] * len(predictions))

accuracy = accuracy_score(true_labels, predictions)
precision = precision_score(true_labels, predictions)
recall = recall_score(true_labels, predictions)
f1 = f1_score(true_labels, predictions)

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)

Accuracy: 0.19244288224956063
Precision: 1.0
Recall: 0.19244288224956063
F1 Score: 0.32277081798084006
