# `018` Reinforcement Learning

Requirements: 016 Transformers

⚡⚡⚡⚡WIP⚡⚡⚡⚡

In [1]:
import torch
from random import choice

device = torch.device('cuda' if torch.backends.cudnn.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [2]:
def new_board():
	return torch.tensor([[0] * 4 for _ in range(4)], device=device)

LEFT, UP, RIGHT, DOWN = 0, 1, 2, 3

def push(board, towards):
	res = board.rot90(towards)  # this way we always push towards the left
	zeroes = []
	for r, row in enumerate(res):
		res[r] = _push_row_left(row)
		zeroes.extend((r, c) for c, n in enumerate(res[r]) if n == 0)
	if zeroes:
		r, c = choice(zeroes)
		res[r, c] = 1 if torch.rand(1) < 0.9 else 2
	res = res.rot90(-towards)
	changed = res.ne(board).any().item()
	return res, changed

def _push_row_left(row):
	new_row = []
	last_compressed = False
	for tile in map(int, list(row)):
		if tile == 0:
			continue
		if len(new_row) and new_row[-1] == tile and not last_compressed:
			new_row[-1] += 1
			last_compressed = True
		else:
			new_row.append(tile)
			last_compressed = False
	while len(new_row) < 4: new_row.append(0)
	return torch.tensor(new_row)

def display_board(board):
	for row in board:
		print('  '.join('%5d' % 2**(n - 1) if n > 0 else '  ·  ' for n in row))

In [3]:
class BoardEmbeding(torch.nn.Module):
	def __init__(self, emb_dim=16):
		super().__init__()
		self.exp_embedding = torch.nn.Embedding(33, emb_dim)  # assuming max is 2**32
		self.pos_x_embedding = torch.nn.Embedding(4, emb_dim)
		self.register_buffer('pos_x', torch.arange(4, device=device).repeat(4))
		self.pos_y_embedding = torch.nn.Embedding(4, emb_dim)
		self.register_buffer('pos_y', torch.arange(4, device=device).repeat_interleave(4))
	
	def forward(self, x):
		exp_emb = self.exp_embedding(x.flatten(-2))
		x_emb = self.pos_x_embedding(self.pos_x)
		y_emb = self.pos_y_embedding(self.pos_y)
		return exp_emb + x_emb + y_emb

In [4]:
class TransformerBlock(torch.nn.Module):
	def __init__(self, channels, num_heads, dropout=.1):
		super().__init__()
		self.norm1 = torch.nn.LayerNorm(channels)
		self.attn = torch.nn.MultiheadAttention(channels, num_heads, dropout)
		self.ff = torch.nn.Sequential(
			torch.nn.Linear(channels, 4 * channels),
			torch.nn.GELU(),
			torch.nn.Linear(4 * channels, channels)
		)
		self.norm2 = torch.nn.LayerNorm(channels)
		self.dropout = torch.nn.Dropout(dropout)

	def forward(self, x):
		x = self.norm1(x)
		x = x + self.attn(x, x, x)[0]
		x = self.norm2(x)
		x = x + self.ff(x)
		x = self.dropout(x)
		return x

In [5]:
class BoardTransformer(torch.nn.Module):
	def __init__(self, outputs, emb_dim=128, hidden_dim=128, num_blocks=4, num_heads=4):
		super().__init__()
		self.embedding = BoardEmbeding(emb_dim)
		self.hidden = torch.nn.Sequential(*[TransformerBlock(hidden_dim, num_heads) for _ in range(num_blocks)])
		self.out = torch.nn.Linear(16 * hidden_dim, outputs)
	
	def forward(self, x):
		x = self.embedding(x)
		x = self.hidden(x)
		x = x.flatten(-2)
		x = self.out(x)
		return x

policy = BoardTransformer(4).to(device)
print(f'Number of parameters in the policy net: {sum(p.numel() for p in policy.parameters()):,}')
print(policy(new_board()))

value = BoardTransformer(1).to(device)
print(f'Number of parameters in the value net: {sum(p.numel() for p in value.parameters()):,}')
print(value(new_board()))

Number of parameters in the policy net: 806,532
tensor([-0.9889, -0.3664, -0.2587,  0.6007], device='cuda:0',
       grad_fn=<ViewBackward0>)
Number of parameters in the value net: 800,385
tensor([-0.9474], device='cuda:0', grad_fn=<ViewBackward0>)


In [6]:
def play_episode(policy, epsilon=1e-5, display=False):
	res = {'states': [], 'logits': [], 'actions': [], 'rewards': []}
	board = new_board()
	if display: display_board(board)
	turns = 0
	while True:
		logits = policy(board)
		probs = logits.softmax(-1)
		towards = torch.multinomial(probs + epsilon, 1)
		next_board, changed = push(board, towards.item())
		if changed:
			res['states'].append(board)
			res['logits'].append(logits)
			res['actions'].append(towards)
			res['rewards'].append(1)
			board = next_board
			if display:
				print('--- moved', ['left', 'up', 'right', 'down'][towards])
				display_board(board)
			turns += 1
			# if no moves are possible and no zero is there, the game is over
			if not (board == 0).any() and not (board[:, :-1] == board[:, 1:]).any() and not (board[:-1] == board[1:]).any(): break
	if display: print('Total turns:', turns)
	return res

history = play_episode(policy)
display_board(history['states'][-1])

    1      2      4      2
    8     32     64      4
    4     16     32      2
    1      2      2      2


In [7]:
class ReplayBuffer:
	def __init__(self, maxlen):
		self.states = torch.empty((maxlen, 16, 4, 4), device=device)
		self.actions = torch.empty((maxlen,), dtype=torch.long, device=device)
		self.rewards = torch.empty((maxlen,), device=device)
		self.logits = torch.empty((maxlen, 4), device=device)
		self.append_idx = 0
		self.cycled = False
	
	def append(self, state, action, reward, logits):
		self.states[self.append_idx] = state
		self.actions[self.append_idx] = action
		self.rewards[self.append_idx] = reward
		self.logits[self.append_idx] = logits
		self.append_idx += 1
		if self.append_idx == len(self.states):
			self.append_idx = 0
			self.cycled = True
	
	def sample(self, batch_size):
		ix = torch.randint(0, len(self.states) if self.cycled else self.append_idx, (batch_size,))
		return self.states[ix], self.actions[ix], self.rewards[ix], self.logits[ix]

buffer = ReplayBuffer(100_000)

In [8]:
def compute_advantages(states, rewards, value_net, gamma=0.99):
	values = value_net(states)
	return rewards + gamma * values[1:] - values[:-1]

def update_policy(policy_net, value_net, buffer, optimizer_policy, optimizer_value, batch_size=64, epochs=4, gamma=0.99, clip=0.2):
	for _ in range(epochs):
		states, actions, rewards, logits = buffer.sample(batch_size)
		advantages = compute_advantages(states, rewards, value_net, gamma)
		log_probs_old = logits.gather(1, actions.unsqueeze(-1)).squeeze(-1).log()
		for state, action, reward, log_prob_old in zip(states, actions, rewards, log_probs_old):
			optimizer_policy.zero_grad()
			optimizer_value.zero_grad()
			logits_new = policy_net(state.unsqueeze(0)).squeeze(0)
			log_probs_new = logits_new.gather(0, action).log()
			ratio = (log_probs_new - log_prob_old).exp()
			surr1 = ratio * reward
			surr2 = torch.clamp(ratio, 1.0 - clip, 1.0 + clip) * reward
			policy_loss = -torch.min(surr1, surr2).mean()
			policy_loss.backward()
			optimizer_policy.step()
			value_loss = ((value_net(state.unsqueeze(0)) - reward) ** 2).mean()
			value_loss.backward()
			optimizer_value.step()

def train_policy(policy_net, value_net, buffer, num_episodes, lr=1e-4):
	optimizer_policy = torch.optim.Adam(policy.parameters(), lr=lr)
	optimizer_value = torch.optim.Adam(value.parameters(), lr=lr)
	for _ in range(num_episodes):
		episode_data = play_episode(policy_net)
		for state, action, reward, logits in zip(episode_data['states'], episode_data['actions'], episode_data['rewards'], episode_data['logits']):
			buffer.append(state, action, reward, logits)
		update_policy(policy_net, value_net, buffer, optimizer_policy, optimizer_value)

train_policy(policy, value, buffer, 100)

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)