# `018` Reinforcement Learning

Requirements: 016 Transformers

Neural networks are, by definition, an universal function approximator. And due to the nature of functions, they are expected to solve classification/regression problems in which we have an input, and we expect an single output.

However, the real world doesn't always work like this. In many occasions we have time-continuous problems, in which we need some agent capable of interacting with the environment multiple times until reaching the desired goal. For instance, to train a neural network to play tic-tac-toe, we need to define a loss function that gets lower the better the network is.

Reinforcement learning (RL) was born as a solution to this problem. In RL, we define a neural network that we call an agent, which is trained to interact with an environment. The agent receives observations from the environment (in the tic-tac-toe context that'd be the pieces in the board), and outputs the logits of the actions it can take. The environment then executes one of the actions based on the logits, and returns the new observation (the new board state) and a reward (a number that tells the agent how well it's doing).

Note that this example defines how RL works on a discrete problem (we have a finite number of actions) with a stochastic environment transition (the same action might lead to different outcomes based on how the opponent reacts).

In [15]:
import torch

device = torch.device('cuda' if torch.backends.cudnn.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: mps


Let's define an environment for the game. Any structure that contains a board initialization, a method to execute an action and a method to check if the game is over will do. We will use a simple class-based format that is very similar to the most standard (and rather noisy) format used by most RL setups, defined by the OpenAI gym library.

In [16]:
class TicTacToeEnv:

	WIN_MASKS = (0, 1, 2), (3, 4, 5), (6, 7, 8), (0, 3, 6), (1, 4, 7), (2, 5, 8), (0, 4, 8), (2, 4, 6)

	def reset(self):
		self.state = torch.zeros(9, device=device, dtype=int)  # first row is (0, 1, 2), first column (0, 3, 6)
		self.forbidden = torch.full((9,), False, device=device)
		self.player = 1

	def step(self, action):
		assert not self.forbidden[action], f'Invalid action: {action} with state {self.state}'
		self.state[action] = self.player
		self.forbidden[action] = True
		for a, b, c in TicTacToeEnv.WIN_MASKS:
			if self.state[a] == self.state[b] == self.state[c] == self.player:
				self.reward = 1
				self.done = 1
				break
		else:
			self.reward = 0
			self.done = self.forbidden.all()
		self.player = 1 if self.player == 2 else 2

	def render(self):
		for row in range(3):
			print(''.join('·XO'[self.state[row * 3 + col]] for col in range(3)))

Let's now test the environment with two random agents. We will define a random agent that will return a random probability distribution over the actions. Then, we will define a method to run a full episode while keeping track of the state (observations), actions and rewards. This method will have a parameter epsilon that will smooth the action probabilities, allowing the agent to explore the environment.

In [17]:
class RandomAgent:
	def __call__(self, _):
		return torch.randn(9, device=device)

def play_episode(env, players, epsilon=0, render=False):
	observations, rewards, actions = [], [], []
	env.reset()
	if render: env.render()
	turn = 0
	while True:
		player = turn % len(players)
		if render: print(f'\nTurn {turn}, player {player}:')
		observations.append(env.state.clone())
		logits = players[player](env.state)
		probs = (logits.softmax(-1) * (1 - epsilon) + epsilon)  # if epsilon=1, probs are uniform
		probs = torch.masked_fill(probs, env.forbidden, 0)  # don't sample forbidden actions
		action = torch.multinomial(probs, 1).item()
		actions.append(action)
		env.step(action)
		rewards.append(env.reward)
		if render:
			env.render()
			print(f'Reward: {env.reward}, Done: {env.done}')
		if env.done: break
		turn += 1
	return observations, rewards, actions

play_episode(TicTacToeEnv(), (RandomAgent(), RandomAgent()), render=True);

···
···
···

Turn 0, player 0:
X··
···
···
Reward: 0, Done: False

Turn 1, player 1:
X··
···
··O
Reward: 0, Done: False

Turn 2, player 0:
X··
···
X·O
Reward: 0, Done: False

Turn 3, player 1:
X··
·O·
X·O
Reward: 0, Done: False

Turn 4, player 0:
X··
XO·
X·O
Reward: 1, Done: 1


Now we have a way to collect experiences to train on them. Now let's define our player network as a very simple transformer with 128 hidden channels, 4 blocks with 4 attention heads each.

In [27]:
class Player(torch.nn.Module):
	def __init__(self, hidden_size=128, num_blocks=4, num_heads=4):
		super().__init__()
		self.embedding = BoardEmbeding(hidden_size)
		self.hidden = torch.nn.Sequential(*[
			TransformerBlock(hidden_size, num_heads)
			for _ in range(num_blocks)
		]).to(device)
		self.out = torch.nn.Linear(hidden_size, 1)
	
	def forward(self, x):
		x = self.embedding(x)
		x = self.hidden(x)
		x = self.out(x)
		return x.squeeze(-1)

class BoardEmbeding(torch.nn.Module):
	def __init__(self, size):
		super().__init__()
		self.piece_embedding = torch.nn.Parameter(torch.randn(3, size, device=device))
		self.pos_embedding = torch.nn.Parameter(torch.randn(9, size, device=device))
		self.register_buffer('pos', torch.arange(9, device=device))
	
	def forward(self, x):
		x = self.piece_embedding[x] + self.pos_embedding[self.pos]
		return x

class TransformerBlock(torch.nn.Module):
	def __init__(self, hidden_size, num_heads, dropout=.1):
		super().__init__()
		self.norm1 = torch.nn.LayerNorm(hidden_size)
		self.attn = torch.nn.MultiheadAttention(hidden_size, num_heads, dropout)
		self.ff = torch.nn.Sequential(
			torch.nn.Linear(hidden_size, 4 * hidden_size),
			torch.nn.GELU(),
			torch.nn.Linear(4 * hidden_size, hidden_size)
		)
		self.norm2 = torch.nn.LayerNorm(hidden_size)
		self.dropout = torch.nn.Dropout(dropout)

	def forward(self, x):
		x = self.norm1(x)
		x = x + self.attn(x, x, x)[0]
		x = self.norm2(x)
		x = x + self.ff(x)
		x = self.dropout(x)
		return x

player = Player().to(device)
print(f'Actor network has {sum(p.numel() for p in player.parameters()):,} parameters.')

Actor network has 794,753 parameters.


Let's now test it with our play episode method. Since agents are untrained, they won't be any better than the random agent.

In [28]:
observations, rewards, actions = play_episode(TicTacToeEnv(), (player, player))
print('actions', actions)
print('rewards', rewards)

actions [2, 4, 7, 0, 3, 6, 1, 8]
rewards [0, 0, 0, 0, 0, 0, 0, 1]


Now, let's define the setup we will use to train the network. The most successful modern reinforcement learning algorithms are those defined as actor-critic. In an actor-critic setup, we are training two networks simultaneously: our agent, which tries to play the environment as good as possible, and our critic, which tries to estimate the value of the current policy. In this sense:

* The actor receives the board state and returns the logits of all possible actions.
* The critic is a function that receives the board state and returns a single number representing what is the expected return if you start in that state and then act according to the given policy.

We already defined the actor, let's now define the critic.

In [25]:
class Critic(torch.nn.Module):
	def __init__(self, hidden_size=128, num_blocks=4, num_heads=4):
		super().__init__()
		self.embedding = BoardEmbeding(hidden_size)
		self.hidden = torch.nn.Sequential(*[
			TransformerBlock(hidden_size, num_heads)
			for _ in range(num_blocks)
		]).to(device)
		self.out = torch.nn.Linear(9 * hidden_size, 1)
	
	def forward(self, x):
		x = self.embedding(x)
		x = self.hidden(x)
		x = self.out(x.flatten(-2))
		return x

critic = Critic()
print(f'Critic network has {sum(p.numel() for p in critic.parameters()):,} parameters.')

Critic network has 795,777 parameters.


Now we are in a good position to implement some reinforcement learning algorithm. We will be implementing [Proximal Policy Optimization](https://arxiv.org/pdf/1707.06347) on the clipped variant (PPO-Clip), since it's one of the empirically best suited algorithms for reinforcement learning among many different contexts.

In [None]:
def g(adv, epsilon=.2):
	if adv >= 0:
		return (1 + epsilon) * adv
	else:
		return (1 - epsilon) * adv

def L(s, a, policy_k, policy):
	adv = advantage(s, a)
	return min(adv * policy(s)[a] / policy_k(s)[a], g(adv))

In [None]:
policy = BoardTransformer(4).to(device)
print(f'Number of parameters in the policy net: {sum(p.numel() for p in policy.parameters()):,}')
print(policy(new_board()))

value = BoardTransformer(1).to(device)
print(f'Number of parameters in the value net: {sum(p.numel() for p in value.parameters()):,}')
print(value(new_board()))

In [None]:
history = play_episode(policy)
display_board(history['states'][-1])

    1      2      4      2
    8     32     64      4
    4     16     32      2
    1      2      2      2


In [None]:
class ReplayBuffer:
	def __init__(self, maxlen):
		self.states = torch.empty((maxlen, 16, 4, 4), device=device)
		self.actions = torch.empty((maxlen,), dtype=torch.long, device=device)
		self.rewards = torch.empty((maxlen,), device=device)
		self.logits = torch.empty((maxlen, 4), device=device)
		self.append_idx = 0
		self.cycled = False
	
	def append(self, state, action, reward, logits):
		self.states[self.append_idx] = state
		self.actions[self.append_idx] = action
		self.rewards[self.append_idx] = reward
		self.logits[self.append_idx] = logits
		self.append_idx += 1
		if self.append_idx == len(self.states):
			self.append_idx = 0
			self.cycled = True
	
	def sample(self, batch_size):
		ix = torch.randint(0, len(self.states) if self.cycled else self.append_idx, (batch_size,))
		return self.states[ix], self.actions[ix], self.rewards[ix], self.logits[ix]

buffer = ReplayBuffer(100_000)

In [None]:
def compute_advantages(states, rewards, value_net, gamma=0.99):
	values = value_net(states)
	return rewards + gamma * values[1:] - values[:-1]

def update_policy(policy_net, value_net, buffer, optimizer_policy, optimizer_value, batch_size=64, epochs=4, gamma=0.99, clip=0.2):
	for _ in range(epochs):
		states, actions, rewards, logits = buffer.sample(batch_size)
		advantages = compute_advantages(states, rewards, value_net, gamma)
		log_probs_old = logits.gather(1, actions.unsqueeze(-1)).squeeze(-1).log()
		for state, action, reward, log_prob_old in zip(states, actions, rewards, log_probs_old):
			optimizer_policy.zero_grad()
			optimizer_value.zero_grad()
			logits_new = policy_net(state.unsqueeze(0)).squeeze(0)
			log_probs_new = logits_new.gather(0, action).log()
			ratio = (log_probs_new - log_prob_old).exp()
			surr1 = ratio * reward
			surr2 = torch.clamp(ratio, 1.0 - clip, 1.0 + clip) * reward
			policy_loss = -torch.min(surr1, surr2).mean()
			policy_loss.backward()
			optimizer_policy.step()
			value_loss = ((value_net(state.unsqueeze(0)) - reward) ** 2).mean()
			value_loss.backward()
			optimizer_value.step()

def train_policy(policy_net, value_net, buffer, num_episodes, lr=1e-4):
	optimizer_policy = torch.optim.Adam(policy.parameters(), lr=lr)
	optimizer_value = torch.optim.Adam(value.parameters(), lr=lr)
	for _ in range(num_episodes):
		episode_data = play_episode(policy_net)
		for state, action, reward, logits in zip(episode_data['states'], episode_data['actions'], episode_data['rewards'], episode_data['logits']):
			buffer.append(state, action, reward, logits)
		update_policy(policy_net, value_net, buffer, optimizer_policy, optimizer_value)

train_policy(policy, value, buffer, 100)

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)