-
Notifications
You must be signed in to change notification settings - Fork 1
/
submission.py
73 lines (64 loc) · 2.68 KB
/
submission.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# -*- coding:utf-8 -*-
# Time : 2021/5/31 下午4:14
# Author: Yahui Cui
"""
# =================================== Important =========================================
Notes:
1. this agent is random agent , which can fit any env in Jidi platform.
2. if you want to load .pth file, please follow the instruction here:
https://github.com/jidiai/ai_lib/blob/master/examples/demo
"""
def my_controller(observation, action_space, is_act_continuous=False):
agent_action = []
for i in range(len(action_space)):
action_ = sample_single_dim(action_space[i], is_act_continuous)
agent_action.append(action_)
return agent_action
def sample_single_dim(action_space_list_each, is_act_continuous):
each = []
if is_act_continuous:
each = action_space_list_each.sample()
else:
if action_space_list_each.__class__.__name__ == "Discrete":
each = [0] * action_space_list_each.n
idx = action_space_list_each.sample()
each[idx] = 1
elif action_space_list_each.__class__.__name__ == "MultiDiscreteParticle":
each = []
nvec = action_space_list_each.high - action_space_list_each.low + 1
sample_indexes = action_space_list_each.sample()
for i in range(len(nvec)):
dim = nvec[i]
new_action = [0] * dim
index = sample_indexes[i]
new_action[index] = 1
each.extend(new_action)
return each
def sample(action_space_list_each, is_act_continuous):
player = []
if is_act_continuous:
for j in range(len(action_space_list_each)):
each = action_space_list_each[j].sample()
player.append(each)
else:
player = []
for j in range(len(action_space_list_each)):
# each = [0] * action_space_list_each[j]
# idx = np.random.randint(action_space_list_each[j])
if action_space_list_each[j].__class__.__name__ == "Discrete":
each = [0] * action_space_list_each[j].n
idx = action_space_list_each[j].sample()
each[idx] = 1
player.append(each)
elif action_space_list_each[j].__class__.__name__ == "MultiDiscreteParticle":
each = []
nvec = action_space_list_each[j].high
sample_indexes = action_space_list_each[j].sample()
for i in range(len(nvec)):
dim = nvec[i] + 1
new_action = [0] * dim
index = sample_indexes[i]
new_action[index] = 1
each.extend(new_action)
player.append(each)
return player