For Kaggle

In [9]:
# !pip install kaggle-environments -U > /dev/null 2>&1
# !cp -r ../input/lux-ai-2021/* .

For agents validation

In [10]:
# timeout 1h lux-ai-2021 --tournament --rankSystem wins --storeReplay false --storeLogs false --maxConcurrentMatches 1 agent/main.py agent_simple/main.py submission_v4/main.py

In [11]:
import numpy as np
import json
from pathlib import Path
import os
import random
from tqdm.notebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.transforms import ToTensor, RandomHorizontalFlip, RandomVerticalFlip
import torch.optim as optim
from sklearn.model_selection import train_test_split
import optuna
from optuna.trial import TrialState

In [12]:
def seed_everything(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

seed_everything(42)

# Preprocessing

Actions:
- m - move: m, unit who moves, direction
- bcity - build city: bcity, unit who builds
- bw - build worker, x-coord of building city, y-coord of building city
- r - research , x-coord of researching city, y-coord of researching city
- t - transfer: transfer, from user_1, to user_2, resourse, quantity

Updates:
- rp - research point: player, number of rp
- r - resources: r, type of resource, x-coord, y-coord, quantity
- u - user: u, worker/cart, player, user id, x-coord, y-coord, cooldown, wood, coal, uranium
- c - city: c, player, city id, number of resources, amount of consuming light at night
- ct - city tile: ct, player, city id, x-coord, y-coord, cooldown
- ccd - level of road: ccd, x-coord, y-coord, level value

In [13]:
!cd episodes && rm *_info.json && cd ..

rm: cannot remove '*_info.json': No such file or directory


Get data from json files

In [15]:
from random import choice

def unit_label(action):
    strs = action.split(' ')
    unit_id = strs[1]
    if strs[0] == 'm':
        label = {'c': None, 'n': 0, 's': 1, 'w': 2, 'e': 3}[strs[2]]
    elif strs[0] == 'bcity':
        label = 4
    else:
        label = None
    return unit_id, label

def city_label(action):
    strs = action.split(' ')
    ctile_coord = (strs[1], strs[2])
    if strs[0] == 'bw':
        label = 0
    elif strs[0] == 'r':
        label = 1
    else:
        label = None
    return ctile_coord, label

def unit_cooldown(updates, index):
    strs = updates.split(' ')
    if strs[0] == 'u' and strs[2] == index:
        unit_id = strs[3]
        cooldown = strs[6]
        return unit_id, cooldown
    return None, None

def city_tile_coord(updates, index):
    strs = updates.split(' ')
    if strs[0] == 'ct' and strs[1] == index:
        ctile_coord = (strs[3], strs[4])
        cooldown = strs[5]
        return ctile_coord, cooldown
    return None, None

def research_points(updates, index):
    for u in updates:
        strs = u.split(' ')
        if strs[0] == 'rp' and strs[1] == index:
            return int(strs[2])
    return 0

def depleted_resources(obs):
    for u in obs['updates']:
        if u.split(' ')[0] == 'r':
            return False
    return True


def create_dataset_from_json(episode_dir, team_name='Toad Brigade'): 
    obses = {}
    unit_samples = []
    city_samples = []
    
    episodes = [path for path in Path(episode_dir).glob('*.json') if 'output' not in path.name]
    for filepath in tqdm(episodes): 
        with open(filepath) as f:
            json_load = json.load(f)
        # load episode from episode json file
        ep_id = json_load['info']['EpisodeId']
        index = np.argmax([r or 0 for r in json_load['rewards']])
        if json_load['info']['TeamNames'][index] != team_name:
            continue

        updates = None
        
        for i in range(len(json_load['steps'])-1):
            units = {}
            city_tiles = {}
            
            if json_load['steps'][i][index]['status'] == 'ACTIVE':
                actions = json_load['steps'][i+1][index]['action']
                obs = json_load['steps'][i][0]['observation']
                updates = obs['updates']
                rp = research_points(updates, str(index))
                # get updates from previous steps
                if i > 0:
                    obs['updates_lag_1'] = json_load['steps'][i-1][0]['observation']['updates']
                else:
                    obs['updates_lag_1'] = None
                if i > 1:
                    obs['updates_lag_2'] = json_load['steps'][i-2][0]['observation']['updates']
                else:
                    obs['updates_lag_2'] = None
                if i > 2:
                    obs['updates_lag_3'] = json_load['steps'][i-3][0]['observation']['updates']
                else:
                    obs['updates_lag_3'] = None
                if i > 3:
                    obs['updates_lag_4'] = json_load['steps'][i-4][0]['observation']['updates']
                else:
                    obs['updates_lag_4'] = None
                
                for u in updates:
                    # get coords for every friendly city tile
                    ctile_coord, cooldown = city_tile_coord(u, str(index))
                    if ctile_coord and cooldown:
                        city_tiles[ctile_coord] = [0, float(cooldown)]
                        continue
                    # get cooldown for every friendly unit 
                    unit_id, cooldown = unit_cooldown(u, str(index))
                    if unit_id and cooldown:
                        units[unit_id] = float(cooldown)                     
                
                # stop research if resources are depleted
                if depleted_resources(obs):
                    break
                
                obs['player'] = index
                obs = dict([
                    (k,v) for k,v in obs.items() 
                    if k in ['step', 'updates', 'updates_lag_1', 'updates_lag_2', 
                             'updates_lag_3', 'updates_lag_4', 'player', 'width', 'height']
                ])
                obs_id = f'{ep_id}_{i}'
                obses[obs_id] = obs
                
                for action in actions:
                    unit_id, label = unit_label(action)
                    # if unit acts - add this action to train
                    if label is not None:
                        unit_samples.append((obs_id, unit_id, label))
                        continue
                    ctile_coord, label = city_label(action)
                    # if city tile acts - add this action to train
                    if label is not None:
                        city_samples.append((obs_id, ctile_coord, label))
                        city_tiles[ctile_coord][0] = label
                        
#                 # it there is a possibility for city tiles to act, (research points number < 200 and number of
#                 # city tiles more than number of units find those of them who can act (cooldown = 0) but doesn't
#                 if rp < 200 and len(city_tiles) > len(units):
#                     city_tiles_no_action = [(k, v[0]) for k, v in city_tiles.items() if v[0] == 0 and v[1] == 0]
#                     # if there are cities with no action - randomly select one item from there 
#                     # and add it to city samples
#                     if city_tiles_no_action:
#                         ct = choice(city_tiles_no_action)
#                         city_samples.append((obs_id, ct[0], ct[1]))
                    
                    
    return obses, unit_samples, city_samples

In [16]:
episode_dir = 'episodes'
obses, samples, city_samples = create_dataset_from_json(episode_dir)

print('observations:', len(obses), 'worker samples:', len(samples), 'city samples:', len(city_samples))

  0%|          | 0/520 [00:00<?, ?it/s]

observations: 138211 worker samples: 647453 city samples: 107505


In [17]:
labels = [sample[-1] for sample in samples]
actions = ['north', 'south', 'west', 'east', 'bcity']
for value, count in zip(*np.unique(labels, return_counts=True)):
    print(f'{actions[value]}: {count}')

north: 144948
south: 141492
west: 153270
east: 154811
bcity: 52932


In [18]:
labels_city = [sample[-1] for sample in city_samples]
actions_city = ['build_worker', 'research']
for value, count in zip(*np.unique(labels_city, return_counts=True)):
    print(f'{actions_city[value]}: {count}')

build_worker: 32407
research: 75098


In [19]:
# obses['26762301_160']

In [20]:
# samples

In [21]:
# city_samples[200000]

# Training

b - training tensor of float32. b dimensions is 20x32x32

- b[0] - position of current unit
- b[1] - cargo sum/100 of current unit
- b[2] - previous position of current unit
- b[3, 4, 5, 6] - position, cooldown/6, cargo sum/100 and Manhattan distance to units from friendly team
- b[7, 8, 9, 10] - position, cooldown/6, cargo sum/100 and Manhattan distance to units from adversarial team
- b[11, 12, 13] - position,  min(city fuel/city energy consumption, 10)/10 and Manhattan distance to cities from friendly team
- b[14, 15, 16] - position,  min(city fuel/city energy consumption, 10)/10 and Manhattan distance to cities from adversarial team
- b[17] - amount of wood / 800 on current position
- b[18] - amount of coal / 800 on current position
- b[19] - amount of uranium / 800 on current position
- b[20] - 1 if unit can mine this resource, 0 otherwise
- b[21] - Manhattan distance to this resource
- b[22] - research points / 200 of unit's team
- b[23] - research points / 200 of another team
- b[24] - time of the day (from 0 to 1, step 0.05)
- b[25] - step of the game (from 0 to 1, step 1/360)
- b[26] - day/night
- b[27] - map size + tiles where unit can't go (other units and adversarial cities)

In [22]:
def manhattan_distance(x1, y1, x2, y2):
    return (abs(x2-x1) + abs(y2-y1))

# make list of users and their current and previous coords and cooldowns 
# to write them in NN training subset
def find_units_from_previous_obs(obs, x_shift, y_shift):
    # at first fill the unit dict with units from previous observation
    updates = obs['updates']
    updates_lag_1 = obs['updates_lag_1']
    updates_lag_2 = obs['updates_lag_2']
    updates_lag_3 = obs['updates_lag_3']
    updates_lag_4 = obs['updates_lag_4']
    prev_updates = [updates, updates_lag_1, updates_lag_2, updates_lag_3, updates_lag_4]
    prev_units = list()
    for prev_update in prev_updates:
        units = dict()
        if prev_update:
            for update in prev_update:
                strs = update.split(' ')
                input_identifier = strs[0]
                # if we found observation for user
                if input_identifier == 'u':
                    unit_id = strs[3]
                    x = int(strs[4]) + x_shift
                    y = int(strs[5]) + y_shift
                    cooldown = float(strs[6])
                    units[unit_id] = [x, y, cooldown]
            prev_units.append(units)
        else:
            break
    return prev_units

def find_prev_coords(prev_units, unit_id):
    # find previous coordinates of unit by ananlyzing its cooldown
    x, y = None, None
    for i in range(len(prev_units)-1):
        if unit_id in prev_units[i] and unit_id in prev_units[i+1]:
            cooldown = prev_units[i][unit_id][0]
            prev_x = prev_units[i+1][unit_id][0]
            prev_y = prev_units[i+1][unit_id][1]
            prev_cooldown = prev_units[i+1][unit_id][2]
            if cooldown > 0 and prev_cooldown == 0:
                return prev_x, prev_y
        else:
            break
    return x, y
            
# Input for Neural Network for workers
def make_input(obs, unit_id):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2

    cities = {}
    
    b = np.zeros((25, 32, 32), dtype=np.float32)
    
    prev_units = find_units_from_previous_obs(obs, x_shift, y_shift)
    x_u, y_u = prev_units[0][unit_id][0], prev_units[0][unit_id][1]
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        my_rp = 0
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if strs[3] == unit_id: # 0:2
                # Position, Cargo and Previous Position
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
#                 prev_x, prev_y = find_prev_coords(prev_units, unit_id)
#                 if not prev_x and not prev_y:
#                     prev_x, prev_y = x, y
#                 b[2, prev_x, prev_y] = 1
            else:                  # 2:9
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 2 + (team - obs['player']) % 2 * 4
                m_dist = manhattan_distance(x_u, y_u, x, y)
                b[idx:idx + 4, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100,
                    m_dist/((width-1) + (height-1))
                )
        elif input_identifier == 'ct':  # 10:15
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 10 + (team - obs['player']) % 2 * 3
            m_dist = manhattan_distance(x_u, y_u, x, y)
            b[idx:idx + 3, x, y] = (
                1,
                cities[city_id],
                m_dist/((width-1) + (height-1))
            )
        elif input_identifier == 'r':  # 16:20
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            access_level = {'wood': 0, 'coal': 50, 'uranium': 200}[r_type]
            access = 0 if my_rp < access_level else 1
            b[{'wood': 16, 'coal': 17, 'uranium': 18}[r_type], x, y] = amt / 800
            b[19, x, y] = access
            b[20, x, y] = manhattan_distance(x_u, y_u, x, y)/((width-1) + (height-1))
        elif input_identifier == 'rp':  # 21:22
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            my_rp = rp if team == obs['player'] else my_rp
            b[21 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    # Day/Night Cycle
    b[23, :] = obs['step'] % 40 / 40
    # Turns
    b[24, :] = obs['step'] / 360
#     # Day / Night
#     b[26, :] = 1 if obs['step'] % 40 < 30 else 0
    # Map Size
    b[25, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1
        
    return b

In [23]:
# obses['26762301_160']

In [24]:
# import sys

# np.set_printoptions(threshold=sys.maxsize)

# make_input(obses['26762301_160'], 'u_36')[3]

Data for the cities:
- b[0] - position of current city
- b[1] - min(city fuel/city energy consumption, 10)/10 of current city
- b[2, 3, 4] - position, cooldown/10, and min(city fuel/city energy consumption, 10)/10 for cities from the same team
- b[5, 6, 7] - position, cooldown/10, and min(city fuel/city energy consumption, 10)/10 for cities from another team
- b[8, 9] - position and cargo sum/100 for units from the same team 
- b[10, 11] - position and cargo sum/100 for units from from another team
- b[12] - amount of wood / 800
- b[13] - amount of coal / 800
- b[14] - amount of uranium / 800
- b[15] - research points / 200 of unit's team
- b[16] - research piints / 200 of another team
- b[17] - time of the day (from 0 to 1, step 0.05)
- b[18] - step of the game (from 0 to 1, step 1/360)
- b[19] - map size

In [34]:
# Input for Neural Network for cities
def make_city_input(obs, city_coord):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'ct':
            # CityTiles
            city_id = strs[2]
            x = int(strs[3]) 
            y = int(strs[4])
            cooldown = float(strs[5])
            if x == int(city_coord[0]) and y == int(city_coord[1]):
                b[:2, x + x_shift, y + y_shift] = (
                    1,
                    cities[city_id]
                )
            else:
                team = int(strs[1])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x + x_shift, y + y_shift] = (
                    1,
                    cooldown / 10,
                    cities[city_id]
                )
        elif input_identifier == 'u':
            team = int(strs[2])
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                (wood + coal + uranium) / 100
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
#     # Day/Night Flag
#     b[19, :] = 1 if obs['step'] % 40 < 30 else 0
    # Map Size
    b[19, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b

### Set modules for NN training

In [26]:
# mask for apropriate change of moving action during rotation
# 0 - no transform, 1 - rotation on 90, 2 - rotation on 180, 
# 3 - rotation on 270, 4 - horizontal flip, 5 - vertical flip
mask = (
        (0, 1, 2, 3, 4), 
        (3, 2, 0, 1, 4), # N -> E, S -> W, W -> N, E -> S
        (1, 0, 3, 2, 4), # N -> S, S -> N, W -> E, E -> W
        (2, 3, 1, 0, 4), # N -> W, S -> E, W -> S, E -> N
        (0, 1, 3, 2, 4), # N -> N, S -> S, W -> E, E -> W
        (1, 0, 2, 3, 4)  # N -> S, S -> N, W -> W, E -> E
        )

def do_nothing(x):
    return x

class RandomChoice(torch.nn.Module):
    def __init__(self, transforms):
       super().__init__()
       self.transforms = transforms

    def __call__(self, x):
        idx = random.choice([i for i in range(len(self.transforms))])
        t = self.transforms[idx]
        x = torch.from_numpy(x)
        return t(x), idx

transform=RandomChoice([lambda x: x,
                        lambda x: torch.rot90(x, 1, [2, 1]),
                        lambda x: torch.rot90(x, 2, [2, 1]),
                        lambda x: torch.rot90(x, 1, [1, 2]),
                        TF.hflip, 
                        TF.vflip])
    
class LuxDataset(Dataset):
    def __init__(self, obses, samples, transform=None, mask=mask, city=False):
        self.obses = obses
        self.samples = samples
        self.data_len = len(self.samples)
        self.len = self.data_len
        self.transform = transform
        self.mask = mask
        self.city = city
            
    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        data_idx = idx % self.data_len
        obs_id, unit_id, action = self.samples[data_idx]
        obs = self.obses[obs_id]
        if not self.city:
            state = make_input(obs, unit_id)
        else:
            state = make_city_input(obs, unit_id)
        if self.transform:
            t = self.transform(state)
            state, action = t[0], self.mask[t[1]][action]
        return state, action

In [27]:
# Neural Network for Lux AI
class BasicConv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn, p=0.1):
        super().__init__()
        self.conv = nn.Conv2d(
            input_dim, output_dim, 
            kernel_size=kernel_size, 
            padding=(kernel_size[0] // 2, kernel_size[1] // 2)
        )
        self.bn = nn.BatchNorm2d(output_dim) if bn else None
        self.dropout = nn.Dropout(p) 

    def forward(self, x):
        h = self.conv(x)
        h = self.bn(h) if self.bn is not None else h
        h = self.dropout(h) 
        return h


class LuxNet(nn.Module):
    def __init__(self):
        super().__init__()
        layers, filters = 12, 44
        self.conv0 = BasicConv2d(26, filters, (3, 3), True)
        self.blocks = nn.ModuleList([BasicConv2d(filters, filters, (3, 3), True) for _ in range(layers)])
        self.head_p = nn.Linear(filters, 5, bias=False)

    def forward(self, x):
        h = F.relu_(self.conv0(x))
        for block in self.blocks:
            h = F.relu_(h + block(h))
        h_head = (h * x[:,:1]).view(h.size(0), h.size(1), -1).sum(-1)
        p = self.head_p(h_head)
        return p
    
    
class LuxCityNet(nn.Module):
    def __init__(self):
        super().__init__()
        layers, filters = 12, 32
        self.conv0 = BasicConv2d(20, filters, (3, 3), True)
        self.blocks = nn.ModuleList([BasicConv2d(filters, filters, (3, 3), True) for _ in range(layers)])
        self.head_p = nn.Linear(filters, 2, bias=False)

    def forward(self, x):
        h = F.relu_(self.conv0(x))
        for block in self.blocks:
            h = F.relu_(h + block(h))
        h_head = (h * x[:,:1]).view(h.size(0), h.size(1), -1).sum(-1)
        p = self.head_p(h_head)
        return p

### Function for NN training

In [28]:
import matplotlib.pyplot as plt
# from torch.utils.tensorboard import SummaryWriter

def train_model(model, dataloaders_dict, criterion, optimizer, num_epochs, city=False):
#     tb = SummaryWriter()
    
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        model.cuda()
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            epoch_loss = 0.0
            epoch_acc = 0
            
            dataloader = dataloaders_dict[phase]
            for item in tqdm(dataloader, leave=False):
                
                states = item[0].cuda().float()
                actions = item[1].cuda().long()

                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    policy = model(states)
                    loss = criterion(policy, actions)
                    _, preds = torch.max(policy, 1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    epoch_loss += loss.item() * len(policy)
                    epoch_acc += torch.sum(preds == actions.data)

            data_size = len(dataloader.dataset)
            epoch_loss = epoch_loss / data_size
            epoch_acc = epoch_acc.double() / data_size

#             if phase == 'train':
#                 tb.add_scalar("Train Loss", epoch_loss, epoch)
#                 tb.add_scalar("Train Accuracy", epoch_acc, epoch)
#             else:
#                 tb.add_scalar("Val Loss", epoch_loss, epoch)
#                 tb.add_scalar("Val Accuracy", epoch_acc, epoch)
            
            print(f'Epoch {epoch + 1}/{num_epochs} | {phase:^5} | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f}')
        
        if epoch_acc > best_acc:
            if city:
                traced = torch.jit.trace(model.cpu(), torch.rand(1, 20, 32, 32))
                traced.save('agent/model_city.pth')
            else:
                traced = torch.jit.trace(model.cpu(), torch.rand(1, 26, 32, 32))
                traced.save('agent/model.pth')
            best_acc = epoch_acc
            
#     tb.close()

In [29]:
# model for unit actions
model = LuxNet()
train, val = train_test_split(samples, test_size=0.1, random_state=42, stratify=labels)
batch_size = 128

train_loader = DataLoader(
    LuxDataset(obses, train), 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=2
)
val_loader = DataLoader(
    LuxDataset(obses, val), 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=2
)
dataloaders_dict = {"train": train_loader, "val": val_loader}

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

# model for city actions
model_city = LuxCityNet()
train_city, val_city = train_test_split(city_samples, test_size=0.1, random_state=42, stratify=labels_city)
batch_size_city = 128

train_city_loader = DataLoader(
    LuxDataset(obses, train_city, city=True), 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=2
)
    
val_city_loader = DataLoader(
    LuxDataset(obses, val_city, city=True), 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=2
)
dataloaders_city_dict = {"train": train_city_loader, "val": val_city_loader}

criterion_city = nn.CrossEntropyLoss()
optimizer_city = torch.optim.AdamW(model_city.parameters(), lr=1e-3)

In [30]:
%%time

num_epochs = 25

train_model(model, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs)

  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 1/25 | train | Loss: 0.8117 | Acc: 0.6705


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 1/25 |  val  | Loss: 0.6632 | Acc: 0.7349


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 2/25 | train | Loss: 0.6660 | Acc: 0.7321


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 2/25 |  val  | Loss: 0.6139 | Acc: 0.7539


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 3/25 | train | Loss: 0.6246 | Acc: 0.7485


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 3/25 |  val  | Loss: 0.5905 | Acc: 0.7618


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 4/25 | train | Loss: 0.5990 | Acc: 0.7597


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 4/25 |  val  | Loss: 0.5605 | Acc: 0.7758


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 5/25 | train | Loss: 0.5809 | Acc: 0.7669


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 5/25 |  val  | Loss: 0.5468 | Acc: 0.7829


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 6/25 | train | Loss: 0.5674 | Acc: 0.7725


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 6/25 |  val  | Loss: 0.5448 | Acc: 0.7819


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 7/25 | train | Loss: 0.5571 | Acc: 0.7767


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 7/25 |  val  | Loss: 0.5312 | Acc: 0.7876


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 8/25 | train | Loss: 0.5477 | Acc: 0.7811


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 8/25 |  val  | Loss: 0.5247 | Acc: 0.7906


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 9/25 | train | Loss: 0.5405 | Acc: 0.7840


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 9/25 |  val  | Loss: 0.5279 | Acc: 0.7885


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 10/25 | train | Loss: 0.5347 | Acc: 0.7861


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 10/25 |  val  | Loss: 0.5155 | Acc: 0.7950


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 11/25 | train | Loss: 0.5286 | Acc: 0.7883


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 11/25 |  val  | Loss: 0.5101 | Acc: 0.7956


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 12/25 | train | Loss: 0.5230 | Acc: 0.7913


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 12/25 |  val  | Loss: 0.5088 | Acc: 0.7970


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 13/25 | train | Loss: 0.5196 | Acc: 0.7916


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 13/25 |  val  | Loss: 0.5076 | Acc: 0.7969


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 14/25 | train | Loss: 0.5155 | Acc: 0.7939


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 14/25 |  val  | Loss: 0.4984 | Acc: 0.8008


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 15/25 | train | Loss: 0.5118 | Acc: 0.7957


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 15/25 |  val  | Loss: 0.5070 | Acc: 0.7962


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 16/25 | train | Loss: 0.5093 | Acc: 0.7964


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 16/25 |  val  | Loss: 0.5103 | Acc: 0.7948


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 17/25 | train | Loss: 0.5063 | Acc: 0.7976


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 17/25 |  val  | Loss: 0.4996 | Acc: 0.8001


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 18/25 | train | Loss: 0.5042 | Acc: 0.7982


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 18/25 |  val  | Loss: 0.4964 | Acc: 0.8016


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 19/25 | train | Loss: 0.5006 | Acc: 0.7994


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 19/25 |  val  | Loss: 0.4968 | Acc: 0.8007


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 20/25 | train | Loss: 0.4991 | Acc: 0.8007


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 20/25 |  val  | Loss: 0.4936 | Acc: 0.8023


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 21/25 | train | Loss: 0.4967 | Acc: 0.8014


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 21/25 |  val  | Loss: 0.4959 | Acc: 0.8007


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 22/25 | train | Loss: 0.4952 | Acc: 0.8024


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 22/25 |  val  | Loss: 0.4967 | Acc: 0.8027


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 23/25 | train | Loss: 0.4927 | Acc: 0.8034


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 23/25 |  val  | Loss: 0.4968 | Acc: 0.8005


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 24/25 | train | Loss: 0.4911 | Acc: 0.8036


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 24/25 |  val  | Loss: 0.4917 | Acc: 0.8026


  0%|          | 0/4553 [00:00<?, ?it/s]

Epoch 25/25 | train | Loss: 0.4906 | Acc: 0.8038


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch 25/25 |  val  | Loss: 0.4892 | Acc: 0.8035
CPU times: user 3h 12min 49s, sys: 4min 50s, total: 3h 17min 39s
Wall time: 3h 20min 51s


In [38]:
# Epoch 1/25 | train | Loss: 0.8117 | Acc: 0.6705
# Epoch 1/25 |  val  | Loss: 0.6632 | Acc: 0.7349
# Epoch 2/25 | train | Loss: 0.6660 | Acc: 0.7321
# Epoch 2/25 |  val  | Loss: 0.6139 | Acc: 0.7539
# Epoch 3/25 | train | Loss: 0.6246 | Acc: 0.7485
# Epoch 3/25 |  val  | Loss: 0.5905 | Acc: 0.7618
# Epoch 4/25 | train | Loss: 0.5990 | Acc: 0.7597
# Epoch 4/25 |  val  | Loss: 0.5605 | Acc: 0.7758
# Epoch 5/25 | train | Loss: 0.5809 | Acc: 0.7669
# Epoch 5/25 |  val  | Loss: 0.5468 | Acc: 0.7829
# Epoch 6/25 | train | Loss: 0.5674 | Acc: 0.7725
# Epoch 6/25 |  val  | Loss: 0.5448 | Acc: 0.7819
# Epoch 7/25 | train | Loss: 0.5571 | Acc: 0.7767
# Epoch 7/25 |  val  | Loss: 0.5312 | Acc: 0.7876
# Epoch 8/25 | train | Loss: 0.5477 | Acc: 0.7811
# Epoch 8/25 |  val  | Loss: 0.5247 | Acc: 0.7906
# Epoch 9/25 | train | Loss: 0.5405 | Acc: 0.7840
# Epoch 9/25 |  val  | Loss: 0.5279 | Acc: 0.7885
# Epoch 10/25 | train | Loss: 0.5347 | Acc: 0.7861
# Epoch 10/25 |  val  | Loss: 0.5155 | Acc: 0.7950
# Epoch 11/25 | train | Loss: 0.5286 | Acc: 0.7883
# Epoch 11/25 |  val  | Loss: 0.5101 | Acc: 0.7956
# Epoch 12/25 | train | Loss: 0.5230 | Acc: 0.7913
# Epoch 12/25 |  val  | Loss: 0.5088 | Acc: 0.7970
# Epoch 13/25 | train | Loss: 0.5196 | Acc: 0.7916
# Epoch 13/25 |  val  | Loss: 0.5076 | Acc: 0.7969
# Epoch 14/25 | train | Loss: 0.5155 | Acc: 0.7939
# Epoch 14/25 |  val  | Loss: 0.4984 | Acc: 0.8008
# Epoch 15/25 | train | Loss: 0.5118 | Acc: 0.7957
# Epoch 15/25 |  val  | Loss: 0.5070 | Acc: 0.7962
# Epoch 16/25 | train | Loss: 0.5093 | Acc: 0.7964
# Epoch 16/25 |  val  | Loss: 0.5103 | Acc: 0.7948
# Epoch 17/25 | train | Loss: 0.5063 | Acc: 0.7976
# Epoch 17/25 |  val  | Loss: 0.4996 | Acc: 0.8001
# Epoch 18/25 | train | Loss: 0.5042 | Acc: 0.7982
# Epoch 18/25 |  val  | Loss: 0.4964 | Acc: 0.8016
# Epoch 19/25 | train | Loss: 0.5006 | Acc: 0.7994
# Epoch 19/25 |  val  | Loss: 0.4968 | Acc: 0.8007
# Epoch 20/25 | train | Loss: 0.4991 | Acc: 0.8007
# Epoch 20/25 |  val  | Loss: 0.4936 | Acc: 0.8023
# Epoch 21/25 | train | Loss: 0.4967 | Acc: 0.8014
# Epoch 21/25 |  val  | Loss: 0.4959 | Acc: 0.8007
# Epoch 22/25 | train | Loss: 0.4952 | Acc: 0.8024
# Epoch 22/25 |  val  | Loss: 0.4967 | Acc: 0.8027
# Epoch 23/25 | train | Loss: 0.4927 | Acc: 0.8034
# Epoch 23/25 |  val  | Loss: 0.4968 | Acc: 0.8005
# Epoch 24/25 | train | Loss: 0.4911 | Acc: 0.8036
# Epoch 24/25 |  val  | Loss: 0.4917 | Acc: 0.8026
# Epoch 25/25 | train | Loss: 0.4906 | Acc: 0.8038
# Epoch 25/25 |  val  | Loss: 0.4892 | Acc: 0.8035

In [35]:
%%time

num_epochs = 15

train_model(model_city, dataloaders_city_dict, criterion_city, optimizer_city, num_epochs=num_epochs, city=True)

  0%|          | 0/756 [00:00<?, ?it/s]

Epoch 1/15 | train | Loss: 0.3273 | Acc: 0.8619


  0%|          | 0/84 [00:00<?, ?it/s]

Epoch 1/15 |  val  | Loss: 0.2643 | Acc: 0.8908


  0%|          | 0/756 [00:00<?, ?it/s]

Epoch 2/15 | train | Loss: 0.2598 | Acc: 0.8919


  0%|          | 0/84 [00:00<?, ?it/s]

Epoch 2/15 |  val  | Loss: 0.2601 | Acc: 0.8896


  0%|          | 0/756 [00:00<?, ?it/s]

Epoch 3/15 | train | Loss: 0.2404 | Acc: 0.9007


  0%|          | 0/84 [00:00<?, ?it/s]

Epoch 3/15 |  val  | Loss: 0.2499 | Acc: 0.8930


  0%|          | 0/756 [00:00<?, ?it/s]

Epoch 4/15 | train | Loss: 0.2304 | Acc: 0.9032


  0%|          | 0/84 [00:00<?, ?it/s]

Epoch 4/15 |  val  | Loss: 0.2400 | Acc: 0.9000


  0%|          | 0/756 [00:00<?, ?it/s]

Epoch 5/15 | train | Loss: 0.2220 | Acc: 0.9076


  0%|          | 0/84 [00:00<?, ?it/s]

Epoch 5/15 |  val  | Loss: 0.2255 | Acc: 0.9073


  0%|          | 0/756 [00:00<?, ?it/s]

Epoch 6/15 | train | Loss: 0.2139 | Acc: 0.9108


  0%|          | 0/84 [00:00<?, ?it/s]

Epoch 6/15 |  val  | Loss: 0.2153 | Acc: 0.9095


  0%|          | 0/756 [00:00<?, ?it/s]

Epoch 7/15 | train | Loss: 0.2103 | Acc: 0.9123


  0%|          | 0/84 [00:00<?, ?it/s]

Epoch 7/15 |  val  | Loss: 0.2235 | Acc: 0.9059


  0%|          | 0/756 [00:00<?, ?it/s]

Epoch 8/15 | train | Loss: 0.2052 | Acc: 0.9147


  0%|          | 0/84 [00:00<?, ?it/s]

Epoch 8/15 |  val  | Loss: 0.1997 | Acc: 0.9170


  0%|          | 0/756 [00:00<?, ?it/s]

Epoch 9/15 | train | Loss: 0.2014 | Acc: 0.9167


  0%|          | 0/84 [00:00<?, ?it/s]

Epoch 9/15 |  val  | Loss: 0.2126 | Acc: 0.9138


  0%|          | 0/756 [00:00<?, ?it/s]

Epoch 10/15 | train | Loss: 0.1949 | Acc: 0.9183


  0%|          | 0/84 [00:00<?, ?it/s]

Epoch 10/15 |  val  | Loss: 0.1996 | Acc: 0.9168


  0%|          | 0/756 [00:00<?, ?it/s]

Epoch 11/15 | train | Loss: 0.1915 | Acc: 0.9210


  0%|          | 0/84 [00:00<?, ?it/s]

Epoch 11/15 |  val  | Loss: 0.2029 | Acc: 0.9171


  0%|          | 0/756 [00:00<?, ?it/s]

Epoch 12/15 | train | Loss: 0.1864 | Acc: 0.9218


  0%|          | 0/84 [00:00<?, ?it/s]

Epoch 12/15 |  val  | Loss: 0.1981 | Acc: 0.9213


  0%|          | 0/756 [00:00<?, ?it/s]

Epoch 13/15 | train | Loss: 0.1829 | Acc: 0.9247


  0%|          | 0/84 [00:00<?, ?it/s]

Epoch 13/15 |  val  | Loss: 0.2081 | Acc: 0.9116


  0%|          | 0/756 [00:00<?, ?it/s]

Epoch 14/15 | train | Loss: 0.1778 | Acc: 0.9268


  0%|          | 0/84 [00:00<?, ?it/s]

Epoch 14/15 |  val  | Loss: 0.2014 | Acc: 0.9184


  0%|          | 0/756 [00:00<?, ?it/s]

Epoch 15/15 | train | Loss: 0.1763 | Acc: 0.9271


  0%|          | 0/84 [00:00<?, ?it/s]

Epoch 15/15 |  val  | Loss: 0.1919 | Acc: 0.9233
CPU times: user 11min 37s, sys: 25.3 s, total: 12min 2s
Wall time: 12min 6s


In [None]:
# !tensorboard --logdir runs

# Submission

In [23]:
%%writefile agent/agent.py
import os
import json
import numpy as np
import torch
from lux.game import Game

path = '/kaggle_simulations/agent' if os.path.exists('/kaggle_simulations') else '.' # change to 'agent' for tests
model = torch.jit.load(f'{path}/model.pth')
model.eval()
model_city = torch.jit.load(f'{path}/model_city.pth')
model_city.eval()

def manhattan_distance(x1, y1, x2, y2):
    return (abs(x2-x1) + abs(y2-y1))

# make list of users and their current and previous coords and cooldowns 
# to write them in NN training subset
def find_units_from_previous_obs(obs, x_shift, y_shift):
    # at first fill the unit dict with units from previous observation
    updates = obs['updates']
    updates_lag_1 = obs['updates_lag_1']
    updates_lag_2 = obs['updates_lag_2']
    updates_lag_3 = obs['updates_lag_3']
    updates_lag_4 = obs['updates_lag_4']
    prev_updates = [updates, updates_lag_1, updates_lag_2, updates_lag_3, updates_lag_4]
    prev_units = list()
    for prev_update in prev_updates:
        units = dict()
        if prev_update:
            for update in prev_update:
                strs = update.split(' ')
                input_identifier = strs[0]
                # if we found observation for user
                if input_identifier == 'u':
                    unit_id = strs[3]
                    x = int(strs[4]) + x_shift
                    y = int(strs[5]) + y_shift
                    cooldown = float(strs[6])
                    units[unit_id] = [x, y, cooldown]
            prev_units.append(units)
        else:
            break
    return prev_units

def find_prev_coords(prev_units, unit_id):
    # find previous coordinates of unit by ananlyzing its cooldown
    x, y = None, None
    for i in range(len(prev_units)-1):
        if unit_id in prev_units[i] and unit_id in prev_units[i+1]:
            cooldown = prev_units[i][unit_id][0]
            prev_x = prev_units[i+1][unit_id][0]
            prev_y = prev_units[i+1][unit_id][1]
            prev_cooldown = prev_units[i+1][unit_id][2]
            if cooldown > 0 and prev_cooldown == 0:
                return prev_x, prev_y
        else:
            break
    return x, y
            
# Input for Neural Network for workers
def make_input(obs, unit_id):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2

    cities = {}
    
    b = np.zeros((28, 32, 32), dtype=np.float32)
    
    prev_units = find_units_from_previous_obs(obs, x_shift, y_shift)
    x_u, y_u = prev_units[0][unit_id][0], prev_units[0][unit_id][1]
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        my_rp = 0
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if strs[3] == unit_id: # 0:2
                # Position, Cargo and Previous Position
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
                prev_x, prev_y = find_prev_coords(prev_units, unit_id)
                if not prev_x and not prev_y:
                    prev_x, prev_y = x, y
                b[2, prev_x, prev_y] = 1
            else:                  # 3:10
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 3 + (team - obs['player']) % 2 * 4
                m_dist = manhattan_distance(x_u, y_u, x, y)
                b[idx:idx + 4, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100,
                    m_dist/((width-1) + (height-1))
                )
        elif input_identifier == 'ct':  # 11:16
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 11 + (team - obs['player']) % 2 * 3
            m_dist = manhattan_distance(x_u, y_u, x, y)
            b[idx:idx + 3, x, y] = (
                1,
                cities[city_id],
                m_dist/((width-1) + (height-1))
            )
        elif input_identifier == 'r':  # 17:20
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            access_level = {'wood': 0, 'coal': 50, 'uranium': 200}[r_type]
            access = 0 if my_rp < access_level else 1
            b[{'wood': 17, 'coal': 18, 'uranium': 19}[r_type], x, y] = amt / 800
            b[20, x, y] = access
            b[21, x, y] = manhattan_distance(x_u, y_u, x, y)/((width-1) + (height-1))
        elif input_identifier == 'rp':  # 22:23
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            my_rp = rp if team == obs['player'] else my_rp
            b[22 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    # Day/Night Cycle
    b[24, :] = obs['step'] % 40 / 40
    # Turns
    b[25, :] = obs['step'] / 360
    # Day/Night Flag
    b[26, :] = 1 if obs['step'] % 40 < 30 else 0
    # Map Size
    b[27, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1
        
    return b

# Input for Neural Network for cities
def make_city_input(obs, city_coord):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((21, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'ct':
            # CityTiles
            city_id = strs[2]
            x = int(strs[3]) 
            y = int(strs[4])
            cooldown = float(strs[5])
            if x == int(city_coord[0]) and y == int(city_coord[1]):
                b[:2, x + x_shift, y + y_shift] = (
                    1,
                    cities[city_id]
                )
            else:
                team = int(strs[1])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x + x_shift, y + y_shift] = (
                    1,
                    cooldown / 10,
                    cities[city_id]
                )
        elif input_identifier == 'u':
            team = int(strs[2])
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                (wood + coal + uranium) / 100
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Day/Night Flag
    b[19, :] = 1 if obs['step'] % 40 < 30 else 0
    # Map Size
    b[20, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b

game_state = None
player = None


def get_game_state(observation):
    global game_state
    
    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation["player"]
    else:
        game_state._update(observation["updates"])
    return game_state


# check if unit is in city or not
def in_city(pos):    
    try:
        city = game_state.map.get_cell_by_pos(pos).citytile
        return city is not None and city.team == game_state.id
    except:
        return False
    
# check if unit has enough time and space to build a city
def build_city_is_possible(unit, pos):    
    global game_state
    global player

    if game_state.turn % 40 < 30:
        return True
    x, y = pos.x, pos.y
    for i, j in ((x-1, y), (x+1, y), (x, y-1), (x, y+1)):
        try:
            city_id = game_state.map.get_cell(i, j).citytile.cityid
        except:
            continue
        if city_id in player.cities:
            city = player.cities[city_id]
            if city.fuel > (city.get_light_upkeep() + 18) * 10:
                return True
    return False


def call_func(obj, method, args=[]):
    return getattr(obj, method)(*args)


# translate unit policy to action
unit_actions = [('move', 'n'), ('move', 's'), ('move', 'w'), ('move', 'e'), ('build_city',)]
def get_unit_action(policy, unit, dest):
    for label in np.argsort(policy)[::-1]:
        act = unit_actions[label]
        pos = unit.pos.translate(act[-1], 1) or unit.pos
        if label == 4 and not build_city_is_possible(unit, pos):
            return unit.move('c'), unit.pos
        if pos not in dest or in_city(pos):
            return call_func(unit, *act), pos      
    return unit.move('c'), unit.pos

# translate city policy to action
city_actions = [('build_worker',), ('research', )]
def get_city_action(policy, city_tile, unit_count):
    global player
    
    for label in np.argsort(policy)[::-1]:
        act = city_actions[label]
        # build unit only if their number less than number of cities and less than 100 (to prevent too high lags)
        if label == 0 and unit_count < player.city_tile_count and unit_count < 100:
            unit_count += 1
            res = call_func(city_tile, *act)
        elif label == 1 and not player.researched_uranium():
            player.research_points += 1
            res = call_func(city_tile, *act)
        else:
            res = None
        return res, unit_count

# agent for making actions
def agent(observation, configuration):
    global game_state
    global player
    
    game_state = get_game_state(observation)    
    player = game_state.players[observation.player]
    actions = [] 
    prev_obs = dict()
    
    with open(f'{path}/tmp.json') as json_file:
        try:
            prev_obs = json.load(json_file)
        except json.decoder.JSONDecodeError:
            prev_obs['updates_lag_1'] = None
            prev_obs['updates_lag_2'] = None
            prev_obs['updates_lag_3'] = None
            prev_obs['updates_lag_4'] = None
            
    observation['updates_lag_4'] = prev_obs['updates_lag_4']
    observation['updates_lag_3'] = prev_obs['updates_lag_3']
    observation['updates_lag_2'] = prev_obs['updates_lag_2']
    observation['updates_lag_1'] = prev_obs['updates_lag_1']
    
    prev_obs['updates_lag_4'] = prev_obs['updates_lag_3']
    prev_obs['updates_lag_3'] = prev_obs['updates_lag_2']
    prev_obs['updates_lag_2'] = prev_obs['updates_lag_1']
    prev_obs['updates_lag_1'] = observation['updates']
    
    if game_state.turn < 359:
        with open(f'{path}/tmp.json', 'w+') as json_file:
            json.dump(prev_obs, json_file)
    else:
        open(f'{path}/tmp.json', 'w+').close()
    
    # Unit Actions
    dest = []
    for unit in player.units:
        if unit.can_act() and (game_state.turn % 40 < 30 or (not in_city(unit.pos))):
            state = make_input(observation, unit.id)
            with torch.no_grad():
                p = model(torch.from_numpy(state).unsqueeze(0))

            policy = p.squeeze(0).numpy()

            action, pos = get_unit_action(policy, unit, dest)
            actions.append(action)
            dest.append(pos)
    
    # City Actions
    unit_count = len(player.units)
    for city in player.cities.values():
        for city_tile in city.citytiles:
            if city_tile.can_act():
                state = make_city_input(observation, [city_tile.pos.x, city_tile.pos.y])
                with torch.no_grad():
                    p = model_city(torch.from_numpy(state).unsqueeze(0))

                policy = p.squeeze(0).numpy()

                action, unit_count = get_city_action(policy, city_tile, unit_count)
                if action:
                    actions.append(action)
    
    return actions

Overwriting agent/agent.py


Submit predictions

In [25]:
!cd agent && tar -czf submission.tar.gz lux agent.py main.py model.pth model_city.pth tmp.json

Test agents on 12x12 field

In [22]:
from kaggle_environments import make

env = make("lux_ai_2021", configuration={"width": 12, "height": 12, "loglevel": 2, "annotations": True}, debug=False)

# first agent is yellow
# second agent is blue
steps = env.run(['agent/agent.py', 'agent.py'])

env.render(mode="ipython", width=1200, height=800)

Test agent on 16x16 field

In [12]:
# env = make("lux_ai_2021", configuration={"width": 16, "height": 16, "loglevel": 2, "annotations": True}, debug=False)

# # first agent is yellow
# # second agent is blue
# steps = env.run(['agent/agent.py', 'agent.py'])

# env.render(mode="ipython", width=1200, height=800)

Test agent on 24x24 field

In [13]:
# env = make("lux_ai_2021", configuration={"width": 24, "height": 24, "loglevel": 2, "annotations": True}, debug=False)

# # first agent is yellow
# # second agent is blue
# steps = env.run(['agent/agent.py', 'agent.py'])

# env.render(mode="ipython", width=1200, height=800)

Test agents on 32x32 field

In [31]:
# env = make("lux_ai_2021", configuration={"width": 32, "height": 32, "loglevel": 2, "annotations": True}, debug=False)

# # first agent is yellow
# # second agent is blue
# steps = env.run(['agent/agent.py', 'agent.py'])

# env.render(mode="ipython", width=1200, height=800)

### Optimize NN parameters with Optuna

In [32]:
# def objective(trial):

#     num_epochs = 10
    
#     # model for unit actions
#     model = LuxNet()
#     train, val = train_test_split(samples, test_size=0.1, random_state=42, stratify=labels)
#     batch_size = 64

#     train_loader = DataLoader(
#         LuxDataset(obses, train), 
#         batch_size=batch_size, 
#         shuffle=True, 
#         num_workers=2
#     )
#     val_loader = DataLoader(
#         LuxDataset(obses, val), 
#         batch_size=batch_size, 
#         shuffle=False, 
#         num_workers=2
#     )
#     dataloaders_dict = {"train": train_loader, "val": val_loader}

#     # Generate the optimizers.
#     criterion = nn.CrossEntropyLoss()
#     optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "AdamW", "RMSprop", "SGD"])
#     lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
#     optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)

#     for epoch in range(num_epochs):
#         model.cuda()
        
#         for phase in ['train', 'val']:
#             if phase == 'train':
#                 model.train()
#             else:
#                 model.eval()
                
#             epoch_loss = 0.0
#             epoch_acc = 0
            
#             dataloader = dataloaders_dict[phase]
#             for item in dataloader:
#                 states = item[0].cuda().float()
#                 actions = item[1].cuda().long()

#                 optimizer.zero_grad()
                
#                 with torch.set_grad_enabled(phase == 'train'):
#                     policy = model(states)
#                     loss = criterion(policy, actions)
#                     _, preds = torch.max(policy, 1)

#                     if phase == 'train':
#                         loss.backward()
#                         optimizer.step()

#                     epoch_loss += loss.item() * len(policy)
#                     epoch_acc += torch.sum(preds == actions.data)

#             data_size = len(dataloader.dataset)
#             epoch_loss = epoch_loss / data_size
#             epoch_acc = epoch_acc.double() / data_size

#         trial.report(epoch_acc, epoch)

#         # Handle pruning based on the intermediate value.
#         if trial.should_prune():
#             raise optuna.exceptions.TrialPruned()

#     return epoch_acc


# study = optuna.create_study(direction="maximize")
# study.optimize(objective, n_trials=500, timeout=10*3600)

# pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
# complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

# print("Study statistics: ")
# print("  Number of finished trials: ", len(study.trials))
# print("  Number of pruned trials: ", len(pruned_trials))
# print("  Number of complete trials: ", len(complete_trials))

# print("Best trial:")
# trial = study.best_trial

# print("  Value: ", trial.value)

# print("  Params: ")
# for key, value in trial.params.items():
#     print("    {}: {}".format(key, value))

# NNs ensemble

In [33]:
import os
import numpy as np
import torch
from lux.game import Game

path = '/kaggle_simulations/agent' if os.path.exists('/kaggle_simulations') else '.' # change to 'agent' for tests
# unit NNs
model_v2 = torch.jit.load(f'{path}/model_v2.pth')
model_v2.eval()
model_v4 = torch.jit.load(f'{path}/model_v4.pth')
model_v4.eval()
model_v5 = torch.jit.load(f'{path}/model_v5.pth')
model_v5.eval()
model_v11 = torch.jit.load(f'{path}/model_v11.pth')
model_v11.eval()
# city NNs
model_city_v2 = torch.jit.load(f'{path}/model_city_v2.pth')
model_city_v2.eval()
model_city_v4 = torch.jit.load(f'{path}/model_city_v4.pth')
model_city_v4.eval()
model_city_v5 = torch.jit.load(f'{path}/model_city_v5.pth')
model_city_v5.eval()
model_city_v11 = torch.jit.load(f'{path}/model_city_v11.pth')
model_city_v11.eval()

# Input for Neural Network for units
def make_input(obs, unit_id):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if unit_id == strs[3]:
                # Position and Cargo
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
            else:
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100
                )
        elif input_identifier == 'ct':
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Day/Night Flag
    b[19, :] = 1 if obs['step'] % 40 < 30 else 0
    # Map Size
    b[20, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1
    return b


# Input for Neural Network for cities
def make_city_input(obs, city_coord):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'ct':
            # CityTiles
            city_id = strs[2]
            x = int(strs[3]) 
            y = int(strs[4])
            cooldown = float(strs[5])
            if x == int(city_coord[0]) and y == int(city_coord[1]):
                b[:2, x + x_shift, y + y_shift] = (
                    1,
                    cities[city_id]
                )
            else:
                team = int(strs[1])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x + x_shift, y + y_shift] = (
                    1,
                    cooldown / 10,
                    cities[city_id]
                )
        elif input_identifier == 'u':
            team = int(strs[2])
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                (wood + coal + uranium) / 100
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Map Size
    b[19, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b

game_state = None
player = None


def get_game_state(observation):
    global game_state
    
    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation["player"]
    else:
        game_state._update(observation["updates"])
    return game_state


def in_city(pos):    
    try:
        city = game_state.map.get_cell_by_pos(pos).citytile
        return city is not None and city.team == game_state.id
    except:
        return False
    
# check if unit has enough time and space to build a city
def build_city_is_possible(unit, pos):    
    global game_state
    global player

    if game_state.turn % 40 < 30:
        return True
    x, y = pos.x, pos.y
    for i, j in ((x-1, y), (x+1, y), (x, y-1), (x, y+1)):
        try:
            city_id = game_state.map.get_cell(i, j).citytile.cityid
        except:
            continue
        if city_id in player.cities:
            city = player.cities[city_id]
            if city.fuel > (city.get_light_upkeep() + 18) * 10:
                return True
    return False


def call_func(obj, method, args=[]):
    return getattr(obj, method)(*args)


# translate unit policy to action
unit_actions = [('move', 'n'), ('move', 's'), ('move', 'w'), ('move', 'e'), ('build_city',)]
def get_unit_action(policy, unit, dest):
    for label in np.argsort(policy)[::-1]:
        act = unit_actions[label]
        pos = unit.pos.translate(act[-1], 1) or unit.pos
        if label == 4 and not build_city_is_possible(unit, pos):
            return unit.move('c'), unit.pos
        if pos not in dest or in_city(pos):
            return call_func(unit, *act), pos 
            
    return unit.move('c'), unit.pos

# translate city policy to action
city_actions = [('build_worker',), ('research', )]
def get_city_action(policy, city_tile, unit_count):
    global player
    
    for label in np.argsort(policy)[::-1]:
        act = city_actions[label]
        if label == 0 and unit_count < player.city_tile_count:
            unit_count += 1
            res = call_func(city_tile, *act)
        elif label == 1 and not player.researched_uranium():
            player.research_points += 1
            res = call_func(city_tile, *act)
        else:
            res = None
        return res, unit_count

# agent for making actions
def agent(observation, configuration):
    global game_state
    global player
    
    game_state = get_game_state(observation)    
    player = game_state.players[observation.player]
    actions = []        

    # Unit Actions
    dest = []
    for unit in player.units:
        if unit.can_act() and (game_state.turn % 40 < 30 or not in_city(unit.pos)):
            state = make_input(observation, unit.id)
            with torch.no_grad():
                p_2 = model_v2(torch.from_numpy(state).unsqueeze(0))
                p_4 = model_v4(torch.from_numpy(state).unsqueeze(0))
                #p_5 = model_v5(torch.from_numpy(state).unsqueeze(0))
                p_11 = model_v11(torch.from_numpy(state).unsqueeze(0))

            policy_2 = p_2.squeeze(0).numpy()
            policy_4 = p_4.squeeze(0).numpy()
            #policy_5 = p_5.squeeze(0).numpy()
            policy_11 = p_11.squeeze(0).numpy()

            policy = [sum(x) for x in zip(*[policy_2, policy_4, policy_11])]

            action, pos = get_unit_action(policy, unit, dest)
            actions.append(action)
            dest.append(pos)

    # City Actions
    unit_count = len(player.units)
    for city in player.cities.values():
        for city_tile in city.citytiles:
            if city_tile.can_act():
                # at first game stages try to produce maximum amount of agents and research point
                if game_state.turn < 60:
                    if unit_count < player.city_tile_count: 
                        actions.append(city_tile.build_worker())
                        unit_count += 1
                    elif not player.researched_uranium():
                        actions.append(city_tile.research())
                        player.research_points += 1
                # then follow NN strategy
                else:
                    state = make_city_input(observation, [city_tile.pos.x, city_tile.pos.y])
                    with torch.no_grad():
                        p_2 = model_city_v2(torch.from_numpy(state).unsqueeze(0))
                        p_4 = model_city_v4(torch.from_numpy(state).unsqueeze(0))
                        #p_5 = model_city_v5(torch.from_numpy(state).unsqueeze(0))
                        p_11 = model_city_v11(torch.from_numpy(state).unsqueeze(0))

                    policy_2 = p_2.squeeze(0).numpy()
                    policy_4 = p_4.squeeze(0).numpy()
                    #policy_5 = p_5.squeeze(0).numpy()
                    policy_11 = p_11.squeeze(0).numpy()

                    policy = [sum(x) for x in zip(*[policy_2, policy_4, policy_11])]

                    action, unit_count = get_city_action(policy, city_tile, unit_count)
                    if action:
                        actions.append(action)
    
    return actions

ValueError: The provided filename ./model_v2.pth does not exist

# Futher Ideas

Cities
- add day/night feature

Hyperparameters
- optimize hyperparameters (number of features and layers, regularization, dropout probability, learning rate) with Optuna

Ensembles
- make ensemble of 3 best NNs (select through Optuna optimization) that makes decision by voting or randomly selects action from presented

In [None]:
# dict for apropriate change of moving action during rotation
# 1 - rotation on 90, 2 - rotation on 180, 3 - rotation on 270
rot_dict = {0: {0:0, 1:1, 2:2, 3:3, 4:4}, # N -> E, S -> W, W -> N, E -> S
            1: {0:3, 1:2, 2:0, 3:1, 4:4}, # N -> E, S -> W, W -> N, E -> S
            2: {0:1, 1:0, 2:3, 3:2, 4:4}, # N -> S, S -> N, W -> E, E -> W
            3: {0:2, 1:3, 2:1, 3:0, 4:4}  # N -> W, S -> E, W -> S, E -> N
           }