# Reinforcment Learning Using Policy Optimisation

Instead of estimating the Q-value of the state, we will directly optimize the policy, however without
MCTS search.

In other words, the agent recieves state of the form $(board, card)$ and returns a probability
distribution over **allowed actions** (achieved by masking out forbidden actions and rescaling
probabilities). Then, based on the probability distribution, the move is chosen.

Inspired by [this phind answer](https://www.phind.com/search?cache=4fb75c5b-c572-479b-9d6c-58df1ff67f52&fbclid=IwAR2lIRiaSEG0jXs3_CVScUe-Gl74HFeKY1I63eyImOhpr8tzuYSRto-lk2Q).

In [1]:
%%capture
%load_ext autoreload
%autoreload 2

%pip install -U pip
%pip install install 'git+https://github.com/balgot/mathematico.git#egg=mathematico&subdirectory=game'
%pip install torch torchview torch-summary graphviz numpy matplotlib tqdm wandb

In [47]:
import random
from tqdm.notebook import tqdm, trange

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchsummary import summary

from collections import deque, namedtuple
from dataclasses import dataclass

dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dev

device(type='cuda')

## Algorithm

The implemented algorithm is PPO (Proximal Policy Optimisation). After playing sufficient number of moves, the agent learns the updated policy.

In [317]:
STATE_SIZE = 25 + 1  # 25 positions on the board + 1 card
ACTION_SIZE = 25

config = {
    "seed": 0,
    "episodes": 100_000,
    "lr": 1e-3,
    "gamma": 0.99,
    "epsilon": 0.5,
    "epsilon_min": 0.01,
    "epsilon_decay": 0.995,
    "memory_size": 1_000,
    "batch_size": 1,
    "decay_epochs": 1_000,
}

random.seed(config["seed"])
torch.random.manual_seed(config["seed"])
np.random.seed(config["seed"])

## Environment

To encourage the agent to play "good" moves and converge quicker, we use reward shaping, i.e. the reward is not awarded only at the end of the episode, but rather after each move as the difference of scores of pre- and post-move boards. Note that the sum of the partial scores is equal to the final score.

In [318]:
from mathematico import Board


class MathematicoEnv:
    def __init__(self):
        self.board = None
        self.deck = None
        self.move_idx = None
        self.reset()
        
    def reset(self):
        self.board = Board()
        self.deck = [k for k in range(1, 13+1) for _ in range(4)]
        random.shuffle(self.deck)
        self.move_idx = 0
        return self.board.grid, [ self.deck[self.move_idx] ]
        
    def step(self, action):
        _score_before = self.board.score()
        self.board.make_move(action, self.deck[self.move_idx])
        self.move_idx += 1
        _score_after = self.board.score()
        return (self.board.grid, [ self.deck[self.move_idx] ]), _score_after - _score_before, self.move_idx >= 25, None
        
env = MathematicoEnv()

## Replay Memory

To keep track of previous moves, actions, next states and rewards, we will use the following class (effectively bounded queue):

In [319]:
Transition = namedtuple('Transition', ('state', 'card', 'action', 'next_state', 'next_card', 'reward'))

class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        transitions = random.sample(self.memory, batch_size)
        batch = Transition(*zip(*transitions))
        state_batch = torch.cat(batch.state)
        card_batch = torch.cat(batch.card)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        next_state_batch = torch.cat(batch.next_state)
        next_card_batch = torch.cat(batch.next_card)
        return state_batch, card_batch, action_batch, reward_batch, next_state_batch, next_card_batch

    def __len__(self):
        return len(self.memory)

## Neural Network

In [320]:
class DQN(nn.Module):
    def __init__(self, input_size=STATE_SIZE, output_size=ACTION_SIZE):
        super(DQN, self).__init__()
        self.flat = nn.Flatten()
        self.pipe = nn.Sequential(*[
            nn.Linear(26, 128), nn.Tanh(),
            # nn.Linear(128, 128), nn.ReLU(),
            # nn.Linear(128, 128), nn.ReLU(),
            # nn.Linear(128, 128), nn.ReLU(),
            # nn.Linear(128, 128), nn.ReLU(),
            # nn.Linear(128, 128), nn.ReLU(),
            # nn.Linear(128, 128), nn.ReLU(),
            nn.Linear(128, 25)
        ])
        

    def forward(self, board, card):
        x = self.flat(board)
        mask = x
       
        x = torch.cat((x, card), dim=1).float()        
        x = self.pipe(x)
        
        x = torch.where(mask == 0, x, -float('inf'))  # apply mask
        return F.softmax(x, dim=1)
    
_batch = 1
_board = Board().grid
_board[0][0] = 1
_card = 1
_board = torch.tensor([_board] * _batch)
_card = torch.tensor([[_card]] * _batch)
summary(DQN(), input_data=(_board, _card));

Layer (type:depth-idx)                   Output Shape              Param #
├─Flatten: 1-1                           [-1, 25]                  --
├─Sequential: 1-2                        [-1, 25]                  --
|    └─Linear: 2-1                       [-1, 128]                 3,456
|    └─Tanh: 2-2                         [-1, 128]                 --
|    └─Linear: 2-3                       [-1, 25]                  3,225
Total params: 6,681
Trainable params: 6,681
Non-trainable params: 0
Total mult-adds (M): 0.01
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.03
Estimated Total Size (MB): 0.03


## Agent

In [323]:
class Agent:
    def __init__(self):
        self.epsilon = config["epsilon"]
        self.memory = ReplayMemory(config["memory_size"])
        self.policy_net = DQN().to(dev)
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=config["lr"])
        
        self.target_net = DQN().to(dev)
        self.update_target_network()
        self.target_net.eval()

    def act(self, state, card):
        """Play a move given state."""
        if np.random.rand() <= self.epsilon:
            available = [5*row + col for row in range(5) for col in range(5) if state[0][row][col] == 0]
            return torch.tensor([random.choice(available)], device=dev)
        with torch.no_grad():
            distr = self.policy_net(state, card)
            distr = torch.distributions.Categorical(distr)
            sample = distr.sample()
            return sample

    def learn(self):
        if len(self.memory) < config["batch_size"]:
            return
        
        print("\n\nLearning\n=========\n")
        _sample = self.memory.sample(config["batch_size"])
        print("sample", _sample)
        print(f"{self.policy_net.pipe[0].weight=}")
        print(f"{self.target_net.pipe[0].weight=}")
        state_batch, card_batch, action_batch, reward_batch, next_state_batch, next_card_batch = _sample
        # q_values = torch.index_select(self.policy_net(state_batch, card_batch), 0, action_batch.to(dev).long())
        q_values = self.policy_net(state_batch, card_batch)
        q_values = torch.gather(q_values, 1, action_batch.unsqueeze(1)).squeeze(1)
        print(f"{q_values.shape=}\n\t{q_values=}")
        print(f"{action_batch.shape=}")
        next_q_values = self.target_net(next_state_batch, next_card_batch).max(1)[0].detach()
        print(f"{next_q_values.shape=}\n\t{next_q_values=}")
        print(f"\t\t{self.target_net(next_state_batch, next_card_batch)=}")
        expected_q_values = ((next_q_values * config["gamma"]) + reward_batch)
        print(f"{expected_q_values.shape=}\n\t{expected_q_values=}")
        
        loss = F.mse_loss(q_values, expected_q_values)
        print(f"{loss=}")
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_network(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

    def decay_epsilon(self):
        self.epsilon = max(config["epsilon_min"], self.epsilon * config["epsilon_decay"])

        
agent = Agent()
agent.act(torch.tensor([Board().grid], device=dev), torch.tensor([[1]], device=dev))

tensor([1], device='cuda:0')

## Training Loop

In [324]:
def _tensor(x):
    return torch.tensor([x], device=dev, dtype=torch.float32)


for i in trange(config["episodes"], desc="Episode"):
    # Reset the environment
    state, card = map(_tensor, env.reset())
    done = False
    score = 0
    
    # Play the moves until the end
    while not done:
        action = agent.act(state, card)
        assert len(action) == 1
        
        row, col = divmod(action.item(), 5)
        next_state, reward, done, _ = env.step((row, col))
        (next_state, next_card), reward = map(_tensor, next_state), _tensor(reward)
        
        agent.memory.push(state, card, action, next_state, next_card, reward)
        agent.learn()
        state, card = next_state, next_card
        score += reward.item()
        
    # Update
    if (i + 1) % config["decay_epochs"] == 0:
        agent.update_target_network()
        agent.decay_epsilon()

Episode:   0%|          | 0/100000 [00:00<?, ?it/s]



Learning

sample (tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]], device='cuda:0'), tensor([[1.]], device='cuda:0'), tensor([8], device='cuda:0'), tensor([0.], device='cuda:0'), tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]], device='cuda:0'), tensor([[3.]], device='cuda:0'))
self.policy_net.pipe[0].weight=Parameter containing:
tensor([[ 0.0199,  0.1933,  0.0048,  ...,  0.0792,  0.1091,  0.1933],
        [ 0.1463, -0.0724,  0.0615,  ...,  0.1861, -0.0947, -0.0113],
        [ 0.0791,  0.0174,  0.1609,  ..., -0.1782, -0.0468,  0.0235],
        ...,
        [-0.0463, -0.1452,  0.1129,  ...,  0.0521, -0.1706, -0.1463],
        [ 0.0761,  0.1072, -0.1423,  ...,  0.1711, -0.0177,  0.0158],
        [-0.1507,  0.1582, -0.0730,  ..., -0.1720, -0.0550, -0.1824]],
       device='c

ValueError: Expected parameter probs (Tensor of shape (1, 25)) of distribution Categorical(probs: torch.Size([1, 25])) to satisfy the constraint Simplex(), but found invalid values:
tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan]], device='cuda:0')

In [None]:
agent.policy_net.pipe[0].weight


In [258]:
agent.act(state, card)

ValueError: Expected parameter probs (Tensor of shape (1, 25)) of distribution Categorical(probs: torch.Size([1, 25])) to satisfy the constraint Simplex(), but found invalid values:
tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan]], device='cuda:0')

In [256]:
state

tensor([[[ 0.,  5.,  0.,  6.,  0.],
         [ 0.,  0.,  0.,  8.,  0.],
         [ 5.,  0.,  0.,  0.,  0.],
         [ 0.,  0., 10.,  0.,  0.],
         [ 0.,  0.,  3.,  0.,  7.]]], device='cuda:0')

In [257]:
card

tensor([[13.]], device='cuda:0')

In [220]:
env.reset()

([[0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0],
  [0, 0, 0, 0, 0]],
 [5])