In [66]:
import torch
num_boards = 100000

In [67]:
import torch.nn.functional as F

In [68]:
# generate starting boards
random_floats = torch.randn((num_boards, 16), dtype=torch.float32)
boards = torch.zeros((num_boards, 16), dtype=torch.int16)
starting_boards = random_floats.topk(2, dim=1).indices
boards[(range(num_boards), starting_boards[:,0])] = 1
boards[(range(num_boards), starting_boards[:,1])] = 1
boards = boards.reshape((num_boards, 1, 4, 4))

In [69]:
def get_legal_moves(bs, very_negative_value=-100):
    num_boards = bs.shape[0]
    # check for empty spaces
    mask0 = torch.tensor([[[[very_negative_value, 1]]]], dtype=torch.int16)
    mask1 = torch.tensor([[[[1, very_negative_value]]]], dtype=torch.int16)
    mask2 = torch.tensor([[[[1], [very_negative_value]]]], dtype=torch.int16)
    mask3 = torch.tensor([[[[very_negative_value], [1]]]], dtype=torch.int16)
    
    m0 = torch.nn.functional.conv2d(bs, mask0, padding=0, bias=None).view(num_boards, 12)
    m1 = torch.nn.functional.conv2d(bs, mask1, padding=0, bias=None).view(num_boards, 12)
    m2 = torch.nn.functional.conv2d(bs, mask2, padding=0, bias=None).view(num_boards, 12)
    m3 = torch.nn.functional.conv2d(bs, mask3, padding=0, bias=None).view(num_boards, 12)

    m0_valid = torch.any(m0 > 0, dim=1, keepdim=True)
    m1_valid = torch.any(m1 > 0, dim=1, keepdim=True)
    m2_valid = torch.any(m2 > 0, dim=1, keepdim=True)
    m3_valid = torch.any(m3 > 0, dim=1, keepdim=True)

    # check for matching tiles
    vertical_comparison = torch.any((torch.logical_and(bs[:,:,:-1,:] == bs[:,:,1:,:], bs[:,:,1:,:] != 0)).view(num_boards, 12), dim=1, keepdim=True)
    horizontal_comparison = torch.any((torch.logical_and(bs[:,:,:,:-1] == bs[:,:,:,1:], bs[:,:,:,1:] != 0)).view(num_boards, 12), dim=1, keepdim=True)

    m0_valid = torch.logical_or(m0_valid, horizontal_comparison)
    m1_valid = torch.logical_or(m1_valid, horizontal_comparison)

    m2_valid = torch.logical_or(m2_valid, vertical_comparison)
    m3_valid = torch.logical_or(m3_valid, vertical_comparison)

    return torch.concat([m0_valid, m1_valid, m2_valid, m3_valid], dim=1)

In [91]:
ones = torch.eye(16, dtype=torch.int16).view(16, 4, 4)
twos = torch.eye(16, dtype=torch.int16).view(16, 4, 4) * 2
base_progressions = torch.concat([ones, twos], dim=0)
probabilities = torch.concat([torch.full((16,), 0.9), torch.full((16,), 0.1)], dim=0)

def get_progressions(bs):
    num_boards = bs.shape[0]
    # check and see if each of the progressions are valid (no tile already in that spot)
    # base_progressions is a 32x4x4 tensor with all the possible progressions
    # bs is an Nx4x4 tensor with N board states
    # returns an 32xNx4x4 tensor with 32 possible progressions for each board state
    valid_progressions = torch.logical_not(torch.any((bs * base_progressions).view(num_boards, 32, 16), dim=2)).view(num_boards, 32, 1, 1)
    progressions = (bs + base_progressions) * valid_progressions
    probs = probabilities * valid_progressions.view(num_boards, 32)
    return progressions, probs

In [101]:
def spawn_tile(bs):
    progs, probs = get_progressions(bs)
    probs += torch.where(probs.amax(dim=1, keepdim=True) == 0, torch.full_like(probs, 1), torch.full_like(probs, 0))
    indices = torch.multinomial(probs, 1)
    bs = progs[(range(bs.shape[0]), indices[:,0])]
    return bs
    

In [77]:
def rotate_in_place(bs, rot_amnts):
    rotations_0 = bs
    rotations_1 = torch.rot90(bs, 1, (2,3))
    rotations_2 = torch.rot90(bs, 2, (2,3))
    rotations_3 = torch.rot90(bs, 3, (2,3))
    mask_0 = (rot_amnts == 0)
    mask_90 = (rot_amnts == 1)
    mask_180 = (rot_amnts == 2)
    mask_270 = (rot_amnts == 3)    
    bs[mask_0] = rotations_0[mask_0]
    bs[mask_90] = rotations_1[mask_90]
    bs[mask_180] = rotations_2[mask_180]
    bs[mask_270] = rotations_3[mask_270]
    return bs

def merge(bs):
    shape = boards.shape
    bs_flat = bs.view(-1, shape[-1])
    mask = (bs_flat != 0)
    _, sorted_indices = torch.sort(mask, dim=1, descending=True)
    bs_flat = torch.gather(bs_flat, 1, sorted_indices)
    for i in range(3):
        is_same = torch.logical_and(bs_flat[:,i] == bs_flat[:,i+1], bs_flat[:,i] != 0).int()
        bs_flat[:,i] += is_same
        bs_flat[:,i+1] *= (1 - is_same)
    mask = (bs_flat != 0)
    _, sorted_indices = torch.sort(mask, dim=1, descending=True)
    bs_flat = torch.gather(bs_flat, 1, sorted_indices)
    bs = bs_flat.view(shape)
    return bs

def push_moves(bs, moves):
    bs = rotate_in_place(bs, moves) # TODO: inplace
    bs = merge(bs)
    bs = rotate_in_place(bs, (4-moves) % 4)
    return bs

In [73]:
moves = torch.randint(0, 4, (num_boards,))

In [80]:
moves[0] = 0
boards[0][0] = torch.tensor([[1,2,3,4], [2, 3, 4, 1], [3, 4, 1, 2], [4, 1, 2, 3]])
boards[0]

tensor([[[1, 2, 3, 4],
         [2, 3, 4, 1],
         [3, 4, 1, 2],
         [4, 1, 2, 3]]], dtype=torch.int16)

In [105]:
spawn_tile(boards).shape

torch.Size([100000, 4, 4])