In [19]:
import torch
num_boards = 1000000

In [20]:
import torch.nn.functional as F

In [21]:
# generate starting boards


random_floats = torch.randn((num_boards, 16), dtype=torch.float32)
boards = torch.zeros((num_boards, 16), dtype=torch.float32)
starting_boards = random_floats.topk(2, dim=1).indices
boards[(range(num_boards), starting_boards[:,0])] = 1
boards[(range(num_boards), starting_boards[:,1])] = 1

In [22]:
boards = boards.reshape((num_boards, 4, 4))

In [23]:

boards[0][0] = torch.tensor([2,2,1,1])
boards[0]

tensor([[2., 2., 1., 1.],
        [0., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.]])

In [24]:
moves = torch.randint(0, 4, (num_boards,))


rotations_0 = boards
rotations_1 = torch.rot90(boards, 1, (1,2))
rotations_2 = torch.rot90(boards, 2, (1,2))
rotations_3 = torch.rot90(boards, 3, (1,2))


In [25]:
moves[0]

tensor(0)

In [26]:
mask_0 = (moves == 0)
mask_90 = (moves == 1)
mask_180 = (moves == 2)
mask_270 = (moves == 3)


In [27]:
result = torch.zeros_like(boards)
result[mask_0] = rotations_0[mask_0]
result[mask_90] = rotations_1[mask_90]
result[mask_180] = rotations_2[mask_180]
result[mask_270] = rotations_3[mask_270]


boards = result

In [28]:
boards[0]

tensor([[2., 2., 1., 1.],
        [0., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.]])

In [29]:

shape = boards.shape
boards_flat = boards.view(-1, shape[-1])
mask = (boards_flat != 0).float()
_, sorted_indices = torch.sort(mask, dim=1, descending=True)
boards_flat = torch.gather(boards_flat, 1, sorted_indices)


In [30]:
boards_flat[:4]

tensor([[2., 2., 1., 1.],
        [0., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.]])

In [31]:
# merge similar tiles from left to right
for i in range(3):
    is_same = torch.logical_and(boards_flat[:,i] == boards_flat[:,i+1], boards_flat[:,i] != 0).float()
    boards_flat[:,i] += is_same
    boards_flat[:,i+1] *= (1 - is_same)

In [32]:

mask = (boards_flat != 0).float()
_, sorted_indices = torch.sort(mask, dim=1, descending=True)
boards_flat = torch.gather(boards_flat, 1, sorted_indices)

In [33]:
boards = boards_flat.view(shape)

In [34]:
moves = (4 - moves) % 4
rotations_0 = boards
rotations_1 = torch.rot90(boards, 1, (1,2))
rotations_2 = torch.rot90(boards, 2, (1,2))
rotations_3 = torch.rot90(boards, 3, (1,2))
mask_0 = (moves == 0)
mask_90 = (moves == 1)
mask_180 = (moves == 2)
mask_270 = (moves == 3)
result = torch.zeros_like(boards)
result[mask_0] = rotations_0[mask_0]
result[mask_90] = rotations_1[mask_90]
result[mask_180] = rotations_2[mask_180]
result[mask_270] = rotations_3[mask_270]

boards = result


In [35]:
boards, moves

(tensor([[[3., 2., 0., 0.],
          [0., 0., 0., 0.],
          [1., 0., 0., 0.],
          [1., 0., 0., 0.]],
 
         [[0., 0., 1., 1.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 2.],
          [0., 0., 0., 0.]],
 
         ...,
 
         [[1., 0., 1., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [2., 0., 0., 0.]],
 
         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [2., 0., 0., 0.]]]),
 tensor([0, 3, 2,  ..., 3, 0, 0]))