In [1]:
from pathlib import Path
import sys

ROOT = Path.cwd().parent
sys.path.insert(0, str(ROOT))

In [2]:
from src.battleships import BattleshipsBoard, AttackResult, Direction
from src.policies import Policy_10x10, sample_action_get_log_prob
from src.losses import VanillaPolictGradientLoss
from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np
import torch

### Simple vanilla policy gradient RL

In [11]:
NUM_PARALLEL_GAMES = 1000
NUM_EPOCHS = 100
LEARNING_RATE = 1e-2
REWARD_FINAL = 10
REWARD_HIT = 1
STEP_PENALTY = -0.05
REWARD_DISCOUNT = 0.99

In [12]:
device = torch.device('mps')
policy = Policy_10x10().to(device)
optimizer = torch.optim.AdamW(policy.parameters(), lr=LEARNING_RATE)
loss = VanillaPolictGradientLoss()

In [13]:
epoch_progress = tqdm(range(NUM_EPOCHS))
for epoch in epoch_progress:

    # initialize multiple game boards
    boards = []
    for _ in range(NUM_PARALLEL_GAMES):
        board = BattleshipsBoard()
        board.randomly_place_all_ships()
        boards.append(board)

    # arrays for tensors to track rewards and log_probs
    all_log_probs = []
    all_rewards = []
    mean_entropies = []

    # loop until all games are finished
    not_finished = True
    while not_finished:
        not_finished = False
        active_idxs, active_states, active_masks = [], [], []

        for idx, board in enumerate(boards):
            if not board.is_finished():
                not_finished = True
                mask, state = board.get_mask_and_state()
                active_idxs.append(idx)
                active_states.append(state)
                active_masks.append(mask)

        if not not_finished:
            continue

        # batch states and do a forward pass, get actions and log-probs
        state_batch = torch.stack(active_states).to(device)
        logits = policy(state_batch)
        mask_batch = torch.stack(active_masks).to(device).bool()

        # sample actions and get log-probs
        actions, log_probs, entropies = sample_action_get_log_prob(logits=logits, mask=mask_batch)
        mean_entropies.append(entropies.mean().item())

        # collect rewards
        rewards_active = []
        actions_cpu = actions.detach().cpu()
        for i, idx in enumerate(active_idxs):
            # conver action to cooardinates
            move = actions_cpu[i]
            move_ax_0 = move // 10
            move_ax_1 = move % 10

            # apply action
            board = boards[idx]
            result: AttackResult = board.receive_attack(ax_0=move_ax_0, ax_1=move_ax_1)

            # calculate reward
            step_penalty = float(STEP_PENALTY)
            reward = float(REWARD_FINAL) if result.finished else (float(REWARD_HIT) if result.hit else 0.0)
            reward_corrected = reward - step_penalty
            rewards_active.append(reward_corrected)

        # tensor with ids of active games
        idx_t = torch.tensor(active_idxs, device=device, dtype=torch.long)

        # differentiable placement using scatter
        lp_step = torch.zeros(NUM_PARALLEL_GAMES, device=device, dtype=log_probs.dtype)
        lp_step = lp_step.scatter(0, idx_t, log_probs)  # keeps grad path to log_prob

        # rewards: no grad needed here
        rw_step = torch.zeros(NUM_PARALLEL_GAMES, device=device, dtype=torch.float32)
        rw_step[idx_t] = torch.tensor(rewards_active, device=device, dtype=torch.float32)

        all_log_probs.append(lp_step)
        all_rewards.append(rw_step)

    # prepare NxT matrices
    LP = torch.stack(all_log_probs, dim=1)
    RW = torch.stack(all_rewards, dim=1)

    # calculate subsequent rewards
    T = RW.shape[1]
    d = (torch.arange(T, device=device)[:, None] - torch.arange(T, device=device)[None, :])
    L = torch.where(d >= 0, (REWARD_DISCOUNT ** d).to(RW.dtype), torch.zeros_like(d, dtype=RW.dtype))
    # L = torch.tril(torch.ones((T, T), device=device, dtype=RW.dtype))
    G = RW @ L # NxT @ TxT -> NxT

    # substract baseline
    G_mean = G.mean(dim=0)
    G -= G_mean

    # calculate loss
    loss = -(LP * G).sum() / NUM_PARALLEL_GAMES

    avg_moves = sum([b.get_total_moves() for b in boards]) / NUM_PARALLEL_GAMES
    avg_entropy = sum(mean_entropies) / len(mean_entropies)
    epoch_progress.set_description(f"[epoch: {epoch + 1}] [loss: {loss}] [avg moves: {avg_moves}] [avg entropy {avg_entropy}]")

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

[epoch: 100] [loss: 54.22935104370117] [avg moves: 93.74] [avg entropy 2.7211014279723167]: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 100/100 [14:37<00:00,  8.77s/it]


In [14]:
def play_one_game(policy, device, greedy: bool = False, step_by_step: bool = True):
    board = BattleshipsBoard()
    board.randomly_place_all_ships()

    policy.eval()
    hits = 0
    steps = 0

    while not board.is_finished():
        mask, state = board.get_mask_and_state()
        state_b = state.unsqueeze(0).to(device).float()
        mask_b  = mask.unsqueeze(0).to(device).bool()

        with torch.no_grad():
            logits = policy(state_b)
            masked_logits = logits.masked_fill(~mask_b, -1e9)

            if greedy:
                action = masked_logits.argmax(dim=-1)
            else:
                dist = torch.distributions.Categorical(logits=masked_logits)
                action = dist.sample()

        a = int(action.item())
        ax0, ax1 = a // 10, a % 10

        result: AttackResult = board.receive_attack(ax_0=ax0, ax_1=ax1)
        hits += int(result.hit)
        steps += 1

        print(f"\nStep {steps}: attack=({ax0},{ax1})  hit={result.hit}  remaining={board.remaining}")
        board.visualize(show_ships=True, show_attacks=True)

        if step_by_step and not board.is_finished():
            input("Enter to continue...")

    print(f"\nFinished in {steps} steps, hits={hits}. Final board (ships shown):")
    board.visualize(show_ships=True, show_attacks=True)
    return steps, hits

In [15]:
play_one_game(policy, device=device, greedy=False, step_by_step=False)


Step 1: attack=(1,7)  hit=False  remaining=17
    0  1  2  3  4  5  6  7  8  9
 0 [44m  [0m [44m  [0m [44m  [0m [44m  [0m [44m  [0m [100m  [0m [44m  [0m [44m  [0m [44m  [0m [44m  [0m
 1 [44m  [0m [44m  [0m [44m  [0m [44m  [0m [44m  [0m [100m  [0m [44m  [0m [44mðŸ’¨[0m [44m  [0m [44m  [0m
 2 [44m  [0m [44m  [0m [100m  [0m [44m  [0m [44m  [0m [100m  [0m [44m  [0m [100m  [0m [44m  [0m [44m  [0m
 3 [44m  [0m [44m  [0m [100m  [0m [44m  [0m [44m  [0m [100m  [0m [44m  [0m [100m  [0m [44m  [0m [44m  [0m
 4 [44m  [0m [44m  [0m [44m  [0m [44m  [0m [44m  [0m [44m  [0m [44m  [0m [100m  [0m [44m  [0m [44m  [0m
 5 [44m  [0m [44m  [0m [44m  [0m [44m  [0m [44m  [0m [44m  [0m [44m  [0m [100m  [0m [44m  [0m [44m  [0m
 6 [44m  [0m [44m  [0m [44m  [0m [44m  [0m [44m  [0m [44m  [0m [44m  [0m [100m  [0m [44m  [0m [44m  [0m
 7 [44m  [0m [44m  [0m [44m  [0m [44m  

(81, 17)