In [1]:
#!pip install -e git+https://github.com/casperbroch/ai-economist@stockmarket#egg=ai-economist

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import time

In [3]:
# Define the configuration of the environment that will be built

env_config_dict = {
    # ===== SCENARIO CLASS =====
    # Which Scenario class to use: the class's name in the Scenario Registry (foundation.scenarios).
    # The environment object will be an instance of the Scenario class.
    'scenario_name': 'stock_market_simulation',
    
    # ===== COMPONENTS =====
    # Which components to use (specified as list of ("component_name", {component_kwargs}) tuples).
    #   "component_name" refers to the Component class's name in the Component Registry (foundation.components)
    #   {component_kwargs} is a dictionary of kwargs passed to the Component class
    # The order in which components reset, step, and generate obs follows their listed order below.
    'components': [
        # (1) Building houses
        ('BuyOrSellStocks', {}),
        # (2) Trading collectible resources
        #('ExecCircuitBreaker', {}),
    ],
    

    # ===== STANDARD ARGUMENTS ======
    # kwargs that are used by every Scenario class (i.e. defined in BaseEnvironment)
    'n_agents': 4,          # Number of non-planner agents (must be > 1)
    'world_size': [1, 1], # [Height, Width] of the env world
    'episode_length': 100, # Number of timesteps per episode
    
    # In multi-action-mode, the policy selects an action for each action subspace (defined in component code).
    # Otherwise, the policy selects only 1 action.
    'multi_action_mode_agents': False,
    'multi_action_mode_planner': False,
    
    # When flattening observations, concatenate scalar & vector observations before output.
    # Otherwise, return observations with minimal processing.
    'flatten_observations': False,
    # When Flattening masks, concatenate each action subspace mask into a single array.
    # Note: flatten_masks = True is required for masking action logits in the code below.
    'flatten_masks': True,
    'dense_log_frequency': 1
}

In [4]:
from rllib.env_wrapper import RLlibEnvWrapper
env_obj = RLlibEnvWrapper({"env_config_dict": env_config_dict}, verbose=True)

Inside covid19_components.py: 0 GPUs are available.
No GPUs found! Running the simulation on a CPU.
Inside covid19_env.py: 0 GPUs are available.
No GPUs found! Running the simulation on a CPU.


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


test std 0.0
[EnvWrapper] Spaces
[EnvWrapper] Obs (a)   
BuyOrSellStocks-stock_price: (1,)
action_mask    : (21,)
time           : (1,)
world-Endogenous-AbleToBuy: (1,)
world-Endogenous-AbleToSell: (1,)
world-Endogenous-AvailableFunds: (1,)
world-Endogenous-Demand: (1,)
world-Endogenous-Labor: (1,)
world-Endogenous-NumberOfStocks: (1,)
world-Endogenous-StockPrice: (1,)
world-Endogenous-StockPriceHistory: (100,)
world-Endogenous-StocksLeft: (1,)
world-Endogenous-Supply: (1,)
world-Endogenous-TotalBalance: (1,)
world-Endogenous-Volumes: (100,)


[EnvWrapper] Obs (p)   
action_mask    : (1,)
time           : (1,)
world-Prices_History: (100,)
world-Total_Demand: (1,)
world-Total_Supply: (1,)
world-Volumes  : (100,)


[EnvWrapper] Action (a) Discrete(21)
[EnvWrapper] Action (p) Discrete(1)


In [5]:
import ray
from ray.rllib.agents.ppo import PPOTrainer

In [6]:
policies = {
    "a": (
        None,  # uses default policy
        env_obj.observation_space,
        env_obj.action_space,
        {}  # define a custom agent policy configuration.
    ),
    "p": (
        None,  # uses default policy
        env_obj.observation_space_pl,
        env_obj.action_space_pl,
        {}  # define a custom planner policy configuration.
    )
}

# In foundation, all the agents have integer ids and the social planner has an id of "p"
policy_mapping_fun = lambda i: "a" if str(i).isdigit() else "p"

policies_to_train = ["a", "p"]

In [7]:
trainer_config = {
    "multiagent": {
        "policies": policies,
        "policies_to_train": policies_to_train,
        "policy_mapping_fn": policy_mapping_fun,
    }
}

In [8]:
trainer_config.update(
    {
        "num_workers": 2,
        "num_envs_per_worker": 2,
        # Other training parameters
        "train_batch_size":  4000,
        "sgd_minibatch_size": 4000,
        "num_sgd_iter": 1
    }
)

In [9]:
# We also add the "num_envs_per_worker" parameter for the env. wrapper to index the environments.
env_config = {
    "env_config_dict": env_config_dict,
    "num_envs_per_worker": trainer_config.get('num_envs_per_worker'),   
}

trainer_config.update(
    {
        "env_config": env_config        
    }
)

In [10]:
# Initialize Ray
ray.init(webui_host="127.0.0.1")

2024-05-03 10:43:43,848	INFO resource_spec.py:212 -- Starting Ray with 4.39 GiB memory available for workers and up to 2.21 GiB for objects. You can adjust these settings with ray.init(memory=<bytes>, object_store_memory=<bytes>).
2024-05-03 10:43:45,335	INFO services.py:1165 -- View the Ray dashboard at [1m[32m127.0.0.1:8265[39m[22m


{'node_ip_address': '192.168.1.41',
 'raylet_ip_address': '192.168.1.41',
 'redis_address': '192.168.1.41:6379',
 'object_store_address': 'tcp://127.0.0.1:64630',
 'raylet_socket_name': 'tcp://127.0.0.1:65499',
 'webui_url': '127.0.0.1:8265',
 'session_dir': 'C:\\Users\\caspe\\AppData\\Local\\Temp\\ray\\session_2024-05-03_10-43-43_841999_19056'}

In [11]:
# Create the PPO trainer.
trainer = PPOTrainer(
    env=RLlibEnvWrapper,
    config=trainer_config,
    )

2024-05-03 10:43:50,638	ERROR syncer.py:46 -- Log sync requires rsync to be installed.
2024-05-03 10:43:50,642	INFO trainer.py:585 -- Tip: set framework=tfe or the --eager flag to enable TensorFlow eager execution
2024-05-03 10:43:50,643	INFO trainer.py:612 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.


test std 0.0
test std 0.0
[2m[36m(pid=18224)[0m   _np_qint8 = np.dtype([("qint8", np.int8, 1)])
[2m[36m(pid=18224)[0m   _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
[2m[36m(pid=18224)[0m   _np_qint16 = np.dtype([("qint16", np.int16, 1)])
[2m[36m(pid=18224)[0m   _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
[2m[36m(pid=18224)[0m   _np_qint32 = np.dtype([("qint32", np.int32, 1)])
[2m[36m(pid=18224)[0m   np_resource = np.dtype([("resource", np.ubyte, 1)])
[2m[36m(pid=8980)[0m   _np_qint8 = np.dtype([("qint8", np.int8, 1)])
[2m[36m(pid=8980)[0m   _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
[2m[36m(pid=8980)[0m   _np_qint16 = np.dtype([("qint16", np.int16, 1)])
[2m[36m(pid=8980)[0m   _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
[2m[36m(pid=8980)[0m   _np_qint32 = np.dtype([("qint32", np.int32, 1)])
[2m[36m(pid=8980)[0m   np_resource = np.dtype([("resource", np.ubyte, 1)])
[2m[36m(pid=18224)[0m   _np_qint8 = np.dtype([("qint8", np.

[2m[36m(pid=18224)[0m Inside covid19_components.py: 0 GPUs are available.
[2m[36m(pid=18224)[0m No GPUs found! Running the simulation on a CPU.
[2m[36m(pid=18224)[0m Inside covid19_env.py: 0 GPUs are available.
[2m[36m(pid=18224)[0m No GPUs found! Running the simulation on a CPU.
[2m[36m(pid=8980)[0m Inside covid19_components.py: 0 GPUs are available.
[2m[36m(pid=8980)[0m No GPUs found! Running the simulation on a CPU.
[2m[36m(pid=8980)[0m Inside covid19_env.py: 0 GPUs are available.
[2m[36m(pid=8980)[0m No GPUs found! Running the simulation on a CPU.
[2m[36m(pid=18224)[0m test std 0.0
[2m[36m(pid=8980)[0m test std 0.0


2024-05-03 10:44:01,873	INFO trainable.py:181 -- _setup took 11.232 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.


In [12]:
NUM_ITERS = 10
start_train = time.time()

for iteration in range(NUM_ITERS):
    print(f'********** Iter : {iteration} **********')
    start = time.time()
    result = trainer.train()
    length = time.time() - start
    print(f'''episode_reward_mean: {result.get('episode_reward_mean')}''')
    print(f'''it_time_taken: {length}''')

    
length_train = time.time() - start_train
print("Training took", length_train, " seconds.")

********** Iter : 0 **********
[2m[36m(pid=18224)[0m test std 0.0
[2m[36m(pid=8980)[0m test std 0.0
[2m[36m(pid=18224)[0m test std 0.0
[2m[36m(pid=8980)[0m test std 0.0
[2m[36m(pid=18224)[0m test std 0.0
[2m[36m(pid=18224)[0m test std 151.8447784589663
[2m[36m(pid=18224)[0m test std 145.5866083822985
[2m[36m(pid=18224)[0m test std 154.360613080576
[2m[36m(pid=18224)[0m test std 141.8117319279604
[2m[36m(pid=18224)[0m test std 167.39035442841893
[2m[36m(pid=18224)[0m test std 174.62649380415607
[2m[36m(pid=18224)[0m test std 143.4136241463843
[2m[36m(pid=18224)[0m test std 191.1411001630064
[2m[36m(pid=18224)[0m test std 154.07726446050276
[2m[36m(pid=18224)[0m test std 172.7483467943524
[2m[36m(pid=8980)[0m test std 0.0
[2m[36m(pid=18224)[0m test std 190.76072262470072
[2m[36m(pid=18224)[0m test std 166.4613927955505
[2m[36m(pid=18224)[0m test std 212.9434494395952
[2m[36m(pid=18224)[0m test std 194.73543555720042
[2m[36m(pi

[2m[36m(pid=18224)[0m test std 67.12027639980671
[2m[36m(pid=18224)[0m test std 47.68559940899836
[2m[36m(pid=18224)[0m test std 67.12027639980671
[2m[36m(pid=18224)[0m test std 47.68559940899836
[2m[36m(pid=18224)[0m test std 67.12027639980671
[2m[36m(pid=18224)[0m test std 47.68559940899836
[2m[36m(pid=18224)[0m test std 67.12027639980671
[2m[36m(pid=18224)[0m test std 47.68559940899836
[2m[36m(pid=18224)[0m test std 67.12027639980671
[2m[36m(pid=18224)[0m test std 47.68559940899836
[2m[36m(pid=18224)[0m test std 67.12027639980671
[2m[36m(pid=18224)[0m test std 47.68559940899836
[2m[36m(pid=18224)[0m test std 67.12027639980671
[2m[36m(pid=18224)[0m test std 47.68559940899836
[2m[36m(pid=18224)[0m test std 67.12027639980671
[2m[36m(pid=8980)[0m test std 40.83972744622399
[2m[36m(pid=8980)[0m test std 83.3832444803744
[2m[36m(pid=8980)[0m test std 40.83972744622399
[2m[36m(pid=8980)[0m test std 83.3832444803744
[2m[36m(pid=8980

[2m[36m(pid=8980)[0m test std 40.83972744622399
[2m[36m(pid=8980)[0m test std 83.3832444803744
[2m[36m(pid=8980)[0m test std 40.83972744622399
[2m[36m(pid=8980)[0m test std 83.3832444803744
[2m[36m(pid=8980)[0m test std 40.83972744622399
[2m[36m(pid=8980)[0m test std 83.3832444803744
[2m[36m(pid=8980)[0m test std 40.83972744622399
[2m[36m(pid=8980)[0m test std 83.3832444803744
[2m[36m(pid=8980)[0m test std 40.83972744622399
[2m[36m(pid=8980)[0m test std 83.3832444803744
[2m[36m(pid=8980)[0m test std 40.83972744622399
[2m[36m(pid=8980)[0m test std 83.3832444803744
[2m[36m(pid=8980)[0m test std 40.83972744622399
[2m[36m(pid=8980)[0m test std 83.3832444803744
[2m[36m(pid=18224)[0m test std 47.68559940899836
[2m[36m(pid=8980)[0m test std 40.83972744622399
[2m[36m(pid=8980)[0m test std 83.3832444803744
[2m[36m(pid=8980)[0m test std 40.83972744622399
[2m[36m(pid=8980)[0m test std 83.3832444803744
[2m[36m(pid=8980)[0m test std 40.839

[2m[36m(pid=18224)[0m test std 86.06544378965005
[2m[36m(pid=18224)[0m test std 53.021711035796976
[2m[36m(pid=18224)[0m test std 86.06544378965005
[2m[36m(pid=18224)[0m test std 53.021711035796976
[2m[36m(pid=18224)[0m test std 86.06544378965005
[2m[36m(pid=18224)[0m test std 53.021711035796976
[2m[36m(pid=18224)[0m test std 86.06544378965005
[2m[36m(pid=18224)[0m test std 53.021711035796976
[2m[36m(pid=18224)[0m test std 86.06544378965005
[2m[36m(pid=18224)[0m test std 53.021711035796976
[2m[36m(pid=18224)[0m test std 86.06544378965005
[2m[36m(pid=18224)[0m test std 53.021711035796976
[2m[36m(pid=18224)[0m test std 86.06544378965005
[2m[36m(pid=18224)[0m test std 53.021711035796976
[2m[36m(pid=18224)[0m test std 86.06544378965005
[2m[36m(pid=8980)[0m test std 70.78490307013571
[2m[36m(pid=8980)[0m test std 54.452449254351606
[2m[36m(pid=8980)[0m test std 70.78490307013571
[2m[36m(pid=8980)[0m test std 54.452449254351606
[2m[3

[2m[36m(pid=18224)[0m test std 86.06544378965005
[2m[36m(pid=18224)[0m test std 53.021711035796976
[2m[36m(pid=18224)[0m test std 86.06544378965005
[2m[36m(pid=18224)[0m test std 53.021711035796976
[2m[36m(pid=18224)[0m test std 86.06544378965005
[2m[36m(pid=18224)[0m test std 53.021711035796976
[2m[36m(pid=8980)[0m test std 70.78490307013571
[2m[36m(pid=8980)[0m test std 54.452449254351606
[2m[36m(pid=8980)[0m test std 70.78490307013571
[2m[36m(pid=8980)[0m test std 54.452449254351606
[2m[36m(pid=8980)[0m test std 70.78490307013571
[2m[36m(pid=8980)[0m test std 54.452449254351606
[2m[36m(pid=8980)[0m test std 70.78490307013571
[2m[36m(pid=8980)[0m test std 54.452449254351606
[2m[36m(pid=8980)[0m test std 70.78490307013571
[2m[36m(pid=8980)[0m test std 54.452449254351606
[2m[36m(pid=8980)[0m test std 70.78490307013571
[2m[36m(pid=8980)[0m test std 54.452449254351606
[2m[36m(pid=18224)[0m test std 0.0
[2m[36m(pid=8980)[0m test 

[2m[36m(pid=18224)[0m test std 80.57940279044446
[2m[36m(pid=18224)[0m test std 69.4042117298251
[2m[36m(pid=8980)[0m test std 61.84098970691936
[2m[36m(pid=8980)[0m test std 63.10702991250365
[2m[36m(pid=18224)[0m test std 80.57940279044446
[2m[36m(pid=8980)[0m test std 61.84098970691936
[2m[36m(pid=18224)[0m test std 69.4042117298251
[2m[36m(pid=8980)[0m test std 63.10702991250365
[2m[36m(pid=18224)[0m test std 80.57940279044446
[2m[36m(pid=8980)[0m test std 61.84098970691936
[2m[36m(pid=8980)[0m test std 63.10702991250365
[2m[36m(pid=18224)[0m test std 69.4042117298251
[2m[36m(pid=18224)[0m test std 80.57940279044446
[2m[36m(pid=18224)[0m test std 69.4042117298251
[2m[36m(pid=18224)[0m test std 80.57940279044446
[2m[36m(pid=18224)[0m test std 69.4042117298251
[2m[36m(pid=18224)[0m test std 80.57940279044446
[2m[36m(pid=18224)[0m test std 69.4042117298251
[2m[36m(pid=18224)[0m test std 80.57940279044446
[2m[36m(pid=18224)[0m

[2m[36m(pid=18224)[0m test std 69.4042117298251
[2m[36m(pid=8980)[0m test std 63.10702991250365
[2m[36m(pid=18224)[0m test std 80.57940279044446
[2m[36m(pid=8980)[0m test std 61.84098970691936
[2m[36m(pid=8980)[0m test std 63.10702991250365
[2m[36m(pid=18224)[0m test std 69.4042117298251
[2m[36m(pid=18224)[0m test std 80.57940279044446
[2m[36m(pid=8980)[0m test std 61.84098970691936
[2m[36m(pid=8980)[0m test std 63.10702991250365
[2m[36m(pid=18224)[0m test std 69.4042117298251
[2m[36m(pid=18224)[0m test std 80.57940279044446
[2m[36m(pid=8980)[0m test std 61.84098970691936
[2m[36m(pid=8980)[0m test std 63.10702991250365
[2m[36m(pid=18224)[0m test std 69.4042117298251
[2m[36m(pid=18224)[0m test std 80.57940279044446
[2m[36m(pid=8980)[0m test std 61.84098970691936
[2m[36m(pid=18224)[0m test std 69.4042117298251
[2m[36m(pid=8980)[0m test std 63.10702991250365
[2m[36m(pid=18224)[0m test std 80.57940279044446
[2m[36m(pid=8980)[0m te

[2m[36m(pid=18224)[0m test std 53.5258360385045
[2m[36m(pid=8980)[0m test std 73.77823498130155
[2m[36m(pid=8980)[0m test std 78.0664949730819
[2m[36m(pid=18224)[0m test std 85.5897962078537
[2m[36m(pid=18224)[0m test std 53.5258360385045
[2m[36m(pid=18224)[0m test std 85.5897962078537
[2m[36m(pid=18224)[0m test std 53.5258360385045
[2m[36m(pid=18224)[0m test std 85.5897962078537
[2m[36m(pid=18224)[0m test std 53.5258360385045
[2m[36m(pid=18224)[0m test std 85.5897962078537
[2m[36m(pid=18224)[0m test std 53.5258360385045
[2m[36m(pid=18224)[0m test std 85.5897962078537
[2m[36m(pid=18224)[0m test std 53.5258360385045
[2m[36m(pid=18224)[0m test std 85.5897962078537
[2m[36m(pid=18224)[0m test std 53.5258360385045
[2m[36m(pid=18224)[0m test std 85.5897962078537
[2m[36m(pid=18224)[0m test std 53.5258360385045
[2m[36m(pid=8980)[0m test std 73.77823498130155
[2m[36m(pid=8980)[0m test std 78.0664949730819
[2m[36m(pid=8980)[0m test std 

[2m[36m(pid=18224)[0m test std 85.5897962078537
[2m[36m(pid=18224)[0m test std 53.5258360385045
[2m[36m(pid=18224)[0m test std 85.5897962078537
[2m[36m(pid=18224)[0m test std 53.5258360385045
[2m[36m(pid=18224)[0m test std 85.5897962078537
[2m[36m(pid=18224)[0m test std 53.5258360385045
[2m[36m(pid=18224)[0m test std 85.5897962078537
[2m[36m(pid=18224)[0m test std 53.5258360385045
[2m[36m(pid=18224)[0m test std 85.5897962078537
[2m[36m(pid=18224)[0m test std 53.5258360385045
[2m[36m(pid=18224)[0m test std 85.5897962078537
[2m[36m(pid=18224)[0m test std 53.5258360385045
[2m[36m(pid=18224)[0m test std 85.5897962078537
[2m[36m(pid=8980)[0m test std 73.77823498130155
[2m[36m(pid=8980)[0m test std 78.0664949730819
[2m[36m(pid=8980)[0m test std 73.77823498130155
[2m[36m(pid=8980)[0m test std 78.0664949730819
[2m[36m(pid=8980)[0m test std 73.77823498130155
[2m[36m(pid=8980)[0m test std 78.0664949730819
[2m[36m(pid=8980)[0m test std 7

[2m[36m(pid=18224)[0m test std 61.40360707714383
[2m[36m(pid=18224)[0m test std 44.155600989225015
[2m[36m(pid=18224)[0m test std 61.40360707714383
[2m[36m(pid=18224)[0m test std 44.155600989225015
[2m[36m(pid=18224)[0m test std 61.40360707714383
[2m[36m(pid=18224)[0m test std 44.155600989225015
[2m[36m(pid=18224)[0m test std 61.40360707714383
[2m[36m(pid=18224)[0m test std 44.155600989225015
[2m[36m(pid=18224)[0m test std 61.40360707714383
[2m[36m(pid=18224)[0m test std 44.155600989225015
[2m[36m(pid=18224)[0m test std 61.40360707714383
[2m[36m(pid=18224)[0m test std 44.155600989225015
[2m[36m(pid=18224)[0m test std 61.40360707714383
[2m[36m(pid=8980)[0m test std 65.48861733126026
[2m[36m(pid=8980)[0m test std 36.685408626707904
[2m[36m(pid=8980)[0m test std 65.48861733126026
[2m[36m(pid=8980)[0m test std 36.685408626707904
[2m[36m(pid=8980)[0m test std 65.48861733126026
[2m[36m(pid=8980)[0m test std 36.685408626707904
[2m[36m

[2m[36m(pid=18224)[0m test std 61.40360707714383
[2m[36m(pid=18224)[0m test std 44.155600989225015
[2m[36m(pid=8980)[0m test std 36.685408626707904
[2m[36m(pid=8980)[0m test std 0.0
[2m[36m(pid=18224)[0m test std 0.0
[2m[36m(pid=8980)[0m test std 0.0
[2m[36m(pid=8980)[0m test std 134.68576509479715
[2m[36m(pid=8980)[0m test std 187.96182095682158
[2m[36m(pid=8980)[0m test std 167.20469825317272
[2m[36m(pid=8980)[0m test std 197.6779668112387
[2m[36m(pid=8980)[0m test std 187.7814019217745
[2m[36m(pid=8980)[0m test std 163.99453882810772
[2m[36m(pid=8980)[0m test std 194.94367865087193
[2m[36m(pid=8980)[0m test std 146.5054001222411
[2m[36m(pid=8980)[0m test std 201.20615398472648
[2m[36m(pid=8980)[0m test std 139.9522326879267
[2m[36m(pid=18224)[0m test std 0.0
[2m[36m(pid=18224)[0m test std 162.46085932817368
[2m[36m(pid=18224)[0m test std 138.43472109990915
[2m[36m(pid=18224)[0m test std 178.55880573985198
[2m[36m(pid=8980

[2m[36m(pid=18224)[0m test std 82.47792077081306
[2m[36m(pid=18224)[0m test std 63.983322060221894
[2m[36m(pid=18224)[0m test std 82.47792077081306
[2m[36m(pid=18224)[0m test std 63.983322060221894
[2m[36m(pid=18224)[0m test std 82.47792077081306
[2m[36m(pid=18224)[0m test std 63.983322060221894
[2m[36m(pid=18224)[0m test std 82.47792077081306
[2m[36m(pid=18224)[0m test std 63.983322060221894
[2m[36m(pid=18224)[0m test std 82.47792077081306
[2m[36m(pid=18224)[0m test std 63.983322060221894
[2m[36m(pid=18224)[0m test std 82.47792077081306
[2m[36m(pid=18224)[0m test std 63.983322060221894
[2m[36m(pid=18224)[0m test std 82.47792077081306
[2m[36m(pid=18224)[0m test std 63.983322060221894
[2m[36m(pid=18224)[0m test std 82.47792077081306
[2m[36m(pid=18224)[0m test std 63.983322060221894
[2m[36m(pid=8980)[0m test std 49.9695962246169
[2m[36m(pid=8980)[0m test std 73.47162780759736
[2m[36m(pid=8980)[0m test std 49.9695962246169
[2m[36m

[2m[36m(pid=18224)[0m test std 63.983322060221894
[2m[36m(pid=18224)[0m test std 82.47792077081306
[2m[36m(pid=18224)[0m test std 63.983322060221894
[2m[36m(pid=18224)[0m test std 82.47792077081306
[2m[36m(pid=18224)[0m test std 63.983322060221894
[2m[36m(pid=18224)[0m test std 82.47792077081306
[2m[36m(pid=18224)[0m test std 63.983322060221894
[2m[36m(pid=18224)[0m test std 82.47792077081306
[2m[36m(pid=18224)[0m test std 63.983322060221894
[2m[36m(pid=18224)[0m test std 82.47792077081306
[2m[36m(pid=18224)[0m test std 63.983322060221894
[2m[36m(pid=18224)[0m test std 82.47792077081306
[2m[36m(pid=18224)[0m test std 63.983322060221894
[2m[36m(pid=18224)[0m test std 82.47792077081306
[2m[36m(pid=18224)[0m test std 63.983322060221894
[2m[36m(pid=18224)[0m test std 82.47792077081306
[2m[36m(pid=18224)[0m test std 63.983322060221894
[2m[36m(pid=8980)[0m test std 0.0
[2m[36m(pid=18224)[0m test std 0.0
[2m[36m(pid=18224)[0m test st

[2m[36m(pid=18224)[0m test std 74.65949959994336
[2m[36m(pid=18224)[0m test std 82.32880155997347
[2m[36m(pid=18224)[0m test std 74.65949959994336
[2m[36m(pid=18224)[0m test std 82.32880155997347
[2m[36m(pid=18224)[0m test std 74.65949959994336
[2m[36m(pid=18224)[0m test std 82.32880155997347
[2m[36m(pid=18224)[0m test std 74.65949959994336
[2m[36m(pid=18224)[0m test std 82.32880155997347
[2m[36m(pid=18224)[0m test std 74.65949959994336
[2m[36m(pid=18224)[0m test std 82.32880155997347
[2m[36m(pid=18224)[0m test std 74.65949959994336
[2m[36m(pid=18224)[0m test std 82.32880155997347
[2m[36m(pid=18224)[0m test std 74.65949959994336
[2m[36m(pid=8980)[0m test std 81.12957928389534
[2m[36m(pid=8980)[0m test std 78.15124409739623
[2m[36m(pid=8980)[0m test std 81.12957928389534
[2m[36m(pid=8980)[0m test std 78.15124409739623
[2m[36m(pid=8980)[0m test std 81.12957928389534
[2m[36m(pid=8980)[0m test std 78.15124409739623
[2m[36m(pid=8980

[2m[36m(pid=18224)[0m test std 0.0
[2m[36m(pid=8980)[0m test std 0.0
[2m[36m(pid=18224)[0m test std 0.0
[2m[36m(pid=18224)[0m test std 165.51065232500076
[2m[36m(pid=18224)[0m test std 125.65047597778575
[2m[36m(pid=18224)[0m test std 136.52080260148657
[2m[36m(pid=8980)[0m test std 0.0
[2m[36m(pid=8980)[0m test std 156.97502594890548
[2m[36m(pid=8980)[0m test std 183.11234990024792
[2m[36m(pid=8980)[0m test std 186.0078023696563
[2m[36m(pid=18224)[0m test std 132.87762323490102
[2m[36m(pid=8980)[0m test std 182.11881316240311
[2m[36m(pid=18224)[0m test std 180.105862157136
[2m[36m(pid=18224)[0m test std 137.90501817986524
[2m[36m(pid=18224)[0m test std 182.10794533197173
[2m[36m(pid=18224)[0m test std 147.01636758502158
[2m[36m(pid=18224)[0m test std 169.45704975212186
[2m[36m(pid=18224)[0m test std 157.87978307867357
[2m[36m(pid=18224)[0m test std 174.38784563974832
[2m[36m(pid=18224)[0m test std 142.69619096516436
[2m[36m

[2m[36m(pid=18224)[0m test std 63.976599158526525
[2m[36m(pid=18224)[0m test std 56.24844995749446
[2m[36m(pid=18224)[0m test std 63.976599158526525
[2m[36m(pid=18224)[0m test std 56.24844995749446
[2m[36m(pid=18224)[0m test std 63.976599158526525
[2m[36m(pid=18224)[0m test std 56.24844995749446
[2m[36m(pid=18224)[0m test std 63.976599158526525
[2m[36m(pid=18224)[0m test std 56.24844995749446
[2m[36m(pid=18224)[0m test std 63.976599158526525
[2m[36m(pid=18224)[0m test std 56.24844995749446
[2m[36m(pid=18224)[0m test std 63.976599158526525
[2m[36m(pid=18224)[0m test std 56.24844995749446
[2m[36m(pid=18224)[0m test std 63.976599158526525
[2m[36m(pid=18224)[0m test std 56.24844995749446
[2m[36m(pid=18224)[0m test std 63.976599158526525
[2m[36m(pid=8980)[0m test std 74.88519668536688
[2m[36m(pid=8980)[0m test std 76.59457036185893
[2m[36m(pid=8980)[0m test std 74.88519668536688
[2m[36m(pid=8980)[0m test std 76.59457036185893
[2m[36

[2m[36m(pid=18224)[0m test std 63.976599158526525
[2m[36m(pid=18224)[0m test std 56.24844995749446
[2m[36m(pid=18224)[0m test std 63.976599158526525
[2m[36m(pid=18224)[0m test std 56.24844995749446
[2m[36m(pid=18224)[0m test std 63.976599158526525
[2m[36m(pid=18224)[0m test std 56.24844995749446
[2m[36m(pid=18224)[0m test std 63.976599158526525
[2m[36m(pid=18224)[0m test std 56.24844995749446
[2m[36m(pid=8980)[0m test std 74.88519668536688
[2m[36m(pid=8980)[0m test std 76.59457036185893
[2m[36m(pid=8980)[0m test std 74.88519668536688
[2m[36m(pid=8980)[0m test std 76.59457036185893
[2m[36m(pid=8980)[0m test std 74.88519668536688
[2m[36m(pid=8980)[0m test std 76.59457036185893
[2m[36m(pid=8980)[0m test std 74.88519668536688
[2m[36m(pid=8980)[0m test std 76.59457036185893
[2m[36m(pid=8980)[0m test std 74.88519668536688
[2m[36m(pid=8980)[0m test std 76.59457036185893
[2m[36m(pid=18224)[0m test std 0.0
[2m[36m(pid=8980)[0m test std

[2m[36m(pid=8980)[0m test std 60.35555981985598
[2m[36m(pid=8980)[0m test std 80.5982751385532
[2m[36m(pid=8980)[0m test std 60.35555981985598
[2m[36m(pid=8980)[0m test std 80.5982751385532
[2m[36m(pid=18224)[0m test std 88.38609836548375
[2m[36m(pid=18224)[0m test std 55.37356568473801
[2m[36m(pid=8980)[0m test std 60.35555981985598
[2m[36m(pid=18224)[0m test std 88.38609836548375
[2m[36m(pid=8980)[0m test std 80.5982751385532
[2m[36m(pid=18224)[0m test std 55.37356568473801
[2m[36m(pid=18224)[0m test std 88.38609836548375
[2m[36m(pid=18224)[0m test std 55.37356568473801
[2m[36m(pid=18224)[0m test std 88.38609836548375
[2m[36m(pid=18224)[0m test std 55.37356568473801
[2m[36m(pid=18224)[0m test std 88.38609836548375
[2m[36m(pid=18224)[0m test std 55.37356568473801
[2m[36m(pid=18224)[0m test std 88.38609836548375
[2m[36m(pid=18224)[0m test std 55.37356568473801
[2m[36m(pid=18224)[0m test std 88.38609836548375
[2m[36m(pid=18224)

[2m[36m(pid=18224)[0m test std 55.37356568473801
[2m[36m(pid=18224)[0m test std 88.38609836548375
[2m[36m(pid=18224)[0m test std 55.37356568473801
[2m[36m(pid=18224)[0m test std 88.38609836548375
[2m[36m(pid=18224)[0m test std 55.37356568473801
[2m[36m(pid=18224)[0m test std 88.38609836548375
[2m[36m(pid=18224)[0m test std 55.37356568473801
[2m[36m(pid=18224)[0m test std 88.38609836548375
[2m[36m(pid=18224)[0m test std 55.37356568473801
[2m[36m(pid=18224)[0m test std 88.38609836548375
[2m[36m(pid=18224)[0m test std 55.37356568473801
[2m[36m(pid=18224)[0m test std 88.38609836548375
[2m[36m(pid=18224)[0m test std 55.37356568473801
[2m[36m(pid=8980)[0m test std 60.35555981985598
[2m[36m(pid=8980)[0m test std 80.5982751385532
[2m[36m(pid=8980)[0m test std 60.35555981985598
[2m[36m(pid=8980)[0m test std 80.5982751385532
[2m[36m(pid=8980)[0m test std 60.35555981985598
[2m[36m(pid=8980)[0m test std 80.5982751385532
[2m[36m(pid=8980)[

[2m[36m(pid=18224)[0m test std 42.2873465476372
[2m[36m(pid=18224)[0m test std 62.61882406549803
[2m[36m(pid=18224)[0m test std 42.2873465476372
[2m[36m(pid=18224)[0m test std 62.61882406549803
[2m[36m(pid=18224)[0m test std 42.2873465476372
[2m[36m(pid=18224)[0m test std 62.61882406549803
[2m[36m(pid=18224)[0m test std 42.2873465476372
[2m[36m(pid=18224)[0m test std 62.61882406549803
[2m[36m(pid=18224)[0m test std 42.2873465476372
[2m[36m(pid=18224)[0m test std 62.61882406549803
[2m[36m(pid=18224)[0m test std 42.2873465476372
[2m[36m(pid=18224)[0m test std 62.61882406549803
[2m[36m(pid=18224)[0m test std 42.2873465476372
[2m[36m(pid=18224)[0m test std 62.61882406549803
[2m[36m(pid=18224)[0m test std 42.2873465476372
[2m[36m(pid=18224)[0m test std 62.61882406549803
[2m[36m(pid=8980)[0m test std 45.92744967894664
[2m[36m(pid=8980)[0m test std 50.01757231849717
[2m[36m(pid=8980)[0m test std 45.92744967894664
[2m[36m(pid=8980)[0m

[2m[36m(pid=18224)[0m test std 42.2873465476372
[2m[36m(pid=18224)[0m test std 62.61882406549803
[2m[36m(pid=18224)[0m test std 42.2873465476372
[2m[36m(pid=18224)[0m test std 62.61882406549803
[2m[36m(pid=18224)[0m test std 42.2873465476372
[2m[36m(pid=18224)[0m test std 62.61882406549803
[2m[36m(pid=18224)[0m test std 42.2873465476372
[2m[36m(pid=18224)[0m test std 62.61882406549803
[2m[36m(pid=18224)[0m test std 42.2873465476372
[2m[36m(pid=18224)[0m test std 62.61882406549803
[2m[36m(pid=18224)[0m test std 42.2873465476372
[2m[36m(pid=18224)[0m test std 62.61882406549803
[2m[36m(pid=18224)[0m test std 42.2873465476372
[2m[36m(pid=18224)[0m test std 62.61882406549803
[2m[36m(pid=18224)[0m test std 42.2873465476372
[2m[36m(pid=18224)[0m test std 62.61882406549803
[2m[36m(pid=8980)[0m test std 45.92744967894664
[2m[36m(pid=8980)[0m test std 50.01757231849717
[2m[36m(pid=18224)[0m test std 42.2873465476372
[2m[36m(pid=18224)[0

[2m[36m(pid=18224)[0m test std 85.22697084795576
[2m[36m(pid=18224)[0m test std 30.778899645054032
[2m[36m(pid=18224)[0m test std 85.22697084795576
[2m[36m(pid=18224)[0m test std 30.778899645054032
[2m[36m(pid=18224)[0m test std 85.22697084795576
[2m[36m(pid=18224)[0m test std 30.778899645054032
[2m[36m(pid=18224)[0m test std 85.22697084795576
[2m[36m(pid=18224)[0m test std 30.778899645054032
[2m[36m(pid=18224)[0m test std 85.22697084795576
[2m[36m(pid=18224)[0m test std 30.778899645054032
[2m[36m(pid=18224)[0m test std 85.22697084795576
[2m[36m(pid=18224)[0m test std 30.778899645054032
[2m[36m(pid=18224)[0m test std 85.22697084795576
[2m[36m(pid=18224)[0m test std 30.778899645054032
[2m[36m(pid=8980)[0m test std 105.58755834748762
[2m[36m(pid=8980)[0m test std 112.86278588844863
[2m[36m(pid=8980)[0m test std 105.58755834748762
[2m[36m(pid=8980)[0m test std 112.86278588844863
[2m[36m(pid=8980)[0m test std 105.58755834748762
[2m

[2m[36m(pid=18224)[0m test std 85.22697084795576
[2m[36m(pid=18224)[0m test std 30.778899645054032
[2m[36m(pid=18224)[0m test std 85.22697084795576
[2m[36m(pid=18224)[0m test std 30.778899645054032
[2m[36m(pid=18224)[0m test std 85.22697084795576
[2m[36m(pid=18224)[0m test std 30.778899645054032
[2m[36m(pid=18224)[0m test std 85.22697084795576
[2m[36m(pid=18224)[0m test std 30.778899645054032
[2m[36m(pid=18224)[0m test std 85.22697084795576
[2m[36m(pid=18224)[0m test std 30.778899645054032
[2m[36m(pid=18224)[0m test std 85.22697084795576
[2m[36m(pid=18224)[0m test std 30.778899645054032
[2m[36m(pid=8980)[0m test std 105.58755834748762
[2m[36m(pid=8980)[0m test std 112.86278588844863
[2m[36m(pid=8980)[0m test std 105.58755834748762
[2m[36m(pid=8980)[0m test std 112.86278588844863
[2m[36m(pid=8980)[0m test std 105.58755834748762
[2m[36m(pid=8980)[0m test std 112.86278588844863
[2m[36m(pid=8980)[0m test std 105.58755834748762
[2m[

[2m[36m(pid=18224)[0m test std 83.88182514097589
[2m[36m(pid=18224)[0m test std 72.33877206092038
[2m[36m(pid=18224)[0m test std 83.88182514097589
[2m[36m(pid=18224)[0m test std 72.33877206092038
[2m[36m(pid=18224)[0m test std 83.88182514097589
[2m[36m(pid=18224)[0m test std 72.33877206092038
[2m[36m(pid=18224)[0m test std 83.88182514097589
[2m[36m(pid=18224)[0m test std 72.33877206092038
[2m[36m(pid=18224)[0m test std 83.88182514097589
[2m[36m(pid=18224)[0m test std 72.33877206092038
[2m[36m(pid=18224)[0m test std 83.88182514097589
[2m[36m(pid=18224)[0m test std 72.33877206092038
[2m[36m(pid=18224)[0m test std 83.88182514097589
[2m[36m(pid=18224)[0m test std 72.33877206092038
[2m[36m(pid=18224)[0m test std 83.88182514097589
[2m[36m(pid=18224)[0m test std 72.33877206092038
[2m[36m(pid=18224)[0m test std 83.88182514097589
[2m[36m(pid=18224)[0m test std 72.33877206092038
[2m[36m(pid=18224)[0m test std 83.88182514097589
[2m[36m(pi

[2m[36m(pid=18224)[0m test std 0.0
[2m[36m(pid=8980)[0m test std 0.0
[2m[36m(pid=18224)[0m test std 0.0
[2m[36m(pid=8980)[0m test std 0.0


KeyboardInterrupt: 

In [None]:
#env_obj.env.get_agent(3).state

In [None]:
def generate_rollout_from_current_trainer_policy(
    trainer, 
    env_obj,
    num_dense_logs=5
):
    dense_logs = {}
    for idx in range(num_dense_logs):
        # Set initial states
        agent_states = {}
        for agent_idx in range(env_obj.env.n_agents):
            agent_states[str(agent_idx)] = trainer.get_policy("a").get_initial_state()
        planner_states = trainer.get_policy("p").get_initial_state()   

        # Play out the episode
        obs = env_obj.reset(force_dense_logging=True)
        for t in range(env_obj.env.episode_length):
            actions = {}
            for agent_idx in range(env_obj.env.n_agents):
                # Use the trainer object directly to sample actions for each agent
                actions[str(agent_idx)] = trainer.compute_action(
                    obs[str(agent_idx)], 
                    agent_states[str(agent_idx)], 
                    policy_id="a",
                    full_fetch=False
                )

            # Action sampling for the planner
            actions["p"] = trainer.compute_action(
                obs['p'], 
                planner_states, 
                policy_id='p',
                full_fetch=False
            )

            obs, rew, done, info = env_obj.step(actions)        
            if done['__all__']:
                break
        dense_logs[idx] = env_obj.env.dense_log
    return dense_logs

In [None]:
dense_logs = generate_rollout_from_current_trainer_policy(
    trainer, 
    env_obj,
    num_dense_logs=1
)

In [None]:
#dense_logs[0]['states'][0]

In [None]:
#planner_gr_score_importances = [log["p"]["GreenScoreImportance"] for log in dense_logs[0]['states']]

In [None]:
agent_0_green_scores = [log["0"]["endogenous"]["TotalBalance"] for log in dense_logs[0]['states']]
agent_1_green_scores = [log["1"]["endogenous"]["TotalBalance"] for log in dense_logs[0]['states']]
agent_2_green_scores = [log["2"]["endogenous"]["TotalBalance"] for log in dense_logs[0]['states']]
agent_3_green_scores = [log["3"]["endogenous"]["TotalBalance"] for log in dense_logs[0]['states']]

stock_prices = [log["0"]["endogenous"]["StockPrice"] for log in dense_logs[0]['states']]

In [None]:
plt.plot(np.arange(0,101),agent_0_green_scores, label="Agent 0")
plt.plot(np.arange(0,101),agent_1_green_scores, label="Agent 1")
plt.plot(np.arange(0,101),agent_2_green_scores, label="Agent 2")
plt.plot(np.arange(0,101),agent_3_green_scores, label="Agent 3")
plt.title('Stock Brocker Total Balance')
plt.legend()
plt.xlabel('Days')
plt.ylabel('Total Balance')
#plt.savefig("miners_green_scores.png")
plt.show()

In [None]:
plt.plot(np.arange(0,101),stock_prices, label="Agent 0")
plt.title('Stock Price Over Time')
plt.legend()
plt.xlabel('Days')
plt.ylabel('Stock Price')
#plt.savefig("miners_green_scores.png")
plt.show()

In [None]:
stocks_left = [log["0"]["endogenous"]["StocksLeft"] for log in dense_logs[0]['states']]

plt.plot(np.arange(0,101),stocks_left, label="Stocks Quantity Left")
plt.title('Stock Quantity Over Time')
plt.legend()
plt.xlabel('Days')
plt.ylabel('Stock Price')
#plt.savefig("miners_green_scores.png")
plt.show()

In [None]:
stocks_left = [log["0"]["endogenous"]["NumberOfStocks"] for log in dense_logs[0]['states']]

plt.plot(np.arange(0,101),stocks_left, label="Circuit Breaker")
plt.title('Circuit Breaker Activation Over Time')
plt.legend()
plt.xlabel('Days')
plt.ylabel('Stock Price')
#plt.savefig("miners_green_scores.png")
plt.show()

In [None]:
# Shutdown Ray after use
ray.shutdown()