In [1]:
import gym
import copy
import numpy as np
import torch
import sys 
sys.path.append("..")
import tools
from tools import get_args, registration_envs,load_policy
from model import DRL_GAT
registration_envs()

In [19]:
class usage():
    def __init__(self):
        self.action=None
        args = get_args()
        self.num_processes = args.num_processes
        self.internal_node_holder = args.internal_node_holder
        self.leaf_node_holder = args.leaf_node_holder
        self.env = gym.make(args.id,
                           setting = args.setting,
                           item_set = args.item_size_set,
                           container_size=args.container_size,
                           data_name=args.dataset_path,
                           load_test_data=args.load_dataset,
                           internal_node_holder = args.internal_node_holder,
                           leaf_node_holder = args.leaf_node_holder,
                           LNES = args.lnes,
                           shuffle=args.shuffle,
                           sample_from_distribution=args.sample_from_distribution,
                           sample_left_bound=args.sample_left_bound,
                           sample_right_bound=args.sample_right_bound
                           )
        if args.no_cuda:
            self.device = torch.device('cpu')
        else:
            self.device = torch.device('cuda', args.device)
            torch.cuda.set_device(self.device)
            
        self.PCT_policy = DRL_GAT(args)
        #self.PCT_policy = self.PCT_policy.to(self.device)
        self.PCT_policy = load_policy('../pretrained_models/PCT_setting1.pt', self.PCT_policy)
        self.PCT_policy.eval()
        self.reset()
    def reset(self):
        self.action=None
        self.obs=self.env.reset()
        self.env.box_creator.preview(1000)
        
    def go(self,box):
        self.env.box_creator.box_list[0]=box
        self.obs=self.env.cur_observation()
        obs = torch.FloatTensor(self.obs).unsqueeze(dim=0)
        batchX = torch.arange(self.num_processes)
        all_nodes, leaf_nodes = tools.get_leaf_nodes_with_factor(obs, 
                                                                 self.num_processes,
                                                                 self.internal_node_holder,
                                                                 self.leaf_node_holder)
        with torch.no_grad():
            selectedlogProb, selectedIdx, policy_dist_entropy, value = self.PCT_policy(all_nodes)
        act_node = leaf_nodes[batchX, selectedIdx.squeeze()].cpu().numpy()[0][0:6]
        obs, r, done, dt = self.env.step(act_node, 'testMCTS')
        self.action = dt['box_action']
        print(self.action)

        if done:
            self.reset()
            return True,self.action,[dt['ratio'], dt['counter'], dt['reward']]
        else:
            return False,self.action,None

# Generate boxes sequence

In [3]:
import pickle
boxes = []
for _ in range(10000):
    box=(np.random.randint(1,6),np.random.randint(1,6),np.random.randint(1,6))
    boxes.append(box)
with open('box_seq.plk','wb') as f:
    pickle.dump(boxes,f)
    f.close()

In [9]:
import pickle
boxes = []
with open('box_seq.plk','rb') as f:
    boxes = pickle.load(f)
    f.close()

# experiment

In [None]:
import time
import csv
agent = usage()
results = []
actions = []
for no, box in enumerate(boxes):
    is_done,action,result=agent.go(box)
    if action != None:
        actions.append(action)
    if is_done:
        t = time.strftime('%Y.%m.%d-%H-%M-%S', time.localtime(time.time()))
        with open(f'actions/{t}.csv','w') as f:
            csv_writer = csv.writer(f)
            for act in actions:
                csv_writer.writerow(act)
            f.close()
        actions=[]
        
        print(result)
        results.append(result)
        if len(results)>=30:
            break



../pretrained_models/PCT_setting1.pt
Loading pre-train upper model ../pretrained_models/PCT_setting1.pt
tensor([[3]])
[ 0.  5.  0.  4. 10. 10.]
(4, 5, 5, 0, 5, 0)
tensor([[2]])
[ 9.  7.  0. 10. 10. 10.]
(1, 3, 4, 9, 7, 0)
tensor([[17]])
[ 8.  7.  0.  9. 10. 10.]
(1, 3, 4, 8, 7, 0)
tensor([[10]])
[ 7.  5.  0.  8. 10. 10.]
(1, 5, 4, 7, 5, 0)
tensor([[9]])
[ 4.  5.  0.  5. 10. 10.]
(1, 5, 5, 4, 5, 0)
tensor([[17]])
[ 6.  5.  0.  7. 10. 10.]
(1, 5, 3, 6, 5, 0)
tensor([[25]])
[ 6.  0.  0. 10.  5. 10.]
(4, 5, 4, 6, 0, 0)
tensor([[18]])
[ 7.  5.  4. 10. 10. 10.]
(3, 5, 5, 7, 5, 4)
tensor([[32]])
[ 5.  7.  0.  6. 10. 10.]
(1, 3, 1, 5, 7, 0)
tensor([[39]])
[ 5.  6.  3.  7. 10. 10.]
(2, 4, 3, 5, 6, 3)
tensor([[33]])
[ 7.  7.  9. 10. 10. 10.]
(3, 3, 1, 7, 7, 9)
tensor([[0]])
[ 5.  4.  0.  6.  6. 10.]
(1, 2, 5, 5, 4, 0)
tensor([[34]])
[ 0.  6.  5.  3. 10. 10.]
(3, 4, 4, 0, 6, 5)
tensor([[44]])
[ 0.  0.  0.  4.  5. 10.]
(4, 5, 3, 0, 0, 0)
tensor([[21]])
[ 7.  0.  4. 10.  5. 10.]
(3, 5, 5, 7, 0, 4)


tensor([[27]])
[ 5.  6.  4.  7.  9. 10.]
(2, 3, 2, 5, 6, 4)
tensor([[25]])
[ 4.  6.  7.  8. 10. 10.]
(4, 4, 3, 4, 6, 7)
tensor([[10]])
[ 3.  6.  4.  4. 10. 10.]
(1, 4, 3, 3, 6, 4)
tensor([[46]])
[ 5.  3.  0.  6.  5. 10.]
(1, 2, 5, 5, 3, 0)
tensor([[27]])
[ 7.  0.  4. 10.  5. 10.]
(3, 5, 4, 7, 0, 4)
tensor([[3]])
[ 4.  0.  5.  7.  5. 10.]
(3, 5, 4, 4, 0, 5)
tensor([[0]])
[ 2.  0.  5.  4.  4. 10.]
(2, 4, 5, 2, 0, 5)
tensor([[2]])
[ 7.  3.  8. 10.  6. 10.]
(3, 3, 2, 7, 3, 8)
tensor([[48]])
[0. 0. 0. 0. 0. 0.]
(3, 5, 5, 0, 0, 10)
[0.777, 26, 7.7700000000000005]
tensor([[19]])
[ 0.  8.  0.  1. 10. 10.]
(1, 2, 3, 0, 8, 0)
tensor([[13]])
[ 0.  0.  0.  2.  3. 10.]
(2, 3, 1, 0, 0, 0)
tensor([[40]])
[ 8.  0.  0. 10.  1. 10.]
(2, 1, 2, 8, 0, 0)
tensor([[0]])
[ 1.  5.  0.  4. 10. 10.]
(3, 5, 3, 1, 5, 0)
tensor([[8]])
[ 7.  7.  0. 10. 10. 10.]
(3, 3, 3, 7, 7, 0)
tensor([[1]])
[ 6.  5.  0.  7. 10. 10.]
(1, 5, 2, 6, 5, 0)
tensor([[36]])
[ 9.  6.  0. 10.  7. 10.]
(1, 1, 1, 9, 6, 0)
tensor([[46]])
[ 8.

tensor([[40]])
[ 5.  6.  1.  6.  7. 10.]
(1, 1, 2, 5, 6, 1)
tensor([[12]])
[ 5.  6.  3.  6.  7. 10.]
(1, 1, 3, 5, 6, 3)
tensor([[32]])
[ 9.  3.  4. 10.  5. 10.]
(1, 2, 2, 9, 3, 4)
tensor([[15]])
[ 7.  2.  6. 10.  5. 10.]
(3, 3, 3, 7, 2, 6)
tensor([[1]])
[ 7.  0.  4. 10.  2. 10.]
(3, 2, 5, 7, 0, 4)
tensor([[17]])
[ 4.  5.  0.  5.  6. 10.]
(1, 1, 2, 4, 5, 0)
tensor([[0]])
[ 0.  1.  3.  3.  6. 10.]
(3, 5, 5, 0, 1, 3)
tensor([[17]])
[ 6.  5.  4.  8.  6. 10.]
(2, 1, 5, 6, 5, 4)
tensor([[8]])
[ 4.  0.  5.  7.  4. 10.]
(3, 4, 3, 4, 0, 5)
tensor([[12]])
[ 1.  5.  9.  4. 10. 10.]
(3, 5, 1, 1, 5, 9)
tensor([[26]])
[ 4.  5.  2.  5.  6. 10.]
(1, 1, 2, 4, 5, 2)
tensor([[31]])
[ 8.  6.  9.  9. 10. 10.]
(1, 4, 1, 8, 6, 9)
tensor([[6]])
[ 5.  0.  9. 10.  5. 10.]
(5, 5, 1, 5, 0, 9)
tensor([[2]])
[ 3.  0.  0.  4.  1. 10.]
(1, 1, 4, 3, 0, 0)
tensor([[0]])
[ 2.  0.  8.  5.  5. 10.]
(3, 5, 2, 2, 0, 8)
tensor([[35]])
[0. 0. 0. 0. 0. 0.]
(5, 5, 5, 0, 0, 10)
[0.814, 42, 8.139999999999999]
tensor([[1]])
[ 0.  

tensor([[11]])
[ 7.  0.  4. 10.  5. 10.]
(3, 5, 4, 7, 0, 4)
tensor([[23]])
[ 4.  6.  3.  5. 10. 10.]
(1, 4, 4, 4, 6, 3)
tensor([[7]])
[ 6.  0.  4.  7.  5. 10.]
(1, 5, 4, 6, 0, 4)
tensor([[21]])
[ 0.  6.  7.  1. 10. 10.]
(1, 4, 3, 0, 6, 7)
tensor([[10]])
[ 6.  5.  8.  7. 10. 10.]
(1, 5, 2, 6, 5, 8)
tensor([[5]])
[ 0.  0.  5.  3.  5. 10.]
(3, 5, 3, 0, 0, 5)
tensor([[46]])
[0. 0. 0. 0. 0. 0.]
(5, 4, 5, 0, 0, 8)
[0.747, 26, 7.47]
tensor([[8]])
[ 0.  6.  0.  1. 10. 10.]
(1, 4, 2, 0, 6, 0)
tensor([[7]])
[ 9.  9.  0. 10. 10. 10.]
(1, 1, 4, 9, 9, 0)
tensor([[7]])
[ 7.  6.  0.  9. 10. 10.]
(2, 4, 2, 7, 6, 0)
tensor([[30]])
[ 5.  6.  0.  7. 10. 10.]
(2, 4, 2, 5, 6, 0)
tensor([[39]])
[ 9.  7.  0. 10.  9. 10.]
(1, 2, 4, 9, 7, 0)
tensor([[13]])
[ 4.  6.  0.  5. 10. 10.]
(1, 4, 4, 4, 6, 0)
tensor([[39]])
[ 3.  5.  0.  4. 10. 10.]
(1, 5, 1, 3, 5, 0)
tensor([[48]])
[ 1.  5.  0.  2. 10. 10.]
(1, 5, 5, 1, 5, 0)
tensor([[43]])
[ 2.  6.  0.  3. 10. 10.]
(1, 4, 5, 2, 6, 0)
tensor([[10]])
[ 9.  6.  0. 10.  

tensor([[18]])
[ 7.  6.  7.  9. 10. 10.]
(2, 4, 3, 7, 6, 7)
tensor([[31]])
[ 6.  5.  5.  7. 10. 10.]
(1, 5, 4, 6, 5, 5)
tensor([[25]])
[ 9.  3.  4. 10.  5. 10.]
(1, 2, 3, 9, 3, 4)
tensor([[17]])
[ 0.  5.  4.  3. 10. 10.]
(3, 5, 5, 0, 5, 4)
tensor([[39]])
[ 3.  6.  4.  4. 10. 10.]
(1, 4, 5, 3, 6, 4)
tensor([[3]])
[ 4.  6.  5.  6. 10. 10.]
(2, 4, 2, 4, 6, 5)
tensor([[14]])
[ 0.  0.  0.  4.  5. 10.]
(4, 5, 4, 0, 0, 0)
tensor([[7]])
[ 4.  0.  0.  7.  5. 10.]
(3, 5, 3, 4, 0, 0)
tensor([[40]])
[ 5.  8.  7.  6. 10. 10.]
(1, 2, 2, 5, 8, 7)
tensor([[26]])
[ 6.  0.  4.  9.  4. 10.]
(3, 4, 5, 6, 0, 4)
tensor([[24]])
[ 4.  0.  3.  6.  5. 10.]
(2, 5, 4, 4, 0, 3)
tensor([[11]])
[ 4.  3.  7.  6.  8. 10.]
(2, 5, 3, 4, 3, 7)
tensor([[17]])
[ 9.  0.  7. 10.  4. 10.]
(1, 4, 3, 9, 0, 7)
tensor([[7]])
[ 7.  4.  5.  9.  6. 10.]
(2, 2, 5, 7, 4, 5)
tensor([[33]])
[ 5.  8.  9.  7. 10. 10.]
(2, 2, 1, 5, 8, 9)
tensor([[2]])
[ 4.  0.  7.  6.  3. 10.]
(2, 3, 3, 4, 0, 7)
tensor([[5]])
[ 0.  0.  4.  3.  5. 10.]
(3, 