In [1]:
from ray import rllib, tune
from ray.rllib.contrib.alpha_zero.core.alpha_zero_trainer import AlphaZeroTrainer
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.tune.registry import register_env
from ray.rllib.contrib.alpha_zero.models.custom_torch_models import DenseModel
from ray.rllib.models.catalog import ModelCatalog
import gym
from src.jss_lite.jss_lite import jss_lite
ModelCatalog.register_custom_model("dense_model", DenseModel)
from copy import deepcopy
import numpy as np

train_agent=True


config = {
    "framework": "torch",
    "disable_env_checking":False,
    "num_workers"       : 6,
    "rollout_fragment_length": 50,
    "train_batch_size"  : 500,
    "sgd_minibatch_size": 64,
    "lr"                : 0.0001,
    "horizon"           : 300,
    "soft_horizon"      : True,
    "num_sgd_iter"      : 1,
    "horizon"           : 100,
    "mcts_config"       : {
        "puct_coefficient"   : 1.5,
        "num_simulations"    : 100,
        "temperature"        : 1.0,
        "dirichlet_epsilon"  : 0.20,
        "dirichlet_noise"    : 0.03,
        "argmax_tree_policy" : False,
        "add_dirichlet_noise": False,
    },
    "ranked_rewards"    : {
        "enable": False,
    },
    "model"             : {
        "custom_model": "dense_model",

    },
}

# def env_creator(env_config):
#     env = jss_lite(instance_path='resources/jsp_instances/standard/ft06.txt')
#     return env

from wrapper.jssplight_wrapper import jssp_light_obs_wrapper


def env_creator(config):
    env = jssp_light_obs_wrapper(jss_lite(instance_path='resources/jsp_instances/standard/ft06.txt'))
    return env

ModelCatalog.register_custom_model("dense_model", DenseModel)    




# use tune to register the custom environment for the ppo trainer
tune.register_env('custom_jssp',env_creator)



In [2]:
from src.jss_lite.jss_lite import jss_lite
env = jss_lite(instance_path='resources/jsp_instances/standard/ft06.txt')
state=env.reset()
#print(state)
#print((max(2*env.n_jobs,env.n_machines)*6,))
env=env_creator("test")
state=env.reset()
print(env.observation_space)
print(env.action_space.n)
print(state['action_mask'])
print(env.spec)

Dict(action_mask:Box([0 0 0 0 0 0 0 0 0 0 0 0], [1 1 1 1 1 1 1 1 1 1 1 1], (12,), int32), obs:Box([0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.], [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], (72,), float64))
12
[1 1 1 1 1 1 0 0 0 0 0 0]
None


In [3]:
agent = AlphaZeroTrainer( config=config, env='custom_jssp')
if train_agent:
    # checkpoint_path = analysis.get_last_checkpoint() or args.checkpoint
    ## use string number to restore pre trained agent
    # nr_restore="10"
    #checkpoint_path=f'checkpoints_az/rllib_checkpoint{nr_restore}/checkpoint_{nr_restore.zfill(6)}/checkpoint-{nr_restore}'
    #agent.load_checkpoint("checkpoints_az/rllib_checkpoint1")
    #print("awd")
    #agent.restore("checkpoints_az/rllib_checkpoint1/checkpoint_000001/checkpoint-1")
    #agent.load_checkpoint("published_checkpoints/az_taxi/checkpoint-34")
    print("start training")
    for _ in range(0,20):
        agent.train()
        print(f"training iteration {_} finished")
        agent.save_checkpoint(f"training_checkpoints/checkpoints_az_jsslite")
    




start training
training iteration 0 finished
training iteration 1 finished
training iteration 2 finished
training iteration 3 finished
training iteration 4 finished
training iteration 5 finished
training iteration 6 finished
training iteration 7 finished
training iteration 8 finished
training iteration 9 finished
training iteration 10 finished
training iteration 11 finished
training iteration 12 finished
training iteration 13 finished
training iteration 14 finished
training iteration 15 finished
training iteration 16 finished
training iteration 17 finished
training iteration 18 finished
training iteration 19 finished


In [8]:
import time
length_list=[]
reward_list=[]
for _ in range(10):
    policy = agent.get_policy(DEFAULT_POLICY_ID)
    action_list=[]
    env = jss_lite(instance_path='resources/jsp_instances/standard/ft06.txt')

    obs = env.reset()
    # env2 is copy for later going evaluation
    env2=deepcopy(env)

    episode = MultiAgentEpisode(
        PolicyMap(0,0),
        lambda _, __: DEFAULT_POLICY_ID,
        lambda: None,
        lambda _: None,
        0,
    )

    episode.user_data['initial_state'] = env.get_state()

    done = False

    while not done:
        action, _, _ = policy.compute_single_action(obs, episode=episode)
        action_list.append(action)
        #print(action_dic[action])
        obs, reward, done, _ = env.step(action)
        #print(obs)
        #env.render(render_mode='human')
        #time.sleep(0.1)
        episode.length += 1

    length_list.append(episode.length)
    reward_list.append(reward)
    env.close()

In [10]:
env.render()
print(done)
print(reward)

True
-78


In [16]:
from random import randrange
for _ in range(30):
    state=env.reset()
    done=False
    i=0
    #print(state['obs'].shape)
    #for i in range(150):
    while not done:
        i+=1
        legal_action=env.get_legal_actions(state)
        #action=np.random.choice(len(legal_action), 1, p=(legal_action / legal_action.sum()))[0]
        action=randrange(env.action_space.n/2)
        #print(action)
        token=True
        #print("before:")
        #env.render(start_count=1,x_bar="Job",y_bar="Machine")
        state, reward, done, info=env.step(action)
    #env.render()
    print(reward)


non-integer arguments to randrange() have been deprecated since Python 3.10 and will be removed in a subsequent version



-75
