In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append(r"../..")


from torch import nn
from torch.nn import functional as F
import torch.optim as optim
import torch
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt

from callbacks import state_to_features, ACTION_MAP, ACTION_MAP_INV
from networks import AgentNet

from events import *

In [2]:
EVENT_HORIZON = {
    MOVED_LEFT: 1,
    MOVED_RIGHT: 1,
    MOVED_UP: 1,
    MOVED_DOWN: 1,
    WAITED: 2,
    INVALID_ACTION: 1,
    BOMB_DROPPED: 2,
    BOMB_EXPLODED: 2,
    CRATE_DESTROYED: 8,
    COIN_FOUND: 3,
    COIN_COLLECTED: 5,
    KILLED_OPPONENT: 8,
    KILLED_SELF: 3,
    GOT_KILLED: 3,
    OPPONENT_ELIMINATED: 0,
    SURVIVED_ROUND: 8
}
"""
REWARDS = {
    MOVED_LEFT: 1,
    MOVED_RIGHT: 1,
    MOVED_UP: 1,
    MOVED_DOWN: 1,
    WAITED: 0,
    INVALID_ACTION: -5,
    BOMB_DROPPED: 1,
    BOMB_EXPLODED: 0,
    CRATE_DESTROYED: 4,
    COIN_FOUND: 0,
    COIN_COLLECTED: 5,
    KILLED_OPPONENT: 5,
    KILLED_SELF: -8,
    GOT_KILLED: -8,
    OPPONENT_ELIMINATED: 0,
    SURVIVED_ROUND: 0
}
"""
REWARDS = {
    MOVED_LEFT: 100,
    MOVED_RIGHT: 100,
    MOVED_UP: 100,
    MOVED_DOWN: 100,
    WAITED: -100,
    INVALID_ACTION: -10,
    BOMB_DROPPED: 0,
    BOMB_EXPLODED: 0,
    CRATE_DESTROYED: 0,
    COIN_FOUND: 0,
    COIN_COLLECTED: 0,
    KILLED_OPPONENT: 0,
    KILLED_SELF: 0,
    GOT_KILLED: 0,
    OPPONENT_ELIMINATED: 0,
    SURVIVED_ROUND: 0
}

# REWARDS = {k: v + 10 for k, v in REWARDS.items()}



In [3]:
from torch.utils.data import Dataset


class BombermanDataset(Dataset):
    def __init__(self, states_dir, max_game_step=200, min_reward=1):
        self.features = []
        self.actions = []
        self.cum_rewards = []
        self.total_number = 0
        for states_file in os.listdir(states_dir):

            if not states_file.endswith('.pickle'):
                continue
            with open(os.path.join(states_dir, states_file), "rb") as f:
                data = pickle.load(f)
            self.total_number += len(data['game_state'])
            last_round = -1
            running_horizons = []
            running_rewards = []
            for i in range(len(data['game_state'])-1,-1,-1):
                game_state = data['game_state'][i]               

                    
                if last_round != game_state['round']:
                    last_round = game_state['round']
                    running_horizons = []
                    running_rewards = []
                if game_state['step'] > max_game_step:
                    continue
                    
                # determine if action should be included in training set
                
                horizons = [EVENT_HORIZON[e] for e in data['events'][i]]
                rewards = [REWARDS[e] for e in data['events'][i]]
                cr = np.where(data['events'] == "CRATE_DESTROYED")
                #print(cr)
                #rewards[np.array(cr[-1:], dtype = "int")] == 0
                #for j in cr[1:]:
                #    rewards[j] = 0             

                running_rewards = [r for j,r in enumerate(running_rewards) if running_horizons[j]>1]
                running_horizons = [h-1 for h in running_horizons if h > 1]
                running_rewards.extend(rewards)
                running_horizons.extend(horizons)
                #if "INVALID_ACTION" in data['events'][i]:
                #    continue   
                self.cum_rewards.append(np.sum(running_rewards))

                last_action = None if game_state['step'] == 1 else data['action'][i-1]
                #print("LAST ACTION:")
                #print(last_action)
                
                channels, features, act_map = state_to_features(game_state, last_action=last_action)
                action = ACTION_MAP[act_map[data['action'][i]]]
                self.features.append((
                    torch.tensor(channels[np.newaxis], dtype=torch.float),
                    torch.tensor(features, dtype=torch.float)
                ))
                self.actions.append(action)
                
                    
        self.cum_rewards = torch.tensor(self.cum_rewards, dtype=torch.float)
        self.n = np.array([np.sum([t == i for t in self.actions]) for i in range(6)])
        self.weights = np.array([1/self.n[t] for t in self.actions])
        self.proportion_selected = len(self.actions)/self.total_number
        print(f"Loaded {len(self.actions)} actions.")

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.actions[idx], self.cum_rewards[idx]

In [17]:
batch_size = 64
data_path = f"../data/"
weight_path = "models/model_weights_pretrained.pth"
num_epoch_per_run = 1

# optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=0.001)
net = AgentNet()
#net.load_state_dict(torch.load(weight_path))
optimizer = optim.AdamW(net.parameters(), lr=0.001)

In [18]:
optimizer = optim.AdamW(net.parameters(), lr=0.001)
batch_size = 128


In [4]:
EVENT_HORIZON = {
    MOVED_LEFT: 1,
    MOVED_RIGHT: 1,
    MOVED_UP: 1,
    MOVED_DOWN: 1,
    WAITED: 1,
    INVALID_ACTION: 1,
    BOMB_DROPPED: 1,
    BOMB_EXPLODED: 1,
    CRATE_DESTROYED: 6,
    COIN_FOUND: 1,
    COIN_COLLECTED: 5,
    KILLED_OPPONENT: 6,
    KILLED_SELF: 5,
    GOT_KILLED: 6,
    OPPONENT_ELIMINATED: 4,
    SURVIVED_ROUND: 100
}


REWARDS = {
    MOVED_LEFT: 5, #1
    MOVED_RIGHT: 5, #1
    MOVED_UP: 5,   #1
    MOVED_DOWN: 5,  #1
    WAITED: 5,  #1
    INVALID_ACTION: -120, #-7
    BOMB_DROPPED: 5,  #1
    BOMB_EXPLODED: 0,
    CRATE_DESTROYED: 120,
    COIN_FOUND: 0,
    COIN_COLLECTED: 120,
    KILLED_OPPONENT: 60,
    KILLED_SELF: -120, #-12
    GOT_KILLED: -120,
    OPPONENT_ELIMINATED: 5,
    SURVIVED_ROUND: 200 #20
}

In [20]:
for run in range(10):
    # os.system("rm -rf ../rule_based_agent/data")
    os.system("rm -rf ../data")
    os.system("cd ../..; python main.py play --no-gui --agents rule_based_agent rule_based_agent rule_based_agent rule_based_agent --train 4 --n-rounds 40 --scenario loot-crate")

    trainset = BombermanDataset(data_path)
    N = len(trainset)
    print(f"{N} actions in the training set:")
    for i in range(6):
        n = np.sum([t == i for t in trainset.actions])
        print(f"{ACTION_MAP_INV[i]}: {n/N*100:.2f}% ({n})")
    # sampler = BatchSampler(WeightedRandomSampler(trainset.weights, len(trainset), replacement=True,), batch_size, False)
    # trainloader = torch.utils.data.DataLoader(trainset, batch_sampler=sampler)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    for epoch in range(num_epoch_per_run):
        running_loss = 0.0
        running = 0
        for i, batch in enumerate(trainloader, 0):
        # for i, batch in enumerate(sampler, 0):
            print(i, end="\r")
            # get the inputs; data is a list of [inputs, labels]
            (channels, features), actions, rewards = batch

            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = net(channels, features)
            # loss = criterion(outputs, labels)
            log_logits = F.log_softmax(outputs, dim=-1)
            log_probs = log_logits.gather(1, actions[:,np.newaxis])
            
            loss = torch.mean(-log_probs * rewards)
            loss.backward()
            optimizer.step()
    
            # print statistics
            running_loss += loss.item()
            running += 1
            if i % 200 == 199:    # print every 2000 mini-batches
                print(f'[{run + 1}, {epoch + 1}, {i + 1:5d}] loss: {running_loss / running:.3f}')
                torch.save(net.state_dict(), weight_path)
                running_loss = 0.0
                running = 0
                #print(log_probs)
                #print(rewards)
                print(i*batch_size/N)
            if i*batch_size/N > 0.15:
                break
        if running > 0:
            print(f'[{run + 1}, {epoch + 1}] loss: {running_loss / running:.3f}')
            
    torch.save(net.state_dict(), weight_path)

print('Finished Training') 

100%|██████████| 40/40 [01:02<00:00,  1.57s/it]


rule_based_agent_0    511
rule_based_agent_1    478
rule_based_agent_2    524
rule_based_agent_3    441
Loaded 528784 actions.
528784 actions in the training set:
UP: 19.59% (103597)
DOWN: 25.18% (133160)
LEFT: 19.49% (103045)
RIGHT: 25.55% (135121)
BOMB: 8.60% (45464)
WAIT: 1.59% (8397)
[1, 1,   200] loss: 271.523
0.04817089775787467
[1, 1,   400] loss: 139.790
0.09658386032860299
[1, 1,   600] loss: 118.204
0.1449968228993313
[1, 1] loss: 101.384


100%|██████████| 40/40 [00:55<00:00,  1.39s/it]


rule_based_agent_0    496
rule_based_agent_1    452
rule_based_agent_2    465
rule_based_agent_3    500
Loaded 510991 actions.
510991 actions in the training set:
UP: 19.38% (99032)
DOWN: 25.37% (129626)
LEFT: 19.66% (100450)
RIGHT: 25.12% (128360)
BOMB: 8.68% (44347)
WAIT: 1.80% (9176)
[2, 1,   200] loss: 118.307
0.04984823607460797
[2, 1,   400] loss: 103.313
0.09994696579783205
[2, 1,   600] loss: 93.825
0.15004569552105615


100%|██████████| 40/40 [01:21<00:00,  2.03s/it]


rule_based_agent_0    463
rule_based_agent_1    431
rule_based_agent_2    543
rule_based_agent_3    486
Loaded 514550 actions.
514550 actions in the training set:
UP: 19.83% (102055)
DOWN: 25.25% (129948)
LEFT: 19.39% (99784)
RIGHT: 25.32% (130278)
BOMB: 8.43% (43397)
WAIT: 1.77% (9088)
[3, 1,   200] loss: 111.844
0.04950344961616947
[3, 1,   400] loss: 98.572
0.09925566028568653
[3, 1,   600] loss: 89.211
0.14900787095520357
[3, 1] loss: 97.197


100%|██████████| 40/40 [01:25<00:00,  2.14s/it]


rule_based_agent_0    475
rule_based_agent_1    509
rule_based_agent_2    465
rule_based_agent_3    469
Loaded 504696 actions.
504696 actions in the training set:
UP: 19.47% (98248)
DOWN: 25.24% (127391)
LEFT: 19.55% (98655)
RIGHT: 25.46% (128490)
BOMB: 8.58% (43321)
WAIT: 1.70% (8591)
[4, 1,   200] loss: 102.938
0.05046998589249766
[4, 1,   400] loss: 85.305
0.10119358980455562
[4, 1] loss: 78.836


100%|██████████| 40/40 [01:24<00:00,  2.11s/it]


rule_based_agent_0    485
rule_based_agent_1    480
rule_based_agent_2    499
rule_based_agent_3    485
Loaded 558522 actions.
558522 actions in the training set:
UP: 19.53% (109078)
DOWN: 24.93% (139254)
LEFT: 20.04% (111904)
RIGHT: 25.29% (141255)
BOMB: 8.64% (48266)
WAIT: 1.57% (8765)
[5, 1,   200] loss: 102.070
0.045606081765803316
[5, 1,   400] loss: 88.812
0.09144133982188705
[5, 1,   600] loss: 76.876
0.13727659787797078
[5, 1] loss: 78.121


100%|██████████| 40/40 [00:59<00:00,  1.49s/it]


rule_based_agent_0    464
rule_based_agent_1    459
rule_based_agent_2    506
rule_based_agent_3    523
Loaded 512145 actions.
512145 actions in the training set:
UP: 19.51% (99945)
DOWN: 24.79% (126978)
LEFT: 19.71% (100964)
RIGHT: 25.68% (131542)
BOMB: 8.54% (43762)
WAIT: 1.75% (8954)
[6, 1,   200] loss: 101.376
0.049735914633551044
[6, 1,   400] loss: 84.654
0.09972175848636616
[6, 1,   600] loss: 74.339
0.1497076023391813
[6, 1] loss: 70.495


100%|██████████| 40/40 [00:57<00:00,  1.44s/it]


rule_based_agent_0    483
rule_based_agent_1    509
rule_based_agent_2    508
rule_based_agent_3    419
Loaded 540884 actions.
540884 actions in the training set:
UP: 19.56% (105783)
DOWN: 25.20% (136326)
LEFT: 19.71% (106607)
RIGHT: 25.17% (136141)
BOMB: 8.74% (47280)
WAIT: 1.62% (8747)
[7, 1,   200] loss: 93.473
0.04709327693183751
[7, 1,   400] loss: 78.871
0.09442320349649833
[7, 1,   600] loss: 69.165
0.14175313006115914
[7, 1] loss: 66.582


100%|██████████| 40/40 [00:55<00:00,  1.39s/it]


rule_based_agent_0    486
rule_based_agent_1    449
rule_based_agent_2    517
rule_based_agent_3    467
Loaded 514745 actions.
514745 actions in the training set:
UP: 19.56% (100665)
DOWN: 24.95% (128407)
LEFT: 19.82% (102026)
RIGHT: 25.38% (130651)
BOMB: 8.68% (44685)
WAIT: 1.61% (8311)
[8, 1,   200] loss: 95.278
0.049484696305937895
[8, 1,   400] loss: 78.088
0.09921805942748352
[8, 1,   600] loss: 74.317
0.14895142254902913
[8, 1] loss: 70.503


100%|██████████| 40/40 [00:55<00:00,  1.39s/it]


rule_based_agent_0    544
rule_based_agent_1    478
rule_based_agent_2    410
rule_based_agent_3    468
Loaded 513918 actions.
513918 actions in the training set:
UP: 19.56% (100499)
DOWN: 25.49% (131002)
LEFT: 19.46% (99996)
RIGHT: 25.32% (130138)
BOMB: 8.53% (43848)
WAIT: 1.64% (8435)
[9, 1,   200] loss: 90.847
0.04956432738296771
[9, 1,   400] loss: 75.577
0.09937772173770913
[9, 1,   600] loss: 67.661
0.14919111609245056
[9, 1] loss: 46.900


100%|██████████| 40/40 [00:54<00:00,  1.37s/it]


rule_based_agent_0    501
rule_based_agent_1    517
rule_based_agent_2    522
rule_based_agent_3    440
Loaded 514822 actions.
514822 actions in the training set:
UP: 19.57% (100732)
DOWN: 24.95% (128455)
LEFT: 19.59% (100868)
RIGHT: 25.62% (131877)
BOMB: 8.60% (44295)
WAIT: 1.67% (8595)
[10, 1,   200] loss: 92.429
0.04947729506509046
[10, 1,   400] loss: 75.591
0.09920321975362359
[10, 1,   600] loss: 66.345
0.1489291444421567
[10, 1] loss: 70.222
Finished Training


In [21]:
torch.save(net.state_dict(), "models/model_weights_pre")

In [22]:
optimizer = optim.AdamW(net.parameters(), lr=0.0001)
batch_size = 256


In [23]:
for run in range(10):
    # os.system("rm -rf ../rule_based_agent/data")
    os.system("rm -rf ../data")
    os.system("cd ../..; python main.py play --no-gui --agents rule_based_agent rule_based_agent rule_based_agent rule_based_agent --train 4 --n-rounds 50 --scenario loot-crate")

    trainset = BombermanDataset(data_path)
    N = len(trainset)
    print(f"{N} actions in the training set:")
    for i in range(6):
        n = np.sum([t == i for t in trainset.actions])
        print(f"{ACTION_MAP_INV[i]}: {n/N*100:.2f}% ({n})")
    # sampler = BatchSampler(WeightedRandomSampler(trainset.weights, len(trainset), replacement=True,), batch_size, False)
    # trainloader = torch.utils.data.DataLoader(trainset, batch_sampler=sampler)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    for epoch in range(num_epoch_per_run):
        running_loss = 0.0
        running = 0
        for i, batch in enumerate(trainloader, 0):
        # for i, batch in enumerate(sampler, 0):
            print(i, end="\r")
            # get the inputs; data is a list of [inputs, labels]
            (channels, features), actions, rewards = batch

            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = net(channels, features)
            # loss = criterion(outputs, labels)
            log_logits = F.log_softmax(outputs, dim=-1)
            log_probs = log_logits.gather(1, actions[:,np.newaxis])
            
            loss = torch.mean(-log_probs * rewards)
            loss.backward()
            optimizer.step()
    
            # print statistics
            running_loss += loss.item()
            running += 1
            if i % 200 == 199:    # print every 2000 mini-batches
                print(f'[{run + 1}, {epoch + 1}, {i + 1:5d}] loss: {running_loss / running:.3f}')
                torch.save(net.state_dict(), weight_path)
                running_loss = 0.0
                running = 0
                #print(log_probs)
                #print(rewards)
                print(i*batch_size/N)
            if i*batch_size/N > 0.25:
                break
        if running > 0:
            print(f'[{run + 1}, {epoch + 1}] loss: {running_loss / running:.3f}')
            
    torch.save(net.state_dict(), weight_path)

print('Finished Training') 

100%|██████████| 50/50 [01:14<00:00,  1.49s/it]


rule_based_agent_0    603
rule_based_agent_1    586
rule_based_agent_2    599
rule_based_agent_3    588
Loaded 822204 actions.
822204 actions in the training set:
UP: 19.61% (161249)
DOWN: 25.23% (207447)
LEFT: 19.47% (160121)
RIGHT: 25.35% (208467)
BOMB: 8.63% (70983)
WAIT: 1.70% (13937)
[1, 1,   200] loss: 91.364
0.0619602920929599
[1, 1,   400] loss: 81.923
0.12423194243764321
[1, 1,   600] loss: 77.439
0.18650359278232653
[1, 1,   800] loss: 71.511
0.24877524312700985
[1, 1] loss: 72.563


100%|██████████| 50/50 [01:24<00:00,  1.70s/it]


rule_based_agent_0    578
rule_based_agent_1    541
rule_based_agent_2    606
rule_based_agent_3    654
Loaded 847121 actions.
847121 actions in the training set:
UP: 19.25% (163068)
DOWN: 24.86% (210609)
LEFT: 20.00% (169398)
RIGHT: 25.75% (218125)
BOMB: 8.55% (72397)
WAIT: 1.60% (13524)
[2, 1,   200] loss: 92.761
0.060137807940069954
[2, 1,   400] loss: 85.659
0.12057781592003976
[2, 1,   600] loss: 79.833
0.18101782390000956
[2, 1,   800] loss: 73.716
0.24145783187997938
[2, 1] loss: 72.354


100%|██████████| 50/50 [01:13<00:00,  1.47s/it]


rule_based_agent_0    642
rule_based_agent_1    557
rule_based_agent_2    620
rule_based_agent_3    573
Loaded 776815 actions.
776815 actions in the training set:
UP: 19.80% (153819)
DOWN: 25.63% (199062)
LEFT: 19.27% (149729)
RIGHT: 25.26% (196186)
BOMB: 8.49% (65931)
WAIT: 1.56% (12088)
[3, 1,   200] loss: 87.313
0.06558060799546868
[3, 1,   400] loss: 75.107
0.13149076678488444
[3, 1,   600] loss: 70.318
0.1974009255743002
[3, 1] loss: 65.128


100%|██████████| 50/50 [01:40<00:00,  2.01s/it]


rule_based_agent_0    576
rule_based_agent_1    663
rule_based_agent_2    548
rule_based_agent_3    596
Loaded 808447 actions.
808447 actions in the training set:
UP: 19.70% (159271)
DOWN: 25.35% (204949)
LEFT: 19.62% (158614)
RIGHT: 25.23% (203992)
BOMB: 8.54% (69059)
WAIT: 1.55% (12562)
[4, 1,   200] loss: 86.266
0.0630146441263311
[4, 1,   400] loss: 78.057
0.1263459447558096
[4, 1,   600] loss: 73.214
0.1896772453852881
[4, 1] loss: 67.609


100%|██████████| 50/50 [01:26<00:00,  1.73s/it]


rule_based_agent_0    611
rule_based_agent_1    636
rule_based_agent_2    612
rule_based_agent_3    626
Loaded 855260 actions.
855260 actions in the training set:
UP: 19.65% (168075)
DOWN: 25.45% (217650)
LEFT: 19.52% (166973)
RIGHT: 25.26% (216041)
BOMB: 8.61% (73598)
WAIT: 1.51% (12923)
[5, 1,   200] loss: 91.453
0.059565512241891355
[5, 1,   400] loss: 81.927
0.11943034866590277
[5, 1,   600] loss: 75.920
0.17929518508991418
[5, 1,   800] loss: 70.973
0.2391600215139256
[5, 1] loss: 65.879


100%|██████████| 50/50 [01:12<00:00,  1.44s/it]


rule_based_agent_0    561
rule_based_agent_1    633
rule_based_agent_2    515
rule_based_agent_3    664
Loaded 803166 actions.
803166 actions in the training set:
UP: 19.76% (158744)
DOWN: 25.76% (206911)
LEFT: 19.46% (156305)
RIGHT: 24.79% (199102)
BOMB: 8.60% (69112)
WAIT: 1.62% (12992)
[6, 1,   200] loss: 86.538
0.06342897981239246
[6, 1,   400] loss: 76.485
0.127176698216807
[6, 1,   600] loss: 69.810
0.1909244166212215
[6, 1] loss: 64.309


100%|██████████| 50/50 [01:29<00:00,  1.79s/it]


rule_based_agent_0    594
rule_based_agent_1    638
rule_based_agent_2    623
rule_based_agent_3    621
Loaded 822476 actions.
822476 actions in the training set:
UP: 19.59% (161156)
DOWN: 25.21% (207337)
LEFT: 19.50% (160423)
RIGHT: 25.30% (208061)
BOMB: 8.54% (70270)
WAIT: 1.85% (15229)
[7, 1,   200] loss: 81.298
0.06193980128295537
[7, 1,   400] loss: 71.906
0.12419085784873966
[7, 1,   600] loss: 66.140
0.18644191441452396
[7, 1,   800] loss: 60.735
0.24869297098030824
[7, 1] loss: 56.948


100%|██████████| 50/50 [01:20<00:00,  1.61s/it]


rule_based_agent_0    587
rule_based_agent_1    554
rule_based_agent_2    631
rule_based_agent_3    622
Loaded 806193 actions.
806193 actions in the training set:
UP: 19.45% (156790)
DOWN: 24.97% (201276)
LEFT: 19.68% (158657)
RIGHT: 25.56% (206035)
BOMB: 8.65% (69774)
WAIT: 1.69% (13661)
[8, 1,   200] loss: 84.761
0.06319082403345104
[8, 1,   400] loss: 74.705
0.12669918989621592
[8, 1,   600] loss: 67.596
0.1902075557589808
[8, 1] loss: 62.399


100%|██████████| 50/50 [01:20<00:00,  1.62s/it]


rule_based_agent_0    585
rule_based_agent_1    620
rule_based_agent_2    542
rule_based_agent_3    646
Loaded 846507 actions.
846507 actions in the training set:
UP: 19.62% (166055)
DOWN: 25.47% (215599)
LEFT: 19.61% (165966)
RIGHT: 25.07% (212243)
BOMB: 8.57% (72534)
WAIT: 1.67% (14110)
[9, 1,   200] loss: 87.836
0.06018142791494931
[9, 1,   400] loss: 76.854
0.12066527506565214
[9, 1,   600] loss: 70.995
0.18114912221635499
[9, 1,   800] loss: 65.990
0.2416329693670578
[9, 1] loss: 63.369


100%|██████████| 50/50 [01:23<00:00,  1.67s/it]


rule_based_agent_0    550
rule_based_agent_1    601
rule_based_agent_2    640
rule_based_agent_3    678
Loaded 821429 actions.
821429 actions in the training set:
UP: 19.63% (161236)
DOWN: 25.13% (206458)
LEFT: 19.76% (162351)
RIGHT: 25.36% (208325)
BOMB: 8.54% (70137)
WAIT: 1.57% (12922)
[10, 1,   200] loss: 83.702
0.06201875025108682
[10, 1,   400] loss: 75.458
0.12434915251348565
[10, 1,   600] loss: 68.025
0.18667955477588447
[10, 1,   800] loss: 61.395
0.2490099570382833
[10, 1] loss: 57.100
Finished Training


In [24]:
torch.save(net.state_dict(), "models/model_weights_fine")

In [9]:
data_path = f"../data/"
weight_path = "models/model_weights.pth"
num_epoch_per_run = 1

# optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=0.001)
net = AgentNet()
net.load_state_dict(torch.load("models/model_weights_fine"))


<All keys matched successfully>

In [10]:
optimizer = optim.AdamW(net.parameters(), lr=0.00001)
batch_size = 256


In [11]:
for run in range(20):
    # os.system("rm -rf ../rule_based_agent/data")
    os.system("rm -rf ../data")
    os.system("cd ../..; python main.py play --no-gui --agents lfov_agent rule_based_agent rule_based_agent rule_based_agent --train 4 --n-rounds 60 --scenario loot-crate --save-winner-game True")

    trainset = BombermanDataset(data_path)
    N = len(trainset)
    print(f"{N} actions in the training set:")
    for i in range(6):
        n = np.sum([t == i for t in trainset.actions])
        print(f"{ACTION_MAP_INV[i]}: {n/N*100:.2f}% ({n})")
    # sampler = BatchSampler(WeightedRandomSampler(trainset.weights, len(trainset), replacement=True,), batch_size, False)
    # trainloader = torch.utils.data.DataLoader(trainset, batch_sampler=sampler)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    for epoch in range(num_epoch_per_run):
        running_loss = 0.0
        running = 0
        for i, batch in enumerate(trainloader, 0):
        # for i, batch in enumerate(sampler, 0):
            print(i, end="\r")
            # get the inputs; data is a list of [inputs, labels]
            (channels, features), actions, rewards = batch

            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = net(channels, features)
            # loss = criterion(outputs, labels)
            log_logits = F.log_softmax(outputs, dim=-1)
            log_probs = log_logits.gather(1, actions[:,np.newaxis])
            
            loss = torch.mean(-log_probs * rewards)
            loss.backward()
            optimizer.step()
    
            # print statistics
            running_loss += loss.item()
            running += 1
            if i % 200 == 199:    # print every 2000 mini-batches
                print(f'[{run + 1}, {epoch + 1}, {i + 1:5d}] loss: {running_loss / running:.3f}')
                torch.save(net.state_dict(), weight_path)
                running_loss = 0.0
                running = 0
                #print(log_probs)
                #print(rewards)
                print(i*batch_size/N)
            if i*batch_size/N > 0.5:
                break
        if running > 0:
            print(f'[{run + 1}, {epoch + 1}] loss: {running_loss / running:.3f}')
            
    torch.save(net.state_dict(), weight_path)

print('Finished Training') 

100%|██████████| 60/60 [01:28<00:00,  1.47s/it]


lfov_agent    690
rule_based_agent_0    767
rule_based_agent_1    785
rule_based_agent_2    686
Loaded 296379 actions.
296379 actions in the training set:
UP: 19.33% (57281)
DOWN: 25.83% (76543)
LEFT: 19.53% (57892)
RIGHT: 25.42% (75331)
BOMB: 8.62% (25541)
WAIT: 1.28% (3791)
[1, 1,   200] loss: 93.594
0.17188802175592738
[1, 1,   400] loss: 88.953
0.34463980241515085
[1, 1] loss: 84.614


100%|██████████| 60/60 [01:26<00:00,  1.44s/it]


lfov_agent    648
rule_based_agent_0    698
rule_based_agent_1    720
rule_based_agent_2    861
Loaded 300310 actions.
300310 actions in the training set:
UP: 19.39% (58231)
DOWN: 25.46% (76446)
LEFT: 19.64% (58991)
RIGHT: 25.81% (77515)
BOMB: 8.38% (25161)
WAIT: 1.32% (3966)
[2, 1,   200] loss: 83.185
0.16963804069128569
[2, 1,   400] loss: 82.310
0.3401285338483567
[2, 1] loss: 80.394


100%|██████████| 60/60 [01:29<00:00,  1.48s/it]


lfov_agent    692
rule_based_agent_0    672
rule_based_agent_1    773
rule_based_agent_2    725
Loaded 299652 actions.
299652 actions in the training set:
UP: 19.78% (59261)
DOWN: 26.03% (77998)
LEFT: 19.31% (57863)
RIGHT: 25.19% (75488)
BOMB: 8.37% (25083)
WAIT: 1.32% (3959)
[3, 1,   200] loss: 84.440
0.1700105455661901
[3, 1,   400] loss: 84.410
0.3408754154819591
[3, 1] loss: 82.549


100%|██████████| 60/60 [01:29<00:00,  1.48s/it]


lfov_agent    681
rule_based_agent_0    761
rule_based_agent_1    723
rule_based_agent_2    739
Loaded 301606 actions.
301606 actions in the training set:
UP: 19.20% (57910)
DOWN: 26.08% (78645)
LEFT: 19.48% (58754)
RIGHT: 25.50% (76914)
BOMB: 8.40% (25349)
WAIT: 1.34% (4034)
[4, 1,   200] loss: 79.566
0.1689091065827603
[4, 1,   400] loss: 79.261
0.33866700264583594
[4, 1] loss: 78.342


100%|██████████| 60/60 [01:26<00:00,  1.44s/it]


lfov_agent    610
rule_based_agent_0    738
rule_based_agent_1    769
rule_based_agent_2    788
Loaded 292092 actions.
292092 actions in the training set:
UP: 19.22% (56143)
DOWN: 25.78% (75299)
LEFT: 19.53% (57031)
RIGHT: 25.73% (75161)
BOMB: 8.42% (24591)
WAIT: 1.32% (3867)
[5, 1,   200] loss: 80.779
0.17441080207605822
[5, 1,   400] loss: 78.057
0.3496980403434534
[5, 1] loss: 78.392


100%|██████████| 60/60 [01:26<00:00,  1.44s/it]


lfov_agent    682
rule_based_agent_0    784
rule_based_agent_1    673
rule_based_agent_2    690
Loaded 293564 actions.
293564 actions in the training set:
UP: 19.55% (57381)
DOWN: 26.00% (76336)
LEFT: 19.16% (56249)
RIGHT: 25.91% (76075)
BOMB: 8.13% (23860)
WAIT: 1.25% (3663)
[6, 1,   200] loss: 82.561
0.17353626466460464
[6, 1,   400] loss: 80.592
0.3479445708601872
[6, 1] loss: 78.853


100%|██████████| 60/60 [01:25<00:00,  1.43s/it]


lfov_agent    599
rule_based_agent_0    745
rule_based_agent_1    750
rule_based_agent_2    763
Loaded 294842 actions.
294842 actions in the training set:
UP: 19.40% (57195)
DOWN: 25.44% (75017)
LEFT: 19.49% (57450)
RIGHT: 25.82% (76125)
BOMB: 8.57% (25274)
WAIT: 1.28% (3781)
[7, 1,   200] loss: 80.875
0.17278406739881022
[7, 1,   400] loss: 78.554
0.3464363964428406
[7, 1] loss: 76.612


100%|██████████| 60/60 [01:25<00:00,  1.43s/it]


lfov_agent    620
rule_based_agent_0    764
rule_based_agent_1    761
rule_based_agent_2    723
Loaded 287542 actions.
287542 actions in the training set:
UP: 19.53% (56157)
DOWN: 25.68% (73838)
LEFT: 19.43% (55881)
RIGHT: 25.69% (73877)
BOMB: 8.42% (24199)
WAIT: 1.25% (3590)
[8, 1,   200] loss: 81.135
0.1771706394196326
[8, 1,   400] loss: 79.417
0.35523158355996687
[8, 1] loss: 78.882


100%|██████████| 60/60 [01:25<00:00,  1.43s/it]


lfov_agent    701
rule_based_agent_0    755
rule_based_agent_1    704
rule_based_agent_2    762
Loaded 295418 actions.
295418 actions in the training set:
UP: 19.54% (57722)
DOWN: 26.22% (77453)
LEFT: 19.22% (56790)
RIGHT: 25.28% (74685)
BOMB: 8.35% (24662)
WAIT: 1.39% (4106)
[9, 1,   200] loss: 80.564
0.1724471765430678
[9, 1,   400] loss: 78.910
0.34576092181248264
[9, 1] loss: 77.739


100%|██████████| 60/60 [01:29<00:00,  1.50s/it]


lfov_agent    594
rule_based_agent_0    686
rule_based_agent_1    695
rule_based_agent_2    784
Loaded 295330 actions.
295330 actions in the training set:
UP: 19.59% (57856)
DOWN: 26.09% (77046)
LEFT: 19.32% (57069)
RIGHT: 25.27% (74630)
BOMB: 8.41% (24828)
WAIT: 1.32% (3901)
[10, 1,   200] loss: 79.371
0.17249856093183896
[10, 1,   400] loss: 79.762
0.3458639488030339
[10, 1] loss: 78.461


100%|██████████| 60/60 [01:29<00:00,  1.49s/it]


lfov_agent    682
rule_based_agent_0    774
rule_based_agent_1    699
rule_based_agent_2    743
Loaded 293655 actions.
293655 actions in the training set:
UP: 19.18% (56312)
DOWN: 26.15% (76790)
LEFT: 19.42% (57034)
RIGHT: 25.73% (75546)
BOMB: 8.20% (24079)
WAIT: 1.33% (3894)
[11, 1,   200] loss: 78.615
0.17348248795355092
[11, 1,   400] loss: 77.862
0.3478367472033509
[11, 1] loss: 75.530


100%|██████████| 60/60 [01:26<00:00,  1.44s/it]


lfov_agent    639
rule_based_agent_0    802
rule_based_agent_1    667
rule_based_agent_2    750
Loaded 284726 actions.
284726 actions in the training set:
UP: 19.17% (54578)
DOWN: 26.11% (74338)
LEFT: 19.35% (55084)
RIGHT: 25.68% (73121)
BOMB: 8.41% (23941)
WAIT: 1.29% (3664)
[12, 1,   200] loss: 76.943
0.1789228942913538
[12, 1,   400] loss: 75.507
0.35874489860427217
[12, 1] loss: 75.901


100%|██████████| 60/60 [01:30<00:00,  1.51s/it]


lfov_agent    706
rule_based_agent_0    761
rule_based_agent_1    764
rule_based_agent_2    689
Loaded 300739 actions.
300739 actions in the training set:
UP: 19.31% (58075)
DOWN: 25.58% (76931)
LEFT: 19.64% (59060)
RIGHT: 25.80% (77602)
BOMB: 8.36% (25148)
WAIT: 1.30% (3923)
[13, 1,   200] loss: 79.963
0.1693960543860291
[13, 1,   400] loss: 79.615
0.33964334522625933
[13, 1] loss: 77.691


100%|██████████| 60/60 [01:27<00:00,  1.46s/it]


lfov_agent    663
rule_based_agent_0    707
rule_based_agent_1    791
rule_based_agent_2    737
Loaded 293376 actions.
293376 actions in the training set:
UP: 19.26% (56514)
DOWN: 26.24% (76993)
LEFT: 19.28% (56551)
RIGHT: 25.59% (75072)
BOMB: 8.37% (24554)
WAIT: 1.26% (3692)
[14, 1,   200] loss: 78.516
0.1736474694589878
[14, 1,   400] loss: 79.554
0.3481675392670157
[14, 1] loss: 78.189


100%|██████████| 60/60 [01:21<00:00,  1.37s/it]


lfov_agent    660
rule_based_agent_0    731
rule_based_agent_1    800
rule_based_agent_2    718
Loaded 283413 actions.
283413 actions in the training set:
UP: 19.23% (54492)
DOWN: 25.57% (72465)
LEFT: 19.64% (55670)
RIGHT: 25.74% (72962)
BOMB: 8.53% (24164)
WAIT: 1.29% (3660)
[15, 1,   200] loss: 77.611
0.1797518109613885
[15, 1,   400] loss: 77.159
0.3604068973547438
[15, 1] loss: 75.116


100%|██████████| 60/60 [01:26<00:00,  1.45s/it]


lfov_agent    650
rule_based_agent_0    796
rule_based_agent_1    742
rule_based_agent_2    676
Loaded 295542 actions.
295542 actions in the training set:
UP: 19.71% (58246)
DOWN: 25.85% (76405)
LEFT: 19.41% (57363)
RIGHT: 25.16% (74372)
BOMB: 8.58% (25349)
WAIT: 1.29% (3807)
[16, 1,   200] loss: 77.549
0.1723748232061771
[16, 1,   400] loss: 76.359
0.34561585155409386
[16, 1] loss: 76.783


100%|██████████| 60/60 [01:28<00:00,  1.47s/it]


lfov_agent    610
rule_based_agent_0    711
rule_based_agent_1    762
rule_based_agent_2    776
Loaded 295737 actions.
295737 actions in the training set:
UP: 19.40% (57376)
DOWN: 25.79% (76281)
LEFT: 19.65% (58122)
RIGHT: 25.57% (75617)
BOMB: 8.38% (24769)
WAIT: 1.21% (3572)
[17, 1,   200] loss: 77.975
0.1722611644806027
[17, 1,   400] loss: 78.863
0.3453879629535702
[17, 1] loss: 77.602


100%|██████████| 60/60 [01:28<00:00,  1.47s/it]


lfov_agent    668
rule_based_agent_0    785
rule_based_agent_1    688
rule_based_agent_2    702
Loaded 284315 actions.
284315 actions in the training set:
UP: 19.63% (55813)
DOWN: 25.98% (73853)
LEFT: 19.15% (54452)
RIGHT: 25.51% (72520)
BOMB: 8.37% (23789)
WAIT: 1.37% (3888)
[18, 1,   200] loss: 78.883
0.17918154159998592
[18, 1,   400] loss: 76.349
0.3592634929567557
[18, 1] loss: 75.507


100%|██████████| 60/60 [01:35<00:00,  1.60s/it]


lfov_agent    670
rule_based_agent_0    667
rule_based_agent_1    726
rule_based_agent_2    778
Loaded 292050 actions.
292050 actions in the training set:
UP: 19.28% (56307)
DOWN: 26.11% (76243)
LEFT: 19.39% (56625)
RIGHT: 25.64% (74882)
BOMB: 8.35% (24377)
WAIT: 1.24% (3616)
[19, 1,   200] loss: 78.297
0.17443588426639275
[19, 1,   400] loss: 77.295
0.3497483307652799
[19, 1] loss: 77.027


100%|██████████| 60/60 [01:26<00:00,  1.45s/it]


lfov_agent    548
rule_based_agent_0    750
rule_based_agent_1    719
rule_based_agent_2    818
Loaded 291567 actions.
291567 actions in the training set:
UP: 19.34% (56386)
DOWN: 25.23% (73551)
LEFT: 19.92% (58072)
RIGHT: 25.85% (75384)
BOMB: 8.43% (24581)
WAIT: 1.23% (3593)
[20, 1,   200] loss: 76.469
0.17472484883405873
[20, 1,   400] loss: 75.179
0.3503277119838665
[20, 1] loss: 74.905
Finished Training


In [12]:
torch.save(net.state_dict(), "models/model_weights_win.pth")


#weight_path_ref = "models/model_weights_pretrained_ref.pth"
#torch.save(net.state_dict(),weight_path_ref)

In [72]:
batch_size = 64
data_path = f"../data/"
weight_path = "models/model_weights.pth"
pretrained_weight_path = "models/model_weights_pretrained.pth"
num_epoch_per_run = 1

In [73]:
# pretrained_net = AgentNet()
# pretrained_net.load_state_dict(torch.load(pretrained_weight_path))
# torch.save(pretrained_net.state_dict(), weigth_path)

In [74]:
net = AgentNet()
pretrained_net = AgentNet()
pretrained_net.load_state_dict(torch.load(pretrained_weight_path))
net.cnn.load_state_dict(pretrained_net.cnn.state_dict())
torch.save(net.state_dict(), weight_path)
optimizer = optim.AdamW(net.mlp.parameters(), lr=0.001)

In [75]:
optimizer = optim.AdamW(net.mlp.parameters(), lr=0.001)

EVENT_HORIZON = {
    MOVED_LEFT: 1,
    MOVED_RIGHT: 1,
    MOVED_UP: 1,
    MOVED_DOWN: 1,
    WAITED: 1,
    INVALID_ACTION: 1,
    BOMB_DROPPED: 1,
    BOMB_EXPLODED: 1,
    CRATE_DESTROYED: 5,
    COIN_FOUND: 3,
    COIN_COLLECTED: 8,
    KILLED_OPPONENT: 8,
    KILLED_SELF: 4,
    GOT_KILLED: 3,
    OPPONENT_ELIMINATED: 7,
    SURVIVED_ROUND: 5
}


REWARDS = {
    MOVED_LEFT: -0.5,  #-0.2
    MOVED_RIGHT: -0.5, #-0.2
    MOVED_UP: -0.5,  #-0.2
    MOVED_DOWN: -0.5,  #-0.2    
    WAITED: 4,
    INVALID_ACTION: -2,
    BOMB_DROPPED: 0,
    BOMB_EXPLODED: 1,
    CRATE_DESTROYED: 2,
    COIN_FOUND: 0,
    COIN_COLLECTED: 6,
    KILLED_OPPONENT: 1,
    KILLED_SELF: -10,
    GOT_KILLED: -1,
    OPPONENT_ELIMINATED: 2,
    SURVIVED_ROUND: 9
}

In [76]:
action_probs = np.zeros((6,500))


In [79]:
for run in range(10):
    # os.system("rm -rf ../rule_based_agent/data")
    os.system("rm -rf ../data")
    #os.system("cd ../..; python main.py play --no-gui --agents basic_agent basic_agent basic_agent basic_agent --train 4 --n-rounds 30 --scenario loot-crate")
    os.system("cd ../..; python main.py play --no-gui --agents basic_agent basic_agent basic_agent basic_agent  --train 4 --n-rounds 30 --scenario loot-crate")

    # os.system("cd ../..; python main.py play --no-gui --agents rule_based_agent rule_based_agent rule_based_agent rule_based_agent --train 4 --n-rounds 10 --scenario classic")

    trainset = BombermanDataset(data_path)
    N = len(trainset)
    print(f"{N} actions in the training set:")
    for i in range(6):
        n = np.sum([t == i for t in trainset.actions])
        print(f"{ACTION_MAP_INV[i]}: {n/N*100:.2f}% ({n})")
        action_probs[i,run] = n/N
    
    
    # sampler = BatchSampler(WeightedRandomSampler(trainset.weights, len(trainset), replacement=True,), batch_size, False)
    # trainloader = torch.utils.data.DataLoader(trainset, batch_sampler=sampler)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    for epoch in range(num_epoch_per_run):
        running_loss = 0.0
        running = 0
        for i, batch in enumerate(trainloader, 0):
        # for i, batch in enumerate(sampler, 0):
            print(i, end="\r")
            # get the inputs; data is a list of [inputs, labels]
            (channels, features), actions, rewards = batch

            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = net(channels, features)
            # loss = criterion(outputs, labels)
            log_logits = F.log_softmax(outputs, dim=-1)
            log_probs = log_logits.gather(1, actions[:,np.newaxis])
            """
            if i == 0:
                print(actions, rewards, log_probs)
                print(-log_probs * rewards)
            """
            loss = torch.mean(-log_probs * rewards)
            loss.backward()
            optimizer.step()
    
            # print statistics
            running_loss += loss.item()
            running += 1
            if i % 200 == 199:    # print every 2000 mini-batches
                print(f'[{run + 1}, {epoch + 1}, {i + 1:5d}] loss: {running_loss / running:.3f}')
                torch.save(net.state_dict(), weight_path)
                running_loss = 0.0
                running = 0
        if running > 0:
            print(f'[{run + 1}, {epoch + 1}] loss: {running_loss / running:.3f}')
    torch.save(net.state_dict(), weight_path)

print('Finished Training')

100%|██████████| 30/30 [00:58<00:00,  1.94s/it]


Loaded 298683 actions.
298683 actions in the training set:
UP: 15.22% (45452)
DOWN: 34.32% (102520)
LEFT: 16.40% (48982)
RIGHT: 26.48% (79098)
BOMB: 6.45% (19274)
WAIT: 1.12% (3357)
[1, 1,   200] loss: 20.545
[1, 1,   400] loss: 17.363
[1, 1,   600] loss: 15.881
[1, 1,   800] loss: 15.734
[1, 1,  1000] loss: 14.239
[1, 1,  1200] loss: 13.750
[1, 1,  1400] loss: 13.503
[1, 1,  1600] loss: 13.106
[1, 1,  1800] loss: 12.162
[1, 1,  2000] loss: 12.178
[1, 1,  2200] loss: 11.820
[1, 1,  2400] loss: 11.852
[1, 1,  2600] loss: 11.281
[1, 1,  2800] loss: 11.148
[1, 1,  3000] loss: 10.654
[1, 1,  3200] loss: 11.071
[1, 1,  3400] loss: 10.446
[1, 1,  3600] loss: 10.402
[1, 1,  3800] loss: 9.904
[1, 1,  4000] loss: 9.635
[1, 1,  4200] loss: 9.458
[1, 1,  4400] loss: 9.511
[1, 1,  4600] loss: 9.639
[1, 1] loss: 9.649


100%|██████████| 30/30 [00:58<00:00,  1.94s/it]


Loaded 275019 actions.
275019 actions in the training set:
UP: 14.84% (40818)
DOWN: 35.45% (97483)
LEFT: 16.15% (44411)
RIGHT: 26.31% (72357)
BOMB: 6.03% (16597)
WAIT: 1.22% (3353)
[2, 1,   200] loss: 15.642
[2, 1,   400] loss: 12.522
[2, 1,   600] loss: 10.814
[2, 1,   800] loss: 10.197
[2, 1,  1000] loss: 9.779
[2, 1,  1200] loss: 8.961
[2, 1,  1400] loss: 8.457
[2, 1,  1600] loss: 8.409
[2, 1,  1800] loss: 8.007
[2, 1,  2000] loss: 7.834
[2, 1,  2200] loss: 7.641
[2, 1,  2400] loss: 7.661
[2, 1,  2600] loss: 7.253
[2, 1,  2800] loss: 6.938
[2, 1,  3000] loss: 6.859
[2, 1,  3200] loss: 7.091
[2, 1,  3400] loss: 7.851
[2, 1,  3600] loss: 6.596
[2, 1,  3800] loss: 7.056
[2, 1,  4000] loss: 6.786
[2, 1,  4200] loss: 6.336
[2, 1] loss: 5.841


100%|██████████| 30/30 [01:02<00:00,  2.07s/it]


Loaded 296943 actions.
296943 actions in the training set:
UP: 14.74% (43777)
DOWN: 34.50% (102442)
LEFT: 15.75% (46779)
RIGHT: 27.90% (82858)
BOMB: 5.79% (17200)
WAIT: 1.31% (3887)
[3, 1,   200] loss: 13.117
[3, 1,   400] loss: 10.529
[3, 1,   600] loss: 8.514
[3, 1,   800] loss: 7.821
[3, 1,  1000] loss: 7.589
[3, 1,  1200] loss: 7.152
[3, 1,  1400] loss: 6.555
[3, 1,  1600] loss: 6.769
[3, 1,  1800] loss: 6.426
[3, 1,  2000] loss: 6.444
[3, 1,  2200] loss: 6.472
[3, 1,  2400] loss: 6.159
[3, 1,  2600] loss: 5.700
[3, 1,  2800] loss: 6.195
[3, 1,  3000] loss: 6.146
[3, 1,  3200] loss: 6.141
[3, 1,  3400] loss: 5.716
[3, 1,  3600] loss: 5.535
[3, 1,  3800] loss: 5.579
[3, 1,  4000] loss: 5.579
[3, 1,  4200] loss: 6.154
[3, 1,  4400] loss: 5.372
[3, 1,  4600] loss: 5.222
[3, 1] loss: 6.239


100%|██████████| 30/30 [00:53<00:00,  1.80s/it]


Loaded 288673 actions.
288673 actions in the training set:
UP: 14.39% (41553)
DOWN: 41.43% (119600)
LEFT: 15.19% (43854)
RIGHT: 22.47% (64863)
BOMB: 5.37% (15504)
WAIT: 1.14% (3299)
[4, 1,   200] loss: 10.355
[4, 1,   400] loss: 8.255
[4, 1,   600] loss: 7.762
[4, 1,   800] loss: 6.706
[4, 1,  1000] loss: 6.530
[4, 1,  1200] loss: 6.148
[4, 1,  1400] loss: 5.509
[4, 1,  1600] loss: 5.585
[4, 1,  1800] loss: 5.632
[4, 1,  2000] loss: 5.073
[4, 1,  2200] loss: 4.864
[4, 1,  2400] loss: 5.536
[4, 1,  2600] loss: 5.197
[4, 1,  2800] loss: 4.650
[4, 1,  3000] loss: 5.106
[4, 1,  3200] loss: 5.531
[4, 1,  3400] loss: 4.926
[4, 1,  3600] loss: 4.624
[4, 1,  3800] loss: 4.378
[4, 1,  4000] loss: 4.987
[4, 1,  4200] loss: 4.764
[4, 1,  4400] loss: 4.636
[4, 1] loss: 4.888


100%|██████████| 30/30 [00:52<00:00,  1.76s/it]


Loaded 278948 actions.
278948 actions in the training set:
UP: 13.04% (36378)
DOWN: 39.66% (110624)
LEFT: 14.77% (41196)
RIGHT: 26.79% (74744)
BOMB: 4.77% (13313)
WAIT: 0.97% (2693)
[5, 1,   200] loss: 11.025
[5, 1,   400] loss: 8.949
[5, 1,   600] loss: 7.630
[5, 1,   800] loss: 7.161
[5, 1,  1000] loss: 6.631
[5, 1,  1200] loss: 6.287
[5, 1,  1400] loss: 6.397
[5, 1,  1600] loss: 5.681
[5, 1,  1800] loss: 6.000
[5, 1,  2000] loss: 6.066
[5, 1,  2200] loss: 5.989
[5, 1,  2400] loss: 5.975
[5, 1,  2600] loss: 5.652
[5, 1,  2800] loss: 5.766
[5, 1,  3000] loss: 5.277
[5, 1,  3200] loss: 5.607
[5, 1,  3400] loss: 5.522
[5, 1,  3600] loss: 5.495
[5, 1,  3800] loss: 5.514
[5, 1,  4000] loss: 5.363
[5, 1,  4200] loss: 5.473
[5, 1] loss: 5.809


100%|██████████| 30/30 [00:51<00:00,  1.71s/it]


Loaded 261623 actions.
261623 actions in the training set:
UP: 15.45% (40432)
DOWN: 39.18% (102493)
LEFT: 14.66% (38343)
RIGHT: 23.93% (62594)
BOMB: 5.39% (14107)
WAIT: 1.40% (3654)
[6, 1,   200] loss: 10.541
[6, 1,   400] loss: 8.960
[6, 1,   600] loss: 7.116
[6, 1,   800] loss: 6.318
[6, 1,  1000] loss: 6.269
[6, 1,  1200] loss: 5.886
[6, 1,  1400] loss: 6.083
[6, 1,  1600] loss: 5.624
[6, 1,  1800] loss: 5.623
[6, 1,  2000] loss: 5.538
[6, 1,  2200] loss: 5.588
[6, 1,  2400] loss: 5.382
[6, 1,  2600] loss: 5.600
[6, 1,  2800] loss: 5.469
[6, 1,  3000] loss: 5.517
[6, 1,  3200] loss: 5.241
[6, 1,  3400] loss: 5.374
[6, 1,  3600] loss: 5.240
[6, 1,  3800] loss: 5.201
[6, 1,  4000] loss: 5.039
[6, 1] loss: 5.375


100%|██████████| 30/30 [00:55<00:00,  1.83s/it]


Loaded 266770 actions.
266770 actions in the training set:
UP: 15.18% (40502)
DOWN: 41.63% (111067)
LEFT: 13.93% (37161)
RIGHT: 22.58% (60247)
BOMB: 5.43% (14481)
WAIT: 1.24% (3312)
[7, 1,   200] loss: 10.172
[7, 1,   400] loss: 8.596
[7, 1,   600] loss: 7.097
[7, 1,   800] loss: 6.640
[7, 1,  1000] loss: 6.194
[7, 1,  1200] loss: 6.160
[7, 1,  1400] loss: 5.763
[7, 1,  1600] loss: 5.387
[7, 1,  1800] loss: 5.189
[7, 1,  2000] loss: 5.195
[7, 1,  2200] loss: 5.018
[7, 1,  2400] loss: 5.369
[7, 1,  2600] loss: 4.919
[7, 1,  2800] loss: 5.100
[7, 1,  3000] loss: 5.002
[7, 1,  3200] loss: 5.368
[7, 1,  3400] loss: 4.856
[7, 1,  3600] loss: 5.708
[7, 1,  3800] loss: 5.186
[7, 1,  4000] loss: 4.962
[7, 1] loss: 4.603


100%|██████████| 30/30 [00:52<00:00,  1.74s/it]


Loaded 272686 actions.
272686 actions in the training set:
UP: 14.62% (39861)
DOWN: 41.77% (113900)
LEFT: 11.12% (30333)
RIGHT: 26.36% (71867)
BOMB: 4.95% (13500)
WAIT: 1.18% (3225)
[8, 1,   200] loss: 8.618
[8, 1,   400] loss: 7.104
[8, 1,   600] loss: 6.021
[8, 1,   800] loss: 5.576
[8, 1,  1000] loss: 5.072
[8, 1,  1200] loss: 4.652
[8, 1,  1400] loss: 4.711
[8, 1,  1600] loss: 4.287
[8, 1,  1800] loss: 4.299
[8, 1,  2000] loss: 4.500
[8, 1,  2200] loss: 4.388
[8, 1,  2400] loss: 4.180
[8, 1,  2600] loss: 4.070
2624

KeyboardInterrupt: 

In [39]:
a = np.zeros((1,7))
print(a)
print(a[0])
print(a[0,0])


[[0. 0. 0. 0. 0. 0. 0.]]
[0. 0. 0. 0. 0. 0. 0.]
0.0


In [51]:
a = np.array([1,2,3,4,5,6,7])-1
print(a)
print(a[1:7])

[0 1 2 3 4 5 6]
[1 2 3 4 5 6]



    extensions = np.zeros(4)
    if pos[0]+4 <= 16:             #field, box_map, explosion_map, others_map, coin_map
        if field[pos[0]+4, pos[1]] == 1:
            extensions[0] = 1
        if box_map[pos[0]+4, pos[1]] ==1:
            extensions[0] = 2
        if explosion_map[pos[0]+4, pos[1]] ==1:
            extensions[0] = 3
        if others_map[pos[0]+4, pos[1]] == 1:
            extensions[0] = 4
        if coin_map[pos[0]+4, pos[1]] ==1:
            extensions[0] =5

    if pos[0]-4 >= 0:             #field, box_map, explosion_map, others_map, coin_map
        if field[pos[0]-4, pos[1]] == 1:
            extensions[1] = 1
        if box_map[pos[0]-4, pos[1]] ==1:
            extensions[1] = 2
        if explosion_map[pos[0]+4, pos[1]] ==1:
            extensions[1] = 3
        if others_map[pos[0]-4, pos[1]] == 1:
            extensions[1] = 4
        if coin_map[pos[0]-4, pos[1]] ==1:
            extensions[1] =5

    if pos[1]+4 <= 16:             #field, box_map, explosion_map, others_map, coin_map
        if field[pos[0], pos[1]+4] == 1:
            extensions[2] = 1
        if box_map[pos[0], pos[1]+4] ==1:
            extensions[2] = 2
        if explosion_map[pos[0], pos[1]+4] ==1:
            extensions[2] = 3
        if others_map[pos[0], pos[1]+4] == 1:
            extensions[2] = 4
        if coin_map[pos[0], pos[1]+4] ==1:
            extensions[2] = 5

    if pos[1]-4 >= 0:             #field, box_map, explosion_map, others_map, coin_map
        if field[pos[0], pos[1]-4] == 1:
            extensions[3] = 1
        if box_map[pos[0], pos[1]-4] ==1:
            extensions[3] = 2
        if explosion_map[pos[0], pos[1]-4] ==1:
            extensions[3] = 3
        if others_map[pos[0], pos[1]-4] == 1:
            extensions[3] = 4
        if coin_map[pos[0], pos[1]-4] ==1:
            extensions[3] = 5

