In [23]:
from pathlib import Path
import sys

ROOT = Path.cwd().parent
sys.path.insert(0, str(ROOT))

In [None]:
from src.battleships import BattleshipsBoard, AttackResult, Direction
from src.policies import Policy_10x10, sample_action_get_log_prob
from src.losses import VanillaPolictGradientLoss
from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np
import torch

### Simple vanilla policy gradient RL

In [25]:
NUM_PARALLEL_GAMES = 100
NUM_EPOCHS = 100
LEARNING_RATE = 1e-3
REWARD_FINAL = 10
REWARD_HIT = 1

In [26]:
device = torch.device('mps')
policy = Policy_10x10().to(device)
optimizer = torch.optim.AdamW(policy.parameters(), lr=LEARNING_RATE)
loss = VanillaPolictGradientLoss()

In [29]:
epoch_progress = tqdm(range(NUM_EPOCHS))
for epoch in epoch_progress:

    # initialize multiple game boards
    boards = []
    for _ in range(NUM_PARALLEL_GAMES):
        board = BattleshipsBoard()
        board.randomly_place_all_ships()
        boards.append(board)

    # for each board initialize an array to track game progress
    rollouts = [[] for _ in range(NUM_PARALLEL_GAMES)]

    # loop until all games are finished
    not_finished = True
    while not_finished:
        not_finished = False
        active_idxs, active_states, active_masks = [], [], []
        for idx, board in enumerate(boards):
            if not board.is_finished():
                not_finished = True
                mask, state = board.get_mask_and_state()
                active_idxs.append(idx)
                active_states.append(state)
                active_masks.append(mask)

        # batch states and do a forward pass
        state_batch = torch.stack(active_states).to(device)
        logits = policy.forward(state=state_batch)

        # debugging prints
        print(f"state_batch dim: {state_batch.shape}")
        print(f"logits dim: {logits.shape}")

        # batch masks and get actions + log probs
        mask_batch = torch.stack(active_masks).to(device)
        actions, log_probs = sample_action_get_log_prob(logits=logits, mask=mask_batch)

        # debugging prints
        print(f"mask_batch dim: {mask_batch.shape}")
        print(f"actions shape: {actions.shape}")
        print(f"log-probs shape: {log_probs.shape}")

        # collect rewards
        active_rewards = []
        actions_copy = actions.clone().detach()
        for i, idx in enumerate(active_idxs):
            board = boards[idx]
            move = actions_copy[i] # idx of flattened vector
            move_ax_0 = move // 10
            move_ax_1 = move % 10
            try:
                result: AttackResult = board.receive_attack(ax_0=move_ax_0, ax_1=move_ax_1)
                if result.finished:
                    active_rewards.append(REWARD_FINAL)
                elif result.hit:
                    active_rewards.append(REWARD_HIT)
                else:
                    active_rewards.append(0)
            
                print(f"result: {result}")
            except ValueError:
                board.visualize()
                print(f"masked: {active_masks[i][move] == False}")
                print(f"move_ax_0: {move_ax_0}")
                print(f"move_ax_0: {move_ax_1}")

                # end
                not_finished = False
                break

        # break

    # loss = ....
    break # temporary


  0%|          | 0/100 [00:00<?, ?it/s]

state_batch dim: torch.Size([100, 2, 10, 10])
logits dim: torch.Size([100, 100])
mask_batch dim: torch.Size([100, 100])
actions shape: torch.Size([100])
log-probs shape: torch.Size([100])
result: AttackResult(hit=False, finished=False)
result: AttackResult(hit=False, finished=False)
result: AttackResult(hit=False, finished=False)
result: AttackResult(hit=False, finished=False)
result: AttackResult(hit=False, finished=False)
result: AttackResult(hit=False, finished=False)
result: AttackResult(hit=False, finished=False)
result: AttackResult(hit=False, finished=False)
result: AttackResult(hit=True, finished=False)
result: AttackResult(hit=False, finished=False)
result: AttackResult(hit=False, finished=False)
result: AttackResult(hit=False, finished=False)
result: AttackResult(hit=False, finished=False)
result: AttackResult(hit=True, finished=False)
result: AttackResult(hit=False, finished=False)
result: AttackResult(hit=False, finished=False)
result: AttackResult(hit=False, finished=False

  0%|          | 0/100 [00:18<?, ?it/s]

result: AttackResult(hit=True, finished=True)
result: AttackResult(hit=True, finished=True)





RuntimeError: stack expects a non-empty TensorList