In [1]:
import glob
import json
from tqdm.auto import tqdm
import bz2
import pickle
import os
import torch
import gc

In [2]:
path = "../input/halite-iml-dataset/*[!_info].json"

In [3]:
!ls ../input/halite-iml-dataset | wc -l

4840


In [4]:
from kaggle_environments.envs.halite.helpers import *
def get_observation(raw_obs, conf):
    raw_obs['remainingOverageTime'] = 60
    board = Board(raw_observation=raw_obs, \
                  raw_configuration=conf)

    obs = torch.zeros((9,\
                       board.configuration.size,\
                       board.configuration.size))

    current_player = board.current_player
    opponents = board.opponents
    
    # player halite
    obs[7] = current_player.halite/5000
    obs[8] = opponents[0].halite/5000
    
    # Halite map
    for c in board.cells:
        obs[(0,)+tuple(c)] = board.cells[c].halite/1000

    # Ships map
    for s in current_player.ships:
        obs[(1,)+tuple(s.position)] = 1
    for i,o in enumerate(opponents):
        for s in o.ships:
            obs[(i+2,)+tuple(s.position)] = 1

    # Ships halite map
    for s in current_player.ships:
        obs[(3,)+tuple(s.position)] = s.halite/1000
    for i,o in enumerate(opponents):
        for s in o.ships:
            obs[(i+4,)+tuple(s.position)] = s.halite/1000

    # Shipyard map
    for s in current_player.shipyards:
        obs[(5,)+tuple(s.position)] = 1
    for i,o in enumerate(opponents):
        for s in o.shipyards:
            obs[(i+6,)+tuple(s.position)] = 1

    return obs, board

In [5]:
inverseShipActions = {
    'NORTH': 1,
    'EAST': 2, # droite
    'SOUTH': 3,
    'WEST': 4, # gauche
    'CONVERT': 5
}

In [6]:
MAX_STEP = 325

In [7]:
counter = 0
for l,path_to_json in enumerate(tqdm(glob.glob(path))):
    """if l > 2000:
        break"""
    with open(path_to_json, 'r') as json_file:
        data = json_file.read()
        steps = json.loads(data)["steps"]
        conf = json.loads(data)["configuration"]
        
        if len(steps)>75:
            episode_x = torch.zeros((min(len(steps)-1,MAX_STEP)-75, 9, 21, 21))
            episode_y = torch.zeros((min(len(steps)-1,MAX_STEP)-75, 2, 2, 21, 21))
            for i in range(75, min(len(steps)-1,MAX_STEP)):#range(len(steps)-1):
                # for the observation
                episode_x[i-75], board = get_observation(steps[i][0]['observation'], conf)
                # for the actions
                # for the current player
                for s in board.current_player.ships:
                    real_action = 0
                    if steps[i+1][0]['action'] is not None \
                        and s.id in steps[i+1][0]['action']:

                        action = steps[i+1][0]['action'][s.id]
                        real_action = inverseShipActions[action]

                    episode_y[i-75, 0, 0, s.position.x, s.position.y] = real_action

                for sy in board.current_player.shipyards:
                    if steps[i+1][0]['action'] is not None \
                        and sy.id in steps[i+1][0]['action']:

                        episode_y[i-75, 0, 1, sy.position.x, sy.position.y] = 1

                # for the opponent
                for j,o in enumerate(board.opponents):
                    for s in o.ships:
                        real_action = 0
                        if steps[i+1][1]['action'] is not None \
                            and s.id in steps[i+1][1]['action']:
                            action = steps[i+1][1]['action'][s.id]
                            real_action = inverseShipActions[action]

                        episode_y[i-75, 1, 0, s.position.x, s.position.y] = real_action

                for j,o in enumerate(board.opponents):
                    for sy in o.shipyards:
                        if steps[i+1][1]['action'] is not None \
                            and sy.id in steps[i+1][1]['action']:

                            episode_y[i-75, 1, 1, sy.position.x, sy.position.y] = 1
            #torch.save(episode_x,f"x_{counter}.pt")
            #torch.save(episode_y,f"y_{counter}.pt")
            pickle.dump(bz2.compress(pickle.dumps(episode_x)),\
                        open(f"x_{counter}.pt","wb"))
            pickle.dump(bz2.compress(pickle.dumps(episode_y)),\
                        open(f"y_{counter}.pt","wb"))
            counter += 1

  0%|          | 0/4840 [00:00<?, ?it/s]

In [8]:
print("ok")

ok
