In [1]:
from mcts_model import AlphaZeroNN, ReplayBuffer, MCTS_Policy, Trainer, generate_lbfs_start_states, CONFIG, NN_Policy
from mcts_worker import run_episode_worker
from mcts_NetworkClasses import ExtendedSixClassNetwork, LBFSPolicy

In [13]:
lbfs_start_states = generate_lbfs_start_states(num_states=5000, warmup=100000, separation=2500)

--- Generating 5000 Warm Start States using LBFS ---
Warming up for 100000 time units...
  Captured 50/5000 states...
  Captured 100/5000 states...
  Captured 150/5000 states...
  Captured 200/5000 states...
  Captured 250/5000 states...
  Captured 300/5000 states...
  Captured 350/5000 states...
  Captured 400/5000 states...
  Captured 450/5000 states...
  Captured 500/5000 states...
  Captured 550/5000 states...
  Captured 600/5000 states...
  Captured 650/5000 states...
  Captured 700/5000 states...
  Captured 750/5000 states...
  Captured 800/5000 states...
  Captured 850/5000 states...
  Captured 900/5000 states...
  Captured 950/5000 states...
  Captured 1000/5000 states...
  Captured 1050/5000 states...
  Captured 1100/5000 states...
  Captured 1150/5000 states...
  Captured 1200/5000 states...
  Captured 1250/5000 states...
  Captured 1300/5000 states...
  Captured 1350/5000 states...
  Captured 1400/5000 states...
  Captured 1450/5000 states...
  Captured 1500/5000 states...
 

In [2]:
import torch

# Define the dimensions exactly as they were during training
# (For your L=2 network, state_size=10 and action_space_size=7)
state_size = CONFIG["MAX_QUEUES_STATE"] + CONFIG["L"]

# A safe way to get the action size without guessing is to init a dummy policy
temp_policy = MCTS_Policy(None, CONFIG) 
action_space_size = len(temp_policy.master_action_list)

# 1. Create a fresh, empty model (random weights)
loaded_model = AlphaZeroNN(state_size, action_space_size)

# 2. Load the dictionary of numbers from the file
path_to_file = "SUCCESS.pth" # Change to your filename
weights = torch.load(path_to_file, map_location=CONFIG["device"])

# 3. Pour the weights into the model
loaded_model.load_state_dict(weights)

# 4. Set to Evaluation Mode (Important for Inference)
loaded_model.eval()
loaded_model.to(CONFIG["device"])

print("Model loaded successfully!")

Model loaded successfully!


  weights = torch.load(path_to_file, map_location=CONFIG["device"])


In [15]:
trainer = Trainer(CONFIG, model = loaded_model, start_states = lbfs_start_states)
trainer.run_training_loop_parallel(num_workers=4)


--- PARALLEL TRAINING: 100 loops x 35 eps ---

--- LOOP 1/100 ---
  > Episode 35/35 fin. (Score: -0.676, Size: 25.95)
  Generated 60103 samples.
  Avg Score: -0.5011
  Avg Sys Size: 20.0557
  Avg Episode Time: 166.65s
  Buffer Size: 60103
  Avg Loss: 1.1121

--- LOOP 2/100 ---
  > Episode 35/35 fin. (Score: -0.680, Size: 26.17)
  Generated 59758 samples.
  Avg Score: -0.5449
  Avg Sys Size: 22.4533
  Avg Episode Time: 189.15s
  Buffer Size: 100000
  Avg Loss: 1.1948

--- LOOP 3/100 ---
  > Episode 35/35 fin. (Score: -0.555, Size: 20.25)
  Generated 59198 samples.
  Avg Score: -0.4883
  Avg Sys Size: 18.8162
  Avg Episode Time: 170.71s
  Buffer Size: 100000
  Avg Loss: 1.2881

--- LOOP 4/100 ---
  > Episode 35/35 fin. (Score: -0.500, Size: 18.08)
  Generated 59922 samples.
  Avg Score: -0.5284
  Avg Sys Size: 20.5934
  Avg Episode Time: 183.55s
  Buffer Size: 100000
  Avg Loss: 1.2443

--- LOOP 5/100 ---
  > Episode 35/35 fin. (Score: -0.882, Size: 39.47)
  Generated 58995 samples.
  A

trainer.save_model(filepath = "SUCCESS.pth")

In [7]:
# Policy Evaluation (NN Only)
import random
policy = NN_Policy(loaded_model)
net = ExtendedSixClassNetwork(policy=policy, L=2, seed=random.randint(100000, 999999))
net.run_and_get_batch_means_stats(
        warmup_time= 20000,
        num_batches= 500,
        batch_duration= 300000,
        include_service = True) 

Running warmup for 20000 time units...
Warmup complete. Starting batch means measurement...
Measurement complete.


{'mean_jobs_in_system': 16.57973327695953,
 'ci_half_width': 0.1372827426719885,
 'reported_mean': 16.57973327695953,
 'reported_ci_half_width': 0.1372827426719885,
 'num_batches': 500}