In [1]:
import os
os.environ['NUMEXPR_MAX_THREADS'] = '1'

import logging
import numexpr as ne
import numpy as np
import torch
import datetime
from ddopai.envs.pricing.dynamic import DynamicPricingEnv
from ddopai.envs.pricing.dynamic_inventory import DynamicPricingInvEnv
from ddopai.envs.actionprocessors import ClipAction, RoundAction

from ddopai.experiments.experiment_functions_online import run_experiment
from ddopai.experiments.meta_experiment_functions import *
import requests
import yaml
import re
import pandas as pd
import wandb
from copy import deepcopy
import warnings
import gc
from mushroom_rl import core 
import pickle
from tqdm import tqdm, trange

In [2]:
logging_level = logging.INFO
logging.basicConfig(level=logging_level)

ne.set_num_threads(1)
torch.backends.cudnn.enabled = False
torch.set_num_threads(1)

set_warnings(logging.INFO) # turn off warnings for any level higher or equal to the input level

project_name = "CMDP-Bandit"
config_hp_sweep = import_config("config_hp_sweep.yaml")
config_env = import_config("config_env.yaml")

INFO:root:Configuration file 'config_hp_sweep.yaml' successfully loaded.
INFO:root:Configuration file 'config_env.yaml' successfully loaded.


In [3]:
artifacts = []

run = wandb.init(
    project=project_name,
    name = f"{project_name}_artifact_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
)
for artifact_index in trange(config_hp_sweep["num_trials"]):
    
    raw_data, val_index_start, test_index_start = get_online_data(
            config_env,
            overwrite=False
        )
    
    with open('data/raw_data.pkl', 'wb') as f:
        pickle.dump(raw_data, f)
    seed = np.random.randint(0, 2**32 - 1)
    artifact = wandb.Artifact('raw_data', type='data')
    artifact.add_file('data/raw_data.pkl')
    artifact.metadata = {
        "seed": seed
    }
    wandb.log_artifact(artifact)
    artifact.wait()
    artifacts.append(artifact.name)
wandb.finish()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtimlachner[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 5/5 [00:20<00:00,  4.04s/it]


In [None]:
config_hp_sweep["parameters"]["artifacts"] = {"values": [artifacts]}


In [5]:
del config_hp_sweep["num_trials"]
config_hp_sweep

{'method': 'grid',
 'metric': {'name': 'cumulative_mean_true_reward', 'goal': 'maximize'},
 'name': 'config_hp_sweep',
 'parameters': {'config_train-agent': {'values': ['RL2PPO']},
  'config_agent-RL2PPO-learning_rate_actor': {'values': [0.001,
    0.0005,
    0.002]},
  'artifacts': {'values': [{'values': ['raw_data:v121',
      'raw_data:v122',
      'raw_data:v123',
      'raw_data:v124',
      'raw_data:v125']}]}}}

In [6]:
sweep_id = wandb.sweep(config_hp_sweep, project=project_name)

Create sweep with ID: shg82khf
Sweep URL: https://wandb.ai/timlachner/CMDP-Bandit/sweeps/shg82khf


In [7]:
wandb.finish()