For Kaggle

In [1]:
# !pip install kaggle-environments -U > /dev/null 2>&1
# !cp -r ../input/lux-ai-2021/* .

For agents validation

In [2]:
# timeout 1h lux-ai-2021 --tournament --rankSystem wins --storeReplay false --storeLogs false --maxConcurrentMatches 1 agent/main.py agent_simple/main.py submission_v4/main.py

In [1]:
import os
import numpy as np
import json
from pathlib import Path
import os
import random
from tqdm.notebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.transforms import ToTensor, RandomHorizontalFlip, RandomVerticalFlip
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split
import optuna
from optuna.trial import TrialState

In [2]:
def seed_everything(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

seed_everything(42)

# Preprocessing

Actions:
- m - move: m, unit who moves, direction
- bcity - build city: bcity, unit who builds
- bw - build worker, x-coord of building city, y-coord of building city
- r - research , x-coord of researching city, y-coord of researching city
- t - transfer: transfer, from user_1, to user_2, resourse, quantity

Updates:
- rp - research point: player, number of rp
- r - resources: r, type of resource, x-coord, y-coord, quantity
- u - user: u, worker/cart, player, user id, x-coord, y-coord, cooldown, wood, coal, uranium
- c - city: c, player, city id, number of resources, amount of consuming light at night
- ct - city tile: ct, player, city id, x-coord, y-coord, cooldown
- ccd - level of road: ccd, x-coord, y-coord, level value

Remove not necessary files from dir

In [3]:
!cd episodes && rm *_info.json && cd ..

In [4]:
episode_list = list()
episode_dir = 'episodes'

with open(episode_dir + '/' +'eps.txt', 'r') as f:
    for s in f:
        s = s.split(' ')
        episode_list.append(s[-1][1:-1] + '.json')
        
for fname in episode_list:
    fname = episode_dir + '/' + fname
    if os.path.isfile(fname):
        os.remove(fname)

Get data from json files

In [5]:
from random import choice

def unit_label(action):
    strs = action.split(' ')
    unit_id = strs[1]
    if strs[0] == 'm':
        label = {'c': 0, 'n': 1, 's': 2, 'w': 3, 'e': 4}[strs[2]]
    elif strs[0] == 'bcity':
        label = 5
    else:
        label = None
    return unit_id, label

def city_label(action):
    strs = action.split(' ')
    if len(strs) < 3:
        return None, None
    ctile_coord = (strs[1], strs[2])
    if strs[0] == 'bw':
        label = 0
    elif strs[0] == 'r':
        label = 1
    else:
        label = None
    return ctile_coord, label

def unit_coord_and_cooldown(updates, index):
    strs = updates.split(' ')
    if strs[0] == 'u' and strs[2] == index:
        unit_id = strs[3]
        unit_coord = (strs[4], strs[5])
        cooldown = strs[6]
        return unit_id, unit_coord, cooldown
    return None, None, None

def city_tile_coord(updates, index):
    strs = updates.split(' ')
    if strs[0] == 'ct' and strs[1] == index:
        ctile_coord = (strs[3], strs[4])
        cooldown = strs[5]
        return ctile_coord, cooldown
    return None, None

def research_points(updates, index):
    for u in updates:
        strs = u.split(' ')
        if strs[0] == 'rp' and strs[1] == index:
            return int(strs[2])
    return 0

def depleted_resources(obs):
    for u in obs['updates']:
        if u.split(' ')[0] == 'r':
            return False
    return True


def create_dataset_from_json(episode_dir, team_name='Toad Brigade'): 
    obses = {}
    unit_samples = []
    city_samples = []
    
    episodes = [path for path in Path(episode_dir).glob('*.json') if 'output' not in path.name]
    for filepath in tqdm(episodes): 
        with open(filepath) as f:
            json_load = json.load(f)
        # load episode from episode json file
        ep_id = json_load['info']['EpisodeId']
        index = np.argmax([r or 0 for r in json_load['rewards']])
        if json_load['info']['TeamNames'][index] != team_name:
            continue

        updates = None
        
        for i in range(len(json_load['steps'])-1):
            units = {}
            city_tiles = {}
            
            if json_load['steps'][i][index]['status'] == 'ACTIVE':
                actions = json_load['steps'][i+1][index]['action']
                obs = json_load['steps'][i][0]['observation']
                updates = obs['updates']
                rp = research_points(updates, str(index))
                # get updates from previous steps
                if i > 0:
                    obs['updates_lag_1'] = json_load['steps'][i-1][0]['observation']['updates']
                else:
                    obs['updates_lag_1'] = None
                if i > 1:
                    obs['updates_lag_2'] = json_load['steps'][i-2][0]['observation']['updates']
                else:
                    obs['updates_lag_2'] = None
                if i > 2:
                    obs['updates_lag_3'] = json_load['steps'][i-3][0]['observation']['updates']
                else:
                    obs['updates_lag_3'] = None
                if i > 3:
                    obs['updates_lag_4'] = json_load['steps'][i-4][0]['observation']['updates']
                else:
                    obs['updates_lag_4'] = None
                if i > 4:
                    obs['updates_lag_5'] = json_load['steps'][i-5][0]['observation']['updates']
                else:
                    obs['updates_lag_5'] = None
                if i > 5:
                    obs['updates_lag_6'] = json_load['steps'][i-6][0]['observation']['updates']
                else:
                    obs['updates_lag_6'] = None
                if i > 6:
                    obs['updates_lag_7'] = json_load['steps'][i-7][0]['observation']['updates']
                else:
                    obs['updates_lag_7'] = None
                
                for u in updates:
                    # get coords for every friendly city tile
                    ctile_coord, cooldown = city_tile_coord(u, str(index))
                    if ctile_coord and cooldown:
                        city_tiles[ctile_coord] = [ctile_coord, float(cooldown)]
                        continue
                    # get cooldown for every friendly unit 
                    unit_id, unit_coord, cooldown = unit_coord_and_cooldown(u, str(index))
                    if unit_id and cooldown:
                        units[unit_id] = (unit_coord, float(cooldown))                    
                
                # stop research if resources are depleted
                if depleted_resources(obs):
                    break
                
                obs['player'] = index
                obs = dict([
                    (k,v) for k,v in obs.items() 
                    if k in ['step', 'updates', 'updates_lag_1', 'updates_lag_2', 
                             'updates_lag_3', 'updates_lag_4', 'updates_lag_5', 
                             'updates_lag_6', 'updates_lag_7', 'player', 'width', 'height']
                ])
                obs_id = f'{ep_id}_{i}'
                obses[obs_id] = obs
                
                for action in actions:
                    unit_id, label = unit_label(action)
                    # if unit acts - add this action to train
                    if label is not None:
                        unit_samples.append((obs_id, unit_id, label))
                        continue
                    # if unit can act but doesn't - add this action to action list
                    if unit_id in units:
                        cooldown = units[unit_id][1]
                        if cooldown == 0:
                            unit_samples.append((obs_id, unit_id, 0))
                            continue
                    ctile_coord, label = city_label(action)
                    # if city tile acts - add this action to train
                    if label is not None:
                        city_samples.append((obs_id, ctile_coord, label))
                        city_tiles[ctile_coord][0] = label
                        
#                 # it there is a possibility for city tiles to act, (research points number < 200 and number of
#                 # city tiles more than number of units find those of them who can act (cooldown = 0) but doesn't
#                 if rp < 200 and len(city_tiles) > len(units):
#                     city_tiles_no_action = [(k, v[0]) for k, v in city_tiles.items() if v[0] == 0 and v[1] == 0]
#                     # if there are cities with no action - randomly select one item from there 
#                     # and add it to city samples
#                     if city_tiles_no_action:
#                         ct = choice(city_tiles_no_action)
#                         city_samples.append((obs_id, ct[0], ct[1]))
                    
                    
    return obses, unit_samples, city_samples

In [6]:
episode_dir = 'episodes'
obses, samples, city_samples = create_dataset_from_json(episode_dir)

print('observations:', len(obses), 'worker samples:', len(samples), 'city samples:', len(city_samples))

  0%|          | 0/277 [00:00<?, ?it/s]

observations: 84269 worker samples: 604080 city samples: 44815


In [7]:
labels = [sample[-1] for sample in samples]
actions = ['center','north', 'south', 'west', 'east', 'bcity']
for value, count in zip(*np.unique(labels, return_counts=True)):
    print(f'{actions[value]}: {count}')

center: 31630
north: 152947
south: 146204
west: 119659
east: 117579
bcity: 36061


In [8]:
labels_city = [sample[-1] for sample in city_samples]
actions_city = ['build_worker', 'research']
for value, count in zip(*np.unique(labels_city, return_counts=True)):
    print(f'{actions_city[value]}: {count}')

build_worker: 14674
research: 30141


In [11]:
# obses['27781498_17']

In [12]:
# city_samples[200000]

# Training

b - training tensor of float32. b dimensions is 20x32x32

- b[0] - position of current unit
- b[1] - cargo sum/100 of current unit
- b[2] - previous position of current unit
- b[3] - pre-previous position of current unit
- b[4, 5, 6, 7] - position, cooldown/6, cargo sum/100 and Manhattan distance to units from friendly team
- b[8, 9, 10, 11] - position, cooldown/6, cargo sum/100 and Manhattan distance to units from adversarial team
- b[12, 13, 14] - position,  min(city fuel/city energy consumption, 10)/10 and Manhattan distance to cities from friendly team
- b[15, 16, 17] - position,  min(city fuel/city energy consumption, 10)/10 and Manhattan distance to cities from adversarial team
- b[18] - amount of wood / 800 on current position
- b[19] - amount of coal / 800 on current position
- b[20] - amount of uranium / 800 on current position
- b[21] - 1 if unit can mine this resource, 0 otherwise
- b[22] - Manhattan distance to this resource
- b[23] - research points / 200 of unit's team
- b[24] - research points / 200 of another team
- b[25] - time of the day (from 0 to 1, step 0.05)
- b[26] - step of the game (from 0 to 1, step 1/360)
- b[27] - day/night flag
- b[28] - map size

In [9]:
def manhattan_distance(x1, y1, x2, y2):
    return (abs(x2-x1) + abs(y2-y1))

# make list of users and their current and previous coords and cooldowns 
# to write them in NN training subset
def find_units_from_previous_obs(obs, x_shift, y_shift):
    # at first fill the unit dict with units from previous observation
    updates = obs['updates']
    updates_lag_1 = obs['updates_lag_1']
    updates_lag_2 = obs['updates_lag_2']
    updates_lag_3 = obs['updates_lag_3']
    updates_lag_4 = obs['updates_lag_4']
    updates_lag_5 = obs['updates_lag_5']
    updates_lag_6 = obs['updates_lag_6']
    updates_lag_7 = obs['updates_lag_7']
    prev_updates = [updates, updates_lag_1, updates_lag_2, updates_lag_3, 
                    updates_lag_4, updates_lag_5, updates_lag_6, updates_lag_7]
    prev_units = list()
    for prev_update in prev_updates:
        units = dict()
        if prev_update:
            for update in prev_update:
                strs = update.split(' ')
                input_identifier = strs[0]
                # if we found observation for user
                if input_identifier == 'u':
                    unit_id = strs[3]
                    x = int(strs[4]) + x_shift
                    y = int(strs[5]) + y_shift
                    cooldown = float(strs[6])
                    units[unit_id] = [x, y, cooldown]
            prev_units.append(units)
        else:
            break
    return prev_units

def find_prev_coords(prev_units, unit_id):
    # find previous coordinates of unit by ananlyzing its cooldown
    x, y = [None, None], [None, None]
    for i in range(len(prev_units)-1):
        if unit_id in prev_units[i] and unit_id in prev_units[i+1]:
            cooldown = prev_units[i][unit_id][0]
            prev_x = prev_units[i+1][unit_id][0]
            prev_y = prev_units[i+1][unit_id][1]
            prev_cooldown = prev_units[i+1][unit_id][2]
            if cooldown > 0 and prev_cooldown == 0:
                if x[1] and y[1]:
                    return x, y
                if not x[0] and not y[0]:
                    x[0], y[0] = prev_x,prev_y
                else:
                    x[1], y[1] = prev_x,prev_y
        else:
            break
    return x, y
            
# Input for Neural Network for workers
def make_input(obs, unit_id):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2

    cities = {}
    
    b = np.zeros((29, 32, 32), dtype=np.float32)
    
    prev_units = find_units_from_previous_obs(obs, x_shift, y_shift)
    x_u, y_u = prev_units[0][unit_id][0], prev_units[0][unit_id][1]
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        my_rp = 0
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if strs[3] == unit_id: # 0:2
                # Position, Cargo and Previous Position
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
                prev_x, prev_y = find_prev_coords(prev_units, unit_id)
                if not prev_x[0] and not prev_y[0]:
                    prev_x[0], prev_y[0] = x, y
                    prev_x[1], prev_y[1] = x, y
                b[2, prev_x[0], prev_y[0]] = 1
                b[3, prev_x[1], prev_y[1]] = 1
            else:                  # 4:11
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 4 + (team - obs['player']) % 2 * 4
                m_dist = manhattan_distance(x_u, y_u, x, y)
                b[idx:idx + 4, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100,
                    m_dist/((width-1) + (height-1))
                )
        elif input_identifier == 'ct':  # 12:17
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 12 + (team - obs['player']) % 2 * 3
            m_dist = manhattan_distance(x_u, y_u, x, y)
            b[idx:idx + 3, x, y] = (
                1,
                cities[city_id],
                m_dist/((width-1) + (height-1))
            )
        elif input_identifier == 'r':  # 18:21
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            access_level = {'wood': 0, 'coal': 50, 'uranium': 200}[r_type]
            access = 0 if my_rp < access_level else 1
            b[{'wood': 18, 'coal': 19, 'uranium': 20}[r_type], x, y] = amt / 800
            b[21, x, y] = access
            b[22, x, y] = manhattan_distance(x_u, y_u, x, y)/((width-1) + (height-1))
        elif input_identifier == 'rp':  # 23:24
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            my_rp = rp if team == obs['player'] else my_rp
            b[23 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    # Day/Night Cycle
    b[25, :] = obs['step'] % 40 / 40
    # Turns
    b[26, :] = obs['step'] / 360
    # Day/Night Flag
    b[27, :] = 1 if obs['step'] % 40 < 30 else 0
    # Map Size
    b[28, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1
        
    return b

In [14]:
# obses['26762301_160']

In [15]:
# import sys

# np.set_printoptions(threshold=sys.maxsize)

# make_input(obses['26762301_170'], 'u_36')[3]

Data for the cities:
- b[0] - position of current city
- b[1] - min(city fuel/city energy consumption, 10)/10 of current city
- b[2, 3, 4] - position, cooldown/10, and min(city fuel/city energy consumption, 10)/10 for cities from the same team
- b[5, 6, 7] - position, cooldown/10, and min(city fuel/city energy consumption, 10)/10 for cities from another team
- b[8, 9] - position and cargo sum/100 for units from the same team 
- b[10, 11] - position and cargo sum/100 for units from from another team
- b[12] - amount of wood / 800
- b[13] - amount of coal / 800
- b[14] - amount of uranium / 800
- b[15] - research points / 200 of unit's team
- b[16] - research piints / 200 of another team
- b[17] - time of the day (from 0 to 1, step 0.05)
- b[18] - step of the game (from 0 to 1, step 1/360)
- b[19] - day/night flag
- b[20] - map size

In [10]:
# Input for Neural Network for cities
def make_city_input(obs, city_coord):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((21, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'ct':
            # CityTiles
            city_id = strs[2]
            x = int(strs[3]) 
            y = int(strs[4])
            cooldown = float(strs[5])
            if x == int(city_coord[0]) and y == int(city_coord[1]):
                b[:2, x + x_shift, y + y_shift] = (
                    1,
                    cities[city_id]
                )
            else:
                team = int(strs[1])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x + x_shift, y + y_shift] = (
                    1,
                    cooldown / 10,
                    cities[city_id]
                )
        elif input_identifier == 'u':
            team = int(strs[2])
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                (wood + coal + uranium) / 100
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Day/Night Flag
    b[19, :] = 1 if obs['step'] % 40 < 30 else 0
    # Map Size
    b[20, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b

### Set modules for NN training

In [11]:
# Neural Network for Lux AI
class BasicConv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn, p=0.1):
        super().__init__()
        self.conv = nn.Conv2d(
            input_dim, output_dim, 
            kernel_size=kernel_size, 
            padding=(kernel_size[0] // 2, kernel_size[1] // 2)
        )
        self.bn = nn.BatchNorm2d(output_dim) if bn else None
        self.dropout = nn.Dropout(p) 

    def forward(self, x):
        h = self.conv(x)
        h = self.bn(h) if self.bn is not None else h
        h = self.dropout(h) 
        return h


class LuxNet(nn.Module):
    def __init__(self):
        super().__init__()
        layers, filters = 12, 64
        self.conv0 = BasicConv2d(29, filters, (3, 3), True, p=0.2)
        self.blocks = nn.ModuleList([BasicConv2d(filters, filters, (3, 3), True) for _ in range(layers)])
        self.head_p = nn.Linear(filters, 6, bias=False)

    def forward(self, x):
        h = F.relu_(self.conv0(x))
        for block in self.blocks:
            h = F.relu_(h + block(h))
        h_head = (h * x[:,:1]).view(h.size(0), h.size(1), -1).sum(-1)
        p = self.head_p(h_head)
        return p
    
    
class LuxCityNet(nn.Module):
    def __init__(self):
        super().__init__()
        layers, filters = 12, 32
        self.conv0 = BasicConv2d(21, filters, (3, 3), True, p=0.2)
        self.blocks = nn.ModuleList([BasicConv2d(filters, filters, (3, 3), True) for i in range(layers)])
        self.head_p = nn.Linear(filters, 2, bias=False)

    def forward(self, x):
        h = F.relu_(self.conv0(x))
        for block in self.blocks:
            h = F.relu_(h + block(h))
        h_head = (h * x[:,:1]).view(h.size(0), h.size(1), -1).sum(-1)
        p = self.head_p(h_head)
        return p

### Set dataloaders

In [12]:
# mask for apropriate change of moving action during rotation
# 0 - no transform, 1 - rotation on 90, 2 - rotation on 180, 
# 3 - rotation on 270, 4 - horizontal flip, 5 - vertical flip
mask = (
        (0, 1, 2, 3, 4, 5), 
        (0, 3, 4, 2, 1, 5), # N -> W, S -> E, W -> S, E -> N - rot 90
        (0, 2, 1, 4, 3, 5), # N -> S, S -> N, W -> E, E -> W - rot 180
        (0, 4, 3, 1, 2, 5), # N -> E, S -> W, W -> N, E -> S - rot 270
        (0, 2, 1, 3, 4, 5), # N -> S, S -> N, W -> W, E -> E - hflip
        (0, 1, 2, 4, 3, 5), # N -> N, S -> S, W -> E, E -> W - vflip
        )

def do_nothing(x):
    return x

class RandomChoice(torch.nn.Module):
    def __init__(self, transforms):
       super().__init__()
       self.transforms = transforms

    def __call__(self, x):
        idx = random.choice([i for i in range(len(self.transforms))])
        t = self.transforms[idx]
        x = torch.from_numpy(x)
        return t(x), idx

transform=RandomChoice([lambda x: x,
                        lambda x: torch.rot90(x, 1, [2, 1]),
                        lambda x: torch.rot90(x, 2, [2, 1]),
                        lambda x: torch.rot90(x, 1, [1, 2]),
                        TF.hflip,
                        TF.vflip])
    
class LuxDataset(Dataset):
    def __init__(self, obses, samples, transform=transform, mask=mask, city=False):
        self.obses = obses
        self.samples = samples
        self.data_len = len(self.samples)
        self.len = self.data_len*6
        self.transform = transform
        self.mask = mask
        self.city = city
            
    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        data_idx = idx % self.data_len
        obs_id, unit_id, action = self.samples[data_idx]
        obs = self.obses[obs_id]
        if not self.city:
            state = make_input(obs, unit_id)
        else:
            state = make_city_input(obs, unit_id)
        if self.transform:
            t = self.transform(state)
            state = t[0]
            if not self.city:
                action = self.mask[t[1]][action]
        return state, action

### Function for NN training

In [13]:
import matplotlib.pyplot as plt
# from torch.utils.tensorboard import SummaryWriter

def train_model(model, dataloaders_dict, criterion, optimizer, scheduler, num_epochs, city=False):
#     tb = SummaryWriter()
    
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        model.cuda()
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            epoch_loss = 0.0
            epoch_acc = 0
            
            dataloader = dataloaders_dict[phase]
            for item in tqdm(dataloader, leave=False):
                
                states = item[0].cuda().float()
                actions = item[1].cuda().long()

                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    policy = model(states)
                    loss = criterion(policy, actions)
                    _, preds = torch.max(policy, 1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    epoch_loss += loss.item() * len(policy)
                    epoch_acc += torch.sum(preds == actions.data)

            data_size = len(dataloader.dataset)
            epoch_loss = epoch_loss / data_size
            epoch_acc = epoch_acc.double() / data_size

#             if phase == 'train':
#                 tb.add_scalar("Train Loss", epoch_loss, epoch)
#                 tb.add_scalar("Train Accuracy", epoch_acc, epoch)
#             else:
#                 tb.add_scalar("Val Loss", epoch_loss, epoch)
#                 tb.add_scalar("Val Accuracy", epoch_acc, epoch)
            
            print(f'Epoch {epoch + 1}/{num_epochs} | {phase:^5} | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f}')
        
        if epoch_acc > best_acc:
            if city:
                traced = torch.jit.trace(model.cpu(), torch.rand(1, 21, 32, 32))
                traced.save('agent/model_city.pth')
            else:
                traced = torch.jit.trace(model.cpu(), torch.rand(1, 29, 32, 32))
                traced.save('agent/model.pth')
            best_acc = epoch_acc
            
        scheduler.step(epoch_loss)
        print(f'LR: {optimizer.param_groups[0]["lr"]}')
            
#     tb.close()

### Init dataloaders

In [14]:
# model for unit actions
model = LuxNet()
train, val = train_test_split(samples, test_size=0.1, random_state=42, stratify=labels)
batch_size = 256

train_loader = DataLoader(
    LuxDataset(obses, train), 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=2
)
val_loader = DataLoader(
    LuxDataset(obses, val), 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=2
)
dataloaders_dict = {"train": train_loader, "val": val_loader}

# model for city actions
model_city = LuxCityNet()
train_city, val_city = train_test_split(city_samples, test_size=0.1, random_state=42, stratify=labels_city)
batch_size_city = 128

train_city_loader = DataLoader(
    LuxDataset(obses, train_city, city=True), 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=2
)
    
val_city_loader = DataLoader(
    LuxDataset(obses, val_city, city=True), 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=2
)
dataloaders_city_dict = {"train": train_city_loader, "val": val_city_loader}

### Train models

In [18]:
%%time

# model = torch.jit.load('agent/model.pth')
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=1, threshold=1e-3, cooldown=3, min_lr=1e-6)

train_model(model, dataloaders_dict, criterion, optimizer, scheduler, num_epochs=5)

  0%|          | 0/12743 [00:00<?, ?it/s]

Epoch 1/5 | train | Loss: 0.7781 | Acc: 0.6937


  0%|          | 0/1416 [00:00<?, ?it/s]

Epoch 1/5 |  val  | Loss: 0.6309 | Acc: 0.7500
LR: 0.001


  0%|          | 0/12743 [00:00<?, ?it/s]

Epoch 2/5 | train | Loss: 0.6476 | Acc: 0.7436


  0%|          | 0/1416 [00:00<?, ?it/s]

Epoch 2/5 |  val  | Loss: 0.5875 | Acc: 0.7669
LR: 0.001


  0%|          | 0/12743 [00:00<?, ?it/s]

Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 411, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


KeyboardInterrupt: 

In [17]:
# Epoch 1/20 | train | Loss: 0.7660 | Acc: 0.7015
#   0%|          | 0/1419 [00:00<?, ?it/s]
# Epoch 1/20 |  val  | Loss: 0.6420 | Acc: 0.7505
# LR: 0.001
#   0%|          | 0/12765 [00:00<?, ?it/s]
# Epoch 2/20 | train | Loss: 0.6396 | Acc: 0.7492
#   0%|          | 0/1419 [00:00<?, ?it/s]
# Epoch 2/20 |  val  | Loss: 0.5845 | Acc: 0.7710
# LR: 0.001
#   0%|          | 0/12765 [00:00<?, ?it/s]
# Epoch 3/20 | train | Loss: 0.6002 | Acc: 0.7640
#   0%|          | 0/1419 [00:00<?, ?it/s]
# Epoch 3/20 |  val  | Loss: 0.5563 | Acc: 0.7812
# LR: 0.001
#   0%|          | 0/12765 [00:00<?, ?it/s]
# Epoch 4/20 | train | Loss: 0.5776 | Acc: 0.7719
#   0%|          | 0/1419 [00:00<?, ?it/s]
# Epoch 4/20 |  val  | Loss: 0.5406 | Acc: 0.7866
# LR: 0.001
#   0%|          | 0/12765 [00:00<?, ?it/s]
# Epoch 5/20 | train | Loss: 0.5624 | Acc: 0.7779
#   0%|          | 0/1419 [00:00<?, ?it/s]
# Epoch 5/20 |  val  | Loss: 0.5335 | Acc: 0.7903
# LR: 0.001
#   0%|          | 0/12765 [00:00<?, ?it/s]
# Epoch 6/20 | train | Loss: 0.5517 | Acc: 0.7820
#   0%|          | 0/1419 [00:00<?, ?it/s]
# Epoch 6/20 |  val  | Loss: 0.5235 | Acc: 0.7938
# LR: 0.001
#   0%|          | 0/12765 [00:00<?, ?it/s]
# Epoch 7/20 | train | Loss: 0.5436 | Acc: 0.7851
#   0%|          | 0/1419 [00:00<?, ?it/s]
# Epoch 7/20 |  val  | Loss: 0.5268 | Acc: 0.7920
# LR: 0.001
#   0%|          | 0/12765 [00:00<?, ?it/s]
# Epoch 8/20 | train | Loss: 0.5371 | Acc: 0.7876
#   0%|          | 0/1419 [00:00<?, ?it/s]
# Epoch 8/20 |  val  | Loss: 0.5149 | Acc: 0.7962

In [15]:
%%time

# model_city = torch.jit.load('agent/model_city.pth')
criterion_city = nn.CrossEntropyLoss()
optimizer_city = torch.optim.AdamW(model_city.parameters(), lr=1e-3)
scheduler_city = ReduceLROnPlateau(optimizer_city, 'min', patience=2, threshold=1e-3, cooldown=3, min_lr=1e-6)

train_model(model_city, dataloaders_city_dict, criterion_city, optimizer_city, scheduler_city, 
            num_epochs=30, city=True)

  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 1/30 | train | Loss: 0.3552 | Acc: 0.8417


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 1/30 |  val  | Loss: 0.2999 | Acc: 0.8724
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 2/30 | train | Loss: 0.2860 | Acc: 0.8754


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 2/30 |  val  | Loss: 0.2888 | Acc: 0.8810
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 3/30 | train | Loss: 0.2648 | Acc: 0.8860


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 3/30 |  val  | Loss: 0.2638 | Acc: 0.8900
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 4/30 | train | Loss: 0.2500 | Acc: 0.8929


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 4/30 |  val  | Loss: 0.2541 | Acc: 0.8964
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 5/30 | train | Loss: 0.2383 | Acc: 0.8977


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 5/30 |  val  | Loss: 0.2484 | Acc: 0.8942
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 6/30 | train | Loss: 0.2303 | Acc: 0.9024


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 6/30 |  val  | Loss: 0.2512 | Acc: 0.8969
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 7/30 | train | Loss: 0.2221 | Acc: 0.9055


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 7/30 |  val  | Loss: 0.2478 | Acc: 0.8941
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 8/30 | train | Loss: 0.2157 | Acc: 0.9078


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 8/30 |  val  | Loss: 0.2356 | Acc: 0.9039
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 9/30 | train | Loss: 0.2098 | Acc: 0.9111


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 9/30 |  val  | Loss: 0.2368 | Acc: 0.9049
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 10/30 | train | Loss: 0.2035 | Acc: 0.9138


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 10/30 |  val  | Loss: 0.2284 | Acc: 0.9079
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 11/30 | train | Loss: 0.1986 | Acc: 0.9157


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 11/30 |  val  | Loss: 0.2287 | Acc: 0.9088
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 12/30 | train | Loss: 0.1929 | Acc: 0.9183


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 12/30 |  val  | Loss: 0.2290 | Acc: 0.9158
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 13/30 | train | Loss: 0.1896 | Acc: 0.9200


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 13/30 |  val  | Loss: 0.2270 | Acc: 0.9108
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 14/30 | train | Loss: 0.1833 | Acc: 0.9232


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 14/30 |  val  | Loss: 0.2233 | Acc: 0.9090
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 15/30 | train | Loss: 0.1809 | Acc: 0.9230


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 15/30 |  val  | Loss: 0.2211 | Acc: 0.9129
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 16/30 | train | Loss: 0.1772 | Acc: 0.9258


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 16/30 |  val  | Loss: 0.2257 | Acc: 0.9074
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 17/30 | train | Loss: 0.1738 | Acc: 0.9270


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 17/30 |  val  | Loss: 0.2116 | Acc: 0.9172
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 18/30 | train | Loss: 0.1723 | Acc: 0.9269


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 18/30 |  val  | Loss: 0.2108 | Acc: 0.9195
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 19/30 | train | Loss: 0.1692 | Acc: 0.9287


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 19/30 |  val  | Loss: 0.2074 | Acc: 0.9186
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 20/30 | train | Loss: 0.1655 | Acc: 0.9306


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 20/30 |  val  | Loss: 0.2114 | Acc: 0.9151
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 21/30 | train | Loss: 0.1642 | Acc: 0.9305


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 21/30 |  val  | Loss: 0.2145 | Acc: 0.9164
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 22/30 | train | Loss: 0.1624 | Acc: 0.9320


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 22/30 |  val  | Loss: 0.2052 | Acc: 0.9206
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 23/30 | train | Loss: 0.1597 | Acc: 0.9325


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 23/30 |  val  | Loss: 0.2077 | Acc: 0.9190
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 24/30 | train | Loss: 0.1588 | Acc: 0.9326


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 24/30 |  val  | Loss: 0.2142 | Acc: 0.9182
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 25/30 | train | Loss: 0.1567 | Acc: 0.9341


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 25/30 |  val  | Loss: 0.2025 | Acc: 0.9216
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 26/30 | train | Loss: 0.1547 | Acc: 0.9351


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 26/30 |  val  | Loss: 0.2085 | Acc: 0.9172
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 27/30 | train | Loss: 0.1527 | Acc: 0.9352


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 27/30 |  val  | Loss: 0.1932 | Acc: 0.9263
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 28/30 | train | Loss: 0.1523 | Acc: 0.9361


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 28/30 |  val  | Loss: 0.1990 | Acc: 0.9235
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 29/30 | train | Loss: 0.1513 | Acc: 0.9361


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 29/30 |  val  | Loss: 0.2089 | Acc: 0.9205
LR: 0.001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 30/30 | train | Loss: 0.1491 | Acc: 0.9369


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 30/30 |  val  | Loss: 0.1926 | Acc: 0.9258
LR: 0.001
CPU times: user 59min 17s, sys: 1min 56s, total: 1h 1min 13s
Wall time: 1h 4min 29s


In [16]:
model_city = torch.jit.load('agent/model_city.pth')
criterion_city = nn.CrossEntropyLoss()
optimizer_city = torch.optim.AdamW(model_city.parameters(), lr=1e-4)
scheduler_city = ReduceLROnPlateau(optimizer_city, 'min', patience=2, threshold=1e-3, cooldown=3, min_lr=1e-6)

train_model(model_city, dataloaders_city_dict, criterion_city, optimizer_city, scheduler_city, 
            num_epochs=20, city=True)

  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 1/20 | train | Loss: 0.1068 | Acc: 0.9551


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 1/20 |  val  | Loss: 0.1959 | Acc: 0.9298
LR: 0.0001




  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 2/20 | train | Loss: 0.1002 | Acc: 0.9579


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 2/20 |  val  | Loss: 0.2021 | Acc: 0.9293
LR: 0.0001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 3/20 | train | Loss: 0.0969 | Acc: 0.9592


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 3/20 |  val  | Loss: 0.2069 | Acc: 0.9302
LR: 0.0001


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 4/20 | train | Loss: 0.0943 | Acc: 0.9600


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 4/20 |  val  | Loss: 0.2121 | Acc: 0.9290
LR: 1e-05


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 5/20 | train | Loss: 0.0900 | Acc: 0.9619


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 5/20 |  val  | Loss: 0.2099 | Acc: 0.9295
LR: 1e-05


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 6/20 | train | Loss: 0.0887 | Acc: 0.9627


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 6/20 |  val  | Loss: 0.2132 | Acc: 0.9300
LR: 1e-05


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 7/20 | train | Loss: 0.0890 | Acc: 0.9626


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 7/20 |  val  | Loss: 0.2123 | Acc: 0.9290
LR: 1e-05


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 8/20 | train | Loss: 0.0884 | Acc: 0.9632


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 8/20 |  val  | Loss: 0.2125 | Acc: 0.9294
LR: 1e-05


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 9/20 | train | Loss: 0.0872 | Acc: 0.9631


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 9/20 |  val  | Loss: 0.2178 | Acc: 0.9285
LR: 1e-05


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 10/20 | train | Loss: 0.0876 | Acc: 0.9631


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 10/20 |  val  | Loss: 0.2179 | Acc: 0.9298
LR: 1.0000000000000002e-06


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 11/20 | train | Loss: 0.0873 | Acc: 0.9631


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 11/20 |  val  | Loss: 0.2177 | Acc: 0.9301
LR: 1.0000000000000002e-06


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 12/20 | train | Loss: 0.0868 | Acc: 0.9637


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 12/20 |  val  | Loss: 0.2153 | Acc: 0.9293
LR: 1.0000000000000002e-06


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 13/20 | train | Loss: 0.0873 | Acc: 0.9632


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 13/20 |  val  | Loss: 0.2152 | Acc: 0.9301
LR: 1.0000000000000002e-06


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 14/20 | train | Loss: 0.0871 | Acc: 0.9633


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 14/20 |  val  | Loss: 0.2163 | Acc: 0.9298
LR: 1.0000000000000002e-06


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 15/20 | train | Loss: 0.0872 | Acc: 0.9631


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 15/20 |  val  | Loss: 0.2148 | Acc: 0.9291
LR: 1.0000000000000002e-06


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 16/20 | train | Loss: 0.0871 | Acc: 0.9633


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 16/20 |  val  | Loss: 0.2149 | Acc: 0.9303
LR: 1.0000000000000002e-06


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 17/20 | train | Loss: 0.0870 | Acc: 0.9637


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 17/20 |  val  | Loss: 0.2161 | Acc: 0.9293
LR: 1.0000000000000002e-06


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 18/20 | train | Loss: 0.0874 | Acc: 0.9631


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 18/20 |  val  | Loss: 0.2154 | Acc: 0.9307
LR: 1.0000000000000002e-06


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 19/20 | train | Loss: 0.0872 | Acc: 0.9634


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 19/20 |  val  | Loss: 0.2134 | Acc: 0.9301
LR: 1.0000000000000002e-06


  0%|          | 0/946 [00:00<?, ?it/s]

Epoch 20/20 | train | Loss: 0.0868 | Acc: 0.9631


  0%|          | 0/106 [00:00<?, ?it/s]

Epoch 20/20 |  val  | Loss: 0.2145 | Acc: 0.9311
LR: 1.0000000000000002e-06


In [23]:
# !tensorboard --logdir runs

# Submission

In [24]:
%%writefile agent/agent.py
import os
import json
import threading
import numpy as np
import torch
from lux.game import Game

path = '/kaggle_simulations/agent' if os.path.exists('/kaggle_simulations') else '.' # change to 'agent' for tests
model = torch.jit.load(f'{path}/model.pth')
model.eval()
model_city = torch.jit.load(f'{path}/model_city.pth')
model_city.eval()

def manhattan_distance(x1, y1, x2, y2):
    return (abs(x2-x1) + abs(y2-y1))

# make list of users and their current and previous coords and cooldowns 
# to write them in NN training subset
def find_units_from_previous_obs(obs, x_shift, y_shift):
    # at first fill the unit dict with units from previous observation
    updates = obs['updates']
    updates_lag_1 = obs['updates_lag_1']
    updates_lag_2 = obs['updates_lag_2']
    updates_lag_3 = obs['updates_lag_3']
    updates_lag_4 = obs['updates_lag_4']
    updates_lag_5 = obs['updates_lag_5']
    updates_lag_6 = obs['updates_lag_6']
    updates_lag_7 = obs['updates_lag_7']
    prev_updates = [updates, updates_lag_1, updates_lag_2, updates_lag_3, 
                    updates_lag_4, updates_lag_5, updates_lag_6, updates_lag_7]
    prev_units = list()
    for prev_update in prev_updates:
        units = dict()
        if prev_update:
            for update in prev_update:
                strs = update.split(' ')
                input_identifier = strs[0]
                # if we found observation for user
                if input_identifier == 'u':
                    unit_id = strs[3]
                    x = int(strs[4]) + x_shift
                    y = int(strs[5]) + y_shift
                    cooldown = float(strs[6])
                    units[unit_id] = [x, y, cooldown]
            prev_units.append(units)
        else:
            break
    return prev_units

def find_prev_coords(prev_units, unit_id):
    # find previous coordinates of unit by ananlyzing its cooldown
    x, y = [None, None], [None, None]
    for i in range(len(prev_units)-1):
        if unit_id in prev_units[i] and unit_id in prev_units[i+1]:
            cooldown = prev_units[i][unit_id][0]
            prev_x = prev_units[i+1][unit_id][0]
            prev_y = prev_units[i+1][unit_id][1]
            prev_cooldown = prev_units[i+1][unit_id][2]
            if cooldown > 0 and prev_cooldown == 0:
                if x[1] and y[1]:
                    return x, y
                if not x[0] and not y[0]:
                    x[0], y[0] = prev_x,prev_y
                else:
                    x[1], y[1] = prev_x,prev_y
        else:
            break
    return x, y
            
# Input for Neural Network for workers
def make_input(obs, unit_id):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2

    cities = {}
    
    b = np.zeros((29, 32, 32), dtype=np.float32)
    
    prev_units = find_units_from_previous_obs(obs, x_shift, y_shift)
    x_u, y_u = prev_units[0][unit_id][0], prev_units[0][unit_id][1]
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        my_rp = 0
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if strs[3] == unit_id: # 0:2
                # Position, Cargo and Previous Position
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
                prev_x, prev_y = find_prev_coords(prev_units, unit_id)
                if not prev_x[0] and not prev_y[0]:
                    prev_x[0], prev_y[0] = x, y
                    prev_x[1], prev_y[1] = x, y
                try:
                    b[2, prev_x[0], prev_y[0]] = 1
                    b[3, prev_x[1], prev_y[1]] = 1
                except IndexError:
                    b[2, x, y] = 1
                    b[3, x, y] = 1
            else:                  # 4:11
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 4 + (team - obs['player']) % 2 * 4
                m_dist = manhattan_distance(x_u, y_u, x, y)
                b[idx:idx + 4, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100,
                    m_dist/((width-1) + (height-1))
                )
        elif input_identifier == 'ct':  # 12:17
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 12 + (team - obs['player']) % 2 * 3
            m_dist = manhattan_distance(x_u, y_u, x, y)
            b[idx:idx + 3, x, y] = (
                1,
                cities[city_id],
                m_dist/((width-1) + (height-1))
            )
        elif input_identifier == 'r':  # 18:21
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            access_level = {'wood': 0, 'coal': 50, 'uranium': 200}[r_type]
            access = 0 if my_rp < access_level else 1
            b[{'wood': 18, 'coal': 19, 'uranium': 20}[r_type], x, y] = amt / 800
            b[21, x, y] = access
            b[22, x, y] = manhattan_distance(x_u, y_u, x, y)/((width-1) + (height-1))
        elif input_identifier == 'rp':  # 23:24
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            my_rp = rp if team == obs['player'] else my_rp
            b[23 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    # Day/Night Cycle
    b[25, :] = obs['step'] % 40 / 40
    # Turns
    b[26, :] = obs['step'] / 360
    # Day/Night Flag
    b[27, :] = 1 if obs['step'] % 40 < 30 else 0
    # Map Size
    b[28, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1
        
    return b

# Input for Neural Network for cities
def make_city_input(obs, city_coord):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((21, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'ct':
            # CityTiles
            city_id = strs[2]
            x = int(strs[3]) 
            y = int(strs[4])
            cooldown = float(strs[5])
            if x == int(city_coord[0]) and y == int(city_coord[1]):
                b[:2, x + x_shift, y + y_shift] = (
                    1,
                    cities[city_id]
                )
            else:
                team = int(strs[1])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x + x_shift, y + y_shift] = (
                    1,
                    cooldown / 10,
                    cities[city_id]
                )
        elif input_identifier == 'u':
            team = int(strs[2])
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                (wood + coal + uranium) / 100
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Day/Night Flag
    b[19, :] = 1 if obs['step'] % 40 < 30 else 0
    # Map Size
    b[20, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b

game_state = None
player = None


def get_game_state(observation):
    global game_state
    
    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation["player"]
    else:
        game_state._update(observation["updates"])
    return game_state


# check if unit is in city
def in_city(pos):    
    try:
        city = game_state.map.get_cell_by_pos(pos).citytile
        return city is not None and city.team == game_state.id 
    except:
        return False

# check if city will survive the night
def city_will_survive(pos):    
    try:
        city_id = game_state.map.get_cell_by_pos(pos).citytile.cityid
    except:
        return False
    if city_id in player.cities:
        city = player.cities[city_id]
        if city.fuel > (city.get_light_upkeep()) * 10:
            return True
    return False
    
# check if unit has enough time and space to build a city
def build_city_is_possible(unit, pos):    
    global game_state
    global player

    if game_state.turn % 40 < 30:
        return True
    x, y = pos.x, pos.y
    for i, j in ((x-1, y), (x+1, y), (x, y-1), (x, y+1)):
        try:
            city_id = game_state.map.get_cell(i, j).citytile.cityid
        except:
            continue
        if city_id in player.cities:
            city = player.cities[city_id]
            if city.fuel > (city.get_light_upkeep() + 18) * 10:
                return True
    return False
    

def call_func(obj, method, args=[]):
    return getattr(obj, method)(*args)


# translate unit policy to action
unit_actions = [('move', 'c'), ('move', 'n'), ('move', 's'), ('move', 'w'), ('move', 'e'), ('build_city',)]
def get_unit_action(policy, unit, dest):
    for label in np.argsort(policy)[::-1]:
        act = unit_actions[label]
        pos = unit.pos.translate(act[-1], 1) or unit.pos
        if label == 5 and not build_city_is_possible(unit, pos): # try to remove this !!!
            return unit.move('c'), unit.pos
        if pos not in dest or in_city(pos):
            return call_func(unit, *act), pos      
    return unit.move('c'), unit.pos

# translate city policy to action
city_actions = [('build_worker',), ('research', )]
def get_city_action(policy, city_tile):
    global player
    global unit_count
    
    for label in np.argsort(policy)[::-1]:
        act = city_actions[label]
        # build unit only if their number less than number of cities
        if label == 0 and unit_count < player.city_tile_count and unit_count < 80:
            unit_count += 1
            res = call_func(city_tile, *act)
        elif label == 1 and not player.researched_uranium():
            player.research_points += 1
            res = call_func(city_tile, *act)
        else:
            res = None
        return res

# agent for making actions
def agent(observation, configuration):
    global game_state
    global player
    global actions
    global dest
    global unit_count
    global city_count
    
    game_state = get_game_state(observation)    
    player = game_state.players[observation.player]
    actions = [] 
    dest = []
    prev_obs = dict()
    
    with open(f'{path}/tmp.json') as json_file:
        try:
            prev_obs = json.load(json_file)
        except json.decoder.JSONDecodeError:
            prev_obs['updates_lag_1'] = None
            prev_obs['updates_lag_2'] = None
            prev_obs['updates_lag_3'] = None
            prev_obs['updates_lag_4'] = None
            prev_obs['updates_lag_5'] = None
            prev_obs['updates_lag_6'] = None
            prev_obs['updates_lag_7'] = None
            
    observation['updates_lag_7'] = prev_obs['updates_lag_7']
    observation['updates_lag_6'] = prev_obs['updates_lag_6']
    observation['updates_lag_5'] = prev_obs['updates_lag_5']
    observation['updates_lag_4'] = prev_obs['updates_lag_4']
    observation['updates_lag_3'] = prev_obs['updates_lag_3']
    observation['updates_lag_2'] = prev_obs['updates_lag_2']
    observation['updates_lag_1'] = prev_obs['updates_lag_1']
    
    prev_obs['updates_lag_7'] = prev_obs['updates_lag_6']
    prev_obs['updates_lag_6'] = prev_obs['updates_lag_5']
    prev_obs['updates_lag_5'] = prev_obs['updates_lag_4']
    prev_obs['updates_lag_4'] = prev_obs['updates_lag_3']
    prev_obs['updates_lag_3'] = prev_obs['updates_lag_2']
    prev_obs['updates_lag_2'] = prev_obs['updates_lag_1']
    prev_obs['updates_lag_1'] = observation['updates']
    
    if game_state.turn == 0 or game_state.turn == 359:
        open(f'{path}/tmp.json', 'w+').close()
    else:
        with open(f'{path}/tmp.json', 'w+') as json_file:
            json.dump(prev_obs, json_file)
    
    # Unit Actions
    def unit_actions(unit, player, game_state, model, observation):
        global actions
        global dest
        if unit.can_act(): # and (game_state.turn % 40 < 31 or (not in_city(unit.pos))):
            state = make_input(observation, unit.id)
            with torch.no_grad():
                p = model(torch.from_numpy(state).unsqueeze(0))

            policy = p.squeeze(0).numpy()

            action, pos = get_unit_action(policy, unit, dest)
            actions.append(action)
            dest.append(pos)
            
    for unit in player.units:
        unit_actions(unit, player, game_state, model, observation)
        
    # City Actions
    def city_actions(city_tile, game_state, model, observation):
        global actions
        global dest
        global unit_count
        if city_tile.can_act():
            # on the last step build as many workers as possible to win the game in case of tie
            if game_state.turn > 350:
                actions.append(city_tile.build_worker())
            # if number of cities is too high or map size is 12. switch to simplify strategy to prevent too high lags
            elif player.city_tile_count > 80: # or (game_state.map.height == 12 and game_state.turn < 40):
                if unit_count < 80: 
                    actions.append(city_tile.build_worker())
                    unit_count += 1
                elif not player.researched_uranium():
                    actions.append(city_tile.research())
                    player.research_points += 1
            # else follow NN strategy
            else:
                state = make_city_input(observation, [city_tile.pos.x, city_tile.pos.y])
                with torch.no_grad():
                    p = model_city(torch.from_numpy(state).unsqueeze(0))

                policy = p.squeeze(0).numpy()

                action = get_city_action(policy, city_tile)
                if action:
                    actions.append(action)    

    unit_count = len(player.units)
    for city in player.cities.values():
        for city_tile in city.citytiles:
            city_actions(city_tile, game_state, model, observation)
            
    
    return actions

Overwriting agent/agent.py


Submit predictions

In [25]:
!cd agent && tar -czf submission.tar.gz lux agent.py main.py model.pth model_city.pth tmp.json

Test agents on 12x12 field

In [37]:
# from kaggle_environments import make

# env = make("lux_ai_2021", configuration={"width": 12, "height": 12, "loglevel": 2, "annotations": True}, debug=False)

# # first agent is yellow
# # second agent is blue
# steps = env.run(['agent/agent.py', 'agent.py'])

# env.render(mode="ipython", width=1200, height=800)

Test agent on 16x16 field

In [39]:
# env = make("lux_ai_2021", configuration={"width": 16, "height": 16, "loglevel": 2, "annotations": True}, debug=False)

# # first agent is yellow
# # second agent is blue
# steps = env.run(['agent/agent.py', 'agent.py'])

# env.render(mode="ipython", width=1200, height=800)

Test agent on 24x24 field

In [42]:
# env = make("lux_ai_2021", configuration={"width": 24, "height": 24, "loglevel": 2, "annotations": True}, debug=False)

# # first agent is yellow
# # second agent is blue
# steps = env.run(['agent/agent.py', 'agent.py'])

# env.render(mode="ipython", width=1200, height=800)

Test agents on 32x32 field

In [43]:
# env = make("lux_ai_2021", configuration={"width": 32, "height": 32, "loglevel": 2, "annotations": True}, debug=False)

# # first agent is yellow
# # second agent is blue
# steps = env.run(['agent/agent.py', 'agent.py'])

# env.render(mode="ipython", width=1200, height=800)

### Optimize NN parameters with Optuna

In [32]:
# def objective(trial):

#     num_epochs = 10
    
#     # model for unit actions
#     model = LuxNet()
#     train, val = train_test_split(samples, test_size=0.1, random_state=42, stratify=labels)
#     batch_size = 64

#     train_loader = DataLoader(
#         LuxDataset(obses, train), 
#         batch_size=batch_size, 
#         shuffle=True, 
#         num_workers=2
#     )
#     val_loader = DataLoader(
#         LuxDataset(obses, val), 
#         batch_size=batch_size, 
#         shuffle=False, 
#         num_workers=2
#     )
#     dataloaders_dict = {"train": train_loader, "val": val_loader}

#     # Generate the optimizers.
#     criterion = nn.CrossEntropyLoss()
#     optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "AdamW", "RMSprop", "SGD"])
#     lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
#     optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)

#     for epoch in range(num_epochs):
#         model.cuda()
        
#         for phase in ['train', 'val']:
#             if phase == 'train':
#                 model.train()
#             else:
#                 model.eval()
                
#             epoch_loss = 0.0
#             epoch_acc = 0
            
#             dataloader = dataloaders_dict[phase]
#             for item in dataloader:
#                 states = item[0].cuda().float()
#                 actions = item[1].cuda().long()

#                 optimizer.zero_grad()
                
#                 with torch.set_grad_enabled(phase == 'train'):
#                     policy = model(states)
#                     loss = criterion(policy, actions)
#                     _, preds = torch.max(policy, 1)

#                     if phase == 'train':
#                         loss.backward()
#                         optimizer.step()

#                     epoch_loss += loss.item() * len(policy)
#                     epoch_acc += torch.sum(preds == actions.data)

#             data_size = len(dataloader.dataset)
#             epoch_loss = epoch_loss / data_size
#             epoch_acc = epoch_acc.double() / data_size

#         trial.report(epoch_acc, epoch)

#         # Handle pruning based on the intermediate value.
#         if trial.should_prune():
#             raise optuna.exceptions.TrialPruned()

#     return epoch_acc


# study = optuna.create_study(direction="maximize")
# study.optimize(objective, n_trials=500, timeout=10*3600)

# pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
# complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

# print("Study statistics: ")
# print("  Number of finished trials: ", len(study.trials))
# print("  Number of pruned trials: ", len(pruned_trials))
# print("  Number of complete trials: ", len(complete_trials))

# print("Best trial:")
# trial = study.best_trial

# print("  Value: ", trial.value)

# print("  Params: ")
# for key, value in trial.params.items():
#     print("    {}: {}".format(key, value))

# NNs ensemble

In [33]:
import os
import numpy as np
import torch
from lux.game import Game

path = '/kaggle_simulations/agent' if os.path.exists('/kaggle_simulations') else '.' # change to 'agent' for tests
# unit NNs
model_v2 = torch.jit.load(f'{path}/model_v2.pth')
model_v2.eval()
model_v4 = torch.jit.load(f'{path}/model_v4.pth')
model_v4.eval()
model_v5 = torch.jit.load(f'{path}/model_v5.pth')
model_v5.eval()
model_v11 = torch.jit.load(f'{path}/model_v11.pth')
model_v11.eval()
# city NNs
model_city_v2 = torch.jit.load(f'{path}/model_city_v2.pth')
model_city_v2.eval()
model_city_v4 = torch.jit.load(f'{path}/model_city_v4.pth')
model_city_v4.eval()
model_city_v5 = torch.jit.load(f'{path}/model_city_v5.pth')
model_city_v5.eval()
model_city_v11 = torch.jit.load(f'{path}/model_city_v11.pth')
model_city_v11.eval()

# Input for Neural Network for units
def make_input(obs, unit_id):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if unit_id == strs[3]:
                # Position and Cargo
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
            else:
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100
                )
        elif input_identifier == 'ct':
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Day/Night Flag
    b[19, :] = 1 if obs['step'] % 40 < 30 else 0
    # Map Size
    b[20, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1
    return b


# Input for Neural Network for cities
def make_city_input(obs, city_coord):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'ct':
            # CityTiles
            city_id = strs[2]
            x = int(strs[3]) 
            y = int(strs[4])
            cooldown = float(strs[5])
            if x == int(city_coord[0]) and y == int(city_coord[1]):
                b[:2, x + x_shift, y + y_shift] = (
                    1,
                    cities[city_id]
                )
            else:
                team = int(strs[1])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x + x_shift, y + y_shift] = (
                    1,
                    cooldown / 10,
                    cities[city_id]
                )
        elif input_identifier == 'u':
            team = int(strs[2])
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                (wood + coal + uranium) / 100
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Map Size
    b[19, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b

game_state = None
player = None


def get_game_state(observation):
    global game_state
    
    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation["player"]
    else:
        game_state._update(observation["updates"])
    return game_state


def in_city(pos):    
    try:
        city = game_state.map.get_cell_by_pos(pos).citytile
        return city is not None and city.team == game_state.id
    except:
        return False
    
# check if unit has enough time and space to build a city
def build_city_is_possible(unit, pos):    
    global game_state
    global player

    if game_state.turn % 40 < 30:
        return True
    x, y = pos.x, pos.y
    for i, j in ((x-1, y), (x+1, y), (x, y-1), (x, y+1)):
        try:
            city_id = game_state.map.get_cell(i, j).citytile.cityid
        except:
            continue
        if city_id in player.cities:
            city = player.cities[city_id]
            if city.fuel > (city.get_light_upkeep() + 18) * 10:
                return True
    return False


def call_func(obj, method, args=[]):
    return getattr(obj, method)(*args)


# translate unit policy to action
unit_actions = [('move', 'n'), ('move', 's'), ('move', 'w'), ('move', 'e'), ('build_city',)]
def get_unit_action(policy, unit, dest):
    for label in np.argsort(policy)[::-1]:
        act = unit_actions[label]
        pos = unit.pos.translate(act[-1], 1) or unit.pos
        if label == 4 and not build_city_is_possible(unit, pos):
            return unit.move('c'), unit.pos
        if pos not in dest or in_city(pos):
            return call_func(unit, *act), pos 
            
    return unit.move('c'), unit.pos

# translate city policy to action
city_actions = [('build_worker',), ('research', )]
def get_city_action(policy, city_tile, unit_count):
    global player
    
    for label in np.argsort(policy)[::-1]:
        act = city_actions[label]
        if label == 0 and unit_count < player.city_tile_count:
            unit_count += 1
            res = call_func(city_tile, *act)
        elif label == 1 and not player.researched_uranium():
            player.research_points += 1
            res = call_func(city_tile, *act)
        else:
            res = None
        return res, unit_count

# agent for making actions
def agent(observation, configuration):
    global game_state
    global player
    
    game_state = get_game_state(observation)    
    player = game_state.players[observation.player]
    actions = []        

    # Unit Actions
    dest = []
    for unit in player.units:
        if unit.can_act() and (game_state.turn % 40 < 30 or not in_city(unit.pos)):
            state = make_input(observation, unit.id)
            with torch.no_grad():
                p_2 = model_v2(torch.from_numpy(state).unsqueeze(0))
                p_4 = model_v4(torch.from_numpy(state).unsqueeze(0))
                #p_5 = model_v5(torch.from_numpy(state).unsqueeze(0))
                p_11 = model_v11(torch.from_numpy(state).unsqueeze(0))

            policy_2 = p_2.squeeze(0).numpy()
            policy_4 = p_4.squeeze(0).numpy()
            #policy_5 = p_5.squeeze(0).numpy()
            policy_11 = p_11.squeeze(0).numpy()

            policy = [sum(x) for x in zip(*[policy_2, policy_4, policy_11])]

            action, pos = get_unit_action(policy, unit, dest)
            actions.append(action)
            dest.append(pos)

    # City Actions
    unit_count = len(player.units)
    for city in player.cities.values():
        for city_tile in city.citytiles:
            if city_tile.can_act():
                # at first game stages try to produce maximum amount of agents and research point
                if game_state.turn < 60:
                    if unit_count < player.city_tile_count: 
                        actions.append(city_tile.build_worker())
                        unit_count += 1
                    elif not player.researched_uranium():
                        actions.append(city_tile.research())
                        player.research_points += 1
                # then follow NN strategy
                else:
                    state = make_city_input(observation, [city_tile.pos.x, city_tile.pos.y])
                    with torch.no_grad():
                        p_2 = model_city_v2(torch.from_numpy(state).unsqueeze(0))
                        p_4 = model_city_v4(torch.from_numpy(state).unsqueeze(0))
                        #p_5 = model_city_v5(torch.from_numpy(state).unsqueeze(0))
                        p_11 = model_city_v11(torch.from_numpy(state).unsqueeze(0))

                    policy_2 = p_2.squeeze(0).numpy()
                    policy_4 = p_4.squeeze(0).numpy()
                    #policy_5 = p_5.squeeze(0).numpy()
                    policy_11 = p_11.squeeze(0).numpy()

                    policy = [sum(x) for x in zip(*[policy_2, policy_4, policy_11])]

                    action, unit_count = get_city_action(policy, city_tile, unit_count)
                    if action:
                        actions.append(action)
    
    return actions

ValueError: The provided filename ./model_v2.pth does not exist

# Futher Ideas

Hyperparameters
- optimize hyperparameters (number of features and layers, regularization, dropout probability, learning rate) with Optuna

In [None]:
# dict for apropriate change of moving action during rotation
# 1 - rotation on 90, 2 - rotation on 180, 3 - rotation on 270
rot_dict = {0: {0:0, 1:1, 2:2, 3:3, 4:4}, # N -> E, S -> W, W -> N, E -> S
            1: {0:3, 1:2, 2:0, 3:1, 4:4}, # N -> E, S -> W, W -> N, E -> S
            2: {0:1, 1:0, 2:3, 3:2, 4:4}, # N -> S, S -> N, W -> E, E -> W
            3: {0:2, 1:3, 2:1, 3:0, 4:4}  # N -> W, S -> E, W -> S, E -> N
           }