# AIRL with TransactionsGraphEnvironment v2

In [1]:
import ast

import gymnasium as gym
import mlflow
import numpy as np
import pandas as pd
import torch
from imitation.algorithms.adversarial.airl import AIRL
from imitation.data.rollout import flatten_trajectories
from imitation.data.types import Trajectory
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.rewards.reward_nets import BasicShapedRewardNet
from imitation.util import logger
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import VecCheckNan
from stable_baselines3.ppo import MlpPolicy

import graph_reinforcement_learning_using_blockchain_data as grl
from graph_reinforcement_learning_using_blockchain_data import config

config.load_dotenv()
mlflow.set_tracking_uri(uri=config.MLFLOW_TRACKING_URI)

[32m2025-05-04 10:42:01.780[0m | [1mINFO    [0m | [36mgraph_reinforcement_learning_using_blockchain_data.config[0m:[36m<module>[0m:[36m12[0m - [1mPROJ_ROOT path is: /Users/liamtessendorf/Programming/Uni/2_Master/4_FS25_Programming/graph-reinforcement-learning-using-blockchain-data[0m


In [2]:
RNG = np.random.default_rng(seed=42)

## Dataset 

In [3]:
df_emb = pd.read_csv(config.FLASHBOTS_Q2_DATA_DIR / "state_embeddings_pre_trained_128.csv")
df_class0 = pd.read_csv(config.RAW_DATA_DIR / "receipts_class0.csv")
df_class1 = pd.read_csv(config.RAW_DATA_DIR / "receipts_class1.csv")
df_eth_balances_class1 = pd.read_csv(config.RAW_DATA_DIR / "eth_balances_class1.csv")
df_eth_balances_class0 = pd.read_csv(config.RAW_DATA_DIR / "eth_balances_class0.csv")

In [4]:
print(df_class0.columns)
print(df_eth_balances_class0.columns)

Index(['block_number', 'transaction_hash', 'blockHash', 'blockNumber',
       'logsBloom', 'gasUsed', 'contractAddress', 'cumulativeGasUsed',
       'transactionIndex', 'from', 'to', 'type', 'effectiveGasPrice', 'logs',
       'status'],
      dtype='object')
Index(['account', 'block_number', 'balance'], dtype='object')


In [5]:
df_class0_with_eth_balances = df_class0.merge(
    df_eth_balances_class0,
    left_on=["from", "blockNumber"],
    right_on=["account", "block_number"],
    how="inner",
)
df_class1_with_eth_balances = df_class1.merge(
    df_eth_balances_class1,
    left_on=["from", "blockNumber"],
    right_on=["account", "block_number"],
    how="inner",
)

In [6]:
df_class0_multi_occ = df_class0_with_eth_balances[
    df_class0_with_eth_balances["from"].duplicated(keep=False)
]

In [7]:
df_emb["embeddings"] = df_emb["embeddings"].apply(
    lambda x: np.array(ast.literal_eval(x), dtype=np.float32)
)

In [8]:
df_class0_with_eth_balances["label"] = 0
df_class1_with_eth_balances["label"] = 1

In [9]:
df_receipts = pd.concat(
    [df_class0_with_eth_balances, df_class1_with_eth_balances], ignore_index=True
)
df_receipts.drop_duplicates("transaction_hash", inplace=True)
df = df_receipts.merge(df_emb, how="right", left_on="transaction_hash", right_on="transactionHash")

In [10]:
df.head()

Unnamed: 0,block_number_x,transaction_hash,blockHash,blockNumber,logsBloom,gasUsed,contractAddress,cumulativeGasUsed,transactionIndex,from,...,created_at,account_address,profit_token_address,start_amount,end_amount,profit_amount,error,protocols,transactionHash,embeddings
0,16969850,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,0x241b2ebd536a6f546ce2214bcf2146d359ae4e0cc4e3...,16969850,0x00000000000000000000000000000000000000000000...,21000,,1315400,19,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,,,,,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,"[0.16500826, -0.09403816, 0.21792151, 0.012831..."
1,16975162,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,0x577d4ef683ebd84c73cdbc3635c7177f6ec137eef3bf...,16975162,0x00000000000000000000000080000000000000000000...,116880,,396290,3,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,,,,,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,"[-0.09382023, -0.012681959, 0.14728837, 0.0354..."
2,16983327,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,0xe79daae4074cb858a3e79b4d95064845295eaddc134c...,16983327,0x00000000000000000000000000000000000000000000...,103421,,3934866,21,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,,,,,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,"[-0.10671734, 0.0006063916, 0.160602, 0.039328..."
3,16990575,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,0x9c3fa88c63603c9c5696174085177f9559e20b331a64...,16990575,0x00200000000000000000000080004000000000000000...,105953,,4683545,39,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,,,,,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,"[-0.11480839, 0.027570017, 0.22743803, 0.04695..."
4,16994491,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,0xae6c66f64f13ba5404bbe4cebd33fdd666c7bc8fe601...,16994491,0x00200000000011000000000080000000000000000000...,223722,,3972377,34,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,,,,,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,"[-0.067668885, -0.094198264, 0.12297824, 0.158..."


In [11]:
df_median_gas_prices = pd.DataFrame(
    {
        "median_gas_price": df.groupby(["blockNumber"])["effectiveGasPrice"].median(),
        "std_gas_price": df.groupby(["blockNumber"])["effectiveGasPrice"].std(),
        "max_gas_price": df.groupby(["blockNumber"])["effectiveGasPrice"].max(),
        "min_gas_price": df.groupby(["blockNumber"])["effectiveGasPrice"].min(),
    }
)

df_with_median_gas_prices = df.merge(df_median_gas_prices, how="left", on="blockNumber")
df_with_median_gas_prices.head()

Unnamed: 0,block_number_x,transaction_hash,blockHash,blockNumber,logsBloom,gasUsed,contractAddress,cumulativeGasUsed,transactionIndex,from,...,end_amount,profit_amount,error,protocols,transactionHash,embeddings,median_gas_price,std_gas_price,max_gas_price,min_gas_price
0,16969850,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,0x241b2ebd536a6f546ce2214bcf2146d359ae4e0cc4e3...,16969850,0x00000000000000000000000000000000000000000000...,21000,,1315400,19,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,"[0.16500826, -0.09403816, 0.21792151, 0.012831...",28412780000.0,5338623000.0,40046142239,27985774295
1,16975162,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,0x577d4ef683ebd84c73cdbc3635c7177f6ec137eef3bf...,16975162,0x00000000000000000000000080000000000000000000...,116880,,396290,3,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,"[-0.09382023, -0.012681959, 0.14728837, 0.0354...",31324250000.0,13106710000.0,62912040686,29651658352
2,16983327,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,0xe79daae4074cb858a3e79b4d95064845295eaddc134c...,16983327,0x00000000000000000000000000000000000000000000...,103421,,3934866,21,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,"[-0.10671734, 0.0006063916, 0.160602, 0.039328...",47015990000.0,94835380000.0,298140379626,44115991364
3,16990575,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,0x9c3fa88c63603c9c5696174085177f9559e20b331a64...,16990575,0x00200000000000000000000080004000000000000000...,105953,,4683545,39,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,"[-0.11480839, 0.027570017, 0.22743803, 0.04695...",39976070000.0,22168940000.0,72384430845,29976074199
4,16994491,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,0xae6c66f64f13ba5404bbe4cebd33fdd666c7bc8fe601...,16994491,0x00200000000011000000000080000000000000000000...,223722,,3972377,34,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,"[-0.067668885, -0.094198264, 0.12297824, 0.158...",26952080000.0,3320303000.0,30915355327,22915355327


In [12]:
df_with_actions = df_with_median_gas_prices.copy()
df_with_actions["action"] = df_with_median_gas_prices.apply(
    lambda r: 1 if r["effectiveGasPrice"] > r["median_gas_price"] else 0, axis=1
)

In [13]:
df_with_actions["action"].mean()

0.24449359876667337

In [14]:
df_with_actions.head()

Unnamed: 0,block_number_x,transaction_hash,blockHash,blockNumber,logsBloom,gasUsed,contractAddress,cumulativeGasUsed,transactionIndex,from,...,profit_amount,error,protocols,transactionHash,embeddings,median_gas_price,std_gas_price,max_gas_price,min_gas_price,action
0,16969850,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,0x241b2ebd536a6f546ce2214bcf2146d359ae4e0cc4e3...,16969850,0x00000000000000000000000000000000000000000000...,21000,,1315400,19,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,"[0.16500826, -0.09403816, 0.21792151, 0.012831...",28412780000.0,5338623000.0,40046142239,27985774295,1
1,16975162,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,0x577d4ef683ebd84c73cdbc3635c7177f6ec137eef3bf...,16975162,0x00000000000000000000000080000000000000000000...,116880,,396290,3,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,"[-0.09382023, -0.012681959, 0.14728837, 0.0354...",31324250000.0,13106710000.0,62912040686,29651658352,1
2,16983327,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,0xe79daae4074cb858a3e79b4d95064845295eaddc134c...,16983327,0x00000000000000000000000000000000000000000000...,103421,,3934866,21,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,"[-0.10671734, 0.0006063916, 0.160602, 0.039328...",47015990000.0,94835380000.0,298140379626,44115991364,1
3,16990575,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,0x9c3fa88c63603c9c5696174085177f9559e20b331a64...,16990575,0x00200000000000000000000080004000000000000000...,105953,,4683545,39,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,"[-0.11480839, 0.027570017, 0.22743803, 0.04695...",39976070000.0,22168940000.0,72384430845,29976074199,0
4,16994491,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,0xae6c66f64f13ba5404bbe4cebd33fdd666c7bc8fe601...,16994491,0x00200000000011000000000080000000000000000000...,223722,,3972377,34,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,"[-0.067668885, -0.094198264, 0.12297824, 0.158...",26952080000.0,3320303000.0,30915355327,22915355327,1


In [15]:
df_with_actions.rename(columns={"balance": "eth_balance"}, inplace=True)
df_with_actions["eth_balance"] = df_with_actions["eth_balance"].astype("float64")
df_with_actions["median_gas_price"] = df_with_actions["median_gas_price"].astype("float64")
df_with_actions["std_gas_price"] = df_with_actions["std_gas_price"].astype("float64")
df_with_actions["from"] = df_with_actions["from"].astype("string")
df_with_actions["to"] = df_with_actions["to"].astype("string")
df_with_actions.head()

Unnamed: 0,block_number_x,transaction_hash,blockHash,blockNumber,logsBloom,gasUsed,contractAddress,cumulativeGasUsed,transactionIndex,from,...,profit_amount,error,protocols,transactionHash,embeddings,median_gas_price,std_gas_price,max_gas_price,min_gas_price,action
0,16969850,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,0x241b2ebd536a6f546ce2214bcf2146d359ae4e0cc4e3...,16969850,0x00000000000000000000000000000000000000000000...,21000,,1315400,19,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0xe4029c908cfc40f825051cd0957797c66196eb8ba437...,"[0.16500826, -0.09403816, 0.21792151, 0.012831...",28412780000.0,5338623000.0,40046142239,27985774295,1
1,16975162,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,0x577d4ef683ebd84c73cdbc3635c7177f6ec137eef3bf...,16975162,0x00000000000000000000000080000000000000000000...,116880,,396290,3,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0xbcfb84169287cf7acce33fba2b7390cfe21852871f78...,"[-0.09382023, -0.012681959, 0.14728837, 0.0354...",31324250000.0,13106710000.0,62912040686,29651658352,1
2,16983327,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,0xe79daae4074cb858a3e79b4d95064845295eaddc134c...,16983327,0x00000000000000000000000000000000000000000000...,103421,,3934866,21,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0x7a954e541df7c296b8fec61b77d22ea8bacd63399467...,"[-0.10671734, 0.0006063916, 0.160602, 0.039328...",47015990000.0,94835380000.0,298140379626,44115991364,1
3,16990575,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,0x9c3fa88c63603c9c5696174085177f9559e20b331a64...,16990575,0x00200000000000000000000080004000000000000000...,105953,,4683545,39,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0x016858c7c133cdf545b6934653544c19475c5c111ba0...,"[-0.11480839, 0.027570017, 0.22743803, 0.04695...",39976070000.0,22168940000.0,72384430845,29976074199,0
4,16994491,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,0xae6c66f64f13ba5404bbe4cebd33fdd666c7bc8fe601...,16994491,0x00200000000011000000000080000000000000000000...,223722,,3972377,34,0x00000000000124d994209fbB955E0217B5C2ECA1,...,,,,0x0f358c5aed8ae456a298eedaf3fc08cb802c3417a6b4...,"[-0.067668885, -0.094198264, 0.12297824, 0.158...",26952080000.0,3320303000.0,30915355327,22915355327,1


In [16]:
df_with_actions[df_with_actions["std_gas_price"].isna() == True]

Unnamed: 0,block_number_x,transaction_hash,blockHash,blockNumber,logsBloom,gasUsed,contractAddress,cumulativeGasUsed,transactionIndex,from,...,profit_amount,error,protocols,transactionHash,embeddings,median_gas_price,std_gas_price,max_gas_price,min_gas_price,action
63,16975107,0xd05a59ef18204af79ae9bf2a7ba722bca892055819c7...,0x8d8cd7dfa64a3867f12d472be895ee2a1b163d854a02...,16975107,0x00200000000000000000000080000000000008000000...,196537,,685900,5,0x00000006e42915A2B6907f8b3fAF311B68862f60,...,6855415086744551,,"[""uniswap_v2""]",0xd05a59ef18204af79ae9bf2a7ba722bca892055819c7...,"[0.33906102, 0.11006685, -0.11049971, -0.11027...",3.442311e+10,,34423111170,34423111170,0
65,16976564,0x5530313d0b0271506691e3732c517172d5bfa1b2ba3d...,0x458b66a35808bf44a7e332b9f9b326ca8660952ab40a...,16976564,0x00200000000000000000000080000000200000000000...,190691,,1322470,7,0x00000006e42915A2B6907f8b3fAF311B68862f60,...,7796163749018364,,"[""uniswap_v2""]",0x5530313d0b0271506691e3732c517172d5bfa1b2ba3d...,"[0.41424438, 0.1412386, -0.15269175, -0.147597...",4.041179e+10,,40411786608,40411786608,0
76,16979829,0xe47601937f0538ecc2a67c0a1b2481a1d339b52b2ef0...,0x4d1ebc8a72732a87fb083679233a06cf0a33ec984b35...,16979829,0x00200000000000000000000084000000200000000000...,158198,,158198,0,0x00000006e42915A2B6907f8b3fAF311B68862f60,...,24408839782523421,,"[""uniswap_v2"",""uniswap_v3""]",0xe47601937f0538ecc2a67c0a1b2481a1d339b52b2ef0...,"[0.4598713, 0.15552238, -0.16139461, -0.179502...",1.532201e+11,,153220129205,153220129205,0
98,16992014,0x23e05562df7784836aaf6c8235d2aca5501621e6aab2...,0x8794222099471416ba2721137ee0e0d60a149b2f2029...,16992014,0x00200000000000000000000080000200000000000000...,219788,,219788,0,0x00000006e42915A2B6907f8b3fAF311B68862f60,...,35168219119161781,,"[""uniswap_v2"",""uniswap_v3""]",0x23e05562df7784836aaf6c8235d2aca5501621e6aab2...,"[0.46931684, 0.12999484, -0.15137316, -0.15249...",1.595249e+11,,159524873969,159524873969,0
167,17240935,0x74e6628155b2f61c067a568235c952e3c1fa4aa22d76...,0xa658109201d2dcd46dcc9f5ad1c71d29d6942c44b1c4...,17240935,0x00000000040000000000000000000000000000000000...,210289,,2954344,23,0x00000006e42915A2B6907f8b3fAF311B68862f60,...,13097045417966494,,"[""uniswap_v3""]",0x74e6628155b2f61c067a568235c952e3c1fa4aa22d76...,"[0.31971875, 0.08582469, -0.06865533, -0.08339...",6.190800e+10,,61907998723,61907998723,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
149181,17296632,0x2f06cc7f885dd98fbf765ad57706fe505c5b63478928...,0x64522455567aa9a4209551fd938a1e31b448487ac157...,17296632,0x00200000000000000010000080000000000000000000...,178376,,440031,1,0xffFf14106945bCB267B34711c416AA3085B8865F,...,22902832502491512,,"[""uniswap_v2""]",0x2f06cc7f885dd98fbf765ad57706fe505c5b63478928...,"[0.6747055, 0.22695717, -0.25271302, -0.264935...",3.450291e+10,,34502909809,34502909809,0
149182,17297932,0xd477617aaa93aad5aa5f5a2f880f3e934b70a91cd523...,0x493341206c063a7ba6c85137473b17f697ed72701c07...,17297932,0x00200000000000000010000080000000000000000000...,178388,,440065,1,0xffFf14106945bCB267B34711c416AA3085B8865F,...,31467600199400760,,"[""uniswap_v2""]",0xd477617aaa93aad5aa5f5a2f880f3e934b70a91cd523...,"[0.73240715, 0.24466729, -0.27648795, -0.27974...",2.928406e+10,,29284056698,29284056698,0
149183,17298360,0xda0b7b1156a57ff85ef5c2e80f47d6fc0c89c6253fca...,0x21a1c5ae76e2cbff779f308288e1cf2273d999f978e1...,17298360,0x00200000000000000010000080000000000000000000...,178352,,439975,1,0xffFf14106945bCB267B34711c416AA3085B8865F,...,38228015952142520,,"[""uniswap_v2""]",0xda0b7b1156a57ff85ef5c2e80f47d6fc0c89c6253fca...,"[0.7970462, 0.2638063, -0.30356315, -0.2968765...",2.959543e+10,,29595427147,29595427147,0
149184,17298379,0x7755b8553cc9f2479e5cc48ebf18dd6f25fe038a668f...,0xa0171ad96bcd73eb6aa92c8da1e3a8629c1c067ea811...,17298379,0x00200000000000000010000080000000000000000000...,178364,,1162669,7,0xffFf14106945bCB267B34711c416AA3085B8865F,...,28570015952142520,,"[""uniswap_v2""]",0x7755b8553cc9f2479e5cc48ebf18dd6f25fe038a668f...,"[0.84662515, 0.26852706, -0.4143204, -0.313244...",2.783189e+10,,27831892530,27831892530,0


In [17]:
df_with_actions = df_with_actions.fillna({"std_gas_price": 0})

In [18]:
df_with_actions_0 = df_with_actions[df_with_actions["label"] == 0]
df_with_actions_1 = df_with_actions[df_with_actions["label"] == 1]

In [19]:
unique_accs_0 = df_with_actions_0["from"].unique()
accs_train_0 = unique_accs_0[: int(0.8 * len(unique_accs_0))]
accs_val_0 = unique_accs_0[int(0.8 * len(unique_accs_0)) : int(0.9 * len(unique_accs_0))]
accs_test_0 = unique_accs_0[int(0.9 * len(unique_accs_0)) :]
df_with_actions_0_train = df_with_actions_0[df_with_actions_0["from"].isin(accs_train_0)]
df_with_actions_0_val = df_with_actions_0[df_with_actions_0["from"].isin(accs_val_0)]
df_with_actions_0_test = df_with_actions_0[df_with_actions_0["from"].isin(accs_test_0)]

In [20]:
unique_accs_1 = df_with_actions_1["from"].unique()
accs_train_1 = unique_accs_1[: int(0.8 * len(unique_accs_1))]
accs_val_1 = unique_accs_1[int(0.8 * len(unique_accs_1)) : int(0.9 * len(unique_accs_1))]
accs_test_1 = unique_accs_1[int(0.9 * len(unique_accs_1)) :]
df_with_actions_1_train = df_with_actions_1[df_with_actions_1["from"].isin(accs_train_1)]
df_with_actions_1_val = df_with_actions_1[df_with_actions_1["from"].isin(accs_val_1)]
df_with_actions_1_test = df_with_actions_1[df_with_actions_1["from"].isin(accs_test_1)]

In [21]:
df_val = pd.concat([df_with_actions_0_val, df_with_actions_1_val])

In [22]:
df_with_actions_1_train[
    df_with_actions_1_train["from"] == "0x1e6c1c4669f612112a7caCa5596BfE6629e669aA"
]

Unnamed: 0,block_number_x,transaction_hash,blockHash,blockNumber,logsBloom,gasUsed,contractAddress,cumulativeGasUsed,transactionIndex,from,...,profit_amount,error,protocols,transactionHash,embeddings,median_gas_price,std_gas_price,max_gas_price,min_gas_price,action


## Creating trajectories

In [23]:
def extract_trajectories(df: pd.DataFrame):
    trajectories = []
    for account, group in df.groupby("from"):
        group = group.sort_values("blockNumber")
        obs_list = group["embeddings"].tolist() + [np.zeros(128, dtype=np.float32)]
        traj = {
            "obs": np.stack(obs_list),  # Convert list of arrays to a single numpy array
            "acts": np.array(group["action"].tolist()),
            "label": group["label"].iloc[0],
        }
        trajectories.append(traj)
    return trajectories


trajectories_1_train = extract_trajectories(df_with_actions_1_train)
trajectories_0_train = extract_trajectories(df_with_actions_0_train)
trajectories_1_test = extract_trajectories(df_with_actions_1_test)
trajectories_0_test = extract_trajectories(df_with_actions_0_test)
trajectories_val = extract_trajectories(df_val)

In [24]:
trajectories_1 = [
    Trajectory(obs=traj["obs"], acts=traj["acts"], infos=None, terminal=True)
    for traj in trajectories_1_train
]
trajectories_0 = [
    Trajectory(obs=traj["obs"], acts=traj["acts"], infos=None, terminal=True)
    for traj in trajectories_0_train
]
trajectories_val = [
    Trajectory(obs=traj["obs"], acts=traj["acts"], infos=None, terminal=True)
    for traj in trajectories_val
]

trajectories_1 = flatten_trajectories(trajectories_1)
trajectories_0 = flatten_trajectories(trajectories_0)
trajectories_val = flatten_trajectories(trajectories_val)

In [25]:
trajectories_1_test = [
    Trajectory(obs=traj["obs"], acts=traj["acts"], infos=None, terminal=True)
    for traj in trajectories_1_test
]
trajectories_0_test = [
    Trajectory(obs=traj["obs"], acts=traj["acts"], infos=None, terminal=True)
    for traj in trajectories_0_test
]

trajectories_1_test = flatten_trajectories(trajectories_1_test)
trajectories_0_test = flatten_trajectories(trajectories_0_test)

## Setting up environments

In [26]:
ID0 = "gymnasium_env/TransactionGraphEnv0-v2"
gym.envs.register(
    id=ID0,
    entry_point=grl.TransactionGraphEnvV2,
    kwargs={"df": df_with_actions_0, "alpha": 0.9, "device": torch.device("mps"), "label": 0},
    max_episode_steps=300,
)

ID1 = "gymnasium_env/TransactionGraphEnv1-v2"
gym.envs.register(
    id=ID1,
    entry_point=grl.TransactionGraphEnvV2,
    kwargs={"df": df_with_actions_1, "alpha": 0.9, "device": torch.device("mps"), "label": 1},
    max_episode_steps=300,
)

In [27]:
gym.pprint_registry()

===== classic_control =====
Acrobot-v1             CartPole-v0            CartPole-v1
MountainCar-v0         MountainCarContinuous-v0 Pendulum-v1
===== phys2d =====
phys2d/CartPole-v0     phys2d/CartPole-v1     phys2d/Pendulum-v0
===== box2d =====
BipedalWalker-v3       BipedalWalkerHardcore-v3 CarRacing-v2
LunarLander-v2         LunarLanderContinuous-v2
===== toy_text =====
Blackjack-v1           CliffWalking-v0        FrozenLake-v1
FrozenLake8x8-v1       Taxi-v3
===== tabular =====
tabular/Blackjack-v0   tabular/CliffWalking-v0
===== mujoco =====
Ant-v2                 Ant-v3                 Ant-v4
HalfCheetah-v2         HalfCheetah-v3         HalfCheetah-v4
Hopper-v2              Hopper-v3              Hopper-v4
Humanoid-v2            Humanoid-v3            Humanoid-v4
HumanoidStandup-v2     HumanoidStandup-v4     InvertedDoublePendulum-v2
InvertedDoublePendulum-v4 InvertedPendulum-v2    InvertedPendulum-v4
Pusher-v2              Pusher-v4              Reacher-v2
Reacher-v4         

In [28]:
env0 = Monitor(gym.make(ID0))

venv0 = make_vec_env(
    ID0,
    rng=RNG,
    n_envs=1,
    post_wrappers=[lambda env0, _: RolloutInfoWrapper(env0)],
    parallel=False,
)

venv0 = VecCheckNan(venv0, raise_exception=True)  # Check for NaN observations
venv0.reset()

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

array([[ 0.18248384, -0.10978805,  0.24212572,  0.01781837,  0.01548136,
        -0.09479892,  0.3450035 ,  0.04917271, -0.22358862, -0.10396218,
        -0.08917117,  0.3299945 , -0.1735357 ,  0.02740014, -0.07751278,
         0.26748022,  0.02605709, -0.1688954 , -0.1445691 , -0.1184966 ,
         0.4315584 ,  0.01139639, -0.30225065, -0.02593596,  0.18624091,
         0.03413535,  0.24536717, -0.00505128, -0.02266762, -0.05883092,
        -0.11157656, -0.09253941,  0.02740761,  0.08645051, -0.04514315,
         0.16507655,  0.1712383 , -0.12266177,  0.22952211,  0.14224248,
        -0.18861863, -0.14294359,  0.04085501,  0.15176764,  0.31031537,
         0.04521439, -0.31404567, -0.3016088 ,  0.20373118, -0.24753569,
        -0.39547986, -0.14190657,  0.03015415, -0.15986337,  0.01052236,
         0.08406067,  0.10165153,  0.0656958 , -0.1357242 , -0.07307398,
         0.24345124,  0.09660505,  0.11921884, -0.2926217 , -0.12088049,
        -0.07295685,  0.18425424, -0.19700514,  0.1

In [29]:
env1 = Monitor(gym.make(ID1))

venv1 = make_vec_env(
    ID1,
    rng=RNG,
    n_envs=1,
    post_wrappers=[lambda env1, _: RolloutInfoWrapper(env1)],
    parallel=False,
)

venv1 = VecCheckNan(venv1, raise_exception=True)  # Check for NaN observations
venv1.reset()

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

array([[ 6.12601787e-02, -8.50920826e-02, -4.99629602e-03,
         4.24816459e-02,  1.22531548e-01,  2.30932459e-02,
         3.42156701e-02, -4.21753488e-02, -3.70578989e-02,
        -5.04653268e-02, -6.62448108e-02,  1.44533977e-01,
         2.52744071e-02, -1.07371524e-01, -5.75902015e-02,
        -4.96821105e-02, -1.12326294e-02, -3.01924534e-02,
         2.28365511e-03,  4.03076038e-02,  2.48026028e-02,
        -1.10254273e-01, -3.90574411e-02,  5.98332658e-03,
         5.38153164e-02,  8.54561031e-02, -2.48320475e-02,
         4.48055193e-02, -1.12906098e-01,  3.27808969e-02,
        -5.19185998e-02, -7.53983855e-04,  3.44916247e-02,
         2.26563774e-02, -4.81302738e-02, -4.45410386e-02,
         1.98302120e-02, -3.14079225e-03, -2.44430453e-03,
        -7.65292495e-02, -1.06847443e-01, -4.31123413e-02,
         5.26587740e-02, -5.49951494e-02, -6.86649084e-02,
         9.87035707e-02,  8.52322206e-03,  2.91684214e-02,
        -6.71655685e-02,  3.76953860e-03, -5.10714240e-0

## AIRL setup

In [30]:
# Set parameters for the PPO algorithm (generator)
learning_rate = 0.001  # Learning rate, can be a function of progress
batch_size = 60  # Mini batch size for each gradient update
n_epochs = 15  # N of epochs when optimizing the surrogate loss

gamma = 0.5  # Discount factor, focus on the recent rewards
gae_lambda = 0  # Generalized advantage estimation
clip_range = 0.1  # Clipping parameter
ent_coef = 0.01  # Entropy coefficient for the loss calculation
vf_coef = 0.5  # Value function coef. for the loss calculation
max_grad_norm = 0.5  # The maximum value for the gradient clipping

verbose = 0  # Verbosity level: 0 no output, 1 info, 2 debug
normalize_advantage = True  # Whether to normalize or not the advantage

clip_range_vf = None  # Clip for the value function
use_sde = False  # Use State Dependent Exploration
sde_sample_freq = -1  # SDE - noise matrix frequency (-1 = disable)

# Set parameters for the AIRL trainer
gen_replay_buffer_capacity = None
allow_variable_horizon = True

disc_opt_kwargs = {
    "lr": 0.001,
}
policy_kwargs = {"use_expln": True}  # Fixing an issue with NaNs

In [31]:
# Set the number of timesteps, batch size and number of disc updates

# Total number of timesteps in the whole training
total_timesteps = 3000 * 25

# Generator
gen_train_timesteps = 3000  # N steps in the environment per one round
n_steps = gen_train_timesteps

# Discriminator batches
demo_minibatch_size = 60  # N samples in minibatch for one discrim. update
demo_batch_size = 300 * 10  # N samples in the batch of expert data (batch)
n_disc_updates_per_round = 4  # N discriminator updates per one round

In [32]:
hier_logger = logger.configure()
hier_logger.default_logger.output_formats.append(grl.MLflowOutputFormat())

In [33]:
# Initialize the learner PPO policy (generator)
learner0 = PPO(
    env=venv0,
    policy=MlpPolicy,
    policy_kwargs=policy_kwargs,
    learning_rate=learning_rate,
    n_steps=n_steps,
    batch_size=batch_size,
    n_epochs=n_epochs,
    gamma=gamma,
    gae_lambda=gae_lambda,
    clip_range=clip_range,
    clip_range_vf=clip_range_vf,
    normalize_advantage=normalize_advantage,
    ent_coef=ent_coef,
    vf_coef=vf_coef,
    max_grad_norm=max_grad_norm,
    use_sde=use_sde,
    sde_sample_freq=sde_sample_freq,
    verbose=verbose,
    seed=42,
    device="mps",
)

reward_net0 = BasicShapedRewardNet(
    observation_space=venv0.observation_space,
    action_space=venv0.action_space,
    normalize_input_layer=RunningNorm,
)

# Initialize the AIRL trainer
airl_trainer0 = AIRL(
    demonstrations=trajectories_0,
    demo_batch_size=demo_batch_size,
    demo_minibatch_size=demo_minibatch_size,
    n_disc_updates_per_round=n_disc_updates_per_round,
    gen_train_timesteps=gen_train_timesteps,
    gen_replay_buffer_capacity=gen_replay_buffer_capacity,
    venv=venv0,
    gen_algo=learner0,
    reward_net=reward_net0,
    allow_variable_horizon=allow_variable_horizon,
    disc_opt_kwargs=disc_opt_kwargs,
    custom_logger=hier_logger,
)

Running with `allow_variable_horizon` set to True. Some algorithms are biased towards shorter or longer episodes, which may significantly confound results. Additionally, even unbiased algorithms can exploit the information leak from the termination condition, producing spuriously high performance. See https://imitation.readthedocs.io/en/latest/getting-started/variable-horizon.html for more information.


In [34]:
# Initialize the learner PPO policy (generator)
learner1 = PPO(
    env=venv1,
    policy=MlpPolicy,
    policy_kwargs=policy_kwargs,
    learning_rate=learning_rate,
    n_steps=n_steps,
    batch_size=batch_size,
    n_epochs=n_epochs,
    gamma=gamma,
    gae_lambda=gae_lambda,
    clip_range=clip_range,
    clip_range_vf=clip_range_vf,
    normalize_advantage=normalize_advantage,
    ent_coef=ent_coef,
    vf_coef=vf_coef,
    max_grad_norm=max_grad_norm,
    use_sde=use_sde,
    sde_sample_freq=sde_sample_freq,
    verbose=verbose,
    seed=42,
    device="mps",
)

reward_net1 = BasicShapedRewardNet(
    observation_space=venv1.observation_space,
    action_space=venv1.action_space,
    normalize_input_layer=RunningNorm,
)

# Initialize the AIRL trainer
airl_trainer1 = AIRL(
    demonstrations=trajectories_1,
    demo_batch_size=demo_batch_size,
    demo_minibatch_size=demo_minibatch_size,
    n_disc_updates_per_round=n_disc_updates_per_round,
    gen_train_timesteps=gen_train_timesteps,
    gen_replay_buffer_capacity=gen_replay_buffer_capacity,
    venv=venv1,
    gen_algo=learner1,
    reward_net=reward_net1,
    allow_variable_horizon=allow_variable_horizon,
    disc_opt_kwargs=disc_opt_kwargs,
    custom_logger=hier_logger,
)

Running with `allow_variable_horizon` set to True. Some algorithms are biased towards shorter or longer episodes, which may significantly confound results. Additionally, even unbiased algorithms can exploit the information leak from the termination condition, producing spuriously high performance. See https://imitation.readthedocs.io/en/latest/getting-started/variable-horizon.html for more information.


## Training AIRL discriminator and generator, stats are saved with mlflow

We need to train 2 distinct airl trainers, one for arbitrage transactions and the other for class 0 transactions. The goal is to use the resulting reward functions to do classification.

In [36]:
mlflow.set_experiment("AIRLv2 DGI pre-trained")

with mlflow.start_run():
    mlflow.log_param("n_steps", n_steps)
    mlflow.log_param("batch_size", batch_size)
    mlflow.log_param("total_timesteps", total_timesteps)

    airl_trainer1.train(total_timesteps=total_timesteps)

    learner1.save(config.MODELS_DIR / "learner1v2")
    torch.save(reward_net1, config.MODELS_DIR / "reward_net1v2")

    mlflow.log_artifact(config.MODELS_DIR / "learner1v2.zip")
    mlflow.log_artifact(config.MODELS_DIR / "reward_net1v2")
    mlflow.end_run()

2025/05/04 10:43:40 INFO mlflow.tracking.fluent: Experiment with name 'AIRLv2 DGI pre-trained' does not exist. Creating a new experiment.
round:   0%|          | 0/25 [00:00<?, ?it/s]

------------------------------------------
| raw/                        |          |
|    gen/rollout/ep_len_mean  | 22.7     |
|    gen/rollout/ep_rew_mean  | 10.8     |
|    gen/time/fps             | 26       |
|    gen/time/iterations      | 1        |
|    gen/time/time_elapsed    | 115      |
|    gen/time/total_timesteps | 3138     |
------------------------------------------
--------------------------------------------------
| raw/                                |          |
|    disc/disc_acc                    | 0.5      |
|    disc/disc_acc_expert             | 1        |
|    disc/disc_acc_gen                | 0        |
|    disc/disc_entropy                | 0.594    |
|    disc/disc_loss                   | 0.0159   |
|    disc/disc_proportion_expert_pred | 1        |
|    disc/disc_proportion_expert_true | 0.5      |
|    disc/global_step                 | 1        |
|    disc/n_expert                    | 60       |
|    disc/n_generated                 | 60       |
-

round:   4%|▍         | 1/25 [02:04<49:45, 124.42s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 25.1         |
|    gen/rollout/ep_rew_mean         | 11.8         |
|    gen/rollout/ep_rew_wrapped_mean | 21.1         |
|    gen/time/fps                    | 30           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 98           |
|    gen/time/total_timesteps        | 6138         |
|    gen/train/approx_kl             | 0.0023807904 |
|    gen/train/clip_fraction         | 0.0849       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.691       |
|    gen/train/explained_variance    | 0.0273       |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.169        |
|    gen/train/n_updates             | 15           |
|    gen/train/policy_gradient_loss  | -0.00124     |
|    gen/train/value_loss   

round:   8%|▊         | 2/25 [03:50<43:31, 113.55s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 25.4         |
|    gen/rollout/ep_rew_mean         | 11.9         |
|    gen/rollout/ep_rew_wrapped_mean | -3.99        |
|    gen/time/fps                    | 31           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 96           |
|    gen/time/total_timesteps        | 9138         |
|    gen/train/approx_kl             | 0.0015264143 |
|    gen/train/clip_fraction         | 0.0935       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.688       |
|    gen/train/explained_variance    | -1.11        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.051        |
|    gen/train/n_updates             | 30           |
|    gen/train/policy_gradient_loss  | 0.000637     |
|    gen/train/value_loss   

round:  12%|█▏        | 3/25 [05:34<39:59, 109.08s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 27.3         |
|    gen/rollout/ep_rew_mean         | 13.4         |
|    gen/rollout/ep_rew_wrapped_mean | -11.5        |
|    gen/time/fps                    | 33           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 90           |
|    gen/time/total_timesteps        | 12138        |
|    gen/train/approx_kl             | 0.0037156004 |
|    gen/train/clip_fraction         | 0.161        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.69        |
|    gen/train/explained_variance    | 0.509        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.0136       |
|    gen/train/n_updates             | 45           |
|    gen/train/policy_gradient_loss  | -0.00407     |
|    gen/train/value_loss   

round:  16%|█▌        | 4/25 [07:12<36:44, 104.96s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 26.8         |
|    gen/rollout/ep_rew_mean         | 13.2         |
|    gen/rollout/ep_rew_wrapped_mean | -23.8        |
|    gen/time/fps                    | 30           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 99           |
|    gen/time/total_timesteps        | 15138        |
|    gen/train/approx_kl             | 0.0036581145 |
|    gen/train/clip_fraction         | 0.166        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.682       |
|    gen/train/explained_variance    | 0.709        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.0382       |
|    gen/train/n_updates             | 60           |
|    gen/train/policy_gradient_loss  | -0.00656     |
|    gen/train/value_loss   

round:  20%|██        | 5/25 [09:00<35:20, 106.05s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 30.6         |
|    gen/rollout/ep_rew_mean         | 15.9         |
|    gen/rollout/ep_rew_wrapped_mean | -33.6        |
|    gen/time/fps                    | 29           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 101          |
|    gen/time/total_timesteps        | 18138        |
|    gen/train/approx_kl             | 0.0030810514 |
|    gen/train/clip_fraction         | 0.126        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.67        |
|    gen/train/explained_variance    | 0.861        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.317        |
|    gen/train/n_updates             | 75           |
|    gen/train/policy_gradient_loss  | -0.00326     |
|    gen/train/value_loss   

round:  24%|██▍       | 6/25 [10:49<33:54, 107.07s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 28.9         |
|    gen/rollout/ep_rew_mean         | 15.2         |
|    gen/rollout/ep_rew_wrapped_mean | -46.5        |
|    gen/time/fps                    | 31           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 94           |
|    gen/time/total_timesteps        | 21138        |
|    gen/train/approx_kl             | 0.0019803063 |
|    gen/train/clip_fraction         | 0.143        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.649       |
|    gen/train/explained_variance    | 0.879        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.496        |
|    gen/train/n_updates             | 90           |
|    gen/train/policy_gradient_loss  | -0.00466     |
|    gen/train/value_loss   

round:  28%|██▊       | 7/25 [12:31<31:33, 105.18s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 33           |
|    gen/rollout/ep_rew_mean         | 17.7         |
|    gen/rollout/ep_rew_wrapped_mean | -83.3        |
|    gen/time/fps                    | 37           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 80           |
|    gen/time/total_timesteps        | 24138        |
|    gen/train/approx_kl             | 0.0020836259 |
|    gen/train/clip_fraction         | 0.0943       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.627       |
|    gen/train/explained_variance    | 0.736        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.618        |
|    gen/train/n_updates             | 105          |
|    gen/train/policy_gradient_loss  | -0.0034      |
|    gen/train/value_loss   

round:  32%|███▏      | 8/25 [13:58<28:12, 99.58s/it] 

---------------------------------------------------
| raw/                               |            |
|    gen/rollout/ep_len_mean         | 16         |
|    gen/rollout/ep_rew_mean         | 8.8        |
|    gen/rollout/ep_rew_wrapped_mean | -83.3      |
|    gen/time/fps                    | 34         |
|    gen/time/iterations             | 1          |
|    gen/time/time_elapsed           | 87         |
|    gen/time/total_timesteps        | 27138      |
|    gen/train/approx_kl             | 0.00206825 |
|    gen/train/clip_fraction         | 0.0654     |
|    gen/train/clip_range            | 0.1        |
|    gen/train/entropy_loss          | -0.606     |
|    gen/train/explained_variance    | 0.794      |
|    gen/train/learning_rate         | 0.001      |
|    gen/train/loss                  | 1.98       |
|    gen/train/n_updates             | 120        |
|    gen/train/policy_gradient_loss  | -0.00139   |
|    gen/train/value_loss            | 4.08       |
------------

round:  36%|███▌      | 9/25 [15:33<26:08, 98.01s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 32.5         |
|    gen/rollout/ep_rew_mean         | 19           |
|    gen/rollout/ep_rew_wrapped_mean | -45.7        |
|    gen/time/fps                    | 30           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 97           |
|    gen/time/total_timesteps        | 30138        |
|    gen/train/approx_kl             | 0.0021440426 |
|    gen/train/clip_fraction         | 0.055        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.593       |
|    gen/train/explained_variance    | 0.795        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 2.39         |
|    gen/train/n_updates             | 135          |
|    gen/train/policy_gradient_loss  | -0.00224     |
|    gen/train/value_loss   

round:  40%|████      | 10/25 [17:17<25:00, 100.06s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 18           |
|    gen/rollout/ep_rew_mean         | 10.8         |
|    gen/rollout/ep_rew_wrapped_mean | -89.2        |
|    gen/time/fps                    | 33           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 89           |
|    gen/time/total_timesteps        | 33138        |
|    gen/train/approx_kl             | 0.0010947405 |
|    gen/train/clip_fraction         | 0.048        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.583       |
|    gen/train/explained_variance    | 0.76         |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.982        |
|    gen/train/n_updates             | 150          |
|    gen/train/policy_gradient_loss  | -0.00127     |
|    gen/train/value_loss   

round:  44%|████▍     | 11/25 [18:55<23:10, 99.30s/it] 

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 32.4         |
|    gen/rollout/ep_rew_mean         | 19.9         |
|    gen/rollout/ep_rew_wrapped_mean | -51          |
|    gen/time/fps                    | 31           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 95           |
|    gen/time/total_timesteps        | 36138        |
|    gen/train/approx_kl             | 0.0015950074 |
|    gen/train/clip_fraction         | 0.0845       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.562       |
|    gen/train/explained_variance    | 0.828        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 1.58         |
|    gen/train/n_updates             | 165          |
|    gen/train/policy_gradient_loss  | -0.00246     |
|    gen/train/value_loss   

round:  48%|████▊     | 12/25 [20:39<21:50, 100.78s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 22.5         |
|    gen/rollout/ep_rew_mean         | 14           |
|    gen/rollout/ep_rew_wrapped_mean | -141         |
|    gen/time/fps                    | 31           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 94           |
|    gen/time/total_timesteps        | 39138        |
|    gen/train/approx_kl             | 0.0016483116 |
|    gen/train/clip_fraction         | 0.0797       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.527       |
|    gen/train/explained_variance    | 0.733        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 4.71         |
|    gen/train/n_updates             | 180          |
|    gen/train/policy_gradient_loss  | -0.0022      |
|    gen/train/value_loss   

round:  52%|█████▏    | 13/25 [22:21<20:13, 101.10s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 29.8         |
|    gen/rollout/ep_rew_mean         | 19.3         |
|    gen/rollout/ep_rew_wrapped_mean | -84.9        |
|    gen/time/fps                    | 33           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 90           |
|    gen/time/total_timesteps        | 42138        |
|    gen/train/approx_kl             | 0.0013329561 |
|    gen/train/clip_fraction         | 0.101        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.506       |
|    gen/train/explained_variance    | 0.834        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 1.83         |
|    gen/train/n_updates             | 195          |
|    gen/train/policy_gradient_loss  | -0.00233     |
|    gen/train/value_loss   

round:  56%|█████▌    | 14/25 [23:58<18:19, 99.94s/it] 

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 37.4         |
|    gen/rollout/ep_rew_mean         | 24.6         |
|    gen/rollout/ep_rew_wrapped_mean | -124         |
|    gen/time/fps                    | 36           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 82           |
|    gen/time/total_timesteps        | 45138        |
|    gen/train/approx_kl             | 0.0014231693 |
|    gen/train/clip_fraction         | 0.0666       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.494       |
|    gen/train/explained_variance    | 0.852        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 3.39         |
|    gen/train/n_updates             | 210          |
|    gen/train/policy_gradient_loss  | -0.00115     |
|    gen/train/value_loss   

round:  60%|██████    | 15/25 [25:27<16:06, 96.69s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 30.8         |
|    gen/rollout/ep_rew_mean         | 20.5         |
|    gen/rollout/ep_rew_wrapped_mean | -176         |
|    gen/time/fps                    | 35           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 84           |
|    gen/time/total_timesteps        | 48138        |
|    gen/train/approx_kl             | 0.0009096148 |
|    gen/train/clip_fraction         | 0.0564       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.493       |
|    gen/train/explained_variance    | 0.802        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 3.84         |
|    gen/train/n_updates             | 225          |
|    gen/train/policy_gradient_loss  | -0.000609    |
|    gen/train/value_loss   

round:  64%|██████▍   | 16/25 [26:59<14:17, 95.30s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 24.3         |
|    gen/rollout/ep_rew_mean         | 15.8         |
|    gen/rollout/ep_rew_wrapped_mean | -159         |
|    gen/time/fps                    | 33           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 88           |
|    gen/time/total_timesteps        | 51138        |
|    gen/train/approx_kl             | 0.0011952106 |
|    gen/train/clip_fraction         | 0.0576       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.481       |
|    gen/train/explained_variance    | 0.884        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 24.6         |
|    gen/train/n_updates             | 240          |
|    gen/train/policy_gradient_loss  | -0.0016      |
|    gen/train/value_loss   

round:  68%|██████▊   | 17/25 [28:35<12:43, 95.39s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 31.6         |
|    gen/rollout/ep_rew_mean         | 20.8         |
|    gen/rollout/ep_rew_wrapped_mean | -99.4        |
|    gen/time/fps                    | 35           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 85           |
|    gen/time/total_timesteps        | 54138        |
|    gen/train/approx_kl             | 0.0009931062 |
|    gen/train/clip_fraction         | 0.0381       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.497       |
|    gen/train/explained_variance    | 0.789        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 11           |
|    gen/train/n_updates             | 255          |
|    gen/train/policy_gradient_loss  | -0.000286    |
|    gen/train/value_loss   

round:  72%|███████▏  | 18/25 [30:08<11:02, 94.62s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 33.4         |
|    gen/rollout/ep_rew_mean         | 22.3         |
|    gen/rollout/ep_rew_wrapped_mean | -190         |
|    gen/time/fps                    | 29           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 100          |
|    gen/time/total_timesteps        | 57138        |
|    gen/train/approx_kl             | 0.0010046152 |
|    gen/train/clip_fraction         | 0.0336       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.459       |
|    gen/train/explained_variance    | 0.817        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 4.35         |
|    gen/train/n_updates             | 270          |
|    gen/train/policy_gradient_loss  | -0.000725    |
|    gen/train/value_loss   

round:  76%|███████▌  | 19/25 [31:56<09:52, 98.69s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 24.5         |
|    gen/rollout/ep_rew_mean         | 15.9         |
|    gen/rollout/ep_rew_wrapped_mean | -184         |
|    gen/time/fps                    | 31           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 94           |
|    gen/time/total_timesteps        | 60138        |
|    gen/train/approx_kl             | 0.0013255443 |
|    gen/train/clip_fraction         | 0.09         |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.469       |
|    gen/train/explained_variance    | 0.782        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 5.76         |
|    gen/train/n_updates             | 285          |
|    gen/train/policy_gradient_loss  | 0.00172      |
|    gen/train/value_loss   

round:  80%|████████  | 20/25 [33:38<08:18, 99.71s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 25.2         |
|    gen/rollout/ep_rew_mean         | 17.3         |
|    gen/rollout/ep_rew_wrapped_mean | -120         |
|    gen/time/fps                    | 29           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 100          |
|    gen/time/total_timesteps        | 63138        |
|    gen/train/approx_kl             | 0.0024900339 |
|    gen/train/clip_fraction         | 0.101        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.475       |
|    gen/train/explained_variance    | 0.792        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 3.55         |
|    gen/train/n_updates             | 300          |
|    gen/train/policy_gradient_loss  | -0.00182     |
|    gen/train/value_loss   

round:  84%|████████▍ | 21/25 [35:29<06:52, 103.13s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 30.1         |
|    gen/rollout/ep_rew_mean         | 21.5         |
|    gen/rollout/ep_rew_wrapped_mean | -136         |
|    gen/time/fps                    | 28           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 105          |
|    gen/time/total_timesteps        | 66138        |
|    gen/train/approx_kl             | 0.0012305041 |
|    gen/train/clip_fraction         | 0.0929       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.438       |
|    gen/train/explained_variance    | 0.818        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 7.12         |
|    gen/train/n_updates             | 315          |
|    gen/train/policy_gradient_loss  | -0.000361    |
|    gen/train/value_loss   

round:  88%|████████▊ | 22/25 [37:23<05:18, 106.24s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 31.6         |
|    gen/rollout/ep_rew_mean         | 23.5         |
|    gen/rollout/ep_rew_wrapped_mean | -180         |
|    gen/time/fps                    | 29           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 100          |
|    gen/time/total_timesteps        | 69138        |
|    gen/train/approx_kl             | 0.0017438586 |
|    gen/train/clip_fraction         | 0.067        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.416       |
|    gen/train/explained_variance    | 0.807        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 5.39         |
|    gen/train/n_updates             | 330          |
|    gen/train/policy_gradient_loss  | 0.000593     |
|    gen/train/value_loss   

round:  92%|█████████▏| 23/25 [39:11<03:33, 106.93s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 38.8         |
|    gen/rollout/ep_rew_mean         | 28           |
|    gen/rollout/ep_rew_wrapped_mean | -223         |
|    gen/time/fps                    | 34           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 86           |
|    gen/time/total_timesteps        | 72138        |
|    gen/train/approx_kl             | 0.0009967635 |
|    gen/train/clip_fraction         | 0.0455       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.389       |
|    gen/train/explained_variance    | 0.812        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 15           |
|    gen/train/n_updates             | 345          |
|    gen/train/policy_gradient_loss  | 8.34e-05     |
|    gen/train/value_loss   

round:  96%|█████████▌| 24/25 [40:45<01:43, 103.02s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 20.6         |
|    gen/rollout/ep_rew_mean         | 15           |
|    gen/rollout/ep_rew_wrapped_mean | -221         |
|    gen/time/fps                    | 36           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 81           |
|    gen/time/total_timesteps        | 75138        |
|    gen/train/approx_kl             | 0.0034892184 |
|    gen/train/clip_fraction         | 0.115        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.433       |
|    gen/train/explained_variance    | 0.823        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 7.38         |
|    gen/train/n_updates             | 360          |
|    gen/train/policy_gradient_loss  | 0.000199     |
|    gen/train/value_loss   

round: 100%|██████████| 25/25 [42:14<00:00, 101.37s/it]

🏃 View run awesome-cat-250 at: http://127.0.0.1:8080/#/experiments/720343879195193287/runs/8fab1f13b9774d0d94ec16a0c3b94076
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/720343879195193287





In [37]:
mlflow.set_experiment("AIRLv2 DGI pre-trained")
with mlflow.start_run():
    mlflow.log_param("n_steps", n_steps)
    mlflow.log_param("batch_size", batch_size)
    mlflow.log_param("total_timesteps", total_timesteps)

    airl_trainer0.train(total_timesteps=total_timesteps)

    learner0.save(config.MODELS_DIR / "learner0v2")
    torch.save(reward_net0, config.MODELS_DIR / "reward_net0v2")

    mlflow.log_artifact(config.MODELS_DIR / "learner0v2.zip")
    mlflow.log_artifact(config.MODELS_DIR / "reward_net0v2")
    mlflow.end_run()

round:   0%|          | 0/25 [00:00<?, ?it/s]

----------------------------------------------------
| raw/                              |              |
|    gen/rollout/ep_len_mean        | 6.75         |
|    gen/rollout/ep_rew_mean        | 2.8          |
|    gen/time/fps                   | 54           |
|    gen/time/iterations            | 1            |
|    gen/time/time_elapsed          | 54           |
|    gen/time/total_timesteps       | 3000         |
|    gen/train/approx_kl            | 0.0015636042 |
|    gen/train/clip_fraction        | 0.0589       |
|    gen/train/clip_range           | 0.1          |
|    gen/train/entropy_loss         | -0.397       |
|    gen/train/explained_variance   | 0.819        |
|    gen/train/learning_rate        | 0.001        |
|    gen/train/loss                 | 18.3         |
|    gen/train/n_updates            | 375          |
|    gen/train/policy_gradient_loss | 0.00163      |
|    gen/train/value_loss           | 29.1         |
----------------------------------------------

round:   4%|▍         | 1/25 [01:02<24:50, 62.11s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 5.63         |
|    gen/rollout/ep_rew_mean         | 2.45         |
|    gen/rollout/ep_rew_wrapped_mean | 9.55         |
|    gen/time/fps                    | 32           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 91           |
|    gen/time/total_timesteps        | 6000         |
|    gen/train/approx_kl             | 0.0015388231 |
|    gen/train/clip_fraction         | 0.078        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.691       |
|    gen/train/explained_variance    | -0.074       |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.133        |
|    gen/train/n_updates             | 15           |
|    gen/train/policy_gradient_loss  | -2.62e-05    |
|    gen/train/value_loss   

round:   8%|▊         | 2/25 [02:40<32:05, 83.71s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 4.78         |
|    gen/rollout/ep_rew_mean         | 1.96         |
|    gen/rollout/ep_rew_wrapped_mean | -0.786       |
|    gen/time/fps                    | 56           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 52           |
|    gen/time/total_timesteps        | 9000         |
|    gen/train/approx_kl             | 0.0021245626 |
|    gen/train/clip_fraction         | 0.0949       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.685       |
|    gen/train/explained_variance    | -3.95        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.149        |
|    gen/train/n_updates             | 30           |
|    gen/train/policy_gradient_loss  | -0.00053     |
|    gen/train/value_loss   

round:  12%|█▏        | 3/25 [03:41<26:44, 72.94s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 2.89        |
|    gen/rollout/ep_rew_mean         | 0.93        |
|    gen/rollout/ep_rew_wrapped_mean | -1.59       |
|    gen/time/fps                    | 36          |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 82          |
|    gen/time/total_timesteps        | 12000       |
|    gen/train/approx_kl             | 0.002131009 |
|    gen/train/clip_fraction         | 0.118       |
|    gen/train/clip_range            | 0.1         |
|    gen/train/entropy_loss          | -0.669      |
|    gen/train/explained_variance    | 0.156       |
|    gen/train/learning_rate         | 0.001       |
|    gen/train/loss                  | 0.0555      |
|    gen/train/n_updates             | 45          |
|    gen/train/policy_gradient_loss  | -0.00265    |
|    gen/train/value_loss            | 1.06   

round:  16%|█▌        | 4/25 [05:10<27:50, 79.57s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 5.76         |
|    gen/rollout/ep_rew_mean         | 2.45         |
|    gen/rollout/ep_rew_wrapped_mean | -0.394       |
|    gen/time/fps                    | 53           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 56           |
|    gen/time/total_timesteps        | 15000        |
|    gen/train/approx_kl             | 0.0005364566 |
|    gen/train/clip_fraction         | 0.0407       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.677       |
|    gen/train/explained_variance    | 0.244        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 2.22         |
|    gen/train/n_updates             | 60           |
|    gen/train/policy_gradient_loss  | -0.000465    |
|    gen/train/value_loss   

round:  20%|██        | 5/25 [06:14<24:37, 73.88s/it]

---------------------------------------------------
| raw/                               |            |
|    gen/rollout/ep_len_mean         | 3.19       |
|    gen/rollout/ep_rew_mean         | 1.13       |
|    gen/rollout/ep_rew_wrapped_mean | -5.01      |
|    gen/time/fps                    | 58         |
|    gen/time/iterations             | 1          |
|    gen/time/time_elapsed           | 51         |
|    gen/time/total_timesteps        | 18000      |
|    gen/train/approx_kl             | 0.00228062 |
|    gen/train/clip_fraction         | 0.125      |
|    gen/train/clip_range            | 0.1        |
|    gen/train/entropy_loss          | -0.678     |
|    gen/train/explained_variance    | 0.559      |
|    gen/train/learning_rate         | 0.001      |
|    gen/train/loss                  | 0.213      |
|    gen/train/n_updates             | 75         |
|    gen/train/policy_gradient_loss  | -0.00296   |
|    gen/train/value_loss            | 1.83       |
------------

round:  24%|██▍       | 6/25 [07:13<21:46, 68.74s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 5.92         |
|    gen/rollout/ep_rew_mean         | 2.49         |
|    gen/rollout/ep_rew_wrapped_mean | -2.09        |
|    gen/time/fps                    | 59           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 50           |
|    gen/time/total_timesteps        | 21000        |
|    gen/train/approx_kl             | 0.0013951008 |
|    gen/train/clip_fraction         | 0.0858       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.67        |
|    gen/train/explained_variance    | 0.47         |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 0.72         |
|    gen/train/n_updates             | 90           |
|    gen/train/policy_gradient_loss  | -0.00122     |
|    gen/train/value_loss   

round:  28%|██▊       | 7/25 [08:10<19:30, 65.05s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 4.03         |
|    gen/rollout/ep_rew_mean         | 1.59         |
|    gen/rollout/ep_rew_wrapped_mean | -4.57        |
|    gen/time/fps                    | 52           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 56           |
|    gen/time/total_timesteps        | 24000        |
|    gen/train/approx_kl             | 0.0016279512 |
|    gen/train/clip_fraction         | 0.0141       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.661       |
|    gen/train/explained_variance    | 0.238        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 4.35         |
|    gen/train/n_updates             | 105          |
|    gen/train/policy_gradient_loss  | -0.00179     |
|    gen/train/value_loss   

round:  32%|███▏      | 8/25 [09:14<18:19, 64.69s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 4.28         |
|    gen/rollout/ep_rew_mean         | 1.82         |
|    gen/rollout/ep_rew_wrapped_mean | -4.97        |
|    gen/time/fps                    | 61           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 48           |
|    gen/time/total_timesteps        | 27000        |
|    gen/train/approx_kl             | 0.0011468899 |
|    gen/train/clip_fraction         | 0.0561       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.655       |
|    gen/train/explained_variance    | 0.558        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 1.45         |
|    gen/train/n_updates             | 120          |
|    gen/train/policy_gradient_loss  | -0.000561    |
|    gen/train/value_loss   

round:  36%|███▌      | 9/25 [10:10<16:30, 61.90s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 4.48         |
|    gen/rollout/ep_rew_mean         | 1.98         |
|    gen/rollout/ep_rew_wrapped_mean | -1.92        |
|    gen/time/fps                    | 61           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 48           |
|    gen/time/total_timesteps        | 30000        |
|    gen/train/approx_kl             | 0.0012943759 |
|    gen/train/clip_fraction         | 0.0839       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.643       |
|    gen/train/explained_variance    | 0.511        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 4.29         |
|    gen/train/n_updates             | 135          |
|    gen/train/policy_gradient_loss  | -0.00311     |
|    gen/train/value_loss   

round:  40%|████      | 10/25 [11:06<15:00, 60.04s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 5.69         |
|    gen/rollout/ep_rew_mean         | 2.54         |
|    gen/rollout/ep_rew_wrapped_mean | -11.2        |
|    gen/time/fps                    | 55           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 53           |
|    gen/time/total_timesteps        | 33000        |
|    gen/train/approx_kl             | 0.0017855392 |
|    gen/train/clip_fraction         | 0.0786       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.626       |
|    gen/train/explained_variance    | 0.611        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 9.66         |
|    gen/train/n_updates             | 150          |
|    gen/train/policy_gradient_loss  | -0.00242     |
|    gen/train/value_loss   

round:  44%|████▍     | 11/25 [12:07<14:04, 60.30s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 4.4          |
|    gen/rollout/ep_rew_mean         | 2.02         |
|    gen/rollout/ep_rew_wrapped_mean | -10.3        |
|    gen/time/fps                    | 59           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 50           |
|    gen/time/total_timesteps        | 36000        |
|    gen/train/approx_kl             | 0.0031200491 |
|    gen/train/clip_fraction         | 0.114        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.623       |
|    gen/train/explained_variance    | 0.78         |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 2.86         |
|    gen/train/n_updates             | 165          |
|    gen/train/policy_gradient_loss  | 0.000331     |
|    gen/train/value_loss   

round:  48%|████▊     | 12/25 [13:05<12:54, 59.54s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 5.4          |
|    gen/rollout/ep_rew_mean         | 2.69         |
|    gen/rollout/ep_rew_wrapped_mean | -13.6        |
|    gen/time/fps                    | 52           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 56           |
|    gen/time/total_timesteps        | 39000        |
|    gen/train/approx_kl             | 0.0017510633 |
|    gen/train/clip_fraction         | 0.0958       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.601       |
|    gen/train/explained_variance    | 0.637        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 3.99         |
|    gen/train/n_updates             | 180          |
|    gen/train/policy_gradient_loss  | -0.00137     |
|    gen/train/value_loss   

round:  52%|█████▏    | 13/25 [14:09<12:10, 60.91s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 3.39         |
|    gen/rollout/ep_rew_mean         | 1.36         |
|    gen/rollout/ep_rew_wrapped_mean | -13.8        |
|    gen/time/fps                    | 57           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 52           |
|    gen/time/total_timesteps        | 42000        |
|    gen/train/approx_kl             | 0.0016283828 |
|    gen/train/clip_fraction         | 0.122        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.604       |
|    gen/train/explained_variance    | 0.559        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 41.1         |
|    gen/train/n_updates             | 195          |
|    gen/train/policy_gradient_loss  | -0.000379    |
|    gen/train/value_loss   

round:  56%|█████▌    | 14/25 [15:08<11:05, 60.50s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 8.01        |
|    gen/rollout/ep_rew_mean         | 3.89        |
|    gen/rollout/ep_rew_wrapped_mean | -9.9        |
|    gen/time/fps                    | 56          |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 53          |
|    gen/time/total_timesteps        | 45000       |
|    gen/train/approx_kl             | 0.000619154 |
|    gen/train/clip_fraction         | 0.0291      |
|    gen/train/clip_range            | 0.1         |
|    gen/train/entropy_loss          | -0.614      |
|    gen/train/explained_variance    | 0.65        |
|    gen/train/learning_rate         | 0.001       |
|    gen/train/loss                  | 7.95        |
|    gen/train/n_updates             | 210         |
|    gen/train/policy_gradient_loss  | -0.000437   |
|    gen/train/value_loss            | 29.3   

round:  60%|██████    | 15/25 [16:09<10:05, 60.54s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 5.15         |
|    gen/rollout/ep_rew_mean         | 2.3          |
|    gen/rollout/ep_rew_wrapped_mean | -9.99        |
|    gen/time/fps                    | 60           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 49           |
|    gen/time/total_timesteps        | 48000        |
|    gen/train/approx_kl             | 0.0014081947 |
|    gen/train/clip_fraction         | 0.117        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.622       |
|    gen/train/explained_variance    | 0.627        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 4.14         |
|    gen/train/n_updates             | 225          |
|    gen/train/policy_gradient_loss  | -0.000601    |
|    gen/train/value_loss   

round:  64%|██████▍   | 16/25 [17:06<08:55, 59.49s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 5.99         |
|    gen/rollout/ep_rew_mean         | 2.85         |
|    gen/rollout/ep_rew_wrapped_mean | -8.28        |
|    gen/time/fps                    | 61           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 48           |
|    gen/time/total_timesteps        | 51000        |
|    gen/train/approx_kl             | 0.0023374856 |
|    gen/train/clip_fraction         | 0.134        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.612       |
|    gen/train/explained_variance    | 0.564        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 12.1         |
|    gen/train/n_updates             | 240          |
|    gen/train/policy_gradient_loss  | -0.0013      |
|    gen/train/value_loss   

round:  68%|██████▊   | 17/25 [18:02<07:47, 58.38s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 8.42         |
|    gen/rollout/ep_rew_mean         | 4.29         |
|    gen/rollout/ep_rew_wrapped_mean | -25.7        |
|    gen/time/fps                    | 63           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 47           |
|    gen/time/total_timesteps        | 54000        |
|    gen/train/approx_kl             | 0.0035403073 |
|    gen/train/clip_fraction         | 0.105        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.588       |
|    gen/train/explained_variance    | 0.434        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 2.63         |
|    gen/train/n_updates             | 255          |
|    gen/train/policy_gradient_loss  | -0.00237     |
|    gen/train/value_loss   

round:  72%|███████▏  | 18/25 [18:56<06:40, 57.26s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 6.44         |
|    gen/rollout/ep_rew_mean         | 3.25         |
|    gen/rollout/ep_rew_wrapped_mean | -15.9        |
|    gen/time/fps                    | 57           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 52           |
|    gen/time/total_timesteps        | 57000        |
|    gen/train/approx_kl             | 0.0032291093 |
|    gen/train/clip_fraction         | 0.135        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.58        |
|    gen/train/explained_variance    | 0.559        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 5.12         |
|    gen/train/n_updates             | 270          |
|    gen/train/policy_gradient_loss  | -0.000434    |
|    gen/train/value_loss   

round:  76%|███████▌  | 19/25 [19:55<05:47, 57.84s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 5.94         |
|    gen/rollout/ep_rew_mean         | 2.9          |
|    gen/rollout/ep_rew_wrapped_mean | -24.7        |
|    gen/time/fps                    | 61           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 48           |
|    gen/time/total_timesteps        | 60000        |
|    gen/train/approx_kl             | 0.0013520322 |
|    gen/train/clip_fraction         | 0.11         |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.583       |
|    gen/train/explained_variance    | 0.305        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 12.9         |
|    gen/train/n_updates             | 285          |
|    gen/train/policy_gradient_loss  | 0.000918     |
|    gen/train/value_loss   

round:  80%|████████  | 20/25 [20:51<04:45, 57.13s/it]

------------------------------------------------------
| raw/                               |               |
|    gen/rollout/ep_len_mean         | 4.42          |
|    gen/rollout/ep_rew_mean         | 2.05          |
|    gen/rollout/ep_rew_wrapped_mean | -20.4         |
|    gen/time/fps                    | 58            |
|    gen/time/iterations             | 1             |
|    gen/time/time_elapsed           | 51            |
|    gen/time/total_timesteps        | 63000         |
|    gen/train/approx_kl             | 0.00096317584 |
|    gen/train/clip_fraction         | 0.0484        |
|    gen/train/clip_range            | 0.1           |
|    gen/train/entropy_loss          | -0.571        |
|    gen/train/explained_variance    | 0.45          |
|    gen/train/learning_rate         | 0.001         |
|    gen/train/loss                  | 25.5          |
|    gen/train/n_updates             | 300           |
|    gen/train/policy_gradient_loss  | -2.89e-05     |
|    gen/t

round:  84%|████████▍ | 21/25 [21:50<03:50, 57.58s/it]

------------------------------------------------------
| raw/                               |               |
|    gen/rollout/ep_len_mean         | 5.83          |
|    gen/rollout/ep_rew_mean         | 2.74          |
|    gen/rollout/ep_rew_wrapped_mean | -22.4         |
|    gen/time/fps                    | 55            |
|    gen/time/iterations             | 1             |
|    gen/time/time_elapsed           | 54            |
|    gen/time/total_timesteps        | 66000         |
|    gen/train/approx_kl             | 0.00066591724 |
|    gen/train/clip_fraction         | 0.0538        |
|    gen/train/clip_range            | 0.1           |
|    gen/train/entropy_loss          | -0.565        |
|    gen/train/explained_variance    | 0.503         |
|    gen/train/learning_rate         | 0.001         |
|    gen/train/loss                  | 68            |
|    gen/train/n_updates             | 315           |
|    gen/train/policy_gradient_loss  | -0.00055      |
|    gen/t

round:  88%|████████▊ | 22/25 [22:51<02:56, 58.71s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 4.97         |
|    gen/rollout/ep_rew_mean         | 2.31         |
|    gen/rollout/ep_rew_wrapped_mean | -23.3        |
|    gen/time/fps                    | 60           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 49           |
|    gen/time/total_timesteps        | 69000        |
|    gen/train/approx_kl             | 0.0024262115 |
|    gen/train/clip_fraction         | 0.101        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.567       |
|    gen/train/explained_variance    | 0.661        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 19.2         |
|    gen/train/n_updates             | 330          |
|    gen/train/policy_gradient_loss  | -0.00169     |
|    gen/train/value_loss   

round:  92%|█████████▏| 23/25 [23:48<01:56, 58.11s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 5.94         |
|    gen/rollout/ep_rew_mean         | 3.19         |
|    gen/rollout/ep_rew_wrapped_mean | -12.1        |
|    gen/time/fps                    | 55           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 53           |
|    gen/time/total_timesteps        | 72000        |
|    gen/train/approx_kl             | 0.0014856467 |
|    gen/train/clip_fraction         | 0.107        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.562       |
|    gen/train/explained_variance    | 0.723        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 53.1         |
|    gen/train/n_updates             | 345          |
|    gen/train/policy_gradient_loss  | -0.00121     |
|    gen/train/value_loss   

round:  96%|█████████▌| 24/25 [24:49<00:58, 58.95s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 7.67         |
|    gen/rollout/ep_rew_mean         | 4.1          |
|    gen/rollout/ep_rew_wrapped_mean | -33.7        |
|    gen/time/fps                    | 59           |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 50           |
|    gen/time/total_timesteps        | 75000        |
|    gen/train/approx_kl             | 0.0024076882 |
|    gen/train/clip_fraction         | 0.0967       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.564       |
|    gen/train/explained_variance    | 0.446        |
|    gen/train/learning_rate         | 0.001        |
|    gen/train/loss                  | 16.8         |
|    gen/train/n_updates             | 360          |
|    gen/train/policy_gradient_loss  | 0.00152      |
|    gen/train/value_loss   

round: 100%|██████████| 25/25 [25:46<00:00, 61.87s/it]

🏃 View run powerful-crane-750 at: http://127.0.0.1:8080/#/experiments/720343879195193287/runs/e091b3e5dd224156a616374786de404a
🧪 View experiment at: http://127.0.0.1:8080/#/experiments/720343879195193287





## Metrics

Download the reward network from mlflow.

In [38]:
local_path = mlflow.artifacts.download_artifacts(
    artifact_uri="mlflow-artifacts:/720343879195193287/e091b3e5dd224156a616374786de404a/artifacts/reward_net0v2"
)

reward_net0 = torch.load(local_path, weights_only=False)

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

In [39]:
local_path = mlflow.artifacts.download_artifacts(
    artifact_uri="mlflow-artifacts:/720343879195193287/8fab1f13b9774d0d94ec16a0c3b94076/artifacts/reward_net1v2"
)

reward_net1 = torch.load(local_path, weights_only=False)

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

### Comparing **uncalibrated** mean rewards

In [40]:
states0, obs0, next_states0, dones0 = (
    trajectories_0_test.obs,
    trajectories_0_test.acts,
    trajectories_0_test.next_obs,
    trajectories_0_test.dones,
)
states1, obs1, next_states1, dones1 = (
    trajectories_1_test.obs,
    trajectories_1_test.acts,
    trajectories_1_test.next_obs,
    trajectories_1_test.dones,
)

In [41]:
# For reward_net1
rewards1 = reward_net1.predict(states0, obs0, next_states0, dones0)
norm_rewards1 = (rewards1 - rewards1.mean()) / rewards1.std()
print("Reward network 1 with traj0: ", norm_rewards1.mean())

# For reward_net0
rewards0 = reward_net0.predict(states0, obs0, next_states0, dones0)
norm_rewards0 = (rewards0 - rewards0.mean()) / rewards0.std()
print("Reward network 0 with traj0: ", norm_rewards0.mean())

Reward network 1 with traj0:  6.010181e-08
Reward network 0 with traj0:  3.1818605e-08


In [42]:
# For reward_net1
rewards1 = reward_net1.predict(states1, obs1, next_states1, dones1)
norm_rewards1 = (rewards1 - rewards1.mean()) / rewards1.std()
print("Reward network 1 with traj1: ", norm_rewards1.mean())

# For reward_net0
rewards0 = reward_net0.predict(states1, obs1, next_states1, dones1)
norm_rewards0 = (rewards0 - rewards0.mean()) / rewards0.std()
print("Reward network 0 with traj1: ", norm_rewards0.mean())

Reward network 1 with traj1:  -9.3858986e-08
Reward network 0 with traj1:  -1.3408426e-08


### Calibration using affine transformations

We use a validation set to calculate alpha and beta parameters for our affine transformation, which is simply
$f(x) = \alpha x + \beta$, where $\alpha = {\sigma_{target}\over\sigma_{x}}$ and $\beta = \mu_{target} - \alpha_x * \mu_x$

In [43]:
states, obs, next_states, dones = (
    trajectories_val.obs,
    trajectories_val.acts,
    trajectories_val.next_obs,
    trajectories_val.dones,
)

In [44]:
outputs_arb = reward_net1.predict(states, obs, next_states, dones)
outputs_nonarb = reward_net0.predict(states, obs, next_states, dones)

# Compute empirical mean and std for each network
mean_arb, std_arb = outputs_arb.mean(), outputs_arb.std()
mean_nonarb, std_nonarb = outputs_nonarb.mean(), outputs_nonarb.std()

# Define target calibration values (e.g., mean=0, std=1)
target_mean, target_std = 0.0, 1.0

# Calculate affine transformation parameters
alpha_arb = target_std / std_arb
beta_arb = target_mean - alpha_arb * mean_arb

alpha_nonarb = target_std / std_nonarb
beta_nonarb = target_mean - alpha_nonarb * mean_nonarb

# Calibrate the outputs
calibrated_arb = outputs_arb * alpha_arb + beta_arb
calibrated_nonarb = outputs_nonarb * alpha_nonarb + beta_nonarb

### Accuracy, Recall and F1-Score for both calibrated reward networks

In [45]:
# For reward_net1
rewards1 = reward_net1.predict(states0, obs0, next_states0, dones0)
calibrated_rewards1 = alpha_arb * rewards1 + beta_arb
print("Reward network 1 with traj0: ", calibrated_rewards1.mean())

# For reward_net0
rewards0 = reward_net0.predict(states0, obs0, next_states0, dones0)
calibrated_rewards0 = alpha_nonarb * rewards0 + beta_nonarb
print("Reward network 0 with traj0: ", calibrated_rewards0.mean())

Reward network 1 with traj0:  -0.21643028
Reward network 0 with traj0:  0.05022597


In [46]:
predictions = np.array(
    [1 if r0 > r1 else 0 for r0, r1 in zip(calibrated_rewards0, calibrated_rewards1)]
)
true_labels = np.array([1] * len(predictions))

accuracy = accuracy_score(true_labels, predictions)
precision = precision_score(true_labels, predictions)
recall = recall_score(true_labels, predictions)
f1 = f1_score(true_labels, predictions)

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)

Accuracy: 0.7152455977757183
Precision: 1.0
Recall: 0.7152455977757183
F1 Score: 0.8339862218019721


In [47]:
# For reward_net1
rewards1 = reward_net1.predict(states1, obs1, next_states1, dones1)
calibrated_rewards1 = alpha_arb * rewards1 + beta_arb
print("Reward network 1 with traj1: ", calibrated_rewards1.mean())

# For reward_net0
rewards0 = reward_net0.predict(states1, obs1, next_states1, dones1)
calibrated_rewards0 = alpha_nonarb * rewards0 + beta_nonarb
print("Reward network 0 with traj1: ", calibrated_rewards0.mean())

Reward network 1 with traj1:  0.39639646
Reward network 0 with traj1:  0.12317644


In [48]:
predictions = np.array(
    [1 if r0 < r1 else 0 for r0, r1 in zip(calibrated_rewards0, calibrated_rewards1)]
)
true_labels = np.array([1] * len(predictions))

accuracy = accuracy_score(true_labels, predictions)
precision = precision_score(true_labels, predictions)
recall = recall_score(true_labels, predictions)
f1 = f1_score(true_labels, predictions)

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)

Accuracy: 0.8791739894551845
Precision: 1.0
Recall: 0.8791739894551845
F1 Score: 0.9357025952770633
