For Kaggle

In [1]:
# !pip install kaggle-environments -U > /dev/null 2>&1
# !cp -r ../input/lux-ai-2021/* .

For agents validation

In [2]:
# timeout 1h lux-ai-2021 --tournament --rankSystem wins --storeReplay false --storeLogs false --maxConcurrentMatches 1 agent/main.py agent_simple/main.py submission_v4/main.py

In [3]:
import numpy as np
import json
from pathlib import Path
import os
import random
from tqdm.notebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.transforms import ToTensor, RandomHorizontalFlip, RandomVerticalFlip
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split
import optuna
from optuna.trial import TrialState

In [4]:
def seed_everything(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

seed_everything(42)

# Preprocessing

Actions:
- m - move: m, unit who moves, direction
- bcity - build city: bcity, unit who builds
- bw - build worker, x-coord of building city, y-coord of building city
- r - research , x-coord of researching city, y-coord of researching city
- t - transfer: transfer, from user_1, to user_2, resourse, quantity

Updates:
- rp - research point: player, number of rp
- r - resources: r, type of resource, x-coord, y-coord, quantity
- u - user: u, worker/cart, player, user id, x-coord, y-coord, cooldown, wood, coal, uranium
- c - city: c, player, city id, number of resources, amount of consuming light at night
- ct - city tile: ct, player, city id, x-coord, y-coord, cooldown
- ccd - level of road: ccd, x-coord, y-coord, level value

In [5]:
!cd episodes && rm *_info.json && cd ..

rm: cannot remove '*_info.json': No such file or directory


Get data from json files

In [6]:
from random import choice

def unit_label(action):
    strs = action.split(' ')
    unit_id = strs[1]
    if strs[0] == 'm':
        label = {'c': None, 'n': 0, 's': 1, 'w': 2, 'e': 3}[strs[2]]
    elif strs[0] == 'bcity':
        label = 4
    else:
        label = None
    return unit_id, label

def city_label(action):
    strs = action.split(' ')
    if len(strs) < 3:
        return None, None
    ctile_coord = (strs[1], strs[2])
    if strs[0] == 'bw':
        label = 0
    elif strs[0] == 'r':
        label = 1
    else:
        label = None
    return ctile_coord, label

def unit_cooldown(updates, index):
    strs = updates.split(' ')
    if strs[0] == 'u' and strs[2] == index:
        unit_id = strs[3]
        cooldown = strs[6]
        return unit_id, cooldown
    return None, None

def city_tile_coord(updates, index):
    strs = updates.split(' ')
    if strs[0] == 'ct' and strs[1] == index:
        ctile_coord = (strs[3], strs[4])
        cooldown = strs[5]
        return ctile_coord, cooldown
    return None, None

def research_points(updates, index):
    for u in updates:
        strs = u.split(' ')
        if strs[0] == 'rp' and strs[1] == index:
            return int(strs[2])
    return 0

def depleted_resources(obs):
    for u in obs['updates']:
        if u.split(' ')[0] == 'r':
            return False
    return True


def create_dataset_from_json(episode_dir, team_name='Toad Brigade'): 
    obses = {}
    unit_samples = []
    city_samples = []
    
    episodes = [path for path in Path(episode_dir).glob('*.json') if 'output' not in path.name]
    for filepath in tqdm(episodes): 
        with open(filepath) as f:
            json_load = json.load(f)
        # load episode from episode json file
        ep_id = json_load['info']['EpisodeId']
        index = np.argmax([r or 0 for r in json_load['rewards']])
        if json_load['info']['TeamNames'][index] != team_name:
            continue

        updates = None
        
        for i in range(len(json_load['steps'])-1):
            units = {}
            city_tiles = {}
            
            if json_load['steps'][i][index]['status'] == 'ACTIVE':
                actions = json_load['steps'][i+1][index]['action']
                obs = json_load['steps'][i][0]['observation']
                updates = obs['updates']
                rp = research_points(updates, str(index))
                # get updates from previous steps
                if i > 0:
                    obs['updates_lag_1'] = json_load['steps'][i-1][0]['observation']['updates']
                else:
                    obs['updates_lag_1'] = None
                if i > 1:
                    obs['updates_lag_2'] = json_load['steps'][i-2][0]['observation']['updates']
                else:
                    obs['updates_lag_2'] = None
                if i > 2:
                    obs['updates_lag_3'] = json_load['steps'][i-3][0]['observation']['updates']
                else:
                    obs['updates_lag_3'] = None
                if i > 3:
                    obs['updates_lag_4'] = json_load['steps'][i-4][0]['observation']['updates']
                else:
                    obs['updates_lag_4'] = None
                if i > 4:
                    obs['updates_lag_5'] = json_load['steps'][i-5][0]['observation']['updates']
                else:
                    obs['updates_lag_5'] = None
                if i > 5:
                    obs['updates_lag_6'] = json_load['steps'][i-6][0]['observation']['updates']
                else:
                    obs['updates_lag_6'] = None
                if i > 6:
                    obs['updates_lag_7'] = json_load['steps'][i-7][0]['observation']['updates']
                else:
                    obs['updates_lag_7'] = None
                
                for u in updates:
                    # get coords for every friendly city tile
                    ctile_coord, cooldown = city_tile_coord(u, str(index))
                    if ctile_coord and cooldown:
                        city_tiles[ctile_coord] = [0, float(cooldown)]
                        continue
                    # get cooldown for every friendly unit 
                    unit_id, cooldown = unit_cooldown(u, str(index))
                    if unit_id and cooldown:
                        units[unit_id] = float(cooldown)                     
                
                # stop research if resources are depleted
                if depleted_resources(obs):
                    break
                
                obs['player'] = index
                obs = dict([
                    (k,v) for k,v in obs.items() 
                    if k in ['step', 'updates', 'updates_lag_1', 'updates_lag_2', 
                             'updates_lag_3', 'updates_lag_4', 'updates_lag_5', 
                             'updates_lag_6', 'updates_lag_7', 'player', 'width', 'height']
                ])
                obs_id = f'{ep_id}_{i}'
                obses[obs_id] = obs
                
                for action in actions:
                    unit_id, label = unit_label(action)
                    # if unit acts - add this action to train
                    if label is not None:
                        unit_samples.append((obs_id, unit_id, label))
                        continue
                    ctile_coord, label = city_label(action)
                    # if city tile acts - add this action to train
                    if label is not None:
                        city_samples.append((obs_id, ctile_coord, label))
                        city_tiles[ctile_coord][0] = label
                        
#                 # it there is a possibility for city tiles to act, (research points number < 200 and number of
#                 # city tiles more than number of units find those of them who can act (cooldown = 0) but doesn't
#                 if rp < 200 and len(city_tiles) > len(units):
#                     city_tiles_no_action = [(k, v[0]) for k, v in city_tiles.items() if v[0] == 0 and v[1] == 0]
#                     # if there are cities with no action - randomly select one item from there 
#                     # and add it to city samples
#                     if city_tiles_no_action:
#                         ct = choice(city_tiles_no_action)
#                         city_samples.append((obs_id, ct[0], ct[1]))
                    
                    
    return obses, unit_samples, city_samples

In [7]:
episode_dir = 'episodes'
obses, samples, city_samples = create_dataset_from_json(episode_dir)

print('observations:', len(obses), 'worker samples:', len(samples), 'city samples:', len(city_samples))

  0%|          | 0/332 [00:00<?, ?it/s]

observations: 96015 worker samples: 553399 city samples: 69609


In [8]:
labels = [sample[-1] for sample in samples]
actions = ['north', 'south', 'west', 'east', 'bcity']
for value, count in zip(*np.unique(labels, return_counts=True)):
    print(f'{actions[value]}: {count}')

north: 122189
south: 122506
west: 132911
east: 136626
bcity: 39167


In [9]:
labels_city = [sample[-1] for sample in city_samples]
actions_city = ['build_worker', 'research']
for value, count in zip(*np.unique(labels_city, return_counts=True)):
    print(f'{actions_city[value]}: {count}')

build_worker: 21127
research: 48482


In [10]:
# obses['26762301_160']

In [11]:
# samples

In [12]:
# city_samples[200000]

# Training

b - training tensor of float32. b dimensions is 20x32x32

- b[0] - position of current unit
- b[1] - cargo sum/100 of current unit
- b[2] - previous position of current unit
- b[3] - pre-previous position of current unit
- b[4, 5, 6, 7] - position, cooldown/6, cargo sum/100 and Manhattan distance to units from friendly team
- b[8, 9, 10, 11] - position, cooldown/6, cargo sum/100 and Manhattan distance to units from adversarial team
- b[12, 13, 14] - position,  min(city fuel/city energy consumption, 10)/10 and Manhattan distance to cities from friendly team
- b[15, 16, 17] - position,  min(city fuel/city energy consumption, 10)/10 and Manhattan distance to cities from adversarial team
- b[18] - amount of wood / 800 on current position
- b[19] - amount of coal / 800 on current position
- b[20] - amount of uranium / 800 on current position
- b[21] - 1 if unit can mine this resource, 0 otherwise
- b[22] - Manhattan distance to this resource
- b[23] - research points / 200 of unit's team
- b[24] - research points / 200 of another team
- b[25] - time of the day (from 0 to 1, step 0.05)
- b[26] - step of the game (from 0 to 1, step 1/360)
- b[27] - day/night flag
- b[28] - map size

In [13]:
def manhattan_distance(x1, y1, x2, y2):
    return (abs(x2-x1) + abs(y2-y1))

# make list of users and their current and previous coords and cooldowns 
# to write them in NN training subset
def find_units_from_previous_obs(obs, x_shift, y_shift):
    # at first fill the unit dict with units from previous observation
    updates = obs['updates']
    updates_lag_1 = obs['updates_lag_1']
    updates_lag_2 = obs['updates_lag_2']
    updates_lag_3 = obs['updates_lag_3']
    updates_lag_4 = obs['updates_lag_4']
    updates_lag_5 = obs['updates_lag_5']
    updates_lag_6 = obs['updates_lag_6']
    updates_lag_7 = obs['updates_lag_7']
    prev_updates = [updates, updates_lag_1, updates_lag_2, updates_lag_3, 
                    updates_lag_4, updates_lag_5, updates_lag_6, updates_lag_7]
    prev_units = list()
    for prev_update in prev_updates:
        units = dict()
        if prev_update:
            for update in prev_update:
                strs = update.split(' ')
                input_identifier = strs[0]
                # if we found observation for user
                if input_identifier == 'u':
                    unit_id = strs[3]
                    x = int(strs[4]) + x_shift
                    y = int(strs[5]) + y_shift
                    cooldown = float(strs[6])
                    units[unit_id] = [x, y, cooldown]
            prev_units.append(units)
        else:
            break
    return prev_units

def find_prev_coords(prev_units, unit_id):
    # find previous coordinates of unit by ananlyzing its cooldown
    x, y = [None, None], [None, None]
    for i in range(len(prev_units)-1):
        if unit_id in prev_units[i] and unit_id in prev_units[i+1]:
            cooldown = prev_units[i][unit_id][0]
            prev_x = prev_units[i+1][unit_id][0]
            prev_y = prev_units[i+1][unit_id][1]
            prev_cooldown = prev_units[i+1][unit_id][2]
            if cooldown > 0 and prev_cooldown == 0:
                if x[1] and y[1]:
                    return x, y
                if not x[0] and not y[0]:
                    x[0], y[0] = prev_x,prev_y
                else:
                    x[1], y[1] = prev_x,prev_y
        else:
            break
    return x, y
            
# Input for Neural Network for workers
def make_input(obs, unit_id):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2

    cities = {}
    
    b = np.zeros((29, 32, 32), dtype=np.float32)
    
    prev_units = find_units_from_previous_obs(obs, x_shift, y_shift)
    x_u, y_u = prev_units[0][unit_id][0], prev_units[0][unit_id][1]
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        my_rp = 0
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if strs[3] == unit_id: # 0:2
                # Position, Cargo and Previous Position
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
                prev_x, prev_y = find_prev_coords(prev_units, unit_id)
                if not prev_x[0] and not prev_y[0]:
                    prev_x[0], prev_y[0] = x, y
                    prev_x[1], prev_y[1] = x, y
                b[2, prev_x[0], prev_y[0]] = 1
                b[3, prev_x[1], prev_y[1]] = 1
            else:                  # 4:11
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 4 + (team - obs['player']) % 2 * 4
                m_dist = manhattan_distance(x_u, y_u, x, y)
                b[idx:idx + 4, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100,
                    m_dist/((width-1) + (height-1))
                )
        elif input_identifier == 'ct':  # 12:17
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 12 + (team - obs['player']) % 2 * 3
            m_dist = manhattan_distance(x_u, y_u, x, y)
            b[idx:idx + 3, x, y] = (
                1,
                cities[city_id],
                m_dist/((width-1) + (height-1))
            )
        elif input_identifier == 'r':  # 18:21
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            access_level = {'wood': 0, 'coal': 50, 'uranium': 200}[r_type]
            access = 0 if my_rp < access_level else 1
            b[{'wood': 18, 'coal': 19, 'uranium': 20}[r_type], x, y] = amt / 800
            b[21, x, y] = access
            b[22, x, y] = manhattan_distance(x_u, y_u, x, y)/((width-1) + (height-1))
        elif input_identifier == 'rp':  # 23:24
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            my_rp = rp if team == obs['player'] else my_rp
            b[23 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    # Day/Night Cycle
    b[25, :] = obs['step'] % 40 / 40
    # Turns
    b[26, :] = obs['step'] / 360
    # Day/Night Flag
    b[27, :] = 1 if obs['step'] % 40 < 30 else 0
    # Map Size
    b[28, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1
        
    return b

In [14]:
# obses['26762301_160']

In [15]:
# import sys

# np.set_printoptions(threshold=sys.maxsize)

# make_input(obses['26762301_170'], 'u_36')[3]

Data for the cities:
- b[0] - position of current city
- b[1] - min(city fuel/city energy consumption, 10)/10 of current city
- b[2, 3, 4] - position, cooldown/10, and min(city fuel/city energy consumption, 10)/10 for cities from the same team
- b[5, 6, 7] - position, cooldown/10, and min(city fuel/city energy consumption, 10)/10 for cities from another team
- b[8, 9] - position and cargo sum/100 for units from the same team 
- b[10, 11] - position and cargo sum/100 for units from from another team
- b[12] - amount of wood / 800
- b[13] - amount of coal / 800
- b[14] - amount of uranium / 800
- b[15] - research points / 200 of unit's team
- b[16] - research piints / 200 of another team
- b[17] - time of the day (from 0 to 1, step 0.05)
- b[18] - step of the game (from 0 to 1, step 1/360)
- b[19] - day/night flag
- b[20] - map size

In [16]:
# Input for Neural Network for cities
def make_city_input(obs, city_coord):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((21, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'ct':
            # CityTiles
            city_id = strs[2]
            x = int(strs[3]) 
            y = int(strs[4])
            cooldown = float(strs[5])
            if x == int(city_coord[0]) and y == int(city_coord[1]):
                b[:2, x + x_shift, y + y_shift] = (
                    1,
                    cities[city_id]
                )
            else:
                team = int(strs[1])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x + x_shift, y + y_shift] = (
                    1,
                    cooldown / 10,
                    cities[city_id]
                )
        elif input_identifier == 'u':
            team = int(strs[2])
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                (wood + coal + uranium) / 100
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Day/Night Flag
    b[19, :] = 1 if obs['step'] % 40 < 30 else 0
    # Map Size
    b[20, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b

### Set modules for NN training

In [17]:
# Neural Network for Lux AI
class BasicConv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn, p=0.1):
        super().__init__()
        self.conv = nn.Conv2d(
            input_dim, output_dim, 
            kernel_size=kernel_size, 
            padding=(kernel_size[0] // 2, kernel_size[1] // 2)
        )
        self.bn = nn.BatchNorm2d(output_dim) if bn else None
        self.dropout = nn.Dropout(p) 

    def forward(self, x):
        h = self.conv(x)
        h = self.bn(h) if self.bn is not None else h
        h = self.dropout(h) 
        return h


class LuxNet(nn.Module):
    def __init__(self):
        super().__init__()
        layers, filters = 12, 64
        self.conv0 = BasicConv2d(29, filters, (3, 3), True, p=0.2)
        self.blocks = nn.ModuleList([BasicConv2d(filters, filters, (3, 3), True) for _ in range(layers)])
        self.head_p = nn.Linear(filters, 5, bias=False)

    def forward(self, x):
        h = F.relu_(self.conv0(x))
        for block in self.blocks:
            h = F.relu_(h + block(h))
        h_head = (h * x[:,:1]).view(h.size(0), h.size(1), -1).sum(-1)
        p = self.head_p(h_head)
        return p
    
    
class LuxCityNet(nn.Module):
    def __init__(self):
        super().__init__()
        layers, filters = 12, 32
        self.conv0 = BasicConv2d(21, filters, (3, 3), True, p=0.2)
        self.blocks = nn.ModuleList([BasicConv2d(filters, filters, (3, 3), True) for _ in range(layers)])
        self.head_p = nn.Linear(filters, 2, bias=False)

    def forward(self, x):
        h = F.relu_(self.conv0(x))
        for block in self.blocks:
            h = F.relu_(h + block(h))
        h_head = (h * x[:,:1]).view(h.size(0), h.size(1), -1).sum(-1)
        p = self.head_p(h_head)
        return p

### Set dataloaders

In [18]:
# mask for apropriate change of moving action during rotation
# 0 - no transform, 1 - rotation on 90, 2 - rotation on 180, 
# 3 - rotation on 270, 4 - horizontal flip, 5 - vertical flip
mask = (
        (0, 1, 2, 3, 4), 
        (2, 3, 1, 0, 4), # N -> W, S -> E, W -> S, E -> N - rot 90
        (1, 0, 3, 2, 4), # N -> S, S -> N, W -> E, E -> W - rot 180
        (3, 2, 0, 1, 4), # N -> E, S -> W, W -> N, E -> S - rot 270
        (1, 0, 2, 3, 4), # N -> S, S -> N, W -> W, E -> E - hflip
        (0, 1, 3, 2, 4), # N -> N, S -> S, W -> E, E -> W - vflip
        )

def do_nothing(x):
    return x

class RandomChoice(torch.nn.Module):
    def __init__(self, transforms):
       super().__init__()
       self.transforms = transforms

    def __call__(self, x):
        idx = random.choice([i for i in range(len(self.transforms))])
        t = self.transforms[idx]
        x = torch.from_numpy(x)
        return t(x), idx

transform=RandomChoice([lambda x: x,
                        lambda x: torch.rot90(x, 1, [2, 1]),
                        lambda x: torch.rot90(x, 2, [2, 1]),
                        lambda x: torch.rot90(x, 1, [1, 2]),
                        TF.hflip,
                        TF.vflip])
    
class LuxDataset(Dataset):
    def __init__(self, obses, samples, transform=transform, mask=mask, city=False):
        self.obses = obses
        self.samples = samples
        self.data_len = len(self.samples)
        self.len = self.data_len*6
        self.transform = transform
        self.mask = mask
        self.city = city
            
    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        data_idx = idx % self.data_len
        obs_id, unit_id, action = self.samples[data_idx]
        obs = self.obses[obs_id]
        if not self.city:
            state = make_input(obs, unit_id)
        else:
            state = make_city_input(obs, unit_id)
        if self.transform:
            t = self.transform(state)
            state = t[0]
            if not self.city:
                action = self.mask[t[1]][action]
        return state, action

### Function for NN training

In [19]:
import matplotlib.pyplot as plt
# from torch.utils.tensorboard import SummaryWriter

def train_model(model, dataloaders_dict, criterion, optimizer, scheduler, num_epochs, city=False):
#     tb = SummaryWriter()
    
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        model.cuda()
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            epoch_loss = 0.0
            epoch_acc = 0
            
            dataloader = dataloaders_dict[phase]
            for item in tqdm(dataloader, leave=False):
                
                states = item[0].cuda().float()
                actions = item[1].cuda().long()

                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    policy = model(states)
                    loss = criterion(policy, actions)
                    _, preds = torch.max(policy, 1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    epoch_loss += loss.item() * len(policy)
                    epoch_acc += torch.sum(preds == actions.data)

            data_size = len(dataloader.dataset)
            epoch_loss = epoch_loss / data_size
            epoch_acc = epoch_acc.double() / data_size

#             if phase == 'train':
#                 tb.add_scalar("Train Loss", epoch_loss, epoch)
#                 tb.add_scalar("Train Accuracy", epoch_acc, epoch)
#             else:
#                 tb.add_scalar("Val Loss", epoch_loss, epoch)
#                 tb.add_scalar("Val Accuracy", epoch_acc, epoch)
            
            print(f'Epoch {epoch + 1}/{num_epochs} | {phase:^5} | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f}')
        
        if epoch_acc > best_acc:
            if city:
                traced = torch.jit.trace(model.cpu(), torch.rand(1, 21, 32, 32))
                traced.save('agent/model_city.pth')
            else:
                traced = torch.jit.trace(model.cpu(), torch.rand(1, 29, 32, 32))
                traced.save('agent/model.pth')
            best_acc = epoch_acc
            
        scheduler.step(epoch_loss)
        print(f'LR: {optimizer.param_groups[0]["lr"]}')
            
#     tb.close()

### Init dataloaders

In [20]:
# model for unit actions
model = LuxNet()
train, val = train_test_split(samples, test_size=0.1, random_state=42, stratify=labels)
batch_size = 256

train_loader = DataLoader(
    LuxDataset(obses, train), 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=2
)
val_loader = DataLoader(
    LuxDataset(obses, val), 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=2
)
dataloaders_dict = {"train": train_loader, "val": val_loader}

# model for city actions
model_city = LuxCityNet()
train_city, val_city = train_test_split(city_samples, test_size=0.1, random_state=42, stratify=labels_city)
batch_size_city = 128

train_city_loader = DataLoader(
    LuxDataset(obses, train_city, city=True), 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=2
)
    
val_city_loader = DataLoader(
    LuxDataset(obses, val_city, city=True), 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=2
)
dataloaders_city_dict = {"train": train_city_loader, "val": val_city_loader}

### Train models

In [None]:
%%time

# model = torch.jit.load('agent/model.pth')
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=4, threshold=1e-2, min_lr=1e-6)

train_model(model, dataloaders_dict, criterion, optimizer, scheduler, num_epochs=10)

  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 1/30 | train | Loss: 0.6575 | Acc: 0.7359


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 1/30 |  val  | Loss: 0.5435 | Acc: 0.7816
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 2/30 | train | Loss: 0.5531 | Acc: 0.7777


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 2/30 |  val  | Loss: 0.5119 | Acc: 0.7941
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 3/30 | train | Loss: 0.5242 | Acc: 0.7890


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 3/30 |  val  | Loss: 0.4874 | Acc: 0.8049
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 4/30 | train | Loss: 0.5077 | Acc: 0.7955


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 4/30 |  val  | Loss: 0.4776 | Acc: 0.8074
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 5/30 | train | Loss: 0.4962 | Acc: 0.7998


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 5/30 |  val  | Loss: 0.4684 | Acc: 0.8111
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 6/30 | train | Loss: 0.4877 | Acc: 0.8033


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 6/30 |  val  | Loss: 0.4747 | Acc: 0.8080
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 7/30 | train | Loss: 0.4816 | Acc: 0.8057


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 7/30 |  val  | Loss: 0.4589 | Acc: 0.8143
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 8/30 | train | Loss: 0.4761 | Acc: 0.8077


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 8/30 |  val  | Loss: 0.4572 | Acc: 0.8149
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 9/30 | train | Loss: 0.4718 | Acc: 0.8094


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 9/30 |  val  | Loss: 0.4621 | Acc: 0.8136
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 10/30 | train | Loss: 0.4688 | Acc: 0.8108


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 10/30 |  val  | Loss: 0.4542 | Acc: 0.8166
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 11/30 | train | Loss: 0.4655 | Acc: 0.8120


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 11/30 |  val  | Loss: 0.4524 | Acc: 0.8170
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 12/30 | train | Loss: 0.4626 | Acc: 0.8130


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 12/30 |  val  | Loss: 0.4496 | Acc: 0.8187
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 13/30 | train | Loss: 0.4607 | Acc: 0.8139


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 13/30 |  val  | Loss: 0.4482 | Acc: 0.8198
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 14/30 | train | Loss: 0.4581 | Acc: 0.8147


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 14/30 |  val  | Loss: 0.4515 | Acc: 0.8170
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 15/30 | train | Loss: 0.4563 | Acc: 0.8156


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 15/30 |  val  | Loss: 0.4589 | Acc: 0.8150
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 16/30 | train | Loss: 0.4545 | Acc: 0.8162


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 16/30 |  val  | Loss: 0.4495 | Acc: 0.8172
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 17/30 | train | Loss: 0.4537 | Acc: 0.8162


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 17/30 |  val  | Loss: 0.4413 | Acc: 0.8215
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 18/30 | train | Loss: 0.4521 | Acc: 0.8171


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 18/30 |  val  | Loss: 0.4463 | Acc: 0.8193
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 19/30 | train | Loss: 0.4505 | Acc: 0.8178


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 19/30 |  val  | Loss: 0.4415 | Acc: 0.8215
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 20/30 | train | Loss: 0.4494 | Acc: 0.8183


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 20/30 |  val  | Loss: 0.4412 | Acc: 0.8227
LR: 0.001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 21/30 | train | Loss: 0.4485 | Acc: 0.8184


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 21/30 |  val  | Loss: 0.4435 | Acc: 0.8204
LR: 0.0001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 22/30 | train | Loss: 0.4185 | Acc: 0.8305


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 22/30 |  val  | Loss: 0.4178 | Acc: 0.8309
LR: 0.0001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 23/30 | train | Loss: 0.4116 | Acc: 0.8332


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 23/30 |  val  | Loss: 0.4143 | Acc: 0.8323
LR: 0.0001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 24/30 | train | Loss: 0.4082 | Acc: 0.8343


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 24/30 |  val  | Loss: 0.4113 | Acc: 0.8335
LR: 0.0001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 25/30 | train | Loss: 0.4060 | Acc: 0.8353


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 25/30 |  val  | Loss: 0.4120 | Acc: 0.8336
LR: 0.0001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 26/30 | train | Loss: 0.4043 | Acc: 0.8359


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 26/30 |  val  | Loss: 0.4113 | Acc: 0.8332
LR: 0.0001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 27/30 | train | Loss: 0.4030 | Acc: 0.8367


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 27/30 |  val  | Loss: 0.4124 | Acc: 0.8342
LR: 0.0001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 28/30 | train | Loss: 0.4022 | Acc: 0.8369


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 28/30 |  val  | Loss: 0.4106 | Acc: 0.8342
LR: 0.0001


  0%|          | 0/11674 [00:00<?, ?it/s]

Epoch 29/30 | train | Loss: 0.4007 | Acc: 0.8374


  0%|          | 0/1298 [00:00<?, ?it/s]

Epoch 29/30 |  val  | Loss: 0.4090 | Acc: 0.8350
LR: 0.0001


  0%|          | 0/11674 [00:00<?, ?it/s]

In [22]:
# Epoch 1/21 | train | Loss: 0.7122 | Acc: 0.7126
# Epoch 1/21 |  val  | Loss: 0.5889 | Acc: 0.7635
# Epoch 2/21 | train | Loss: 0.6028 | Acc: 0.7573
# Epoch 2/21 |  val  | Loss: 0.5381 | Acc: 0.7830
# Epoch 3/21 | train | Loss: 0.5726 | Acc: 0.7694
# Epoch 3/21 |  val  | Loss: 0.5305 | Acc: 0.7860
# Epoch 4/21 | train | Loss: 0.5554 | Acc: 0.7760
# Epoch 4/21 |  val  | Loss: 0.5152 | Acc: 0.7925
# Epoch 5/21 | train | Loss: 0.5441 | Acc: 0.7804
# Epoch 5/21 |  val  | Loss: 0.5014 | Acc: 0.7973
# Epoch 6/21 | train | Loss: 0.5355 | Acc: 0.7840
# Epoch 6/21 |  val  | Loss: 0.4989 | Acc: 0.7992
# Epoch 7/21 | train | Loss: 0.5286 | Acc: 0.7867
# Epoch 7/21 |  val  | Loss: 0.5029 | Acc: 0.7970
# Epoch 8/21 | train | Loss: 0.5232 | Acc: 0.7888
# Epoch 8/21 |  val  | Loss: 0.4933 | Acc: 0.8007
# Epoch 9/21 | train | Loss: 0.5186 | Acc: 0.7905
# Epoch 9/21 |  val  | Loss: 0.4909 | Acc: 0.8010
# Epoch 10/21 | train | Loss: 0.5153 | Acc: 0.7920
# Epoch 10/21 |  val  | Loss: 0.4833 | Acc: 0.8044
# Epoch 11/21 | train | Loss: 0.5125 | Acc: 0.7932
# Epoch 11/21 |  val  | Loss: 0.4838 | Acc: 0.8047
# Epoch 12/21 | train | Loss: 0.5097 | Acc: 0.7940
# Epoch 12/21 |  val  | Loss: 0.4815 | Acc: 0.8064
# Epoch 13/21 | train | Loss: 0.5074 | Acc: 0.7948
# Epoch 13/21 |  val  | Loss: 0.4794 | Acc: 0.8070
# Epoch 14/21 | train | Loss: 0.5057 | Acc: 0.7957
# Epoch 14/21 |  val  | Loss: 0.4782 | Acc: 0.8071
# Epoch 15/21 | train | Loss: 0.5041 | Acc: 0.7962
# Epoch 15/21 |  val  | Loss: 0.4754 | Acc: 0.8077
# Epoch 16/21 | train | Loss: 0.5024 | Acc: 0.7971
# Epoch 16/21 |  val  | Loss: 0.4748 | Acc: 0.8081
# Epoch 17/21 | train | Loss: 0.5012 | Acc: 0.7977
# Epoch 17/21 |  val  | Loss: 0.4861 | Acc: 0.8045
# Epoch 18/21 | train | Loss: 0.4995 | Acc: 0.7982
# Epoch 18/21 |  val  | Loss: 0.4707 | Acc: 0.8111
# Epoch 19/21 | train | Loss: 0.4989 | Acc: 0.7983
# Epoch 19/21 |  val  | Loss: 0.4721 | Acc: 0.8090
# Epoch 20/21 | train | Loss: 0.4980 | Acc: 0.7987
# Epoch 20/21 |  val  | Loss: 0.4756 | Acc: 0.8077
# Epoch 21/21 | train | Loss: 0.4965 | Acc: 0.7992
# Epoch 21/21 |  val  | Loss: 0.4671 | Acc: 0.8115

In [23]:
%%time

# model_city = torch.jit.load('agent/model_city.pth')
# criterion_city = nn.CrossEntropyLoss()
# optimizer_city = torch.optim.AdamW(model_city.parameters(), lr=1e-3)
# scheduler_city = ReduceLROnPlateau(optimizer_city, 'min', patience=3, threshold=1e-3, min_lr=1e-5)

# train_model(model_city, dataloaders_city_dict, criterion_city, optimizer_city, scheduler_city, 
#             num_epochs=35, city=True)

  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 1/35 | train | Loss: 0.2636 | Acc: 0.8870


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 1/35 |  val  | Loss: 0.2170 | Acc: 0.9077
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 2/35 | train | Loss: 0.2046 | Acc: 0.9139


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 2/35 |  val  | Loss: 0.2045 | Acc: 0.9123
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 3/35 | train | Loss: 0.1873 | Acc: 0.9215


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 3/35 |  val  | Loss: 0.1946 | Acc: 0.9151
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 4/35 | train | Loss: 0.1776 | Acc: 0.9256


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 4/35 |  val  | Loss: 0.1770 | Acc: 0.9229
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 5/35 | train | Loss: 0.1682 | Acc: 0.9294


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 5/35 |  val  | Loss: 0.1754 | Acc: 0.9239
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 6/35 | train | Loss: 0.1610 | Acc: 0.9325


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 6/35 |  val  | Loss: 0.1771 | Acc: 0.9240
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 7/35 | train | Loss: 0.1558 | Acc: 0.9353


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 7/35 |  val  | Loss: 0.1663 | Acc: 0.9283
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 8/35 | train | Loss: 0.1507 | Acc: 0.9373


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 8/35 |  val  | Loss: 0.1881 | Acc: 0.9175
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 9/35 | train | Loss: 0.1450 | Acc: 0.9396


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 9/35 |  val  | Loss: 0.1582 | Acc: 0.9349
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 10/35 | train | Loss: 0.1411 | Acc: 0.9413


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 10/35 |  val  | Loss: 0.1552 | Acc: 0.9382
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 11/35 | train | Loss: 0.1379 | Acc: 0.9428


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 11/35 |  val  | Loss: 0.1565 | Acc: 0.9353
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 12/35 | train | Loss: 0.1348 | Acc: 0.9441


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 12/35 |  val  | Loss: 0.1573 | Acc: 0.9349
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 13/35 | train | Loss: 0.1315 | Acc: 0.9452


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 13/35 |  val  | Loss: 0.1505 | Acc: 0.9364
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 14/35 | train | Loss: 0.1292 | Acc: 0.9462


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 14/35 |  val  | Loss: 0.1500 | Acc: 0.9366
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 15/35 | train | Loss: 0.1269 | Acc: 0.9474


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 15/35 |  val  | Loss: 0.1497 | Acc: 0.9384
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 16/35 | train | Loss: 0.1255 | Acc: 0.9482


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 16/35 |  val  | Loss: 0.1541 | Acc: 0.9333
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 17/35 | train | Loss: 0.1230 | Acc: 0.9491


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 17/35 |  val  | Loss: 0.1455 | Acc: 0.9400
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 18/35 | train | Loss: 0.1203 | Acc: 0.9501


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 18/35 |  val  | Loss: 0.1449 | Acc: 0.9412
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 19/35 | train | Loss: 0.1183 | Acc: 0.9509


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 19/35 |  val  | Loss: 0.1457 | Acc: 0.9403
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 20/35 | train | Loss: 0.1175 | Acc: 0.9515


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 20/35 |  val  | Loss: 0.1466 | Acc: 0.9403
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 21/35 | train | Loss: 0.1163 | Acc: 0.9515


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 21/35 |  val  | Loss: 0.1508 | Acc: 0.9381
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 22/35 | train | Loss: 0.1146 | Acc: 0.9526


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 22/35 |  val  | Loss: 0.1452 | Acc: 0.9403
LR: 0.001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 23/35 | train | Loss: 0.1123 | Acc: 0.9535


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 23/35 |  val  | Loss: 0.1492 | Acc: 0.9420
LR: 0.0001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 24/35 | train | Loss: 0.0994 | Acc: 0.9586


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 24/35 |  val  | Loss: 0.1380 | Acc: 0.9463
LR: 0.0001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 25/35 | train | Loss: 0.0967 | Acc: 0.9599


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 25/35 |  val  | Loss: 0.1441 | Acc: 0.9437
LR: 0.0001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 26/35 | train | Loss: 0.0965 | Acc: 0.9600


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 26/35 |  val  | Loss: 0.1418 | Acc: 0.9450
LR: 0.0001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 27/35 | train | Loss: 0.0944 | Acc: 0.9611


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 27/35 |  val  | Loss: 0.1423 | Acc: 0.9441
LR: 0.0001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 28/35 | train | Loss: 0.0939 | Acc: 0.9609


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 28/35 |  val  | Loss: 0.1445 | Acc: 0.9444
LR: 0.0001


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 29/35 | train | Loss: 0.0934 | Acc: 0.9614


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 29/35 |  val  | Loss: 0.1431 | Acc: 0.9446
LR: 1e-05


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 30/35 | train | Loss: 0.0919 | Acc: 0.9620


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 30/35 |  val  | Loss: 0.1436 | Acc: 0.9436
LR: 1e-05


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 31/35 | train | Loss: 0.0917 | Acc: 0.9621


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 31/35 |  val  | Loss: 0.1443 | Acc: 0.9431
LR: 1e-05


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 32/35 | train | Loss: 0.0915 | Acc: 0.9623


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 32/35 |  val  | Loss: 0.1445 | Acc: 0.9449
LR: 1e-05


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 33/35 | train | Loss: 0.0916 | Acc: 0.9621


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 33/35 |  val  | Loss: 0.1476 | Acc: 0.9435
LR: 1e-05


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 34/35 | train | Loss: 0.0916 | Acc: 0.9618


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 34/35 |  val  | Loss: 0.1411 | Acc: 0.9450
LR: 1.0000000000000002e-06


  0%|          | 0/1392 [00:00<?, ?it/s]

Epoch 35/35 | train | Loss: 0.0912 | Acc: 0.9621


  0%|          | 0/155 [00:00<?, ?it/s]

Epoch 35/35 |  val  | Loss: 0.1420 | Acc: 0.9440
LR: 1.0000000000000002e-06
CPU times: user 1h 33min 25s, sys: 2min 49s, total: 1h 36min 15s
Wall time: 1h 39min 8s


In [24]:
# Epoch 1/8 | train | Loss: 0.2698 | Acc: 0.8844
# Epoch 1/8 |  val  | Loss: 0.2208 | Acc: 0.9044
# Epoch 2/8 | train | Loss: 0.2080 | Acc: 0.9113
# Epoch 2/8 |  val  | Loss: 0.2100 | Acc: 0.9125
# Epoch 3/8 | train | Loss: 0.1920 | Acc: 0.9188
# Epoch 3/8 |  val  | Loss: 0.1863 | Acc: 0.9228
# Epoch 4/8 | train | Loss: 0.1817 | Acc: 0.9231
# Epoch 4/8 |  val  | Loss: 0.2121 | Acc: 0.9164
# Epoch 5/8 | train | Loss: 0.1733 | Acc: 0.9272
# Epoch 5/8 |  val  | Loss: 0.1744 | Acc: 0.9258
# Epoch 6/8 | train | Loss: 0.1672 | Acc: 0.9295
# Epoch 6/8 |  val  | Loss: 0.1684 | Acc: 0.9307
# Epoch 7/8 | train | Loss: 0.1598 | Acc: 0.9329
# Epoch 7/8 |  val  | Loss: 0.1611 | Acc: 0.9340
# Epoch 8/8 | train | Loss: 0.1549 | Acc: 0.9348
# Epoch 8/8 |  val  | Loss: 0.1606 | Acc: 0.9349

In [25]:
# !tensorboard --logdir runs

# Submission

In [26]:
%%writefile agent/agent.py
import os
import json
import threading
import numpy as np
import torch
from lux.game import Game

path = '/kaggle_simulations/agent' if os.path.exists('/kaggle_simulations') else '.' # change to 'agent' for tests
model = torch.jit.load(f'{path}/model.pth')
model.eval()
model_city = torch.jit.load(f'{path}/model_city.pth')
model_city.eval()

def manhattan_distance(x1, y1, x2, y2):
    return (abs(x2-x1) + abs(y2-y1))

# make list of users and their current and previous coords and cooldowns 
# to write them in NN training subset
def find_units_from_previous_obs(obs, x_shift, y_shift):
    # at first fill the unit dict with units from previous observation
    updates = obs['updates']
    updates_lag_1 = obs['updates_lag_1']
    updates_lag_2 = obs['updates_lag_2']
    updates_lag_3 = obs['updates_lag_3']
    updates_lag_4 = obs['updates_lag_4']
    updates_lag_5 = obs['updates_lag_5']
    updates_lag_6 = obs['updates_lag_6']
    updates_lag_7 = obs['updates_lag_7']
    prev_updates = [updates, updates_lag_1, updates_lag_2, updates_lag_3, 
                    updates_lag_4, updates_lag_5, updates_lag_6, updates_lag_7]
    prev_units = list()
    for prev_update in prev_updates:
        units = dict()
        if prev_update:
            for update in prev_update:
                strs = update.split(' ')
                input_identifier = strs[0]
                # if we found observation for user
                if input_identifier == 'u':
                    unit_id = strs[3]
                    x = int(strs[4]) + x_shift
                    y = int(strs[5]) + y_shift
                    cooldown = float(strs[6])
                    units[unit_id] = [x, y, cooldown]
            prev_units.append(units)
        else:
            break
    return prev_units

def find_prev_coords(prev_units, unit_id):
    # find previous coordinates of unit by ananlyzing its cooldown
    x, y = [None, None], [None, None]
    for i in range(len(prev_units)-1):
        if unit_id in prev_units[i] and unit_id in prev_units[i+1]:
            cooldown = prev_units[i][unit_id][0]
            prev_x = prev_units[i+1][unit_id][0]
            prev_y = prev_units[i+1][unit_id][1]
            prev_cooldown = prev_units[i+1][unit_id][2]
            if cooldown > 0 and prev_cooldown == 0:
                if x[1] and y[1]:
                    return x, y
                if not x[0] and not y[0]:
                    x[0], y[0] = prev_x,prev_y
                else:
                    x[1], y[1] = prev_x,prev_y
        else:
            break
    return x, y
            
# Input for Neural Network for workers
def make_input(obs, unit_id):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2

    cities = {}
    
    b = np.zeros((29, 32, 32), dtype=np.float32)
    
    prev_units = find_units_from_previous_obs(obs, x_shift, y_shift)
    x_u, y_u = prev_units[0][unit_id][0], prev_units[0][unit_id][1]
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        my_rp = 0
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if strs[3] == unit_id: # 0:2
                # Position, Cargo and Previous Position
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
                prev_x, prev_y = find_prev_coords(prev_units, unit_id)
                if not prev_x[0] and not prev_y[0]:
                    prev_x[0], prev_y[0] = x, y
                    prev_x[1], prev_y[1] = x, y
                try:
                    b[2, prev_x[0], prev_y[0]] = 1
                    b[3, prev_x[1], prev_y[1]] = 1
                except IndexError:
                    b[2, x, y] = 1
                    b[3, x, y] = 1
            else:                  # 4:11
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 4 + (team - obs['player']) % 2 * 4
                m_dist = manhattan_distance(x_u, y_u, x, y)
                b[idx:idx + 4, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100,
                    m_dist/((width-1) + (height-1))
                )
        elif input_identifier == 'ct':  # 12:17
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 12 + (team - obs['player']) % 2 * 3
            m_dist = manhattan_distance(x_u, y_u, x, y)
            b[idx:idx + 3, x, y] = (
                1,
                cities[city_id],
                m_dist/((width-1) + (height-1))
            )
        elif input_identifier == 'r':  # 18:21
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            access_level = {'wood': 0, 'coal': 50, 'uranium': 200}[r_type]
            access = 0 if my_rp < access_level else 1
            b[{'wood': 18, 'coal': 19, 'uranium': 20}[r_type], x, y] = amt / 800
            b[21, x, y] = access
            b[22, x, y] = manhattan_distance(x_u, y_u, x, y)/((width-1) + (height-1))
        elif input_identifier == 'rp':  # 23:24
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            my_rp = rp if team == obs['player'] else my_rp
            b[23 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    # Day/Night Cycle
    b[25, :] = obs['step'] % 40 / 40
    # Turns
    b[26, :] = obs['step'] / 360
    # Day/Night Flag
    b[27, :] = 1 if obs['step'] % 40 < 30 else 0
    # Map Size
    b[28, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1
        
    return b

# Input for Neural Network for cities
def make_city_input(obs, city_coord):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((21, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'ct':
            # CityTiles
            city_id = strs[2]
            x = int(strs[3]) 
            y = int(strs[4])
            cooldown = float(strs[5])
            if x == int(city_coord[0]) and y == int(city_coord[1]):
                b[:2, x + x_shift, y + y_shift] = (
                    1,
                    cities[city_id]
                )
            else:
                team = int(strs[1])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x + x_shift, y + y_shift] = (
                    1,
                    cooldown / 10,
                    cities[city_id]
                )
        elif input_identifier == 'u':
            team = int(strs[2])
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                (wood + coal + uranium) / 100
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Day/Night Flag
    b[19, :] = 1 if obs['step'] % 40 < 30 else 0
    # Map Size
    b[20, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b

game_state = None
player = None


def get_game_state(observation):
    global game_state
    
    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation["player"]
    else:
        game_state._update(observation["updates"])
    return game_state


# check if unit is in city
def in_city(pos):    
    try:
        city = game_state.map.get_cell_by_pos(pos).citytile
        return city is not None and city.team == game_state.id 
    except:
        return False

# check if city will survive the night
def city_will_survive(pos):    
    try:
        city_id = game_state.map.get_cell_by_pos(pos).citytile.cityid
    except:
        return False
    if city_id in player.cities:
        city = player.cities[city_id]
        if city.fuel > (city.get_light_upkeep()) * 10:
            return True
    return False
    
# check if unit has enough time and space to build a city
def build_city_is_possible(unit, pos):    
    global game_state
    global player

    if game_state.turn % 40 < 30:
        return True
    x, y = pos.x, pos.y
    for i, j in ((x-1, y), (x+1, y), (x, y-1), (x, y+1)):
        try:
            city_id = game_state.map.get_cell(i, j).citytile.cityid
        except:
            continue
        if city_id in player.cities:
            city = player.cities[city_id]
            if city.fuel > (city.get_light_upkeep() + 18) * 10:
                return True
    return False
    

def call_func(obj, method, args=[]):
    return getattr(obj, method)(*args)


# translate unit policy to action
unit_actions = [('move', 'n'), ('move', 's'), ('move', 'w'), ('move', 'e'), ('build_city',)]
def get_unit_action(policy, unit, dest):
    for label in np.argsort(policy)[::-1]:
        act = unit_actions[label]
        pos = unit.pos.translate(act[-1], 1) or unit.pos
        if label == 4 and not build_city_is_possible(unit, pos): # try to remove this
            return unit.move('c'), unit.pos
        if pos not in dest or in_city(pos):
            return call_func(unit, *act), pos      
    return unit.move('c'), unit.pos

# translate city policy to action
city_actions = [('build_worker',), ('research', )]
def get_city_action(policy, city_tile):
    global player
    global unit_count
    
    for label in np.argsort(policy)[::-1]:
        act = city_actions[label]
        # build unit only if their number less than number of cities
        if label == 0 and unit_count < player.city_tile_count and unit_count < 80:
            unit_count += 1
            res = call_func(city_tile, *act)
        elif label == 1 and not player.researched_uranium():
            player.research_points += 1
            res = call_func(city_tile, *act)
        else:
            res = None
        return res

# agent for making actions
def agent(observation, configuration):
    global game_state
    global player
    global actions
    global dest
    global unit_count
    global city_count
    
    game_state = get_game_state(observation)    
    player = game_state.players[observation.player]
    actions = [] 
    dest = []
    prev_obs = dict()
    
    with open(f'{path}/tmp.json') as json_file:
        try:
            prev_obs = json.load(json_file)
        except json.decoder.JSONDecodeError:
            prev_obs['updates_lag_1'] = None
            prev_obs['updates_lag_2'] = None
            prev_obs['updates_lag_3'] = None
            prev_obs['updates_lag_4'] = None
            prev_obs['updates_lag_5'] = None
            prev_obs['updates_lag_6'] = None
            prev_obs['updates_lag_7'] = None
            
    observation['updates_lag_7'] = prev_obs['updates_lag_7']
    observation['updates_lag_6'] = prev_obs['updates_lag_6']
    observation['updates_lag_5'] = prev_obs['updates_lag_5']
    observation['updates_lag_4'] = prev_obs['updates_lag_4']
    observation['updates_lag_3'] = prev_obs['updates_lag_3']
    observation['updates_lag_2'] = prev_obs['updates_lag_2']
    observation['updates_lag_1'] = prev_obs['updates_lag_1']
    
    prev_obs['updates_lag_7'] = prev_obs['updates_lag_6']
    prev_obs['updates_lag_6'] = prev_obs['updates_lag_5']
    prev_obs['updates_lag_5'] = prev_obs['updates_lag_4']
    prev_obs['updates_lag_4'] = prev_obs['updates_lag_3']
    prev_obs['updates_lag_3'] = prev_obs['updates_lag_2']
    prev_obs['updates_lag_2'] = prev_obs['updates_lag_1']
    prev_obs['updates_lag_1'] = observation['updates']
    
    if game_state.turn == 0 or game_state.turn == 359:
        open(f'{path}/tmp.json', 'w+').close()
    else:
        with open(f'{path}/tmp.json', 'w+') as json_file:
            json.dump(prev_obs, json_file)
    
    # Unit Actions
    def unit_actions(unit, player, game_state, model, observation):
        global actions
        global dest
        if unit.can_act():
            state = make_input(observation, unit.id)
            with torch.no_grad():
                p = model(torch.from_numpy(state).unsqueeze(0))

            policy = p.squeeze(0).numpy()

            action, pos = get_unit_action(policy, unit, dest)
            actions.append(action)
            dest.append(pos)
            
    for unit in player.units:
        unit_actions(unit, player, game_state, model, observation)
        
    # City Actions
    def city_actions(city_tile, game_state, model, observation):
        global actions
        global dest
        global unit_count
        if city_tile.can_act():
            # on the last step build as many workers as possible to win the game in case of tie
            if game_state.turn == 358:
                city_tile.build_worker()
            # if number of cities is too high, switch to simplify strategy to prevent too high lags
            elif player.city_tile_count > 80:
                if unit_count < 80: 
                    actions.append(city_tile.build_worker())
                    unit_count += 1
                elif not player.researched_uranium():
                    actions.append(city_tile.research())
                    player.research_points += 1
            # else follow NN strategy
            else:
                state = make_city_input(observation, [city_tile.pos.x, city_tile.pos.y])
                with torch.no_grad():
                    p = model_city(torch.from_numpy(state).unsqueeze(0))

                policy = p.squeeze(0).numpy()

                action = get_city_action(policy, city_tile)
                if action:
                    actions.append(action)    

    unit_count = len(player.units)
    for city in player.cities.values():
        for city_tile in city.citytiles:
            city_actions(city_tile, game_state, model, observation)
            
    
    return actions

Overwriting agent/agent.py


Submit predictions

In [31]:
!cd agent && tar -czf submission.tar.gz lux agent.py main.py model.pth model_city.pth tmp.json

Test agents on 12x12 field

In [36]:
# from kaggle_environments import make

# env = make("lux_ai_2021", configuration={"width": 12, "height": 12, "loglevel": 2, "annotations": True}, debug=False)

# # first agent is yellow
# # second agent is blue
# steps = env.run(['agent/agent.py', 'agent.py'])

# env.render(mode="ipython", width=1200, height=800)

Test agent on 16x16 field

In [38]:
# env = make("lux_ai_2021", configuration={"width": 16, "height": 16, "loglevel": 2, "annotations": True}, debug=False)

# # first agent is yellow
# # second agent is blue
# steps = env.run(['agent/agent.py', 'agent.py'])

# env.render(mode="ipython", width=1200, height=800)

Test agent on 24x24 field

In [42]:
# env = make("lux_ai_2021", configuration={"width": 24, "height": 24, "loglevel": 2, "annotations": True}, debug=False)

# # first agent is yellow
# # second agent is blue
# steps = env.run(['agent/agent.py', 'agent.py'])

# env.render(mode="ipython", width=1200, height=800)

Test agents on 32x32 field

In [43]:
# env = make("lux_ai_2021", configuration={"width": 32, "height": 32, "loglevel": 2, "annotations": True}, debug=False)

# # first agent is yellow
# # second agent is blue
# steps = env.run(['agent/agent.py', 'agent.py'])

# env.render(mode="ipython", width=1200, height=800)

### Optimize NN parameters with Optuna

In [32]:
# def objective(trial):

#     num_epochs = 10
    
#     # model for unit actions
#     model = LuxNet()
#     train, val = train_test_split(samples, test_size=0.1, random_state=42, stratify=labels)
#     batch_size = 64

#     train_loader = DataLoader(
#         LuxDataset(obses, train), 
#         batch_size=batch_size, 
#         shuffle=True, 
#         num_workers=2
#     )
#     val_loader = DataLoader(
#         LuxDataset(obses, val), 
#         batch_size=batch_size, 
#         shuffle=False, 
#         num_workers=2
#     )
#     dataloaders_dict = {"train": train_loader, "val": val_loader}

#     # Generate the optimizers.
#     criterion = nn.CrossEntropyLoss()
#     optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "AdamW", "RMSprop", "SGD"])
#     lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
#     optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)

#     for epoch in range(num_epochs):
#         model.cuda()
        
#         for phase in ['train', 'val']:
#             if phase == 'train':
#                 model.train()
#             else:
#                 model.eval()
                
#             epoch_loss = 0.0
#             epoch_acc = 0
            
#             dataloader = dataloaders_dict[phase]
#             for item in dataloader:
#                 states = item[0].cuda().float()
#                 actions = item[1].cuda().long()

#                 optimizer.zero_grad()
                
#                 with torch.set_grad_enabled(phase == 'train'):
#                     policy = model(states)
#                     loss = criterion(policy, actions)
#                     _, preds = torch.max(policy, 1)

#                     if phase == 'train':
#                         loss.backward()
#                         optimizer.step()

#                     epoch_loss += loss.item() * len(policy)
#                     epoch_acc += torch.sum(preds == actions.data)

#             data_size = len(dataloader.dataset)
#             epoch_loss = epoch_loss / data_size
#             epoch_acc = epoch_acc.double() / data_size

#         trial.report(epoch_acc, epoch)

#         # Handle pruning based on the intermediate value.
#         if trial.should_prune():
#             raise optuna.exceptions.TrialPruned()

#     return epoch_acc


# study = optuna.create_study(direction="maximize")
# study.optimize(objective, n_trials=500, timeout=10*3600)

# pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
# complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

# print("Study statistics: ")
# print("  Number of finished trials: ", len(study.trials))
# print("  Number of pruned trials: ", len(pruned_trials))
# print("  Number of complete trials: ", len(complete_trials))

# print("Best trial:")
# trial = study.best_trial

# print("  Value: ", trial.value)

# print("  Params: ")
# for key, value in trial.params.items():
#     print("    {}: {}".format(key, value))

# NNs ensemble

In [33]:
import os
import numpy as np
import torch
from lux.game import Game

path = '/kaggle_simulations/agent' if os.path.exists('/kaggle_simulations') else '.' # change to 'agent' for tests
# unit NNs
model_v2 = torch.jit.load(f'{path}/model_v2.pth')
model_v2.eval()
model_v4 = torch.jit.load(f'{path}/model_v4.pth')
model_v4.eval()
model_v5 = torch.jit.load(f'{path}/model_v5.pth')
model_v5.eval()
model_v11 = torch.jit.load(f'{path}/model_v11.pth')
model_v11.eval()
# city NNs
model_city_v2 = torch.jit.load(f'{path}/model_city_v2.pth')
model_city_v2.eval()
model_city_v4 = torch.jit.load(f'{path}/model_city_v4.pth')
model_city_v4.eval()
model_city_v5 = torch.jit.load(f'{path}/model_city_v5.pth')
model_city_v5.eval()
model_city_v11 = torch.jit.load(f'{path}/model_city_v11.pth')
model_city_v11.eval()

# Input for Neural Network for units
def make_input(obs, unit_id):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if unit_id == strs[3]:
                # Position and Cargo
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
            else:
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100
                )
        elif input_identifier == 'ct':
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Day/Night Flag
    b[19, :] = 1 if obs['step'] % 40 < 30 else 0
    # Map Size
    b[20, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1
    return b


# Input for Neural Network for cities
def make_city_input(obs, city_coord):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'ct':
            # CityTiles
            city_id = strs[2]
            x = int(strs[3]) 
            y = int(strs[4])
            cooldown = float(strs[5])
            if x == int(city_coord[0]) and y == int(city_coord[1]):
                b[:2, x + x_shift, y + y_shift] = (
                    1,
                    cities[city_id]
                )
            else:
                team = int(strs[1])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x + x_shift, y + y_shift] = (
                    1,
                    cooldown / 10,
                    cities[city_id]
                )
        elif input_identifier == 'u':
            team = int(strs[2])
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                (wood + coal + uranium) / 100
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Map Size
    b[19, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b

game_state = None
player = None


def get_game_state(observation):
    global game_state
    
    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation["player"]
    else:
        game_state._update(observation["updates"])
    return game_state


def in_city(pos):    
    try:
        city = game_state.map.get_cell_by_pos(pos).citytile
        return city is not None and city.team == game_state.id
    except:
        return False
    
# check if unit has enough time and space to build a city
def build_city_is_possible(unit, pos):    
    global game_state
    global player

    if game_state.turn % 40 < 30:
        return True
    x, y = pos.x, pos.y
    for i, j in ((x-1, y), (x+1, y), (x, y-1), (x, y+1)):
        try:
            city_id = game_state.map.get_cell(i, j).citytile.cityid
        except:
            continue
        if city_id in player.cities:
            city = player.cities[city_id]
            if city.fuel > (city.get_light_upkeep() + 18) * 10:
                return True
    return False


def call_func(obj, method, args=[]):
    return getattr(obj, method)(*args)


# translate unit policy to action
unit_actions = [('move', 'n'), ('move', 's'), ('move', 'w'), ('move', 'e'), ('build_city',)]
def get_unit_action(policy, unit, dest):
    for label in np.argsort(policy)[::-1]:
        act = unit_actions[label]
        pos = unit.pos.translate(act[-1], 1) or unit.pos
        if label == 4 and not build_city_is_possible(unit, pos):
            return unit.move('c'), unit.pos
        if pos not in dest or in_city(pos):
            return call_func(unit, *act), pos 
            
    return unit.move('c'), unit.pos

# translate city policy to action
city_actions = [('build_worker',), ('research', )]
def get_city_action(policy, city_tile, unit_count):
    global player
    
    for label in np.argsort(policy)[::-1]:
        act = city_actions[label]
        if label == 0 and unit_count < player.city_tile_count:
            unit_count += 1
            res = call_func(city_tile, *act)
        elif label == 1 and not player.researched_uranium():
            player.research_points += 1
            res = call_func(city_tile, *act)
        else:
            res = None
        return res, unit_count

# agent for making actions
def agent(observation, configuration):
    global game_state
    global player
    
    game_state = get_game_state(observation)    
    player = game_state.players[observation.player]
    actions = []        

    # Unit Actions
    dest = []
    for unit in player.units:
        if unit.can_act() and (game_state.turn % 40 < 30 or not in_city(unit.pos)):
            state = make_input(observation, unit.id)
            with torch.no_grad():
                p_2 = model_v2(torch.from_numpy(state).unsqueeze(0))
                p_4 = model_v4(torch.from_numpy(state).unsqueeze(0))
                #p_5 = model_v5(torch.from_numpy(state).unsqueeze(0))
                p_11 = model_v11(torch.from_numpy(state).unsqueeze(0))

            policy_2 = p_2.squeeze(0).numpy()
            policy_4 = p_4.squeeze(0).numpy()
            #policy_5 = p_5.squeeze(0).numpy()
            policy_11 = p_11.squeeze(0).numpy()

            policy = [sum(x) for x in zip(*[policy_2, policy_4, policy_11])]

            action, pos = get_unit_action(policy, unit, dest)
            actions.append(action)
            dest.append(pos)

    # City Actions
    unit_count = len(player.units)
    for city in player.cities.values():
        for city_tile in city.citytiles:
            if city_tile.can_act():
                # at first game stages try to produce maximum amount of agents and research point
                if game_state.turn < 60:
                    if unit_count < player.city_tile_count: 
                        actions.append(city_tile.build_worker())
                        unit_count += 1
                    elif not player.researched_uranium():
                        actions.append(city_tile.research())
                        player.research_points += 1
                # then follow NN strategy
                else:
                    state = make_city_input(observation, [city_tile.pos.x, city_tile.pos.y])
                    with torch.no_grad():
                        p_2 = model_city_v2(torch.from_numpy(state).unsqueeze(0))
                        p_4 = model_city_v4(torch.from_numpy(state).unsqueeze(0))
                        #p_5 = model_city_v5(torch.from_numpy(state).unsqueeze(0))
                        p_11 = model_city_v11(torch.from_numpy(state).unsqueeze(0))

                    policy_2 = p_2.squeeze(0).numpy()
                    policy_4 = p_4.squeeze(0).numpy()
                    #policy_5 = p_5.squeeze(0).numpy()
                    policy_11 = p_11.squeeze(0).numpy()

                    policy = [sum(x) for x in zip(*[policy_2, policy_4, policy_11])]

                    action, unit_count = get_city_action(policy, city_tile, unit_count)
                    if action:
                        actions.append(action)
    
    return actions

ValueError: The provided filename ./model_v2.pth does not exist

# Futher Ideas

Hyperparameters
- optimize hyperparameters (number of features and layers, regularization, dropout probability, learning rate) with Optuna

In [None]:
# dict for apropriate change of moving action during rotation
# 1 - rotation on 90, 2 - rotation on 180, 3 - rotation on 270
rot_dict = {0: {0:0, 1:1, 2:2, 3:3, 4:4}, # N -> E, S -> W, W -> N, E -> S
            1: {0:3, 1:2, 2:0, 3:1, 4:4}, # N -> E, S -> W, W -> N, E -> S
            2: {0:1, 1:0, 2:3, 3:2, 4:4}, # N -> S, S -> N, W -> E, E -> W
            3: {0:2, 1:3, 2:1, 3:0, 4:4}  # N -> W, S -> E, W -> S, E -> N
           }