In [1]:
import os
os.environ['NUMEXPR_MAX_THREADS'] = '1'

import logging
import numexpr as ne
import numpy as np
import torch
import datetime
from ddopai.envs.pricing.dynamic import DynamicPricingEnv
from ddopai.envs.pricing.dynamic_RL2 import RL2DynamicPricingEnv
from ddopai.envs.pricing.dynamic_lag import LagDynamicPricingEnv
from ddopai.envs.pricing.dynamic_inventory import DynamicPricingInvEnv
from ddopai.envs.actionprocessors import ClipAction, RoundAction
from ddopai.agents.obsprocessors import ConvertDictSpace

from ddopai.experiments.experiment_functions_online import run_experiment 
from ddopai.experiments.meta_experiment_functions import *
import requests
import yaml
import re
import pandas as pd
import wandb
from copy import deepcopy
import warnings
import gc
#from mushroom_rl import core 
from ddopai.experiments.meta_core import Core
import pickle

In [2]:
logging_level = logging.INFO
logging.basicConfig(level=logging_level)

ne.set_num_threads(1)
torch.backends.cudnn.enabled = False
torch.set_num_threads(1)

set_warnings(logging.INFO) # turn off warnings for any level higher or equal to the input level
LIBRARIES_TO_TRACK = ["ddopai", "mushroom_rl"]
PROJECT_NAME = "pricing_cMDP_test"

ENVCLASS = DynamicPricingEnv
RESULTS_DIR = "results"
def get_ENVCLASS(class_name):
    if class_name == "DynamicPricingEnv":
        return DynamicPricingEnv
    elif class_name == "DynamicPricingInvEnv":
        return DynamicPricingInvEnv
    elif class_name == "LagDynamicPricingEnv":
        return LagDynamicPricingEnv
    elif class_name == "RL2DynamicPricingEnv":
        return RL2DynamicPricingEnv
    else:
        raise ValueError(f"Unknown class name {class_name}")

# Experiment preparations
## Set-up WandB
### Init WandB

In [3]:
project_name = "pricing_cMDP"


### Track library versions and git hash of experiment

# Experiment parameters

In [4]:
config_train, config_agent, config_env, AgentClass, agent_name = prep_experiment(
        PROJECT_NAME,
        LIBRARIES_TO_TRACK,
        config_train_name="config_train.yaml",
        config_agent_name="config_agent.yaml",
        config_env_name="config_env.yaml",
    )

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtimlachner[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO:root:ddopai: 0.0.7
INFO:root:mushroom_rl: 1.10.1
INFO:root:Git hash: 0240bca214a0886c93a097164834526abebcdb70
INFO:root:Configuration file 'config_train.yaml' successfully loaded.
INFO:root:Configuration file 'config_agent.yaml' successfully loaded.
INFO:root:Configuration file 'config_env.yaml' successfully loaded.


In [5]:
if config_env['lag_window_params'].get("lag_window") is not None:
        for env_kwargs in config_env["env_kwargs"]:
            env_kwargs["lag_window"] = config_env['lag_window_params']['lag_window']
            env_kwargs["env_class"] = "LagDynamicPricingEnv"


In [6]:
#artifact = wandb.use_artifact('raw_data:latest')
#path = artifact.download()
#raw_data = pickle.load(open(os.path.join(path, 'raw_data.pkl'), 'rb'))

In [7]:
raw_data, val_index_start, test_index_start = get_online_data(
            config_env,
            overwrite=False
        )

In [8]:
raw_data[0][0][0]

array([1.        , 0.17042525, 0.02231759, 0.37522929, 0.41223724,
       0.01571206, 0.48614612, 0.42916275, 0.19386668, 0.10603631])

## Environment parameters

* Get the environment parameters from the config file 
* Overwrite the ```lag_window```parameter with the parameter specified in the agent, if it is specified (since lag window is provided by the environment, but a tunable hyperparameter of the agent)

In [None]:
round_action = RoundAction(unit_size=config_env["unit_size"])
clip_action = ClipAction(low =config_env["setup_kwargs"][])
postprocessors = [round_action]

#ENVCLASS = get_ENVCLASS(config_env["env_class"])
#environment = set_up_env_online(ENVCLASS, raw_data, val_index_start, test_index_start, config_env, postprocessors)
environments = prepare_env_online(get_ENVCLASS=get_ENVCLASS, raw_data=raw_data, val_index_start=0, test_index_start=0, config_env=config_env, postprocessors=postprocessors)

In [10]:
environments[0].inv

array([2000.])

In [11]:
environments[0].observation_space

Dict('features': Box([1.0000000e+00 1.4287769e-04 3.8245969e-04 9.3830319e-04 2.1704179e-04
 2.1403281e-03 8.0802059e-04 1.9769531e-03 2.2935087e-03 2.8973455e-03], [1.         0.49980804 0.49928105 0.49585605 0.49503562 0.49853876
 0.49922723 0.49936512 0.4979667  0.4993368 ], (10,), float32), 'inventory': Box(0.0, 1.0, (1,), float32), 'prev_action': Box(0.0, 20.0, (1,), float32), 'prev_done': Box(0.0, 1.0, (1,), float32), 'prev_reward': Box(-inf, inf, (1,), float32))

In [12]:
obs, _ = environments[0].get_observation()
obs

{'features': array([1.        , 0.17042525, 0.02231759, 0.37522929, 0.41223724,
        0.01571206, 0.48614612, 0.42916275, 0.19386668, 0.10603631]),
 'inventory': array([[1.]], dtype=float32),
 'prev_action': array([0.], dtype=float32),
 'prev_reward': array([0.], dtype=float32),
 'prev_done': array([1.], dtype=float32)}

In [13]:
conv = ConvertDictSpace(keep_time_dim=False)
X = conv(environments[0].get_observation()[0])
X

array([1.        , 0.17042525, 0.02231759, 0.37522929, 0.41223724,
       0.01571206, 0.48614612, 0.42916275, 0.19386668, 0.10603631,
       1.        , 0.        , 0.        , 1.        ])

## Agent Parameter

In [14]:
logging.info(f"Agent: {agent_name}")
if agent_name in ["SAC", "PPORNN", "RL2PPO"]:
    obsprocessors = [ConvertDictSpace(keep_time_dim=False, )]
else:
    obsprocessors = []
if AgentClass.train_mode == "env_interaction":
    if "link" in config_agent:
        glm_link, price_function = set_up_agent(AgentClass, environments[0], config_agent)
        config_agent["g"] = glm_link
        config_agent["price_function"] = price_function
        
        del config_agent["link"]
    if agent_name == "Clairvoyant":
        agent = AgentClass(
        alpha=environments[0].alpha,
        beta=environments[0].beta,
        environment_info=environments[0].mdp_info,
        **config_agent
        )
    else:
        agent = AgentClass(
        environment_info=environments[0].mdp_info,
        obsprocessors=obsprocessors,
        **config_agent
        )

else:
    raise ValueError("Invalid train_mode for online training")

INFO:root:Agent: RL2PPO
INFO:root:Actor (RL²) network:


Layer (type:depth-idx)                   Output Shape              Param #
RL2RNNActor                              [1, 1, 1]                 --
├─RL2RNN: 1-1                            [1, 1, 1]                 --
│    └─SpecificRNNWrapperHS: 2-1         [1, 1, 64]                --
│    │    └─GRU: 3-1                     [1, 1, 64]                15,360
│    └─Sequential: 2-2                   [1, 1]                    --
│    │    └─Linear: 3-2                  [1, 64]                   4,160
│    │    └─ReLU: 3-3                    [1, 64]                   --
│    │    └─Dropout: 3-4                 [1, 64]                   --
│    │    └─Linear: 3-5                  [1, 32]                   2,080
│    │    └─ReLU: 3-6                    [1, 32]                   --
│    │    └─Dropout: 3-7                 [1, 32]                   --
│    │    └─Linear: 3-8                  [1, 1]                    33
Total params: 21,633
Trainable params: 21,633
Non-trainable params: 0
Total

INFO:root:Critic (RL²) network:


Layer (type:depth-idx)                   Output Shape              Param #
RL2RNNValue                              [1, 1, 1]                 --
├─RL2RNN: 1-1                            [1, 1, 1]                 --
│    └─SpecificRNNWrapperHS: 2-1         [1, 1, 64]                --
│    │    └─GRU: 3-1                     [1, 1, 64]                15,360
│    └─Sequential: 2-2                   [1, 1]                    --
│    │    └─Linear: 3-2                  [1, 64]                   4,160
│    │    └─ReLU: 3-3                    [1, 64]                   --
│    │    └─Dropout: 3-4                 [1, 64]                   --
│    │    └─Linear: 3-5                  [1, 32]                   2,080
│    │    └─ReLU: 3-6                    [1, 32]                   --
│    │    └─Dropout: 3-7                 [1, 32]                   --
│    │    └─Linear: 3-8                  [1, 1]                    33
Total params: 21,633
Trainable params: 21,633
Non-trainable params: 0
Total

In [15]:
earlystoppinghandler = set_up_earlystoppinghandler(config_train)

In [16]:
environments[0].step(np.array([5]))

({'features': array([1.        , 0.00350445, 0.09968054, 0.08908962, 0.00783165,
         0.42725708, 0.24546228, 0.4031461 , 0.03825726, 0.02094546]),
  'inventory': array([[0.99793303]], dtype=float32),
  'prev_action': array([[5.]], dtype=float32),
  'prev_reward': array([[20.669678]], dtype=float32),
  'prev_done': array([0.], dtype=float32)},
 array([20.66967708]),
 array([False]),
 {'inv': array([1995.86606026]),
  'demand': array([4.13393542]),
  'demand_noise_free': array([4.96748428]),
  'action': array([5.]),
  'reward': array([20.66967708])})

In [17]:
dataset = run_experiment(
        agent,
        environments[:1],
        n_epochs=config_train["n_epochs"],
        n_steps=config_train["n_steps"],
        n_steps_per_fit=None,
        n_episodes_per_fit=1,
        early_stopping_handler=earlystoppinghandler,
        save_best=config_train["save_best"],
        run_id=wandb.run.id,
        tracking="wandb",
        eval_step_info=False,
        print_freq=1,
        results_dir = RESULTS_DIR,
        return_dataset=True,
        return_score=False
    )

INFO:root:Starting experiment


Experiment directory: results/5kdasfgd


  0%|          | 0/400 [00:00<?, ?it/s]INFO:root:Epoch 1: R=[-10.06929935], J=[-10.06929935]
  0%|          | 1/400 [00:00<00:48,  8.15it/s]

advs (204, 1)
vpreds (204, 1)
tdlam_rets (204, 1)


INFO:root:Epoch 2: R=[13.93980706], J=[13.93980706]
  0%|          | 2/400 [00:00<00:46,  8.56it/s]

advs (223, 1)
vpreds (223, 1)
tdlam_rets (223, 1)


INFO:root:Epoch 3: R=[0.44701754], J=[0.44701754]
  1%|          | 3/400 [00:00<00:44,  8.88it/s]

advs (211, 1)
vpreds (211, 1)
tdlam_rets (211, 1)


INFO:root:Epoch 4: R=[-9.52421152], J=[-9.52421152]
  1%|          | 4/400 [00:00<00:43,  9.10it/s]

advs (210, 1)
vpreds (210, 1)
tdlam_rets (210, 1)


INFO:root:Epoch 5: R=[-5.64145389], J=[-5.64145389]
  1%|▏         | 5/400 [00:00<00:42,  9.25it/s]

advs (205, 1)
vpreds (205, 1)
tdlam_rets (205, 1)


INFO:root:Epoch 6: R=[-5.74289771], J=[-5.74289771]


advs (202, 1)
vpreds (202, 1)
tdlam_rets (202, 1)


  2%|▏         | 6/400 [00:00<00:51,  7.72it/s]INFO:root:Epoch 7: R=[7.08539874], J=[7.08539874]
  2%|▏         | 7/400 [00:00<00:47,  8.25it/s]

advs (199, 1)
vpreds (199, 1)
tdlam_rets (199, 1)


INFO:root:Epoch 8: R=[-0.27054091], J=[-0.27054091]
  2%|▏         | 8/400 [00:00<00:44,  8.73it/s]

advs (195, 1)
vpreds (195, 1)
tdlam_rets (195, 1)


INFO:root:Epoch 9: R=[-12.07181465], J=[-12.07181465]
  2%|▏         | 9/400 [00:01<00:43,  9.04it/s]

advs (202, 1)
vpreds (202, 1)
tdlam_rets (202, 1)


INFO:root:Epoch 10: R=[-11.95891311], J=[-11.95891311]
  2%|▎         | 10/400 [00:01<00:45,  8.50it/s]

advs (196, 1)
vpreds (196, 1)
tdlam_rets (196, 1)


INFO:root:Epoch 11: R=[-11.62137369], J=[-11.62137369]


advs (192, 1)
vpreds (192, 1)
tdlam_rets (192, 1)




advs (194, 1)
vpreds (194, 1)
tdlam_rets (194, 1)


INFO:root:Epoch 12: R=[-12.75330205], J=[-12.75330205]
  3%|▎         | 12/400 [00:01<00:42,  9.09it/s]INFO:root:Epoch 13: R=[-2.59301796], J=[-2.59301796]
  3%|▎         | 13/400 [00:01<00:41,  9.29it/s]

advs (194, 1)
vpreds (194, 1)
tdlam_rets (194, 1)




advs

INFO:root:Epoch 14: R=[-30.49812982], J=[-30.49812982]


 (178, 1)
vpreds (178, 1)
tdlam_rets (178, 1)




advs (176, 1)
vpreds (176, 1)
tdlam_rets (176, 1)


INFO:root:Epoch 15: R=[2.01608156], J=[2.01608156]
  4%|▍         | 15/400 [00:01<00:39,  9.66it/s]

advs (168, 1)
vpreds (168, 1)
tdlam_rets (168, 1)


INFO:root:Epoch 16: R=[13.78403753], J=[13.78403753]
INFO:root:Epoch 17: R=[7.15829537], J=[7.15829537]
  4%|▍         | 17/400 [00:01<00:40,  9.52it/s]

advs (248, 1)
vpreds (248, 1)
tdlam_rets (248, 1)


INFO:root:Epoch 18: R=[-9.52421152], J=[-9.52421152]
  4%|▍         | 18/400 [00:02<00:43,  8.83it/s]

advs (282, 1)
vpreds (282, 1)
tdlam_rets (282, 1)


INFO:root:Epoch 19: R=[-14.72214884], J=[-14.72214884]


advs (142, 1)
vpreds (142, 1)
tdlam_rets (142, 1)


INFO:root:Epoch 20: R=[7.30350603], J=[7.30350603]
  5%|▌         | 20/400 [00:02<00:43,  8.64it/s]

advs (259, 1)
vpreds (259, 1)
tdlam_rets (259, 1)


INFO:root:Epoch 21: R=[-29.51478565], J=[-29.51478565]


advs (135, 1)
vpreds (135, 1)
tdlam_rets (135, 1)


INFO:root:Epoch 22: R=[-21.22055663], J=[-21.22055663]
  6%|▌         | 22/400 [00:02<00:38,  9.91it/s]

advs (136, 1)
vpreds (136, 1)
tdlam_rets (136, 1)


INFO:root:Epoch 23: R=[-18.33349435], J=[-18.33349435]


advs (132, 1)
vpreds (132, 1)
tdlam_rets (132, 1)


INFO:root:Epoch 24: R=[-3.36184483], J=[-3.36184483]
  6%|▌         | 24/400 [00:02<00:34, 11.01it/s]

advs (129, 1)
vpreds (129, 1)
tdlam_rets (129, 1)


INFO:root:Epoch 25: R=[-27.71382086], J=[-27.71382086]


advs (126, 1)
vpreds (126, 1)
tdlam_rets (126, 1)


INFO:root:Epoch 26: R=[-15.19356324], J=[-15.19356324]
  6%|▋         | 26/400 [00:02<00:32, 11.45it/s]

advs (123, 1)
vpreds (123, 1)
tdlam_rets (123, 1)


INFO:root:Epoch 27: R=[-19.57611591], J=[-19.57611591]


advs (121, 1)
vpreds (121, 1)
tdlam_rets (121, 1)


INFO:root:Epoch 28: R=[-43.51377119], J=[-43.51377119]
  7%|▋         | 28/400 [00:02<00:29, 12.43it/s]

advs (118, 1)
vpreds (118, 1)
tdlam_rets (118, 1)




advs (115, 1)
vpreds (115, 1)
tdlam_rets (115, 1)


INFO:root:Epoch 29: R=[-25.94567563], J=[-25.94567563]


advs (111, 1)
vpreds (111, 1)
tdlam_rets (111, 1)


INFO:root:Epoch 30: R=[-27.71382086], J=[-27.71382086]
  8%|▊         | 30/400 [00:02<00:27, 13.27it/s]

advs (109, 1)
vpreds (109, 1)
tdlam_rets (109, 1)


INFO:root:Epoch 31: R=[-35.40972937], J=[-35.40972937]
INFO:root:Epoch 32: R=[-35.55758666], J=[-35.55758666]
  8%|▊         | 32/400 [00:03<00:27, 13.26it/s]

advs (105, 1)
vpreds (105, 1)
tdlam_rets (105, 1)


INFO:root:Epoch 33: R=[-23.02819633], J=[-23.02819633]


advs (101, 1)
vpreds (101, 1)
tdlam_rets (101, 1)


INFO:root:Epoch 34: R=[-47.68878083], J=[-47.68878083]
  8%|▊         | 34/400 [00:03<00:25, 14.35it/s]

advs (100, 1)
vpreds (100, 1)
tdlam_rets (100, 1)


INFO:root:Epoch 35: R=[-80.58963942], J=[-80.58963942]


advs (97, 1)
vpreds (97, 1)
tdlam_rets (97, 1)


INFO:root:Epoch 36: R=[-46.55182556], J=[-46.55182556]
  9%|▉         | 36/400 [00:03<00:23, 15.27it/s]

advs (94, 1)
vpreds (94, 1)
tdlam_rets (94, 1)


INFO:root:Epoch 37: R=[-51.65915504], J=[-51.65915504]


advs (91, 1)
vpreds (91, 1)
tdlam_rets (91, 1)


INFO:root:Epoch 38: R=[-83.99128347], J=[-83.99128347]
 10%|▉         | 38/400 [00:03<00:22, 16.27it/s]

advs (89, 1)
vpreds (89, 1)
tdlam_rets (89, 1)


INFO:root:Epoch 39: R=[-54.0268635], J=[-54.0268635]


advs (86, 1)
vpreds (86, 1)
tdlam_rets (86, 1)


INFO:root:Epoch 40: R=[-71.28794721], J=[-71.28794721]


advs (84, 1)
vpreds (84, 1)
tdlam_rets (84, 1)




advs (81, 1)
vpreds (81, 1)
tdlam_rets (81, 1)


INFO:root:Epoch 41: R=[-93.47712585], J=[-93.47712585]
 10%|█         | 41/400 [00:03<00:20, 17.63it/s]

advs (79, 1)
vpreds (79, 1)
tdlam_rets (79, 1)


INFO:root:Epoch 42: R=[-67.54429717], J=[-67.54429717]


advs (77, 1)
vpreds (77, 1)
tdlam_rets (77, 1)


INFO:root:Epoch 43: R=[-34.23386223], J=[-34.23386223]
INFO:root:Epoch 44: R=[-153.51454361], J=[-153.51454361]
 11%|█         | 44/400 [00:03<00:18, 19.18it/s]

advs (75, 1)
vpreds (75, 1)
tdlam_rets (75, 1)


INFO:root:Epoch 45: R=[-154.81797697], J=[-154.81797697]


advs (74, 1)
vpreds (74, 1)
tdlam_rets (74, 1)


INFO:root:Epoch 46: R=[-120.6253768], J=[-120.6253768]
 12%|█▏        | 46/400 [00:03<00:19, 18.09it/s]

advs (71, 1)
vpreds (71, 1)
tdlam_rets (71, 1)


INFO:root:Epoch 47: R=[-95.38111569], J=[-95.38111569]


advs (69, 1)
vpreds (69, 1)
tdlam_rets (69, 1)


INFO:root:Epoch 48: R=[-153.25443954], J=[-153.25443954]


advs (69, 1)
vpreds (69, 1)
tdlam_rets (69, 1)


INFO:root:Epoch 49: R=[-165.6892598], J=[-165.6892598]
 12%|█▏        | 49/400 [00:03<00:17, 20.02it/s]

advs (68, 1)
vpreds (68, 1)
tdlam_rets (68, 1)


INFO:root:Epoch 50: R=[-173.85067849], J=[-173.85067849]


advs (67, 1)
vpreds (67, 1)
tdlam_rets (67, 1)




advs (67, 1)
vpreds (67, 1)
tdlam_rets (67, 1)


INFO:root:Epoch 51: R=[-155.60236737], J=[-155.60236737]


advs (66, 1)
vpreds (66, 1)
tdlam_rets (66, 1)


INFO:root:Epoch 52: R=[-177.71915372], J=[-177.71915372]
 13%|█▎        | 52/400 [00:04<00:16, 21.39it/s]

advs

INFO:root:Epoch 53: R=[-123.92562693], J=[-123.92562693]


 (67, 1)
vpreds (67, 1)
tdlam_rets (67, 1)


INFO:root:Epoch 54: R=[-131.12576256], J=[-131.12576256]


advs (67, 1)
vpreds (67, 1)
tdlam_rets (67, 1)


INFO:root:Epoch 55: R=[-125.82857362], J=[-125.82857362]
 14%|█▍        | 55/400 [00:04<00:15, 22.32it/s]

advs (68, 1)
vpreds (68, 1)
tdlam_rets (68, 1)


INFO:root:Epoch 56: R=[-107.58160501], J=[-107.58160501]


advs (69, 1)
vpreds (69, 1)
tdlam_rets (69, 1)


INFO:root:Epoch 57: R=[-119.45594048], J=[-119.45594048]


advs (70, 1)
vpreds (70, 1)
tdlam_rets (70, 1)


 14%|█▍        | 57/400 [00:04<00:26, 13.10it/s]


KeyboardInterrupt: 

In [None]:
wandb.finish()

0,1
Action,▄▇▁▆▄▃▅▄▄▂▅▆▆▄▃▇▄▇▆▇▆▆▃▄▄▁▆▅█▆▄▇▆▇██▂▄▆▆
Cumulative_Reward,▁▁▁▁▅▃▃▂▄▅▂▄▄▂▃▃▂▄▂▄▃▃▆▃▂▂▅▄▂▅▅▂▂▃▃▂▆▅▆█
Epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇████
Inventory,▆█▂▃▂▅▇▂█▁▄▂▁▃▅▇▃██▄▆▅▇▆▁▃▂█▄▃▂█▆▄▃▁▅▁▂█
Reward,▆▇▄▆▄▄▅▆▅█▅▅█▆▅▅▆▆▆▆▆▁▆▄▇▅█▄█▆▅▅▆▃▁▆▂▃▆▅
t,▃▂▂▃▁▂▂▁▆█▆▃▇▇▂▁▄▇▁▃▂▇▆▃▇▄▅▆▇█▅▄▁▅▇▂▅▂▄▁

0,1
Action,0.71
Cumulative_Reward,1184.4423
Epoch,399.0
Inventory,0.0
Reward,3.26642
t,220.0
