In [1]:
import gym
from gym import error, spaces, utils
from gym.utils import seeding
import numpy as np

import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
from keras.layers import Bidirectional
from keras.models import Sequential
from keras.layers.core import Activation, Dropout, Dense
from keras.layers import Flatten, LSTM
from keras.layers import Bidirectional
from ray import tune
from ray.tune.registry import register_env

In [2]:
sequence_length=3
state_len = 52
num_action = 1
encoding_len = state_len + num_action

class WorldMovelEnv(gym.Env):

    def __init__(self):
        
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(52,))
        self.action_space = gym.spaces.Discrete(145)
        
        self.step_count = 0
        self.st_model = Sequential()
        self.st_model.add(Bidirectional(LSTM(256, activation='relu', return_sequences=True), input_shape=(sequence_length, encoding_len)))
        self.st_model.add(Flatten())
        self.st_model.add(Dense(512, activation='relu'))
        self.st_model.add(Dense(512, activation='relu'))
        self.st_model.add(Dense(512, activation='relu'))
        self.st_model.add(Dense(9674, activation='softmax'))
        self.st_model.compile(optimizer='adam', loss=tf.keras.losses.CategoricalCrossentropy(), metrics=[tf.keras.metrics.Accuracy()])
        self.st_model.load_weights('ClassStateModel')
        
        self.r_model = Sequential()
        self.r_model.add(Bidirectional(LSTM(256, activation='relu', return_sequences=True), input_shape=(sequence_length, state_len)))
        self.r_model.add(Flatten())
        self.r_model.add(Dense(256, activation='relu'))
        self.r_model.add(Dense(128, activation='relu'))
        self.r_model.add(Dense(1, activation='linear'))
        self.r_model.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
        self.r_model.load_weights('RewardModel')
        
        self.index_to_state = np.load('index_to_state.npy', allow_pickle=True).item()
        self.state_to_index = np.load('state_to_index.npy', allow_pickle=True).item()
        
        self.state_action = np.zeros((sequence_length,encoding_len))
        self.state = np.zeros((sequence_length,state_len))

    def step(self, action):
        self.step_count += 1
        
        self.state_action[0,52] = action / 145
        obs_index_probs = self.st_model.predict(np.array([self.state_action]))
        #obs_index = np.argmax(obs_index_probs)
        obs_index = np.random.choice(np.arange(9674), p=obs_index_probs[0])
        new_state = self.index_to_state[obs_index]
        new_state = np.frombuffer(new_state)
        
        
        for i in range(1,3):
            self.state[i,:] = self.state[i-1,:]
            
        self.state[0,:] = new_state
        
        reward = self.r_model.predict(np.array([self.state]))
        
        for i in range(1,3):
            self.state_action[i,:] = self.state_action[i-1,:]
        
        self.state_action[0,:] = np.concatenate([new_state, [0]])
        
        done = self.step_count == 100
        return new_state, reward[0][0], done, {}

    def reset(self):
        step_count = 0
        
        self.state_action = np.zeros((sequence_length,encoding_len))
        self.state = np.zeros((sequence_length,state_len))
        
        obs_index_probs = self.st_model.predict(np.array([self.state_action]))
        #obs_index = np.argmax(obs_index_probs)
        obs_index = np.random.choice(np.arange(9674), p=obs_index_probs[0])
        new_state = self.index_to_state[obs_index]
        new_state = np.frombuffer(new_state)
      
        return new_state
    

    def render(self, mode='human', close=False):
        pass

    def close(self):
        pass
    
def env_creator(config):
    return WorldMovelEnv() 

In [3]:
register_env(name="DreamCybORG", env_creator=env_creator)

In [4]:
def print_results(results_dict):
    train_iter = results_dict["training_iteration"]
    r_mean = results_dict["episode_reward_mean"]
    r_max = results_dict["episode_reward_max"]
    r_min = results_dict["episode_reward_min"]
    print(f"{train_iter:4d} \tr_mean: {r_mean:.1f} \tr_max: {r_max:.1f} \tr_min: {r_min: .1f}")

In [5]:
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.policy.policy import PolicySpec

config = (
    PPOConfig()
    #Each rollout worker uses a single cpu
    .rollouts(num_rollout_workers=5, num_envs_per_worker=4, horizon=100)\
    .training(sgd_minibatch_size = 100, train_batch_size=4000, gamma=0.85, lr=0.00005, 
              model={"fcnet_hiddens": [256, 256], "fcnet_activation": "tanh",})\
    .environment(disable_env_checking=True, env = 'DreamCybORG')\
    .resources(num_gpus=1)\
    .framework('torch')\
    #.exploration(explore=True, exploration_config={"type": "RE3", "embeds_dim": 128, "beta_schedule": "constant", "sub_exploration": {"type": "StochasticSampling",},})\
)
trainer = config.build()


2023-01-02 23:48:09,839	INFO worker.py:1528 -- Started a local Ray instance.
[2m[36m(RolloutWorker pid=8626)[0m Instructions for updating:
[2m[36m(RolloutWorker pid=8626)[0m Call initializer instance with the dtype argument instead of passing it to the constructor
[2m[36m(RolloutWorker pid=8626)[0m Instructions for updating:
[2m[36m(RolloutWorker pid=8626)[0m Call initializer instance with the dtype argument instead of passing it to the constructor
[2m[36m(RolloutWorker pid=8626)[0m Instructions for updating:
[2m[36m(RolloutWorker pid=8626)[0m Call initializer instance with the dtype argument instead of passing it to the constructor
[2m[36m(RolloutWorker pid=8621)[0m Instructions for updating:
[2m[36m(RolloutWorker pid=8621)[0m Call initializer instance with the dtype argument instead of passing it to the constructor
[2m[36m(RolloutWorker pid=8621)[0m Instructions for updating:
[2m[36m(RolloutWorker pid=8621)[0m Call initializer instance with the dtype arg

In [6]:
for i in range(500):
    print_results(trainer.train())

[2m[36m(RolloutWorker pid=8626)[0m   updates=self.state_updates,
[2m[36m(RolloutWorker pid=8621)[0m   updates=self.state_updates,
[2m[36m(RolloutWorker pid=8622)[0m   updates=self.state_updates,
[2m[36m(RolloutWorker pid=8623)[0m   updates=self.state_updates,
[2m[36m(RolloutWorker pid=8624)[0m   updates=self.state_updates,


   1 	r_mean: -128.3 	r_max: -17.5 	r_min: -262.1
   2 	r_mean: -122.4 	r_max: -17.5 	r_min: -262.1
   3 	r_mean: -125.8 	r_max: -17.5 	r_min: -362.7
   4 	r_mean: -126.4 	r_max: -20.9 	r_min: -362.7
   5 	r_mean: -131.0 	r_max: -20.9 	r_min: -478.9
   6 	r_mean: -134.5 	r_max: -24.8 	r_min: -478.9
   7 	r_mean: -122.2 	r_max: -24.8 	r_min: -379.4
   8 	r_mean: -126.1 	r_max: -18.3 	r_min: -389.8
   9 	r_mean: -114.9 	r_max: -18.3 	r_min: -389.8
  10 	r_mean: -117.2 	r_max: -15.6 	r_min: -389.8
  11 	r_mean: -124.9 	r_max: -15.6 	r_min: -727.0
  12 	r_mean: -125.4 	r_max: -15.6 	r_min: -727.0
  13 	r_mean: -114.3 	r_max: -15.7 	r_min: -341.7
  14 	r_mean: -104.9 	r_max: -15.7 	r_min: -250.7
  15 	r_mean: -101.6 	r_max: -16.0 	r_min: -399.9
  16 	r_mean: -94.6 	r_max: -16.0 	r_min: -399.9
  17 	r_mean: -110.3 	r_max: -17.1 	r_min: -421.6
  18 	r_mean: -112.4 	r_max: -17.1 	r_min: -421.6
  19 	r_mean: -100.2 	r_max: -17.1 	r_min: -274.3
  20 	r_mean: -95.2 	r_max: -24.9 	r_min: -258.5
  

 168 	r_mean: -28.8 	r_max: -12.9 	r_min: -182.6
 169 	r_mean: -29.6 	r_max: -12.5 	r_min: -179.6
 170 	r_mean: -29.5 	r_max: -12.5 	r_min: -238.7
 171 	r_mean: -25.6 	r_max: -12.5 	r_min: -238.7
 172 	r_mean: -29.4 	r_max: -12.3 	r_min: -695.0
 173 	r_mean: -29.7 	r_max: -12.3 	r_min: -695.0
 174 	r_mean: -31.4 	r_max: -12.3 	r_min: -695.0
 175 	r_mean: -25.0 	r_max: -12.5 	r_min: -111.7
 176 	r_mean: -27.6 	r_max: -12.7 	r_min: -151.6
 177 	r_mean: -27.8 	r_max: -12.3 	r_min: -151.6
 178 	r_mean: -23.6 	r_max: -12.3 	r_min: -119.5
 179 	r_mean: -24.1 	r_max: -12.1 	r_min: -141.1
 180 	r_mean: -26.1 	r_max: -12.1 	r_min: -177.7
 181 	r_mean: -24.5 	r_max: -12.1 	r_min: -177.7
 182 	r_mean: -32.7 	r_max: -12.6 	r_min: -329.4
 183 	r_mean: -32.9 	r_max: -12.6 	r_min: -329.4
 184 	r_mean: -30.0 	r_max: -12.3 	r_min: -127.7
 185 	r_mean: -31.2 	r_max: -12.3 	r_min: -127.7
 186 	r_mean: -28.9 	r_max: -12.2 	r_min: -126.8
 187 	r_mean: -27.2 	r_max: -12.2 	r_min: -126.8
 188 	r_mean: -31.6 

 336 	r_mean: -23.2 	r_max: -12.4 	r_min: -202.1
 337 	r_mean: -20.9 	r_max: -12.4 	r_min: -147.4
 338 	r_mean: -23.9 	r_max: -12.3 	r_min: -123.2
 339 	r_mean: -24.7 	r_max: -12.3 	r_min: -202.3
 340 	r_mean: -24.8 	r_max: -12.1 	r_min: -202.3
 345 	r_mean: -28.6 	r_max: -12.2 	r_min: -202.8
 346 	r_mean: -27.6 	r_max: -12.2 	r_min: -202.8
 347 	r_mean: -20.2 	r_max: -12.2 	r_min: -125.5
 348 	r_mean: -19.8 	r_max: -12.6 	r_min: -177.6
 349 	r_mean: -20.4 	r_max: -12.0 	r_min: -177.6
 350 	r_mean: -25.2 	r_max: -12.0 	r_min: -323.6
 351 	r_mean: -24.9 	r_max: -12.0 	r_min: -323.6
 352 	r_mean: -21.4 	r_max: -12.0 	r_min: -213.4
 353 	r_mean: -20.1 	r_max: -12.0 	r_min: -213.4
 354 	r_mean: -20.3 	r_max: -12.2 	r_min: -172.2
 355 	r_mean: -24.2 	r_max: -12.2 	r_min: -176.6
 356 	r_mean: -21.7 	r_max: -12.3 	r_min: -176.6
 357 	r_mean: -21.9 	r_max: -12.3 	r_min: -176.6
 358 	r_mean: -23.1 	r_max: -12.3 	r_min: -168.5
 359 	r_mean: -27.2 	r_max: -12.5 	r_min: -247.7
 360 	r_mean: -28.9 

In [7]:
test = WorldMovelEnv()
import time 

t = time.time()
print(test.reset())
for i in range(100):
    print(test.step(55))
print(t-time.time())

[0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0.]
(array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0.]), -0.023691773, False, {})
(array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0.]), -0.023691773, False, {})
(array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0.]), -0.9225037, False, {})
(array([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.

(array([0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0.]), -0.90259886, False, {})
(array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0.]), -0.03656867, False, {})
(array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0.]), -0.36987838, False, {})
(array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0

(array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 0.,
       0.]), -2.3917186, False, {})
(array([0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
       0.]), -1.7029457, False, {})
(array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
       0.]), -1.7581191, False, {})
(array([0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 

(array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
       0.]), -1.8697165, False, {})
(array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
       0.]), -1.8697165, False, {})
(array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
       0.]), -1.8697165, False, {})
(array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 

In [8]:
trainer.save('results')

'results/checkpoint_000500'

In [9]:
dir(trainer)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_allow_unknown_configs',
 '_allow_unknown_subkeys',
 '_annotated',
 '_automatic_evaluation_duration_fn',
 '_before_evaluate',
 '_by_agent_steps',
 '_checkpoint_info_to_algorithm_state',
 '_close_logfiles',
 '_compile_iteration_results',
 '_counters',
 '_create_checkpoint_dir',
 '_create_local_replay_buffer_if_necessary',
 '_create_logger',
 '_env_id',
 '_episode_history',
 '_episodes_to_be_collected',
 '_episodes_total',
 '_evaluate_async',
 '_evaluation_async_req_manager',
 '_experiment_id',
 '_export_model',
 '_get_env_id_and_creator',
 '_get_latest_available_checkpoint',
 '_get_latest