# `018` Reinforcement Learning

Requirements: 016 Transformers

⚡⚡⚡⚡WIP⚡⚡⚡⚡

In [2]:
import torch
from random import choice

device = torch.device('cuda' if torch.backends.cudnn.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: mps


Let's define an environment for the game. We can do this in many different ways, but we'll use this specific class-based format to make it more similar to the widespread gymnasium format. Although we'll implement our own RL setup, environments with this format can be used interchangeably with many RL libraries.

In [3]:
class TicTacToeEnv:
	WIN_MASKS = (0, 1, 2), (3, 4, 5), (6, 7, 8), (0, 3, 6), (1, 4, 7), (2, 5, 8), (0, 4, 8), (2, 4, 6)
	def reset(self):
		self.state = torch.zeros(9, device=device, dtype=int)
		self.mask = torch.ones(9, device=device, dtype=torch.uint8)
		self.player = 0

	def step(self, action):
		assert self.state[action] == 0, f'Invalid action: {action} in state: {self.state}'
		self.state[action] = self.player + 1
		for a, b, c in self.WIN_MASKS:
			if self.state[a] == self.state[b] == self.state[c] == self.player + 1:
				self.reward = 1
				self.done = 1
				break
		else:
			self.reward = 0
			self.mask = (self.state == 0).to(torch.uint8)
			self.done = not self.mask.any()
		self.player = 1 - self.player

	def render(self):
		for row in range(3):
			print(''.join('·XO'[self.state[row * 3 + col]] for col in range(3)))

Let's now test the environment with two random agents. We will define a random agent that will return a random probability distribution over the actions. Then, we will define a method to run a full episode while keeping track of the state (observations), actions and rewards. This method will have a parameter epsilon that will smooth the action probabilities, allowing the agent to explore the environment.

In [4]:
class RandomAgent:
	def __call__(self, _):
		return torch.randn(9, device=device)

def play_episode(env, players, epsilon=0, render=False):
	observations, rewards, actions = [], [], []
	env.reset()
	if render: env.render()
	turn = 0
	while True:
		player = turn % len(players)
		if render: print(f'\nTurn {turn}, player {player}:')
		observations.append(env.state.clone())
		logits = players[player](env.state)
		probs = (logits.softmax(-1) + epsilon) * env.mask
		action = torch.multinomial(probs, 1).item()
		actions.append(action)
		env.step(action)
		rewards.append(env.reward)
		if render:
			env.render()
			print(f'Reward: {env.reward}, Done: {env.done}')
		if env.done: break
		turn += 1
	return observations, rewards, actions

play_episode(TicTacToeEnv(), (RandomAgent(), RandomAgent()), render=True);

···
···
···

Turn 0, player 0:
···
···
··X
Reward: 0, Done: False

Turn 1, player 1:
···
O··
··X
Reward: 0, Done: False

Turn 2, player 0:
···
O·X
··X
Reward: 0, Done: False

Turn 3, player 1:
···
O·X
O·X
Reward: 0, Done: False

Turn 4, player 0:
·X·
O·X
O·X
Reward: 0, Done: False

Turn 5, player 1:
·X·
O·X
OOX
Reward: 0, Done: False

Turn 6, player 0:
·X·
OXX
OOX
Reward: 0, Done: False

Turn 7, player 1:
·XO
OXX
OOX
Reward: 0, Done: False

Turn 8, player 0:
XXO
OXX
OOX
Reward: 1, Done: 1


In [5]:
class Player(torch.nn.Module):
	def __init__(self, emb_dim=128, hidden_dim=128, num_blocks=4, num_heads=4):
		super().__init__()
		self.embedding = BoardEmbeding(emb_dim)
		self.hidden = torch.nn.Sequential(*[TransformerBlock(hidden_dim, num_heads) for _ in range(num_blocks)]).to(device)
		self.out = torch.nn.Linear(hidden_dim, 1, device=device)
	
	def forward(self, x):
		x = self.embedding(x)
		x = self.hidden(x)
		x = self.out(x)
		return x.squeeze(-1)

class BoardEmbeding(torch.nn.Module):
	def __init__(self, emb_dim):
		super().__init__()
		self.piece_embedding = torch.nn.Parameter(torch.randn(3, emb_dim, device=device))
		self.pos_embedding = torch.nn.Parameter(torch.randn(9, emb_dim, device=device))
		self.register_buffer('pos', torch.arange(9, device=device))
	
	def forward(self, x):
		x= self.piece_embedding[x] + self.pos_embedding[self.pos]
		return x

class TransformerBlock(torch.nn.Module):
	def __init__(self, channels, num_heads, dropout=.1):
		super().__init__()
		self.norm1 = torch.nn.LayerNorm(channels)
		self.attn = torch.nn.MultiheadAttention(channels, num_heads, dropout)
		self.ff = torch.nn.Sequential(
			torch.nn.Linear(channels, 4 * channels),
			torch.nn.GELU(),
			torch.nn.Linear(4 * channels, channels)
		)
		self.norm2 = torch.nn.LayerNorm(channels)
		self.dropout = torch.nn.Dropout(dropout)

	def forward(self, x):
		x = self.norm1(x)
		x = x + self.attn(x, x, x)[0]
		x = self.norm2(x)
		x = x + self.ff(x)
		x = self.dropout(x)
		return x

In [66]:
p1 = Player()
p2 = Player()
observations, rewards, actions = play_episode(TicTacToeEnv(), (p1, p2))
observations

[tensor([0, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0'),
 tensor([0, 0, 0, 0, 0, 1, 0, 0, 0], device='mps:0'),
 tensor([0, 0, 2, 0, 0, 1, 0, 0, 0], device='mps:0'),
 tensor([1, 0, 2, 0, 0, 1, 0, 0, 0], device='mps:0'),
 tensor([1, 0, 2, 2, 0, 1, 0, 0, 0], device='mps:0'),
 tensor([1, 1, 2, 2, 0, 1, 0, 0, 0], device='mps:0'),
 tensor([1, 1, 2, 2, 2, 1, 0, 0, 0], device='mps:0'),
 tensor([1, 1, 2, 2, 2, 1, 1, 0, 0], device='mps:0'),
 tensor([1, 1, 2, 2, 2, 1, 1, 2, 0], device='mps:0')]

In [69]:
actions

[5, 2, 0, 3, 1, 4, 6, 7, 8]

# CONTINUE FROM HERE

In [None]:
policy = BoardTransformer(4).to(device)
print(f'Number of parameters in the policy net: {sum(p.numel() for p in policy.parameters()):,}')
print(policy(new_board()))

value = BoardTransformer(1).to(device)
print(f'Number of parameters in the value net: {sum(p.numel() for p in value.parameters()):,}')
print(value(new_board()))

In [None]:
history = play_episode(policy)
display_board(history['states'][-1])

    1      2      4      2
    8     32     64      4
    4     16     32      2
    1      2      2      2


In [None]:
class ReplayBuffer:
	def __init__(self, maxlen):
		self.states = torch.empty((maxlen, 16, 4, 4), device=device)
		self.actions = torch.empty((maxlen,), dtype=torch.long, device=device)
		self.rewards = torch.empty((maxlen,), device=device)
		self.logits = torch.empty((maxlen, 4), device=device)
		self.append_idx = 0
		self.cycled = False
	
	def append(self, state, action, reward, logits):
		self.states[self.append_idx] = state
		self.actions[self.append_idx] = action
		self.rewards[self.append_idx] = reward
		self.logits[self.append_idx] = logits
		self.append_idx += 1
		if self.append_idx == len(self.states):
			self.append_idx = 0
			self.cycled = True
	
	def sample(self, batch_size):
		ix = torch.randint(0, len(self.states) if self.cycled else self.append_idx, (batch_size,))
		return self.states[ix], self.actions[ix], self.rewards[ix], self.logits[ix]

buffer = ReplayBuffer(100_000)

In [None]:
def compute_advantages(states, rewards, value_net, gamma=0.99):
	values = value_net(states)
	return rewards + gamma * values[1:] - values[:-1]

def update_policy(policy_net, value_net, buffer, optimizer_policy, optimizer_value, batch_size=64, epochs=4, gamma=0.99, clip=0.2):
	for _ in range(epochs):
		states, actions, rewards, logits = buffer.sample(batch_size)
		advantages = compute_advantages(states, rewards, value_net, gamma)
		log_probs_old = logits.gather(1, actions.unsqueeze(-1)).squeeze(-1).log()
		for state, action, reward, log_prob_old in zip(states, actions, rewards, log_probs_old):
			optimizer_policy.zero_grad()
			optimizer_value.zero_grad()
			logits_new = policy_net(state.unsqueeze(0)).squeeze(0)
			log_probs_new = logits_new.gather(0, action).log()
			ratio = (log_probs_new - log_prob_old).exp()
			surr1 = ratio * reward
			surr2 = torch.clamp(ratio, 1.0 - clip, 1.0 + clip) * reward
			policy_loss = -torch.min(surr1, surr2).mean()
			policy_loss.backward()
			optimizer_policy.step()
			value_loss = ((value_net(state.unsqueeze(0)) - reward) ** 2).mean()
			value_loss.backward()
			optimizer_value.step()

def train_policy(policy_net, value_net, buffer, num_episodes, lr=1e-4):
	optimizer_policy = torch.optim.Adam(policy.parameters(), lr=lr)
	optimizer_value = torch.optim.Adam(value.parameters(), lr=lr)
	for _ in range(num_episodes):
		episode_data = play_episode(policy_net)
		for state, action, reward, logits in zip(episode_data['states'], episode_data['actions'], episode_data['rewards'], episode_data['logits']):
			buffer.append(state, action, reward, logits)
		update_policy(policy_net, value_net, buffer, optimizer_policy, optimizer_value)

train_policy(policy, value, buffer, 100)

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)