In [1]:
from pathlib import Path
import sys

ROOT = Path.cwd().parent
sys.path.insert(0, str(ROOT))

In [2]:
from src.battleships_torch import BattleshipsBoardCollection, StepOut
from src.policies import Policy_10x10, sample_action_get_log_prob
from src.losses import VanillaPolictGradientLoss
from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np
import torch

### Simple vanilla policy gradient RL

In [3]:
DEVICE = 'mps'
NUM_PARALLEL_GAMES = 1000
NUM_EPOCHS = 100
LEARNING_RATE = 1e-2
REWARD_FINAL = 10
REWARD_HIT = 1
STEP_PENALTY = -0.05
REWARD_DISCOUNT = 0.99

In [4]:
device = torch.device(DEVICE)
policy = Policy_10x10().to(device)
boards = BattleshipsBoardCollection(batch_size=NUM_PARALLEL_GAMES, device=DEVICE)
optimizer = torch.optim.AdamW(policy.parameters(), lr=LEARNING_RATE)
loss = VanillaPolictGradientLoss()

In [5]:
epoch_progress = tqdm(range(NUM_EPOCHS))
for epoch in epoch_progress:

    boards.reset()

    # arrays for tensors to track rewards and log_probs
    mean_entropies = []
    all_log_probs = []
    all_rewards = []
    all_active = []

    # loop until all games are finished
    not_finished = True
    while not_finished:
        not_finished = False

        # get action logits according to policy
        states = boards.state()
        logits = policy(states)

        # sample actions, get their log_probs, and entropies of all actions to track early collapse
        masks = boards.mask()
        actions, log_probs, entropies = sample_action_get_log_prob(logits=logits, mask=masks)
        
        # carry out a step
        step_out: StepOut = boards.step(actions)
        not_finished = not step_out.all_finished

        # calculate rewards
        active_mask = step_out.active
        hit = step_out.hit
        won = step_out.won

        rewards = torch.full((NUM_PARALLEL_GAMES,), STEP_PENALTY, device=device, dtype=torch.float32)
        rewards += hit.to(rewards.dtype) * REWARD_HIT
        rewards += won.to(rewards.dtype) * REWARD_FINAL
        rewards *= active_mask.float()

        # track rewards and log probs for loss calculation at the end
        all_rewards.append(rewards)
        all_log_probs.append(log_probs * active_mask.float())
        all_active.append(active_mask.float())
        mean_entropies.append((entropies * active_mask.float()).sum() / active_mask.sum().clamp_min(1))

    # prepare NxT matrices
    LP = torch.stack(all_log_probs, dim=1)
    RW = torch.stack(all_rewards, dim=1)

    # calculate subsequent rewards
    T = RW.shape[1]
    d = (torch.arange(T, device=device)[:, None] - torch.arange(T, device=device)[None, :])
    L = torch.where(d >= 0, (REWARD_DISCOUNT ** d).to(RW.dtype), torch.zeros_like(d, dtype=RW.dtype))
    # L = torch.tril(torch.ones((T, T), device=device, dtype=RW.dtype))
    G = RW @ L # NxT @ TxT -> NxT

    # substract baseline
    G_mean = G.mean(dim=0)
    G -= G_mean

    # calculate loss
    M = torch.stack(all_active, dim=1)
    denom = M.sum().clamp_min(1.0)
    loss = -(LP * G).sum() / denom

    avg_moves = boards.move_count.float().mean().item()
    avg_entropy = sum(mean_entropies) / len(mean_entropies)
    epoch_progress.set_description(f"[epoch: {epoch + 1}] [loss: {loss.item()}] [avg moves: {avg_moves}] [avg entropy {avg_entropy}]")

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

[epoch: 100] [loss: 0.06555197387933731] [avg moves: 93.52300262451172] [avg entropy 0.5614566802978516]: 100%|██████████| 100/100 [05:47<00:00,  3.48s/it]
