In [1]:
import numpy as np
import torch
device = 'cuda' if torch.cuda.is_available else 'cpu'

In [2]:
class Game:
    
    def __init__(self, player1, player2, size=3, n_dim=2):
        assert(type(n_dim) is int and n_dim >= 2), "wrong n_dim"
        self.n_dim = n_dim
        self.size = size
        self.player1 = player1
        self.player2 = player2
        self.board = torch.zeros([size]*n_dim, dtype=int).to(device)
    
    def score(self):
        def slice_to_mask(L, size, n_dim):
            mask = torch.zeros([size]*n_dim, dtype=int).to(device).bool()
            dim = L.index(-1)
            for tile in range(size):
                L[dim] = tile
                mask[tuple(L)] = True
            return mask
        
        score_p1 = 0
        score_p2 = 0
        all_axis = []
        for d in range(self.size ** self.n_dim):
            all_axis.append([(d // self.size**k) % self.size for k in range(self.n_dim)[::-1]])
            
        for d in range(self.n_dim):
            d_axis = np.array(all_axis)
            d_axis[:, d] = -1
            d_axis = np.unique(d_axis, axis=0)
            for axis in d_axis:
                space_mask = slice_to_mask(list(axis), self.size, self.n_dim)
                in_game_axis = self.board[space_mask]
                axis_value = in_game_axis.sum().item()
                if axis_value == self.size:
                    score_p1 += 1
                elif axis_value == -self.size:
                    score_p2 += 1
        return score_p1, score_p2

In [3]:
game = Game(None, None, 3, 3)

game.board[:, :, 1] = 1
game.board[:, 2, :] = -1
game.board

tensor([[[ 0,  1,  0],
         [ 0,  1,  0],
         [-1, -1, -1]],

        [[ 0,  1,  0],
         [ 0,  1,  0],
         [-1, -1, -1]],

        [[ 0,  1,  0],
         [ 0,  1,  0],
         [-1, -1, -1]]], device='cuda:0')

In [4]:
game.score()

(2, 6)

In [5]:
for d in range(2**4):
    L = [(d // 2**k) % 2 for k in range(4)[::-1]]
    print(L)

[0, 0, 0, 0]
[0, 0, 0, 1]
[0, 0, 1, 0]
[0, 0, 1, 1]
[0, 1, 0, 0]
[0, 1, 0, 1]
[0, 1, 1, 0]
[0, 1, 1, 1]
[1, 0, 0, 0]
[1, 0, 0, 1]
[1, 0, 1, 0]
[1, 0, 1, 1]
[1, 1, 0, 0]
[1, 1, 0, 1]
[1, 1, 1, 0]
[1, 1, 1, 1]
