In [1]:
from ray import rllib, tune
from ray.rllib.contrib.alpha_zero.core.alpha_zero_trainer import AlphaZeroTrainer
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.tune.registry import register_env
from ray.rllib.contrib.alpha_zero.models.custom_torch_models import DenseModel
from ray.rllib.models.catalog import ModelCatalog
import gym
from src.jss_lite.jss_lite import jss_lite
ModelCatalog.register_custom_model("dense_model", DenseModel)
from copy import deepcopy
import numpy as np

train_agent=True
instance_path='resources/jsp_instances/standard/ft06.txt'
restore_agent= False
num_episodes = 10
restore_path= 'published_checkpoints/checkpoints_az_jsslite/checkpoint-255'
config = {
    "framework": "torch",
    "disable_env_checking":True,
    "num_workers"       : 6,
    "rollout_fragment_length": 50,
    "train_batch_size"  : 500,
    "sgd_minibatch_size": 64,
    "lr"                : 0.0001,
    #"horizon"           : 600,
    #"soft_horizon"      : True,
    "num_sgd_iter"      : 1,
    "horizon"           : 100,
    "mcts_config"       : {
        "puct_coefficient"   : 1.5,
        "num_simulations"    : 100,
        "temperature"        : 1.0,
        "dirichlet_epsilon"  : 0.20,
        "dirichlet_noise"    : 0.03,
        "argmax_tree_policy" : False,
        "add_dirichlet_noise": False,
    },
    "ranked_rewards"    : {
        "enable": True,
    },
    "model"             : {
        "custom_model": "dense_model",

    },
}

# def env_creator(env_config):
#     env = jss_lite(instance_path='resources/jsp_instances/standard/ft06.txt')
#     return env

from wrapper.jssplight_wrapper import jssp_light_obs_wrapper


def env_creator(config):
    env = jssp_light_obs_wrapper(jss_lite(instance_path=instance_path))
    return env

ModelCatalog.register_custom_model("dense_model", DenseModel)    




# use tune to register the custom environment for the ppo trainer
tune.register_env('custom_jssp',env_creator)

agent = AlphaZeroTrainer( config=config, env='custom_jssp')

2022-09-28 12:45:35,741	INFO trainable.py:159 -- Trainable.setup took 24.268 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.


In [2]:
if restore_agent:

    agent.load_checkpoint(restore_path)

In [3]:
# env = env_creator("setting")
# config.update(
#     simple_optimizer=True,
#     num_workers=0,
#     train_batch_size=0,
#     rollout_fragment_length=0,
#     timesteps_per_iteration=0,
#     evaluation_interval=1,
#     # evaluation_num_workers=...,
#     # evaluation_config=dict(explore=False),
#     # evaluation_num_episodes=...,
# )

# results = tune.run(
#     agent,
#     config=config)

In [4]:

import time
if train_agent:
    # checkpoint_path = analysis.get_last_checkpoint() or args.checkpoint
    ## use string number to restore pre trained agent
    # nr_restore="10"
    #checkpoint_path=f'checkpoints_az/rllib_checkpoint{nr_restore}/checkpoint_{nr_restore.zfill(6)}/checkpoint-{nr_restore}'
    #agent.load_checkpoint("checkpoints_az/rllib_checkpoint1")
    #print("awd")
    #agent.restore("checkpoints_az/rllib_checkpoint1/checkpoint_000001/checkpoint-1")
    #agent.load_checkpoint("published_checkpoints/az_taxi/checkpoint-34")
    print("start training")
    for _ in range(0,num_episodes):
        t=time.time()
        agent.train()
        print(f"training iteration {_} finished after {time.time()-t} seconds")
        agent.save_checkpoint(f"training_checkpoints/checkpoints_az_jsslite")
    

start training
training iteration 0 finished after 116.73885703086853 seconds
training iteration 1 finished after 76.65271878242493 seconds
training iteration 2 finished after 63.644580125808716 seconds
training iteration 3 finished after 102.0259461402893 seconds
training iteration 4 finished after 431.44853138923645 seconds


In [None]:


import time
length_list=[]
reward_list=[]
for _ in range(1):
    policy = agent.get_policy(DEFAULT_POLICY_ID)
    action_list=[]
    env = env_creator("s")

    obs = env.reset()
    # env2 is copy for later going evaluation
    env2=deepcopy(env)

    episode = MultiAgentEpisode(
        PolicyMap(0,0),
        lambda _, __: DEFAULT_POLICY_ID,
        lambda: None,
        lambda _: None,
        0,
    )

    episode.user_data['initial_state'] = env.get_state()

    done = False

    while not done:
        action, _, _ = policy.compute_single_action(obs, episode=episode)
        action_list.append(action)
        #print(action_dic[action])
        obs, reward, done, _ = env.step(action)
        #print(obs)
        #env.render(render_mode='human')
        #time.sleep(0.1)
        episode.length += 1

    length_list.append(episode.length)
    reward_list.append(reward)
    #env.close()

In [None]:
env.render()

In [None]:
print(reward)

In [None]:
print(env.env.current_timestep)


In [None]:
print(env.env.done)

In [None]:
print(env.env.production_list)

In [None]:
obs_1=env.reset()
#print(obs_1)

env_1=env_creator("a")
env_2=env_creator("b")
state_1=env_1.reset()
state_2=env_2.reset()

for a in a_list:
    state_1, reward_1, done_1, info_1 = env_1.step(a)
    state_2, reward_2, done_2, info_2 = env_2.step(a)

    if np.array_equal(state_1['obs'],state_2['obs'])==False :
        print("error obs")
        print(state_1)
        print(state_2)
    if np.array_equal(state_1['action_mask'],state_2['action_mask'])==False :
        print("error mask")
        print(state_1)
        print(state_2)

In [None]:
for i in range(5):
    
    agent.load_checkpoint('training_checkpoints/checkpoints_az_jsslite/checkpoint-'+str(i+1))
    policy = agent.get_policy(DEFAULT_POLICY_ID)
    action_list=[]
    env = env_creator("s")

    obs = env.reset()
    # env2 is copy for later going evaluation
    #env2=deepcopy(env)

    episode = MultiAgentEpisode(
        PolicyMap(0,0),
        lambda _, __: DEFAULT_POLICY_ID,
        lambda: None,
        lambda _: None,
        0,
    )

    episode.user_data['initial_state'] = env.get_state()

    done = False
    steps=0
    t=time.time()
    while not done:
        action, _, _ = policy.compute_single_action(obs, episode=episode)
        action_list.append(action)
        #print(action_dic[action])
        obs, reward, done, _ = env.step(action)
        #print(obs)
        #env.render(render_mode='human')
        #time.sleep(0.1)
        steps+=1
    print(f"checkpoint {i} got reward {reward} in {steps} steps and time: {time.time()-t}")

In [None]:
env.render()

In [None]:
print(env.env.job_tasklength_matrix)

In [None]:
print(reward)
print(env.env.count_finished_tasks_job_matrix)
env.render()
env.render(y_bar="Machine",x_bar="Job")


In [None]:
env.render()
#env.render(y_bar="Machine",x_bar="Job")
print(done)
print(reward)
print(env.invalid_actions)

In [None]:
env.reset()
from random import randrange
a_list=[]
for _ in range(30):
    state=env.reset()
    done=False
    i=0
    reward=0
    #print(state['obs'].shape)
    #for i in range(150):
    while not done:
        i+=1
        legal_action=state['action_mask']
        #print(legal_action)
        action=np.random.choice(len(legal_action), 1, p=(legal_action / legal_action.sum()))[0]
        a_list.append(action)
        #action=randrange(env.action_space.n/2)
        #print(action)
        token=True
        #print("before:")
        #env.render(start_count=1,x_bar="Job",y_bar="Machine")
        state, reward, done, info=env.step(action)
    #env.render()
    print(reward)
    
    if reward > -68:
        print(env.env.get_legal_actions("yy"))
        print(env.done)
        print(env.env.done)
        print(env.env.production_list)
        env.env.render()


In [None]:
print(env.env.get_legal_actions("a"))
print(env.env.current_timestep)
print(env.env.current_machines_status)
print(env.env.processed_and_max_time_job_matrix)

In [None]:
print(env.env.get_legal_actions("a"))
print(env.env.blocked_actions)

In [None]:

env.render()

In [None]:
print(env.env.blocked_actions)
print(env.env.get_legal_actions("obs"))


In [None]:
x=[[1,2,None],[1,2,3]]
if any(x==None for x in x):
        print(True)
for row in x:
    if None in row:
        #print(True)
        pass

In [None]:
print(env.env.current_machines_status)
print(env.env.get_legal_actions("stat"))

In [None]:
env.render()