In [1]:
import numpy as np
import torch

In [2]:
class Tester:
    def __init__(self):
        self.pos_map = {0: (-2, -2), 1: (-2, 2), 2: (-1, -1), 3: (-1, 1), 4: (1, -1), 5: (1, 1), 6: (2, -2), 7: (2, 2)}
        self.inv_pos_map = {v: k for k, v in self.pos_map.items()}
        
    def action_dict_to_array(self, action_dict):
        """
        Converts the legal moves from the action dictionary into an array that the computer can easily handle
        :param action_dict:
        :return:
        """
        available_actions = np.zeros(64 * 8)
        available_pieces = list(action_dict.keys())
        col_size = 8
        for i, piece in enumerate(available_pieces):
            legal_moves_for_piece = action_dict[piece]
            for j, available_action in enumerate(legal_moves_for_piece):
                unraveled_piece_position = (piece[0] * col_size + piece[1])
                movement = (available_action[0] - piece[0], available_action[1] - piece[1])
                pos = self.inv_pos_map[movement]
                available_actions[unraveled_piece_position * 8 + pos] = 1
        return available_actions

    def interpret_action(self, action):
        """
        Interprets the action (an int) and changes that to a selected piece and selected move
        :param action:
        :return:
        """
        unraveled_piece_pos = action // 8
        selected_piece = (unraveled_piece_pos // 8, unraveled_piece_pos % 8)

        unraveled_move_pos = action % 8
        i, j = self.pos_map[unraveled_move_pos]

        selected_move = (selected_piece[0] + i,  selected_piece[1] + j)

        return selected_piece, selected_move

In [3]:
t = Tester()
action_dict = {(0, 6): [], (1, 5): [(0, 4), (2, 4)], (1, 7): [], (2, 6): [], (3, 5): [(2, 4), (4, 4)], (3, 7): [], (4, 6): [], (5, 5): [(4, 4), (6, 4)], (5, 7): [], (6, 6): [], (7, 5): [(6, 4)], (7, 7): []}
action_array = t.action_dict_to_array(action_dict)
action = np.random.choice(np.where(action_array == 1)[0])
print(action_dict)
t.interpret_action(action)

{(0, 6): [], (1, 5): [(0, 4), (2, 4)], (1, 7): [], (2, 6): [], (3, 5): [(2, 4), (4, 4)], (3, 7): [], (4, 6): [], (5, 5): [(4, 4), (6, 4)], (5, 7): [], (6, 6): [], (7, 5): [(6, 4)], (7, 7): []}


((5, 5), (4, 4))

In [4]:
state_rep = [['[]', '[]', 'B-', '[]', '[]', '[]', '[]', '[]'],
 ['[]', 'R-', '[]', '[]', '[]', '[]', '[]', 'RK'],
 ['[]', '[]', 'RK', '[]', '[]', '[]', '[]', '[]'],
 ['[]', 'BK', '[]', 'B-', '[]', '[]', '[]', '[]'],
 ['[]', '[]', '[]', '[]', '[]', '[]', '[]', '[]'],
 ['[]', '[]', '[]', '[]', '[]', '[]', '[]', '[]'],
 ['[]', '[]', 'B-', '[]', '[]', '[]', '[]', '[]'],
 ['[]', 'BK', '[]', 'BK', '[]', '[]', '[]', '[]']]
state_rep

[['[]', '[]', 'B-', '[]', '[]', '[]', '[]', '[]'],
 ['[]', 'R-', '[]', '[]', '[]', '[]', '[]', 'RK'],
 ['[]', '[]', 'RK', '[]', '[]', '[]', '[]', '[]'],
 ['[]', 'BK', '[]', 'B-', '[]', '[]', '[]', '[]'],
 ['[]', '[]', '[]', '[]', '[]', '[]', '[]', '[]'],
 ['[]', '[]', '[]', '[]', '[]', '[]', '[]', '[]'],
 ['[]', '[]', 'B-', '[]', '[]', '[]', '[]', '[]'],
 ['[]', 'BK', '[]', 'BK', '[]', '[]', '[]', '[]']]

In [5]:
import torch
import time


In [20]:
# %%time
player = 1
state = torch.zeros((4*4,8,8))
for i in range(1):
#     n, r, c = len(state_reps), len(state_reps[0]), len(state_reps[0][0])

#     for k, state_rep in enumerate(state_reps):
    k, n = 0, 1
    for i in range(8):
        for j in range(8):
            if state_rep[i][j] == 'B-':
                state[0 + 2*k, i, j] = 1.
            elif state_rep[i][j] == 'BK':
                state[1 + 2*k, i, j] = 1.
            elif state_rep[i][j] == 'R-':
                state[0 + 2*(k + 4), i, j] = 1.
            elif state_rep[i][j] == 'RK':
                state[1 + 2*(k + 4), i, j] = 1.
                
    if player == 1:
        state = torch.cat([state, torch.ones((1, 8, 8))], dim=0)
    else:
        state = torch.cat([state, torch.zeros((1, 8, 8))], dim=0)

    # Pad the current state if we have less observations
#     padding = self.observations_per_state - n
#     if padding > 0:
#         state = torch.cat([state, torch.zeros(padding, 8, 8)], dim=0)
        
[print(state[i]) for i in range(17)]

tensor([[0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 1., 0., 0., 0., 0.]])
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.,

[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None]

In [4]:
import pickle
def load_game(game_file):
    with open(game_file, 'rb') as handle:
        shape, indices, values = pickle.load(handle)
    loaded_game = torch.zeros(shape)
    loaded_game[indices[:, 0], indices[:, 1]] = values

    return loaded_game

game = load_game('./game_data/game_0_agent_0.pickle')


RuntimeError: expected device cpu but got device cuda:0