-
Notifications
You must be signed in to change notification settings - Fork 0
/
policy_jy.py
55 lines (48 loc) · 1.78 KB
/
policy_jy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# fixes import troubles
import sys, os
sys.path.append('./picomino_play')
import logging
from picomino_play import PicominoNNet, PicominoState, PicominoActions, Replay
from policy import Policy
log = logging.getLogger('policy_jy')
class PolicyJY(Policy):
def __init__(self):
self.nnet = PicominoNNet()
self.nnet.load_checkpoint('picomino_play/temp','best.pth.')
@property
def q(self):
import collections
q = collections.namedtuple('Q', 'fname')
q.fname = 'picomino_play/temp/best.pth.'
return q
def play(self, state):
# State -> PicominoState
ps = PicominoState()
ps.remaining_pico = state.stash
ps.player1_pico_stack = state.player
ps.player2_pico_stack = state.opponent
ps.active_player = 1
ps.remaining_dices = sum(state.roll)
ps.rolled_dices = state.roll[1:] + [state.roll[0]]
ps.kept_dices = [1 if n else 0 for n in (state.dices[1:] + [state.dices[0]])]
ps.current_score = state.total()
ps.nb_turns = 1
ps.player1_nb_lost_picos = 0
ps.player2_nb_lost_picos = 0
# choose action
# ps.display()
possible_action_ids = ps.getValidActions()
log.debug('possible_action_ids: %s', possible_action_ids)
if not possible_action_ids:
return -1, [], []
best_action_id = self.nnet.predict_best_action_id(ps, possible_action_ids)
action = PicominoActions.action_indexes[best_action_id]
# print(action.display())
# convert action
dice = action.dice_type + 1
# take dice and reroll
act = dice if dice<6 else 0
if not isinstance(action, Replay):
# take or steal tile
act = act + 6
return act, [], []