In [1]:
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import json
from tqdm.notebook import tqdm
import sys
import logging
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)

def episode_loader(episode_dir): 
    action_dict = {'n': 0, 's': 1, 'w': 2, 'e': 3}
    observations = {}
    samples = []
    jsons = os.listdir(episode_dir)
    for filename in jsons: 
        path = os.path.join(episode_dir,filename)
        with open(path) as f:
            temp = json.load(f)
        index = np.argmax(temp['rewards'])
        if temp['info']['TeamNames'][index] != 'Toad Brigade' and temp['info']['TeamNames'][index] !='RL is all you need' and temp['info']['TeamNames'][index] != 'skyramp': #They are the top score teams.
            continue
        i=0
        while True:
            if i >= len(temp['steps'])-1:
                break
            observation = temp['steps'][i][0]['observation']
            flag = 0
            for update in observation['updates']:
                if update.split(' ')[0] == 'r':
                    flag = 1
                    break
            if temp['steps'][i][index]['status'] == 'ACTIVE' and flag:
                observation['player'] = index
                filter_observation = {}
                for key,value in observation.items():
                    if key not in ['width', 'height','player','updates','step']:
                        continue
                    else:
                        filter_observation[key] = value
                observation_id = str(temp['info']['EpisodeId'])+'_'+str(i)
                observations[observation_id] = filter_observation
                actions = temp['steps'][i+1][index]['action']     
                for action in actions:
                    strs = action.split(' ')
                    if strs[0] == 'm' :
                        act = strs[2]
                        if act != 'c':
                            samples.append((observation_id, strs[1], action_dict[act]))
                    elif strs[0] == 'bcity':
                        samples.append((observation_id, strs[1], 4))
            i+=1
    return observations, samples


In [2]:
def get_state(observation, unit_id):
    delta_pos = [(32 - observation['width']) // 2,(32 - observation['height']) // 2]
    cities = np.zeros(1024,dtype = np.float32)
    state = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in observation['updates']:
        info = update.split(' ')
        if info[0] == 'u':
            pos = [int(info[4]) + delta_pos[0],int(info[5]) + delta_pos[1]]
            total_resource = (int(info[7]) + int(info[8]) + int(info[9])) / 100
            if unit_id == info[3]:
                # present_unit->0,1
                state[:2, pos[0], pos[1]] = (1,total_resource)
            else:
                if (int(info[2])- observation['player']) %2 == 0: #our team
                    state[2:5, pos[0], pos[1]] = (1,float(info[6]) / 6,total_resource)
                else:#enemy team
                    state[5:8, pos[0], pos[1]] = (1,float(info[6]) / 6,total_resource)
        elif info[0] == 'ct':
            pos = [int(info[3]) + delta_pos[0],int(info[4]) + delta_pos[1]]
            if (int(info[1])- observation['player']) %2 == 0: #our team
                state[8:10, pos[0], pos[1]] = (1,cities[int(info[2].split('_')[-1])])
            else:#enemy team
                state[10:12, pos[0], pos[1]] = (1,cities[int(info[2].split('_')[-1])])
        elif info[0] == 'r':
            pos = [int(info[2]) + delta_pos[0],int(info[3]) + delta_pos[1]]
            if info[1] == 'wood':
                state[12, pos[0], pos[1]] = int(float(info[4])) / 800
            elif info[1] == 'coal':
                state[13, pos[0], pos[1]] = int(float(info[4])) / 800
            else:
                state[14, pos[0], pos[1]] = int(float(info[4])) / 800
        elif info[0] == 'rp':
            if (int(info[1])- observation['player']) %2 == 0: #our team
                state[15,:,:] = min(int(info[2]), 200) / 200
            else:#enemy team
                state[16,:,:] = min(int(info[2]), 200) / 200
        elif info[0] == 'c':
            cities[int(info[2].split('_')[-1])] = min(float(info[3]) / float(info[4]), 10) / 10
    
    state[17, :,:] = observation['step'] % 40 / 40 #time 
    state[18, :,:] = observation['step'] / 360 #->turn
    state[19, delta_pos[0]:32 - delta_pos[0], delta_pos[1]:32 - delta_pos[1]] = 1 #other unidentifier places

    return state


In [3]:
class LuxDataset(Dataset):
    def __init__(self, observations, samples):
        self.observations = observations
        self.samples = samples
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return get_state(self.observations[self.samples[idx][0]], self.samples[idx][1]), self.samples[idx][2]

class LuxNN(nn.Module):
    def __init__(self):
        super(LuxNN, self).__init__()
        self.conv = ConvLayer(in_dim=20, out_dim=36, kernel_size=3)
        self.bn = nn.BatchNorm2d(36)

        self.convs = nn.ModuleList(
            [
                ConvLayer(in_dim=36, out_dim=36, kernel_size=3) for i in range(36)
            ]
        )
        self.bns = nn.ModuleList(
            [
                nn.BatchNorm2d(36) for i in range(36)
            ]
        )

        self.fc = nn.Linear(36, 5)

        self.relu = nn.ReLU()

    def forward(self, x):
        x_r = x

        x = self.relu(
            self.bn(
                self.conv(x)
            )
        )

        for conv, bn in zip(self.convs, self.bns):
            # Residual block
            s = x + bn(conv(x))
            x = self.relu(s)

        return self.fc(
            (
                x * x_r[:,:1]).view(x.size(0),
                x.size(1),
                -1
            ).sum(-1)
        )

# convolution layer
class ConvLayer(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size, stride=1, dilation=1):
        super(ConvLayer, self).__init__()
        reflect_padding = int(dilation * (kernel_size - 1) / 2)
        self.reflection_pad = nn.ReflectionPad2d(reflect_padding)
        self.conv2d = nn.Conv2d(in_dim, out_dim, kernel_size, stride, dilation=dilation)

    def forward(self, x):
        return self.conv2d(
            self.reflection_pad(x)
        )


In [4]:
# training config

model = LuxNN()

observations, samples = episode_loader('./dataset')
labels = [sample[-1] for sample in samples]

from sklearn.model_selection import train_test_split

train_set, val_set = train_test_split(samples, test_size=0.1, random_state=4, stratify=labels)

train_loader = DataLoader(
    LuxDataset(observations, train_set), 
    batch_size=64, 
    shuffle=True, 
    num_workers=0
)

val_loader = DataLoader(
    LuxDataset(observations, val_set), 
    batch_size=64, 
    shuffle=False, 
    num_workers=0
)

model_save_path = './best_model.pth'
max_epoch = 21
lr = 1e-3
weight_decay = 1e-6
val_freq = 2

In [7]:

def train_model(
    model,
    train_loader,
    val_loader,
    model_save_path,
    max_epoch,
    lr,
    weight_decay,
    val_freq = 10
):
    best_score = -1
    model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_f = nn.CrossEntropyLoss()

    if os.path.exists(model_save_path):
        model.load_state_dict(torch.load(model_save_path))
        
        

    tbar = tqdm(total=max_epoch*len(train_loader))
    for epoch in range(max_epoch):
        model.train()
        epoch_loss = 0.0
        step = 0

        for x, y in train_loader:
            tbar.update(1)
            step += 1
            x, y = x.cuda().float(), y.cuda()
            optimizer.zero_grad()
            y_pred = model(x)

            loss = loss_f(y_pred, y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        epoch_loss /= step
        logging.info(f'epoch {epoch + 1}/{max_epoch}, average loss: {epoch_loss:.4f}')

        if epoch % val_freq == (val_freq-1):
            logging.info('validation processing ...')
            val_score = 0.0
            for x, y in val_loader:
                x, y = x.cuda().float(), y.cuda()
                _, action_pred = torch.max(model(x), 1)
                val_score += torch.sum(action_pred == y.data)
            if val_score > best_score:
                best_score = val_score
                torch.save(model.state_dict(), model_save_path)
                logging.info(f'new best score {best_score:.4f}, model saved at epoch {epoch}!')
    tbar.close()

In [8]:
# start training
train_model(
    model,
    train_loader,
    val_loader,
    model_save_path,
    max_epoch,
    lr,
    weight_decay,
    val_freq
)

  0%|          | 0/32298 [00:00<?, ?it/s]

INFO:root:epoch 1/21, average loss: 0.2603
INFO:root:epoch 2/21, average loss: 0.2478
INFO:root:validation processing ...
INFO:root:new best score 8593.0000, model saved at epoch 1!
INFO:root:epoch 3/21, average loss: 0.2391
INFO:root:epoch 4/21, average loss: 0.2283
INFO:root:validation processing ...
INFO:root:epoch 5/21, average loss: 0.2220
INFO:root:epoch 6/21, average loss: 0.2125
INFO:root:validation processing ...
INFO:root:epoch 7/21, average loss: 0.2046
INFO:root:epoch 8/21, average loss: 0.1963
INFO:root:validation processing ...
INFO:root:epoch 9/21, average loss: 0.1914
INFO:root:epoch 10/21, average loss: 0.1857
INFO:root:validation processing ...
INFO:root:epoch 11/21, average loss: 0.1783
INFO:root:epoch 12/21, average loss: 0.1745
INFO:root:validation processing ...
INFO:root:epoch 13/21, average loss: 0.1669
INFO:root:epoch 14/21, average loss: 0.1616
INFO:root:validation processing ...
INFO:root:epoch 15/21, average loss: 0.1573
INFO:root:epoch 16/21, average loss: 

In [9]:
%%writefile agent.py
import os
import copy
import numpy as np
import torch
from lux.game import Game

model.load_state_dict(torch.load(model_save_path))
model.eval()


def get_state(observation, unit_id):
    delta_pos = [(32 - observation['width']) // 2,(32 - observation['height']) // 2]
    cities = np.zeros(1024,dtype = np.float32)
    state = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in observation['updates']:
        info = update.split(' ')
        if info[0] == 'u':
            pos = [int(info[4]) + delta_pos[0],int(info[5]) + delta_pos[1]]
            total_resource = (int(info[7]) + int(info[8]) + int(info[9])) / 100
            if unit_id == info[3]:
                # present_unit->0,1
                state[:2, pos[0], pos[1]] = (1,total_resource)
            else:
                if (int(info[2])- observation['player']) %2 == 0: #our team
                    state[2:5, pos[0], pos[1]] = (1,float(info[6]) / 6,total_resource)
                else:#enemy team
                    state[5:8, pos[0], pos[1]] = (1,float(info[6]) / 6,total_resource)
        elif info[0] == 'ct':
            pos = [int(info[3]) + delta_pos[0],int(info[4]) + delta_pos[1]]
            if (int(info[1])- observation['player']) %2 == 0: #our team
                state[8:10, pos[0], pos[1]] = (1,cities[int(info[2].split('_')[-1])])
            else:#enemy team
                state[10:12, pos[0], pos[1]] = (1,cities[int(info[2].split('_')[-1])])
        elif info[0] == 'r':
            pos = [int(info[2]) + delta_pos[0],int(info[3]) + delta_pos[1]]
            if info[1] == 'wood':
                state[12, pos[0], pos[1]] = int(float(info[4])) / 800
            elif info[1] == 'coal':
                state[13, pos[0], pos[1]] = int(float(info[4])) / 800
            else:
                state[14, pos[0], pos[1]] = int(float(info[4])) / 800
        elif info[0] == 'rp':
            if (int(info[1])- observation['player']) %2 == 0: #our team
                state[15,:,:] = min(int(info[2]), 200) / 200
            else:#enemy team
                state[16,:,:] = min(int(info[2]), 200) / 200
        elif info[0] == 'c':
            cities[int(info[2].split('_')[-1])] = min(float(info[3]) / float(info[4]), 10) / 10
    
    state[17, :,:] = observation['step'] % 40 / 40 #time
    state[18, :,:] = observation['step'] / 360 #->turn
    state[19, delta_pos[0]:32 - delta_pos[0], delta_pos[1]:32 - delta_pos[1]] = 1 #other unidentifier places

    return state

def bool_act(city_tile):
    return city_tile.can_act()

def city_tile_action(player):
    #find city and corresponding city_tile
    units_num = len(player.units)
    research_point = player.research_points
    city_actions = []
    cities = copy.deepcopy(player.cities.values())
    for city in cities:
        for i in range(len(city.citytiles)):
            act_flag = bool_act(city.citytiles[i])
            if not act_flag:
                continue
            elif player.city_tile_count > units_num:
                units_num+=1
                city_actions +=[city.citytiles[i].build_worker()] 
            elif not player.researched_uranium():
                research_point +=1
                city_actions +=[city.citytiles[i].research()]
    return city_actions,research_point

def get_worker_action(observation,player):
    global game_state
    pos_judged = []
    temp_units = copy.deepcopy(player.units)
    worker_actions = []
    for unit in temp_units:
        if unit.can_act():
            is_city = 0
            try:
                city = game_state.map.get_cell_by_pos(unit.pos).citytile
                if city is not None：
                    if city.team == game_state.id:
                        is_city = 1
            except:
                is_city = 0
            if not(game_state.turn % 40 >= 30 and is_city):
                with torch.no_grad():
                    policy = model(torch.from_numpy(get_state(observation, unit.id)).unsqueeze(0))
                    policy = np.squeeze(policy.numpy(),0)
                    for i in np.argsort[policy][::-1]:
                        actions = ['n','s','w','e','build_city']
                        pos = unit.pos.translate(actions[i], 1)
                        if pos is None:
                            pos = unit.pos
                        is_city = 0
                        try:
                            city = game_state.map.get_cell_by_pos(pos).citytile
                            if city is not None:
                                if city.team == game_state.id:
                                    is_city = 1
                        except:
                            is_city = 0

                        if pos not in pos_judged or is_city:
                            if actions[i] != 'build_city':
                                worker_actions += [unit.move(actions[i])]
                                pos_judged +=[pos]
                            else:
                                worker_actions += [unit.build_city()]
                                pos_judged +=[pos]
                    worker_actions += [unit.move('c')]
                    pos_judged +=[unit.pos]
    return worker_actions
        

def agent(observation, configuration):
    global game_state
    
    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation["player"]
    else:
        game_state._update(observation["updates"])        
    player = game_state.players[observation.player]
    
    city_actions,player.research_points = city_tile_action(player)
    worker_actions = get_worker_action(observation,player)
    return city_actions+worker_actions

Overwriting agent.py


In [None]:
from kaggle_environments import make

env = make("lux_ai_2021", configuration={"width": 24, "height": 24, "loglevel": 2, "annotations": True}, debug=False)
steps = env.run(['agent.py', 'agent.py'])
env.render(mode="ipython", width=1200, height=800)

In [None]:
#!tar -czf submission.tar.gz *