In [2]:
import torch
from random import choice

In [152]:
def new_board():
	return torch.tensor([[0] * 4 for _ in range(4)])

LEFT, UP, RIGHT, DOWN = 0, 1, 2, 3

def push(board, towards):
	res = board.rot90(towards)  # this way we always push towards the left
	zeroes = []
	for r, row in enumerate(res):
		res[r] = push_row_left(row)
		if res[r, 3] == 0: zeroes.append(r)
	if len(zeroes):
		res[choice(zeroes), 3] = choice([1, 2, 3])
		return res.rot90(-towards)
	else:
		return None

def push_row_left(row):
	new_row = []
	last_compressed = False
	for tile in map(int, list(row)):
		if tile == 0:
			continue
		if len(new_row) and new_row[-1] == tile and not last_compressed:
			new_row[-1] += 1
			last_compressed = True
		else:
			new_row.append(tile)
			last_compressed = False
	while len(new_row) < 4: new_row.append(0)
	return torch.tensor(new_row)

def display_board(board):
	for row in board:
		print('  '.join('%5d' % 2**(n - 1) if n > 0 else '  ·  ' for n in row))

In [153]:
class Player(torch.nn.Module):
	def __init__(self, emb_dim=16, hidden_dim=128, depth=8):
		super().__init__()
		self.exp_embedding = torch.nn.Embedding(16, emb_dim)  # assuming max is 2**16
		self.pos_embedding = torch.nn.Embedding(16, emb_dim)
		self.input = torch.nn.Linear(16 * emb_dim, hidden_dim)
		self.hidden = torch.nn.ModuleList([
			torch.nn.Linear(hidden_dim, hidden_dim)
			for _ in range(depth)
		])
		self.out = torch.nn.Linear(hidden_dim, 4)
	
	def forward(self, x):
		exp_emb = self.exp_embedding(x.view(-1))
		pos_emb = self.pos_embedding(torch.arange(16))
		x = exp_emb + pos_emb
		x = self.input(x.view(-1))
		for layer in self.hidden:
			x = x + layer(x)
			x = x.relu()
			x = x / (x**2).sum()**.5
		x = self.out(x).softmax(0)
		return x

model = Player()

In [177]:
def episode(model, epsilon=1e-5, display=False):
	board = new_board()
	if display: display_board(board)
	turns = 0
	while True:
		logits = model(board)
		towards = torch.multinomial(logits + epsilon, 1)
		board = push(board, towards.item())
		if board is None: break
		if display:
			print('--- moved', ['left', 'up', 'right', 'down'][towards])
			display_board(board)
		turns += 1
	if display: print('Total turns:', turns)
	return turns

episode(model, display=True)

  ·      ·      ·      ·  
  ·      ·      ·      ·  
  ·      ·      ·      ·  
  ·      ·      ·      ·  
--- moved left
  ·      ·      ·        4
  ·      ·      ·      ·  
  ·      ·      ·      ·  
  ·      ·      ·      ·  
--- moved up
  ·      ·      ·        4
  ·      ·      ·      ·  
  ·      ·      ·      ·  
  ·      ·      ·        1
--- moved up
  ·      ·      ·        4
  ·      ·      ·        1
  ·      ·      ·      ·  
  ·        2    ·      ·  
--- moved left
    4    ·      ·      ·  
    1    ·      ·      ·  
  ·      ·      ·      ·  
    2    ·      ·        4
--- moved down
  ·        2    ·      ·  
    4    ·      ·      ·  
    1    ·      ·      ·  
    2    ·      ·        4
--- moved right
  ·      ·      ·        2
    2    ·      ·        4
  ·      ·      ·        1
  ·      ·        2      4
--- moved right
  ·      ·      ·        2
  ·      ·        2      4
    2    ·      ·        1
  ·      ·        2      4
--- moved up
    2    ·        4 

36