# Single Ship
Trains an agent to move a single ship on a 21x21 board with 25% halite for 100 episode steps. The agent collects on average around 35 pieces of halite. The ship does not deposit in a shipyard, it just collects the Halite.

- Single Ship
- No Shipyards
- No Convert
- No Halite Regeneration
- Single Halite Value (either 0 or 1 per cell)

In [None]:
%pip install \
    pip install -e ../../ \
#    git+https://github.com/mauricef/halite-iv-jax.git@709ff0fbfbf09c366a9c23a8565e83bc8ec2f34d \
    git+https://github.com/deepmind/dm-haiku@v0.0.4 \
    git+https://github.com/deepmind/optax@v0.0.9 \

In [17]:
from jax.tree_util import tree_map

from collections import defaultdict
from enum import Enum
from typing import NamedTuple

import haiku as hk
from jax import jit, partial
import jax.nn as nn
import jax.numpy as np
import jax.lax as lax
from jax.ops import index_update
import jax.random as random
from jax.tree_util import tree_flatten, tree_structure, tree_unflatten, tree_multimap, tree_map
import optax

from kaggle_environments import evaluate, make
from kaggle_environments.envs.halite.helpers import ShipAction, ShipyardAction, board_agent

from halite_jax import Configuration, State, Cell, Action, \
environment_to_initial_state, generate_episode, Action, episode_to_environment, generate_empty_state, \
get_next_state

In [18]:
@partial(jit, static_argnums=(1,))
def random_mask(rng, shape, count, filter_mask=None):
    filter_mask = filter_mask if filter_mask is not None else np.array(np.zeros(shape), bool)
    filter_mask = np.ravel(filter_mask)
    mask = random.uniform(rng, shape=filter_mask.shape)
    mask = np.where(filter_mask, np.inf, mask)
    mask = mask < np.sort(mask)[count]
    mask = np.reshape(mask, shape)
    return mask

rng, r = random.split(rng)
np.array(random_mask(r, (5, 5), 5), int)

DeviceArray([[0, 1, 0, 0, 0],
             [0, 0, 0, 0, 1],
             [0, 0, 0, 0, 1],
             [0, 0, 0, 1, 0],
             [0, 0, 0, 0, 1]], dtype=int32)

In [2]:
class Trajectory(NamedTuple):
    states: State
    actions: Action

## Generate Initial State

In [19]:
# initial_state.py
import jax.numpy as np
import jax.random as random

from halite_jax.public_types import State, Cell, Action

def generate_initial_state(rng, configuration, player_count, 
       halite_cell_pct, initial_player_halite=5000., ships_per_player=1, shipyards_per_player=0):
    
    board_size = configuration['size']
    max_cell_halite = configuration['maxCellHalite']
    starting_halite = configuration['startingHalite']
    
    board_shape = (board_size, board_size)
    ship_count = player_count * ships_per_player
    shipyard_count = player_count * shipyards_per_player
    cell_count = board_size ** 2
    halite_count = np.array(np.floor(halite_cell_pct * cell_count), int)

    rng, r = random.split(rng)
    ship_mask = random_mask(r, (cell_count,), ship_count)

    rng, r = random.split(rng)
    shipyard_mask = random_mask(r, (cell_count,), shipyard_count, ship_mask)

    rng, r = random.split(rng)
    halite_mask = random_mask(r, (cell_count,), halite_count, ship_mask | shipyard_mask)
    
    rng, r = random.split(rng)
    halite = random.uniform(r, minval=0, maxval=max_cell_halite, shape=(cell_count,))
    halite = np.where(halite_mask, halite, 0.)
    halite /= np.sum(halite)
    halite *= starting_halite
    
    owner = np.ones(cell_count) * -1
    for player in range(player_count):
        rng, r = random.split(rng)
        owner_filter_mask = ~((ship_mask | shipyard_mask) & (owner == -1))
        owner_mask = random_mask(r, (cell_count,), ships_per_player + shipyards_per_player, owner_filter_mask)
        owner = np.where(owner_mask, player, owner)
        
    return State(
        halite=np.array([initial_player_halite] * player_count, float),
        step=np.array(0, int),
        cells=Cell(
            owner=np.array(owner.reshape(board_shape), int),
            halite=halite.reshape(board_shape),
            shipyard=np.array(shipyard_mask.reshape(board_shape), bool),
            ship=np.array(ship_mask.reshape(board_shape), bool),
            cargo=np.zeros(board_shape, float)
        )
    )

## Model

In [20]:
nn.sigmoid(np.linspace(-10, 10, 11))

DeviceArray([4.5397872e-05, 3.3535014e-04, 2.4726230e-03, 1.7986210e-02,
             1.1920292e-01, 5.0000000e-01, 8.8079715e-01, 9.8201376e-01,
             9.9752742e-01, 9.9966466e-01, 9.9995458e-01], dtype=float32)

In [21]:
def decode_action(state, action):
    ship_location = np.unravel_index(np.argmax(state.cells.ship), state.cells.ship.shape)
    return action.ship[ship_location]

def encode_action(state, action):
    return Action(
        ship=np.where(state.cells.ship, action, 0),
        shipyard=np.zeros_like(state.cells.shipyard, np.int32)
    )

def get_state_value(state):
    return np.sum(state.cells.cargo)

def loss(weights, trajectory, model):
    final_state = tree_map(lambda a: a[-1], trajectory.states)
    value = get_state_value(final_state)
    
    @jax.grad
    def _batch_compute_loss(weights):
        @jax.vmap
        def _compute_loss(state, action):
            action = decode_action(state, action)
            action_logit, predicted_value = model.apply(weights, state)
            action_prob = jax.nn.softmax(action_logit)
            action_prob = action_prob[action]
            action_prob = np.squeeze(action_prob)
            advantage = value - predicted_value
            action_log_prob = np.log(action_prob)
            actor_loss = action_log_prob * advantage
            critic_loss = np.square(advantage)
            loss = critic_loss - actor_loss
            return loss

        losses = _compute_loss(trajectory.states, trajectory.actions)
        return np.sum(losses)
    
    return _batch_compute_loss(weights)

def manhattan_distances(size):
    assert size % 2 == 1, 'requires odd size'
    center = size // 2
    rows = []
    for i in range(size):
        row = []
        for j in range(size):
            row.append(abs(i - center) + abs(j - center))
        rows.append(row)
    return np.array(rows)

class Radar(hk.Module): 
    def __call__(self, location, field):
        size = field.shape[0]
        assert size % 2 == 1, 'requires odd size'
        assert field.shape[0] == field.shape[1]
        center = size // 2
        distances = manhattan_distances(size) - 1
        indexes = [
            jax.ops.index[:center, :], #north
            jax.ops.index[:, center+1:], #east
            jax.ops.index[center+1:, :], #south
            jax.ops.index[:, :center] #west
        ]        
        roll = location[0], location[1]
        field = np.roll(field, (center-roll[0], center-roll[1]), axis=(0, 1))
        init = hk.initializers.RandomUniform()
        discount = hk.get_parameter('discount', [1], np.float32, init=init)
        discount = nn.sigmoid(discount)
        def _measure_direction(index):
            distances_ = np.ravel(distances[index])
            weights = np.power(discount, distances_)
            values = np.ravel(field[index])
            return np.dot(weights, values)
        return np.stack([_measure_direction(index) for index in indexes])
    
class ShipModel(hk.Module):
    def __init__(self, action_count, name=None):
        super().__init__(name=name)
        self.action_count = action_count
        
    def __call__(self, ship_location, state):
        ship_cell = tree_map(lambda a: a[ship_location], state.cells)
        ship_is_over_halite = ship_cell.halite
        halite_radar = Radar()(ship_location, state.cells.halite)  
        max_radar = np.max(halite_radar)
        mine = hk.Linear(1, name='mine')(np.array([
            ship_is_over_halite,
        ]))
        
        move_linear = hk.Linear(1, name='move')
        def get_direction(i):
            r = np.array([
                halite_radar[i],
            ])
            return move_linear(r)[0]
        moves = np.array(list(map(get_direction, range(4))))
        return np.concatenate([mine, moves])

class CriticModel(hk.Module):        
    def __call__(self, state):
        board_halite = np.sum(state.cells.halite)
        features = np.array([
            board_halite
        ])
        value = hk.Sequential([
            hk.Linear(1)
        ])(features)[0]         
        return value
    
def ActorCriticModel(action_count):
    @hk.without_apply_rng
    @hk.transform
    def model(state):
        ship_model = ShipModel(action_count)
        critic_model = CriticModel()
        ship_location = np.unravel_index(np.argmax(state.cells.ship), state.cells.ship.shape)
        policy = ship_model(ship_location, state)
        value = critic_model(state)
        return policy, value
    return model

In [56]:
rng = random.PRNGKey(42)

In [60]:
board_size = 5
rng, r = random.split(rng)
weights = model.init(r, generate_empty_state(board_size))
optstate = optimizer.init(weights)
episode_steps = 50
environment = make("halite", configuration=dict(size=board_size, episodeSteps=episode_steps, maxCellHalite=250.))
_ = environment.reset(1)
config = environment.configuration
epochs = 10

In [77]:
action_count = 5
model = ActorCriticModel(action_count)
optimizer = optax.adam(learning_rate=.01)
eval_batch_size = 128

def sample_action(rng, weights, state):
#    action_logits, _ = model.apply(weights, state)
#    action = random.categorical(rng, action_logits)
    action = random.randint(rng, minval=0, maxval=5, shape=())
    action = encode_action(state, action)
    return action

def sample_episode(rng, weights):
    rng, r = random.split(rng)
    initial_state = generate_initial_state(r, config, player_count=1,
                                           halite_cell_pct=.5, 
                                          ships_per_player=1)
    
    def step(carry, x):
        rng, state = carry
        rng, r = random.split(rng)
        action = sample_action(r, weights, state)
        next_state = get_next_state(config, state, action)
        return (rng, next_state), Trajectory(state, action)

    init = rng, initial_state
    _, trajectory = jax.lax.scan(step, init, xs=None, length=episode_steps)
    return trajectory

def evaluate(rng, weights):
    trajectory = sample_episode(r, weights)
    final_state = tree_map(lambda a: a[-1], trajectory.states)
    value = get_state_value(final_state)
    return value

In [78]:
try:
    for epoch in range(0, epochs):
        rng, r = random.split(rng)
        trajectory = sample_episode(r, weights)
        grads = loss(weights, trajectory, model)
        updates, optstate = optimizer.update(grads, optstate, weights)
        weights = optax.apply_updates(weights, updates)
        print(f'grads {grads}')
        rng, r = random.split(rng)
        value = evaluate(r, weights)
        print(f'{epoch} {value}')
        
except KeyboardInterrupt:
    pass

grads FlatMap({
  'critic_model/linear': FlatMap({
                           'b': DeviceArray([-inf], dtype=float32),
                           'w': DeviceArray([[-inf]], dtype=float32),
                         }),
  'ship_model/mine': FlatMap({
                       'b': DeviceArray([nan], dtype=float32),
                       'w': DeviceArray([[nan]], dtype=float32),
                     }),
  'ship_model/move': FlatMap({
                       'b': DeviceArray([nan], dtype=float32),
                       'w': DeviceArray([[nan]], dtype=float32),
                     }),
  'ship_model/radar': FlatMap({'discount': DeviceArray([nan], dtype=float32)}),
})
0 425.0
grads FlatMap({
  'critic_model/linear': FlatMap({
                           'b': DeviceArray([nan], dtype=float32),
                           'w': DeviceArray([[nan]], dtype=float32),
                         }),
  'ship_model/mine': FlatMap({
                       'b': DeviceArray([nan], dtype=float32),
             

In [81]:
rng, r = random.split(rng)
episode = sample_episode(r, weights)
environment = episode_to_environment(config, episode)
environment.render(mode="ipython", width=800, height=700)