In [49]:
import torch

In [50]:
num_states = 10000
states = torch.zeros((num_states, 2, 8, 8), dtype=torch.long)

In [51]:
states[:, 0, 3, 3] = 1
states[:, 1, 3, 4] = 1
states[:, 1, 4, 3] = 1
states[:, 0, 4, 4] = 1
states[:, 0, 5, 2] = 1

In [52]:
from envs.othello.torchscripts import get_legal_actions
ray_tensor = torch.zeros((num_states, 48, 8, 8), dtype=torch.long)
rays = get_legal_actions(states, ray_tensor)

In [53]:
num_rays = ray_tensor.shape[1]
states_size = states.shape[-1]

In [54]:
flips = torch.zeros((num_rays, states_size, states_size, states_size, states_size), device=states.device, requires_grad=False, dtype=torch.bool)
f_index = 0
for i in range(2, states_size):
    for x in range(states_size):
        for y in range(states_size):
            # right, down, left, up
            if x+1 < states_size:
                flips[f_index, y, x, y, x+1:min(x+i, states_size)] = True
            if y+1 < states_size:
                flips[f_index+1, y, x, y+1:min(y+i, states_size), x] = True
            flips[f_index+2, y, x, y, max(x-i, 0):x] = True
            flips[f_index+3, y, x, max(y-i, 0):y, x] = True

            # diag right down, diag left down, diag left up, diag right up
            for j in range(1, i):
                if y+j < states_size:
                    if x+j < states_size:
                        flips[f_index+4, y, x, y+j, x+j] = True
                    if x-j >= 0:
                        flips[f_index+5, y, x, y+j, x-j] = True
                if y-j >= 0:
                    if x-j >= 0:
                        flips[f_index+6, y, x, y-j, x-j] = True
                    if x+j < states_size:
                        flips[f_index+7, y, x, y-j, x+j] = True
    f_index += 8

In [55]:
action_ids = rays.view(-1, 64).long().argmax(dim=1)

In [56]:
actions = torch.zeros((num_states, 64), dtype=torch.bool)
actions[torch.arange(num_states), rays.view(-1, 64).long().argmax(dim=1)] = True
actions = actions.view(num_states, 8, 8)

In [57]:
actions.shape

torch.Size([10000, 8, 8])

In [58]:
action_xs = action_ids % 8
action_ys = action_ids // 8

In [59]:
action_xs

tensor([4, 4, 4,  ..., 4, 4, 4])

In [60]:
action_ys

tensor([2, 2, 2,  ..., 2, 2, 2])

In [61]:
activated_rays = ray_tensor[torch.arange(num_states), :, action_ys, action_xs]

In [73]:
action_ys.unsqueeze(1).repeat(1, 48).shape

torch.Size([10000, 48])

In [69]:
activated_rays.shape

torch.Size([10000, 48])

In [84]:
flips_to_apply = flips[activated_rays, action_ys.unsqueeze(1).repeat(1, 48), action_xs.unsqueeze(1).repeat(1, 48)].amax(dim=1)

In [88]:
(states[:, 0, :, :] | flips_to_apply)[0]

tensor([[0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 1, 1, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0]])

In [89]:
(states[:, 1, :, :] & torch.logical_not(flips_to_apply))[0]

tensor([[0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0]])