In [None]:
%pip install \
    git+https://github.com/mauricef/halite-iv-jax.git@709ff0fbfbf09c366a9c23a8565e83bc8ec2f34d \
#    git+https://github.com/deepmind/dm-haiku@v0.0.4 \
#    git+https://github.com/deepmind/optax@v0.0.9 \

In [1]:
import jax.numpy as np
from jax import jit, partial
import jax.numpy as np
import jax.random as random

from kaggle_environments import make

from halite_jax.py_to_jax import environment_to_initial_state
from halite_jax import generate_episode, episode_to_environment

Loading environment football failed: No module named 'gfootball'


In [64]:
@partial(jit, static_argnums=(1,))
def random_mask(rng, shape, count, filter_mask=None):
    filter_mask = filter_mask if filter_mask is not None else np.array(np.zeros(shape), bool)
    filter_mask = np.ravel(filter_mask)
    mask = random.uniform(rng, shape=filter_mask.shape)
    mask = np.where(filter_mask, np.inf, mask)
    mask = mask < np.sort(mask)[count]
    mask = np.reshape(mask, shape)
    return mask

rng, r = random.split(rng)
np.array(random_mask(r, (5, 5), 5), int)

DeviceArray([[0, 0, 0, 0, 0],
             [1, 0, 0, 0, 0],
             [0, 0, 0, 1, 1],
             [0, 0, 1, 1, 0],
             [0, 0, 0, 0, 0]], dtype=int32)

In [69]:
# initial_state.py
import jax.numpy as np
import jax.random as random

from halite_jax.public_types import State, Cell, Action

def generate_initial_state(rng, configuration, player_count, 
       halite_cell_pct, initial_player_halite=5000., ships_per_player=1, shipyards_per_player=0):
    
    board_size = configuration['size']
    max_cell_halite = configuration['maxCellHalite']
    starting_halite = configuration['startingHalite']
    
    board_shape = (board_size, board_size)
    ship_count = player_count * ships_per_player
    shipyard_count = player_count * shipyards_per_player
    cell_count = board_size ** 2
    halite_count = np.array(np.floor(halite_cell_pct * cell_count), int)

    rng, r = random.split(rng)
    ship_mask = random_mask(r, (cell_count,), ship_count)

    rng, r = random.split(rng)
    shipyard_mask = random_mask(r, (cell_count,), shipyard_count, ship_mask)

    rng, r = random.split(rng)
    halite_mask = random_mask(r, (cell_count,), halite_count, ship_mask | shipyard_mask)
    
    rng, r = random.split(rng)
    halite = random.uniform(r, minval=0, maxval=max_cell_halite, shape=(cell_count,))
    halite = np.where(halite_mask, halite, 0.)
    halite /= np.sum(halite)
    halite *= starting_halite
    
    owner = np.ones(cell_count) * -1
    for player in range(player_count):
        rng, r = random.split(rng)
        owner_filter_mask = ~((ship_mask | shipyard_mask) & (owner == -1))
        owner_mask = random_mask(r, (cell_count,), ships_per_player + shipyards_per_player, owner_filter_mask)
        owner = np.where(owner_mask, player, owner)
        
    return State(
        halite=np.array([initial_player_halite] * player_count, float),
        step=np.array(0, int),
        cells=Cell(
            owner=np.array(owner.reshape(board_shape), int),
            halite=halite.reshape(board_shape),
            shipyard=np.array(shipyard_mask.reshape(board_shape), bool),
            ship=np.array(ship_mask.reshape(board_shape), bool),
            cargo=np.zeros(board_shape, float)
        )
    )

In [70]:
rng = random.PRNGKey(42)
environment = make("halite", configuration=dict(size=5, episodeSteps=50, maxCellHalite=250.))
_ = environment.reset(1)
configuration = environment.configuration
rng, r = random.split(rng)
state = generate_initial_state(r, configuration, player_count=2, 
                               halite_cell_pct=.5, initial_player_halite=100, 
                               ships_per_player=2, shipyards_per_player=2)

In [75]:
# getting_started_agent.py

import jax.numpy as np
from jax.scipy.signal import correlate
import jax.random as random

from kaggle_environments.envs.halite.helpers import ShipAction, ShipyardAction

from halite_jax.public_types import Action

def generate_manhattan_distances(size):
    a = np.abs(np.arange(size) - size // 2)
    return a + a[:, None]

def generate_smoothing_kernel(kernel_size, discount=.5):
    center = kernel_size // 2
    distances = generate_manhattan_distances(kernel_size)
    return np.power(discount, distances)

def calculate_smooth_field(field, discount=.5):
    kernel_size = field.shape[0]
    kernel = generate_smoothing_kernel(kernel_size, discount)
    field = np.pad(field, pad_width=kernel_size//2, mode='wrap')
    return correlate(field, kernel, mode='valid')

def roll_moves(a):
    rolled = np.array([
        a,
        np.roll(a, (1, 0), axis=(0, 1)),
        np.roll(a, (0, -1), axis=(0, 1)),
        np.roll(a, (-1, 0), axis=(0, 1)),
        np.roll(a, (0, 1), axis=(0, 1)),
    ])
    rolled = np.moveaxis(rolled, 0, -1)
    return rolled


def agent(
    state, 
    rng, 
    player,
    shipyard_discount,
    cell_halite_discount,
    min_mining_halite,
    max_cargo
    ):
    
    board_shape = state.cells.owner.shape
    size = board_shape[0]
    
    ship_actions = np.zeros(board_shape)
    shipyard_actions = np.zeros(board_shape)

    ships = np.array(state.cells.ship & (state.cells.owner == player), float)
    shipyards = np.array(state.cells.shipyard & (state.cells.owner == player), float)
    ship_count = np.sum(ships)
    shipyard_count = np.sum(shipyards)

    shipyard_action = np.zeros_like(shipyards)
    ship_actions = np.zeros_like(ships)
    is_converting = np.array(ship_actions == ShipAction.CONVERT.value, float)

    cell_halite = calculate_smooth_field(state.cells.halite, discount=cell_halite_discount)
    cell_halite_moves = roll_moves(cell_halite)
    ship_direction_to_most_halite = np.argmax(cell_halite_moves, -1)

    shipyard_distances = calculate_smooth_field(shipyards, discount=shipyard_discount)
    shipyard_distances = roll_moves(shipyard_distances)
    ship_direction_to_nearest_shipyard = np.argmax(shipyard_distances, axis=-1)

    ships_to_move_to_shipyard = (ships == 1) & (state.cells.cargo >= max_cargo)
    ships_to_mine_halite = (ships == 1) & (state.cells.halite >= min_mining_halite) & (~ships_to_move_to_shipyard)
    ships_to_move_to_halite = (ships == 1) & ~ships_to_move_to_shipyard & ~ships_to_mine_halite
    ship_action = np.where(ships_to_move_to_halite, ship_direction_to_most_halite, 0)
    ship_action += np.where(ships_to_move_to_shipyard, ship_direction_to_nearest_shipyard, 0)
    action = Action(ship=ship_action, shipyard=shipyard_action)
    return action

In [82]:
rng = random.PRNGKey(42)
environment = make("halite", configuration=dict(size=11, episodeSteps=100, maxCellHalite=500.))
_ = environment.reset(1)
configuration = environment.configuration

@jit
def generate_episode_(rng):
    agent_ = partial(agent, **dict(
        shipyard_discount = .5,
        cell_halite_discount = .5,
        min_mining_halite = 100,
        max_cargo = 500))
    
    rng, r = random.split(rng)
    initial_state = generate_initial_state(
        r, 
        configuration, 
        player_count = 1,
        halite_cell_pct = .05,
        initial_player_halite = 5000,
        ships_per_player = 1,
        shipyards_per_player = 0.
    )
    rng, r = random.split(rng)
    return generate_episode(configuration, [agent_], initial_state, r)

In [83]:
rng, r = random.split(rng)
episode = generate_episode_(r)

In [84]:
environment = episode_to_environment(configuration, episode)
environment.render(mode="ipython", width=800, height=700)

In [86]:
import haiku as hk

import jax
from jax import jit, partial, vmap, grad
from jax import random
import jax.lax as lax
import jax.nn as nn
import jax.numpy as np

import optax