In [1]:
!ls ../input/halite-offline-dataset-builder | wc -l

8854


In [2]:
import glob
nb = len(glob.glob('../input/halite-offline-dataset-builder/x_*.pt'))

In [3]:
nb

4425

In [4]:
import torch
import os
from tqdm.auto import tqdm
dataset_x = None
for i in tqdm(range(nb)):
    p = f"../input/halite-offline-dataset-builder/x_{i}.pt"
    if os.path.exists(p):
        if dataset_x == None:
            dataset_x = torch.load(p, map_location='cuda')
        else:
            dataset_x = torch.cat([dataset_x, torch.load(p, map_location='cuda')],0)
        torch.cuda.empty_cache()

  0%|          | 0/4425 [00:00<?, ?it/s]

In [5]:
dataset_x.shape

torch.Size([322015, 9, 21, 21])

In [6]:
import math
dataset_s = dataset_x.shape[0]
test_nb = math.floor(dataset_s*0.1)
train_nb = dataset_s-test_nb

In [7]:
import numpy as np

In [8]:
torch.cuda.empty_cache()

In [9]:
import gc
gc.collect()

207

In [10]:
test_idx = np.random.choice(np.arange(dataset_s), test_nb, replace=False)
test_x = dataset_x[test_idx]
train_idx = np.array(list(set(np.arange(dataset_s))-set(test_idx)))
train_x = dataset_x[train_idx]

In [11]:
import gc
del dataset_x
torch.cuda.empty_cache()
gc.collect()

42

In [12]:
dataset_y = None
for i in tqdm(range(nb)):
    p = f"../input/halite-offline-dataset-builder/y_{i}.pt"
    if os.path.exists(p):
        if dataset_y == None:
            dataset_y = torch.load(p, map_location='cuda')
        else:
            dataset_y = torch.cat([dataset_y, torch.load(p, map_location='cuda')],0)
        torch.cuda.empty_cache()

  0%|          | 0/4425 [00:00<?, ?it/s]

In [13]:
torch.cuda.empty_cache()

In [14]:
test_y = dataset_y[test_idx]
train_y = dataset_y[train_idx]

In [15]:
del dataset_y
torch.cuda.empty_cache()

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical

class TorusConv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn):
        super().__init__()
        self.edge_size = (kernel_size[0] // 2, kernel_size[1] // 2)
        self.conv = nn.Conv2d(input_dim, output_dim, kernel_size=kernel_size)
        self.bn = nn.BatchNorm2d(output_dim) if bn else None

    def forward(self, x):
        h = torch.cat([x[:,:,:,-self.edge_size[1]:], x, x[:,:,:,:self.edge_size[1]]], dim=3)
        h = torch.cat([h[:,:,-self.edge_size[0]:], h, h[:,:,:self.edge_size[0]]], dim=2)
        h = self.conv(h)
        h = self.bn(h) if self.bn is not None else h
        return h


class HaliteNet(nn.Module):
    def __init__(self):
        super().__init__()
        layers, filters = 16, 128
        self.conv0 = TorusConv2d(9, filters, (3, 3), True)
        self.blocks = nn.ModuleList([TorusConv2d(filters, filters, (3, 3), True) for _ in range(layers)])
        
        self.head_ships_p = nn.Conv2d(filters, 6, kernel_size=1, stride=1)
        self.head_shipyards_p = nn.Conv2d(filters, 2, kernel_size=1, stride=1)
        self.head_v = nn.Linear(filters * 2, 1, bias=False)

    def forward(self, x, action=None):
        h = F.relu_(self.conv0(x))
        for block in self.blocks:
            h = F.relu_(h + block(h))
        ########################## /!\ ###########################
        # Do we concentrate around the ships of current player ? #
        ########################## /!\ ###########################
        #h_head = (h * x[:,:1]).view(h.size(0), h.size(1), -1).sum(-1)
        h_head = h.view(h.size(0), h.size(1), -1).sum(-1)
        h_avg = h.view(h.size(0), h.size(1), -1).mean(-1)
        ships_p = self.head_ships_p(h)
        shipyards_p = self.head_shipyards_p(h)
        v = torch.tanh(self.head_v(torch.cat([h_head, h_avg], 1)))
        
        ships_logits = ships_p.reshape(-1,6,21*21)
        shipyards_logits = shipyards_p.reshape(-1,2,21*21)
        action = torch.cat([ships_logits, shipyards_logits], 1)
   
        return action

In [17]:
import numpy as np
model = HaliteNet().to('cuda')
MINI_BATCH_SIZE = 128
BATCH_SIZE = MINI_BATCH_SIZE*4
criterion = torch.nn.CrossEntropyLoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.0005)
roll_loss = torch.zeros((40,1))

best_test_score = np.Inf

for t in tqdm(range(20000)):
    batch_idx = np.random.choice(np.arange(train_x.shape[0]),BATCH_SIZE,replace=False)
    
    _batch_x = train_x[batch_idx]
    batch_x = _batch_x.clone()
    # loading opponents moves
    # loading opponents actions
    opponent_idx = np.random.choice(np.arange(BATCH_SIZE),2*MINI_BATCH_SIZE,replace=False)
    _batch_y = train_y[batch_idx].to(torch.int64)
    batch_y = _batch_y[:,0]
    batch_y[opponent_idx] = _batch_y[opponent_idx,1]
    # loading opponents boards
    batch_x[opponent_idx,1], batch_x[opponent_idx,2] = _batch_x[opponent_idx,2], _batch_x[opponent_idx,1]
    batch_x[opponent_idx,3], batch_x[opponent_idx,4] = _batch_x[opponent_idx,4], _batch_x[opponent_idx,3]
    batch_x[opponent_idx,5], batch_x[opponent_idx,6] = _batch_x[opponent_idx,6], _batch_x[opponent_idx,5]
    batch_x[opponent_idx,7], batch_x[opponent_idx,8] = _batch_x[opponent_idx,8], _batch_x[opponent_idx,7]
    
    # x axis symetries
    ax_x_s_idx = np.random.choice(np.arange(BATCH_SIZE),2*MINI_BATCH_SIZE,replace=False)
    # board symetry
    batch_x[ax_x_s_idx] = batch_x[ax_x_s_idx].flip(2)
    # actions symetry
    batch_y[ax_x_s_idx] = batch_y[ax_x_s_idx].flip(2)
    batch_y[ax_x_s_idx,0] = torch.where(batch_y[ax_x_s_idx,0]==2,\
                                                        7,batch_y[ax_x_s_idx,0])
    batch_y[ax_x_s_idx,0] = torch.where(batch_y[ax_x_s_idx,0]==4,\
                                                        2,batch_y[ax_x_s_idx,0])
    batch_y[ax_x_s_idx,0] = torch.where(batch_y[ax_x_s_idx,0]==7,\
                                                        4,batch_y[ax_x_s_idx,0])
    
    # same but for y
    ax_y_s_idx = np.random.choice(np.arange(BATCH_SIZE),2*MINI_BATCH_SIZE,replace=False)
    batch_x[ax_y_s_idx] = batch_x[ax_y_s_idx].flip(3)
    batch_y[ax_y_s_idx] = batch_y[ax_y_s_idx].flip(3)
    batch_y[ax_y_s_idx,0] = torch.where(batch_y[ax_y_s_idx,0]==1,\
                                                        7,batch_y[ax_y_s_idx,0])
    batch_y[ax_y_s_idx,0] = torch.where(batch_y[ax_y_s_idx,0]==3,\
                                                        1,batch_y[ax_y_s_idx,0])
    batch_y[ax_y_s_idx,0] = torch.where(batch_y[ax_y_s_idx,0]==7,\
                                                        3,batch_y[ax_y_s_idx,0])
    
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(batch_x)
    
    # Compute and print loss
    losses = [F.cross_entropy(y_pred[:,:6,i],batch_y.reshape(-1,2,21*21)[:,0,i],\
                              reduction='none', weight=torch.tensor([0.75,1.,1.,1.,1.,400.],device='cuda'))* \
              batch_x[:,1].reshape(-1,21*21)[:,i] for i in range(21*21)] + \
                [F.cross_entropy(y_pred[:,6:,i],batch_y.reshape(-1,2,21*21)[:,1,i],\
                                 reduction='none', weight=torch.tensor([1,20.],device='cuda'))* \
                 batch_x[:,5].reshape(-1,21*21)[:,i] for i in range(21*21)]
    
    loss = torch.stack(losses).T.sum(1).mean()
        
    roll_loss = torch.cat([roll_loss[1:], loss.detach().cpu().unsqueeze(0).unsqueeze(0)])
    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if t % 40 == 39:
        print(t, loss.item(), roll_loss.mean().numpy())
    
    if t % 300 == 299:
        model.eval()
        total_test_loss = 0
        for i in range(test_nb//BATCH_SIZE):
            batch_test_x = test_x[BATCH_SIZE*i:BATCH_SIZE*(i+1)]
            batch_test_y = test_y[BATCH_SIZE*i:BATCH_SIZE*(i+1)].to(torch.int64)
            with torch.no_grad():
                test_y_pred = model(batch_test_x)
            test_losses = [F.cross_entropy(test_y_pred[:,:6,i],batch_test_y[:,0].reshape(-1,2,21*21)[:,0,i],\
                              reduction='none', weight=torch.tensor([0.75,1.,1.,1.,1.,400.],device='cuda'))* \
              batch_test_x[:,1].reshape(-1,21*21)[:,i] for i in range(21*21)] + \
                [F.cross_entropy(test_y_pred[:,6:,i],batch_test_y[:,0].reshape(-1,2,21*21)[:,1,i],\
                                 reduction='none', weight=torch.tensor([1,20.],device='cuda'))* \
                 batch_test_x[:,5].reshape(-1,21*21)[:,i] for i in range(21*21)]
    
            total_test_loss += torch.stack(test_losses).T.sum(1).mean()
        if best_test_score > total_test_loss:
            print('Saved new best model')
            torch.save(model, 'model.pkl')
            best_test_score = total_test_loss
        model.train()
        print(f"Test loss {total_test_loss.cpu().numpy()/(test_nb//BATCH_SIZE)}")

  0%|          | 0/20000 [00:00<?, ?it/s]

39 71.29249572753906 95.2831
79 69.02450561523438 73.16762
119 79.4933853149414 71.09863
159 59.315513610839844 68.05406
199 74.53529357910156 69.01979
239 71.41436767578125 68.58394
279 63.925594329833984 67.03667
Saved new best model
Test loss 65.03410093245968
319 62.975982666015625 68.581474
359 68.8465347290039 64.8388
399 61.91830062866211 64.86824
439 81.61508178710938 83.437744
479 62.91874694824219 78.529976
519 73.353515625 73.81151
559 76.72874450683594 67.67672
599 60.87340545654297 68.45598
Test loss 65.55796370967742
639 74.30863189697266 69.38322
679 74.87690734863281 66.81705
719 65.52664184570312 65.55499
759 57.65904235839844 65.93186
799 69.53068542480469 67.106415
839 67.7800521850586 68.47622
879 61.52591323852539 65.83551
Test loss 66.67809664818549
919 64.11488342285156 67.68554
959 72.10465240478516 68.20837
999 57.388946533203125 66.42534
1039 64.03872680664062 67.64824
1079 68.90668487548828 66.78093
1119 63.88421630859375 66.04503
1159 66.92798614501953 65.22