In [1]:
import torch
from random import choice

In [2]:
def new_board():
	return torch.tensor([[0] * 4 for _ in range(4)])

LEFT, UP, RIGHT, DOWN = 0, 1, 2, 3

def push(board, towards):
	res = board.rot90(towards)  # this way we always push towards the left
	zeroes = []
	for r, row in enumerate(res):
		res[r] = _push_row_left(row)
		zeroes.extend((r, c) for c, n in enumerate(res[r]) if n == 0)
	if zeroes:
		r, c = choice(zeroes)
		res[r, c] = 1 if torch.rand(1) < 0.9 else 2
	res = res.rot90(-towards)
	changed = res.ne(board).any().item()
	return res, changed

def _push_row_left(row):
	new_row = []
	last_compressed = False
	for tile in map(int, list(row)):
		if tile == 0:
			continue
		if len(new_row) and new_row[-1] == tile and not last_compressed:
			new_row[-1] += 1
			last_compressed = True
		else:
			new_row.append(tile)
			last_compressed = False
	while len(new_row) < 4: new_row.append(0)
	return torch.tensor(new_row)

def display_board(board):
	for row in board:
		print('  '.join('%5d' % 2**(n - 1) if n > 0 else '  ·  ' for n in row))

In [37]:
class BoardTransformer(torch.nn.Module):
	def __init__(self, emb_dim=16, hidden_dim=128, num_hidden=4, num_heads=4):
		super().__init__()
		self.embedding = BoardEmbeding(emb_dim)
		self.input = torch.nn.Linear(emb_dim, hidden_dim)
		self.hidden = torch.nn.Sequential(*[TransformerBlock(hidden_dim, num_heads) for _ in range(num_hidden)])
		self.out = torch.nn.Linear(hidden_dim, 4)
	
	def forward(self, x):
		x = self.embedding(x)
		x = self.input(x)
		x = self.hidden(x)
		x = self.out(x)
		return x.mean(dim=-2)

class BoardEmbeding(torch.nn.Module):
	def __init__(self, emb_dim=16):
		super().__init__()
		self.exp_embedding = torch.nn.Embedding(16, emb_dim)  # assuming max is 2**16
		self.pos_x_embedding = torch.nn.Embedding(4, emb_dim)
		self.pos_y_embedding = torch.nn.Embedding(4, emb_dim)
	
	def forward(self, x):
		exp_emb = self.exp_embedding(x.view(-1))
		x_emb = self.pos_x_embedding(torch.arange(4, device=x.device).repeat(4))
		y_emb = self.pos_y_embedding(torch.arange(4, device=x.device).repeat_interleave(4))
		return exp_emb + x_emb + y_emb

class TransformerBlock(torch.nn.Module):
	def __init__(self, channels, num_heads, dropout=.1):
		super().__init__()
		self.norm1 = torch.nn.LayerNorm(channels)
		self.attn = torch.nn.MultiheadAttention(channels, num_heads, dropout)
		self.ff = torch.nn.Sequential(
			torch.nn.Linear(channels, 4 * channels),
			torch.nn.GELU(),
			torch.nn.Linear(4 * channels, channels)
		)
		self.norm2 = torch.nn.LayerNorm(channels)
		self.dropout = torch.nn.Dropout(dropout)

	def forward(self, x):
		x = self.norm1(x)
		x = x + self.attn(x, x, x)[0]
		x = self.norm2(x)
		x = x + self.ff(x)
		x = self.dropout(x)
		return x

policy = BoardTransformer()
print(f'Number of parameters: {sum(p.numel() for p in model.parameters()):,}')

Number of parameters: 796,164


In [34]:
def episode(policy_net, epsilon=1e-5, display=False):
	res = {'states': [], 'logits': [], 'actions': [], 'rewards': []}
	board = new_board()
	if display: display_board(board)
	turns = 0
	while True:
		logits = policy_net(board)
		probs = logits.softmax(-1)
		towards = torch.multinomial(probs + epsilon, 1)
		next_board, changed = push(board, towards.item())
		if changed:
			res['states'].append(board)
			res['logits'].append(logits)
			res['actions'].append(towards)
			res['rewards'].append(1)
			board = next_board
			if display:
				print('--- moved', ['left', 'up', 'right', 'down'][towards])
				display_board(board)
			turns += 1
			if not board.eq(0).any().item(): break
	if display: print('Total turns:', turns)
	return res

history = episode(policy, display=True)

  ·      ·      ·      ·  
  ·      ·      ·      ·  
  ·      ·      ·      ·  
  ·      ·      ·      ·  
--- moved up
  ·      ·      ·      ·  
  ·      ·      ·      ·  
  ·      ·      ·      ·  
  ·        1    ·      ·  
--- moved left
  ·      ·      ·      ·  
  ·      ·      ·      ·  
    1    ·      ·      ·  
    1    ·      ·      ·  
--- moved down
  ·      ·      ·      ·  
  ·        1    ·      ·  
  ·      ·      ·      ·  
    2    ·      ·      ·  
--- moved right
  ·      ·      ·      ·  
  ·        1    ·        1
  ·      ·      ·      ·  
  ·      ·      ·        2
--- moved left
  ·      ·      ·      ·  
    2    ·      ·      ·  
  ·      ·      ·        1
    2    ·      ·      ·  
--- moved down
  ·      ·      ·      ·  
  ·      ·        1    ·  
  ·      ·      ·      ·  
    4    ·      ·        1
--- moved up
    4    ·        1      1
  ·      ·      ·      ·  
  ·        1    ·      ·  
  ·      ·      ·      ·  
--- moved up
    4      1      1  

In [None]:
policy = BoardTransformer()
value = BoardTransformer()