In [1]:
!node --version
!pip install kaggle-environments -U
!cp -r ../input/lux-ai-2021/* .

v12.22.6


In [2]:
import numpy as np
import json
from pathlib import Path
import os
import random
from tqdm.notebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.model_selection import train_test_split

In [3]:
def seed_everything(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

seed = 42
seed_everything(seed)

# Preprocessing

In [4]:
def to_label(action):
    strs = action.split(' ')
    unit_id = strs[1]
    if strs[0] == 'm':
        label = {'c': None, 'n': 0, 's': 1, 'w': 2, 'e': 3}[strs[2]]
    elif strs[0] == 'bcity':
        label = 4
    else:
        label = None
    return unit_id, label


def depleted_resources(obs):
    for u in obs['updates']:
        if u.split(' ')[0] == 'r':
            return False
    return True


def create_dataset_from_json(episode_dir, team_name='Toad Brigade'): 
    obses = {}
    samples = []
    append = samples.append
    
    episodes = [path for path in Path(episode_dir).glob('*.json') if 'output' not in path.name]
    for filepath in tqdm(episodes): 
        with open(filepath) as f:
            json_load = json.load(f)

        ep_id = json_load['info']['EpisodeId']
        index = np.argmax([r or 0 for r in json_load['rewards']])
        if json_load['info']['TeamNames'][index] != team_name:
            continue

        for i in range(len(json_load['steps'])-1):
            if json_load['steps'][i][index]['status'] == 'ACTIVE':
                actions = json_load['steps'][i+1][index]['action']
                obs = json_load['steps'][i][0]['observation']
                
                if depleted_resources(obs):
                    break
                
                obs['player'] = index
                obs = dict([
                    (k,v) for k,v in obs.items() 
                    if k in ['step', 'updates', 'player', 'width', 'height']
                ])
                obs_id = f'{ep_id}_{i}'
                obses[obs_id] = obs
                                
                for action in actions:
                    unit_id, label = to_label(action)
                    if label is not None:
                        append((obs_id, unit_id, label))

    return obses, samples

In [5]:
episode_dir = '../input/lux-ai-episodes'
obses, samples = create_dataset_from_json(episode_dir)
print('obses:', len(obses), 'samples:', len(samples))

  0%|          | 0/125 [00:00<?, ?it/s]

obses: 32575 samples: 109319


In [6]:
labels = [sample[-1] for sample in samples]
actions = ['north', 'south', 'west', 'east', 'bcity']
for value, count in zip(*np.unique(labels, return_counts=True)):
    print(f'{actions[value]:^5}: {count:>3}')

north: 25714
south: 22410
west : 25145
east : 25029
bcity: 11021


# Training

In [7]:
# Input for Neural Network
def make_input(obs, unit_id):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if unit_id == strs[3]:
                # Position and Cargo
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
            else:
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100
                )
        elif input_identifier == 'ct':
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Map Size
    b[19, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b



class LuxDataset(Dataset):
    def __init__(self, obses, samples):
        self.obses = obses
        self.samples = samples
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        obs_id, unit_id, action = self.samples[idx]
        obs = self.obses[obs_id]
        state = make_input(obs, unit_id)
        
        return state, action

In [8]:
# Neural Network for Lux AI
import torch.nn as nn
import torch


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,
                 blocks_num,
                 num_classes=5,
                 include_top=True,
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.groups = groups
        self.width_per_group = width_per_group

        self.conv = nn.Conv2d(in_channels=20, out_channels=147,
                               kernel_size=3, stride=1, padding=1, bias=False)
        
        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), 3, 224, 224)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x


def resnet34(num_classes=5, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet50(num_classes=5, include_top=True):
    # https://download.pytorch.org/models/resnet50-19c8e357.pth
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet101(num_classes=5, include_top=True):
    # https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)


def resnext50_32x4d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
    groups = 32
    width_per_group = 4
    return ResNet(Bottleneck, [3, 4, 6, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)


def resnext101_32x8d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth
    groups = 32
    width_per_group = 8
    return ResNet(Bottleneck, [3, 4, 23, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)

In [9]:
def train_model(model, dataloaders_dict, criterion, optimizer, num_epochs):
    best_acc = 0.0

    for epoch in range(num_epochs):
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        model.to(device)
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            epoch_loss = 0.0
            epoch_acc = 0
            
            dataloader = dataloaders_dict[phase]
            for item in tqdm(dataloader, leave=False):
                states = item[0].cuda().float()
                actions = item[1].cuda().long()

                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    policy = model(states)
                    loss = criterion(policy, actions)
                    _, preds = torch.max(policy, 1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    epoch_loss += loss.item() * len(policy)
                    epoch_acc += torch.sum(preds == actions.data)

            data_size = len(dataloader.dataset)
            epoch_loss = epoch_loss / data_size
            epoch_acc = epoch_acc.double() / data_size

            print(f'Epoch {epoch + 1}/{num_epochs} | {phase:^5} | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f}')
        
        if epoch_acc > best_acc:
            traced = torch.jit.trace(model.cpu(), torch.rand(1, 20, 32, 32))
            traced.save('model.pth')
            best_acc = epoch_acc

In [10]:
model = resnet50()
train, val = train_test_split(samples, test_size=0.1, random_state=42, stratify=labels)
batch_size = 64
train_loader = DataLoader(
    LuxDataset(obses, train), 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=2
)
val_loader = DataLoader(
    LuxDataset(obses, val), 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=2
)
dataloaders_dict = {"train": train_loader, "val": val_loader}
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [11]:
train_model(model, dataloaders_dict, criterion, optimizer, num_epochs=40)

  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 1/40 | train | Loss: 1.5646 | Acc: 0.2464


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 1/40 |  val  | Loss: 1.5385 | Acc: 0.2605


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 2/40 | train | Loss: 1.5369 | Acc: 0.2700


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 2/40 |  val  | Loss: 1.5219 | Acc: 0.2818


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 3/40 | train | Loss: 1.5233 | Acc: 0.2866


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 3/40 |  val  | Loss: 1.5041 | Acc: 0.3031


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 4/40 | train | Loss: 1.4518 | Acc: 0.3222


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 4/40 |  val  | Loss: 1.3852 | Acc: 0.3577


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 5/40 | train | Loss: 1.3431 | Acc: 0.3841


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 5/40 |  val  | Loss: 1.2995 | Acc: 0.4160


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 6/40 | train | Loss: 1.2528 | Acc: 0.4458


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 6/40 |  val  | Loss: 1.2007 | Acc: 0.4788


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 7/40 | train | Loss: 1.1200 | Acc: 0.5258


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 7/40 |  val  | Loss: 1.0662 | Acc: 0.5553


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 8/40 | train | Loss: 1.0085 | Acc: 0.5831


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 8/40 |  val  | Loss: 0.9951 | Acc: 0.5872


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 9/40 | train | Loss: 0.9132 | Acc: 0.6296


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 9/40 |  val  | Loss: 0.9204 | Acc: 0.6271


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 10/40 | train | Loss: 0.8467 | Acc: 0.6607


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 10/40 |  val  | Loss: 0.8911 | Acc: 0.6403


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 11/40 | train | Loss: 0.7947 | Acc: 0.6828


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 11/40 |  val  | Loss: 0.8693 | Acc: 0.6544


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 12/40 | train | Loss: 0.7495 | Acc: 0.7041


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 12/40 |  val  | Loss: 0.8739 | Acc: 0.6571


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 13/40 | train | Loss: 0.7076 | Acc: 0.7243


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 13/40 |  val  | Loss: 0.8431 | Acc: 0.6718


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 14/40 | train | Loss: 0.6730 | Acc: 0.7369


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 14/40 |  val  | Loss: 0.8155 | Acc: 0.6812


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 15/40 | train | Loss: 0.6377 | Acc: 0.7526


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 15/40 |  val  | Loss: 0.8225 | Acc: 0.6804


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 16/40 | train | Loss: 0.6043 | Acc: 0.7672


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 16/40 |  val  | Loss: 0.8264 | Acc: 0.6928


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 17/40 | train | Loss: 0.5751 | Acc: 0.7777


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 17/40 |  val  | Loss: 0.8577 | Acc: 0.6803


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 18/40 | train | Loss: 0.5455 | Acc: 0.7897


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 18/40 |  val  | Loss: 0.8326 | Acc: 0.6951


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 19/40 | train | Loss: 0.5126 | Acc: 0.8021


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 19/40 |  val  | Loss: 0.8673 | Acc: 0.6930


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 20/40 | train | Loss: 0.4866 | Acc: 0.8124


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 20/40 |  val  | Loss: 0.8938 | Acc: 0.6871


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 21/40 | train | Loss: 0.4581 | Acc: 0.8256


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 21/40 |  val  | Loss: 0.8778 | Acc: 0.6969


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 22/40 | train | Loss: 0.4313 | Acc: 0.8341


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 22/40 |  val  | Loss: 0.8884 | Acc: 0.6959


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 23/40 | train | Loss: 0.4100 | Acc: 0.8447


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 23/40 |  val  | Loss: 0.9388 | Acc: 0.6926


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 24/40 | train | Loss: 0.3859 | Acc: 0.8536


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 24/40 |  val  | Loss: 0.9570 | Acc: 0.6861


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 25/40 | train | Loss: 0.3642 | Acc: 0.8618


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 25/40 |  val  | Loss: 0.9311 | Acc: 0.6868


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 26/40 | train | Loss: 0.3461 | Acc: 0.8686


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 26/40 |  val  | Loss: 0.9642 | Acc: 0.6937


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 27/40 | train | Loss: 0.3261 | Acc: 0.8755


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 27/40 |  val  | Loss: 0.9697 | Acc: 0.7013


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 28/40 | train | Loss: 0.3080 | Acc: 0.8844


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 28/40 |  val  | Loss: 1.0349 | Acc: 0.6917


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 29/40 | train | Loss: 0.2939 | Acc: 0.8888


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 29/40 |  val  | Loss: 1.0566 | Acc: 0.6797


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 30/40 | train | Loss: 0.2837 | Acc: 0.8932


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 30/40 |  val  | Loss: 1.0427 | Acc: 0.6894


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 31/40 | train | Loss: 0.2660 | Acc: 0.9003


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 31/40 |  val  | Loss: 1.1272 | Acc: 0.6840


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 32/40 | train | Loss: 0.2608 | Acc: 0.9021


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 32/40 |  val  | Loss: 1.0727 | Acc: 0.6922


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 33/40 | train | Loss: 0.2480 | Acc: 0.9078


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 33/40 |  val  | Loss: 1.1200 | Acc: 0.6913


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 34/40 | train | Loss: 0.2398 | Acc: 0.9106


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 34/40 |  val  | Loss: 1.1039 | Acc: 0.6964


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 35/40 | train | Loss: 0.2345 | Acc: 0.9128


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 35/40 |  val  | Loss: 1.1339 | Acc: 0.6926


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 36/40 | train | Loss: 0.2263 | Acc: 0.9161


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 36/40 |  val  | Loss: 1.1579 | Acc: 0.6917


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 37/40 | train | Loss: 0.2192 | Acc: 0.9183


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 37/40 |  val  | Loss: 1.1768 | Acc: 0.6981


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 38/40 | train | Loss: 0.2117 | Acc: 0.9214


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 38/40 |  val  | Loss: 1.1502 | Acc: 0.6919


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 39/40 | train | Loss: 0.2042 | Acc: 0.9240


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 39/40 |  val  | Loss: 1.2716 | Acc: 0.6894


  0%|          | 0/1538 [00:00<?, ?it/s]

Epoch 40/40 | train | Loss: 0.1984 | Acc: 0.9256


  0%|          | 0/171 [00:00<?, ?it/s]

Epoch 40/40 |  val  | Loss: 1.2062 | Acc: 0.6993


# Submission

In [12]:
%%writefile agent.py
import os
import numpy as np
import torch
from lux.game import Game


path = '/kaggle_simulations/agent' if os.path.exists('/kaggle_simulations') else '.'
model = torch.jit.load(f'{path}/model.pth')
model.eval()


def make_input(obs, unit_id):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if unit_id == strs[3]:
                # Position and Cargo
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
            else:
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100
                )
        elif input_identifier == 'ct':
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Map Size
    b[19, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1
    
    #conv = nn.Conv2d(in_channels=20, out_channels=147,kernel_size=3, stride=1, padding=1, bias=False)
    #tmp = conv(torch.from_numpy(b).unsqueeze(dim = 0))
    #tmp = tmp.squeeze(0).view(3,224,224)
    #tmp = tmp.detach().numpy()

    return b


game_state = None
def get_game_state(observation):
    global game_state
    
    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation["player"]
    else:
        game_state._update(observation["updates"])
    return game_state


def in_city(pos):    
    try:
        city = game_state.map.get_cell_by_pos(pos).citytile
        return city is not None and city.team == game_state.id
    except:
        return False


def call_func(obj, method, args=[]):
    return getattr(obj, method)(*args)


unit_actions = [('move', 'n'), ('move', 's'), ('move', 'w'), ('move', 'e'), ('build_city',)]
def get_action(policy, unit, dest):
    for label in np.argsort(policy)[::-1]:
        act = unit_actions[label]
        pos = unit.pos.translate(act[-1], 1) or unit.pos
        if pos not in dest or in_city(pos):
            return call_func(unit, *act), pos 
            
    return unit.move('c'), unit.pos


def agent(observation, configuration):
    global game_state
    
    game_state = get_game_state(observation)    
    player = game_state.players[observation.player]
    actions = []
    
    # City Actions
    unit_count = len(player.units)
    for city in player.cities.values():
        for city_tile in city.citytiles:
            if city_tile.can_act():
                if unit_count < player.city_tile_count: 
                    actions.append(city_tile.build_worker())
                    unit_count += 1
                elif not player.researched_uranium():
                    actions.append(city_tile.research())
                    player.research_points += 1
    
    # Worker Actions
    dest = []
    for unit in player.units:
        if unit.can_act() and (game_state.turn % 40 < 30 or not in_city(unit.pos)):
            state = make_input(observation, unit.id)
            with torch.no_grad():
                p = model(torch.from_numpy(state).unsqueeze(0))

            policy = p.squeeze(0).numpy()

            action, pos = get_action(policy, unit, dest)
            actions.append(action)
            dest.append(pos)

    return actions

Overwriting agent.py


In [13]:
from kaggle_environments import make

env = make("lux_ai_2021", configuration={"width": 24, "height": 24, "loglevel": 2, "annotations": True}, debug=True)
steps = env.run(['agent.py', 'agent.py'])
env.render(mode="ipython", width=1200, height=800)

Loading environment football failed: No module named 'gfootball'
[33m[WARN][39m (match_jyOXU8Av4Xus) - Agent 1 tried to build CityTile with insufficient materials wood + coal + uranium: 80; turn 11; cmd: bcity u_2
[33m[WARN][39m (match_jyOXU8Av4Xus) - turn 22; Unit u_11 collided when trying to move w to (5, 3)
[33m[WARN][39m (match_jyOXU8Av4Xus) - Agent 1 tried to build CityTile with insufficient materials wood + coal + uranium: 80; turn 46; cmd: bcity u_6
[33m[WARN][39m (match_jyOXU8Av4Xus) - Agent 1 tried to build CityTile with insufficient materials wood + coal + uranium: 80; turn 53; cmd: bcity u_2
[33m[WARN][39m (match_jyOXU8Av4Xus) - Agent 0 tried to build CityTile with insufficient materials wood + coal + uranium: 80; turn 57; cmd: bcity u_16
[33m[WARN][39m (match_jyOXU8Av4Xus) - Agent 1 tried to build CityTile with insufficient materials wood + coal + uranium: 80; turn 63; cmd: bcity u_19
[33m[WARN][39m (match_jyOXU8Av4Xus) - turn 64; Unit u_24 collided when tryin

In [14]:
!tar -czf submission.tar.gz *