In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append(r"../..")


from torch import nn
from torch.nn import functional as F
import torch.optim as optim
import torch
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt

from callbacks import state_to_features, ACTION_MAP, ACTION_MAP_INV
from networks import AgentNet

from events import *

In [2]:
EVENT_HORIZON = {
    MOVED_LEFT: 1,
    MOVED_RIGHT: 1,
    MOVED_UP: 1,
    MOVED_DOWN: 1,
    WAITED: 2,
    INVALID_ACTION: 1,
    BOMB_DROPPED: 2,
    BOMB_EXPLODED: 2,
    CRATE_DESTROYED: 8,
    COIN_FOUND: 3,
    COIN_COLLECTED: 5,
    KILLED_OPPONENT: 8,
    KILLED_SELF: 3,
    GOT_KILLED: 3,
    OPPONENT_ELIMINATED: 0,
    SURVIVED_ROUND: 8
}
"""
REWARDS = {
    MOVED_LEFT: 1,
    MOVED_RIGHT: 1,
    MOVED_UP: 1,
    MOVED_DOWN: 1,
    WAITED: 0,
    INVALID_ACTION: -5,
    BOMB_DROPPED: 1,
    BOMB_EXPLODED: 0,
    CRATE_DESTROYED: 4,
    COIN_FOUND: 0,
    COIN_COLLECTED: 5,
    KILLED_OPPONENT: 5,
    KILLED_SELF: -8,
    GOT_KILLED: -8,
    OPPONENT_ELIMINATED: 0,
    SURVIVED_ROUND: 0
}
"""
REWARDS = {
    MOVED_LEFT: 100,
    MOVED_RIGHT: 100,
    MOVED_UP: 100,
    MOVED_DOWN: 100,
    WAITED: -100,
    INVALID_ACTION: -10,
    BOMB_DROPPED: 0,
    BOMB_EXPLODED: 0,
    CRATE_DESTROYED: 0,
    COIN_FOUND: 0,
    COIN_COLLECTED: 0,
    KILLED_OPPONENT: 0,
    KILLED_SELF: 0,
    GOT_KILLED: 0,
    OPPONENT_ELIMINATED: 0,
    SURVIVED_ROUND: 0
}

# REWARDS = {k: v + 10 for k, v in REWARDS.items()}



In [3]:
# Specify directory with training data
data_path = "../data/"

In [2]:
from torch.utils.data import Dataset


class BombermanDataset(Dataset):
    def __init__(self, states_dir, max_game_step=200, min_reward=1):
        self.features = []
        self.actions = []
        self.cum_rewards = []
        self.total_number = 0
        for states_file in os.listdir(states_dir):

            if not states_file.endswith('.pickle'):
                continue
            with open(os.path.join(states_dir, states_file), "rb") as f:
                data = pickle.load(f)
            self.total_number += len(data['game_state'])
            last_round = -1
            running_horizons = []
            running_rewards = []
            for i in range(len(data['game_state'])-1,-1,-1):
                game_state = data['game_state'][i]               

                    
                if last_round != game_state['round']:
                    last_round = game_state['round']
                    running_horizons = []
                    running_rewards = []
                if game_state['step'] > max_game_step:
                    continue
                    
                # determine if action should be included in training set
                
                horizons = [EVENT_HORIZON[e] for e in data['events'][i]]
                rewards = [REWARDS[e] for e in data['events'][i]]
                cr = np.where(data['events'] == "CRATE_DESTROYED")
                #print(cr)
                #rewards[np.array(cr[-1:], dtype = "int")] == 0
                #for j in cr[1:]:
                #    rewards[j] = 0             

                running_rewards = [r for j,r in enumerate(running_rewards) if running_horizons[j]>1]
                running_horizons = [h-1 for h in running_horizons if h > 1]
                running_rewards.extend(rewards)
                running_horizons.extend(horizons)
                #if "INVALID_ACTION" in data['events'][i]:
                #    continue   
                self.cum_rewards.append(np.sum(running_rewards))

                last_action = None if game_state['step'] == 1 else data['action'][i-1]
                #print("LAST ACTION:")
                #print(last_action)
                
                channels, features, act_map = state_to_features(game_state, last_action=last_action)
                action = ACTION_MAP[act_map[data['action'][i]]]
                self.features.append((
                    torch.tensor(channels[np.newaxis], dtype=torch.float),
                    torch.tensor(features, dtype=torch.float)
                ))
                self.actions.append(action)
                
                    
        self.cum_rewards = torch.tensor(self.cum_rewards, dtype=torch.float)
        self.n = np.array([np.sum([t == i for t in self.actions]) for i in range(6)])
        self.weights = np.array([1/self.n[t] for t in self.actions])
        self.proportion_selected = len(self.actions)/self.total_number
        print(f"Loaded {len(self.actions)} actions.")

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.actions[idx], self.cum_rewards[idx]

In [21]:
batch_size = 128
data_path = f"../data/"
weight_path = "models/model_weights_pretrained.pth"
num_epoch_per_run = 1

# optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=0.001)
net = AgentNet()
torch.save(net.state_dict(), weight_path)
optimizer = optim.AdamW(net.parameters(), lr=0.001)

In [22]:
optimizer = optim.AdamW(net.parameters(), lr=0.001)
batch_size = 128


In [3]:
EVENT_HORIZON = {
    MOVED_LEFT: 1,
    MOVED_RIGHT: 1,
    MOVED_UP: 1,
    MOVED_DOWN: 1,
    WAITED: 1,
    INVALID_ACTION: 1,
    BOMB_DROPPED: 1,
    BOMB_EXPLODED: 1,
    CRATE_DESTROYED: 6,
    COIN_FOUND: 1,
    COIN_COLLECTED: 5,
    KILLED_OPPONENT: 6,
    KILLED_SELF: 5,
    GOT_KILLED: 6,
    OPPONENT_ELIMINATED: 4,
    SURVIVED_ROUND: 100
}


REWARDS = {
    MOVED_LEFT: 5, #1
    MOVED_RIGHT: 5, #1
    MOVED_UP: 5,   #1
    MOVED_DOWN: 5,  #1
    WAITED: 5,  #1
    INVALID_ACTION: -120, #-7
    BOMB_DROPPED: 5,  #1
    BOMB_EXPLODED: 0,
    CRATE_DESTROYED: 120,
    COIN_FOUND: 0,
    COIN_COLLECTED: 120,
    KILLED_OPPONENT: 60,
    KILLED_SELF: -120, #-12
    GOT_KILLED: -120,
    OPPONENT_ELIMINATED: 5,
    SURVIVED_ROUND: 200 #20
}

In [24]:
for run in range(10):
    # os.system("rm -rf ../rule_based_agent/data")
    os.system("rm -rf ../data")
    os.system("cd ../..; python main.py play --no-gui --agents rule_based_agent rule_based_agent rule_based_agent rule_based_agent --train 4 --n-rounds 40 --scenario loot-crate")

    trainset = BombermanDataset(data_path)
    N = len(trainset)
    print(f"{N} actions in the training set:")
    for i in range(6):
        n = np.sum([t == i for t in trainset.actions])
        print(f"{ACTION_MAP_INV[i]}: {n/N*100:.2f}% ({n})")
    # sampler = BatchSampler(WeightedRandomSampler(trainset.weights, len(trainset), replacement=True,), batch_size, False)
    # trainloader = torch.utils.data.DataLoader(trainset, batch_sampler=sampler)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    for epoch in range(num_epoch_per_run):
        running_loss = 0.0
        running = 0
        for i, batch in enumerate(trainloader, 0):
        # for i, batch in enumerate(sampler, 0):
            print(i, end="\r")
            # get the inputs; data is a list of [inputs, labels]
            (channels, features), actions, rewards = batch

            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = net(channels, features)
            # loss = criterion(outputs, labels)
            log_logits = F.log_softmax(outputs, dim=-1)
            log_probs = log_logits.gather(1, actions[:,np.newaxis])
            
            loss = torch.mean(-log_probs * rewards)
            loss.backward()
            optimizer.step()
    
            # print statistics
            running_loss += loss.item()
            running += 1
            if i % 200 == 199:    # print every 2000 mini-batches
                print(f'[{run + 1}, {epoch + 1}, {i + 1:5d}] loss: {running_loss / running:.3f}')
                torch.save(net.state_dict(), weight_path)
                running_loss = 0.0
                running = 0
                #print(log_probs)
                #print(rewards)
            if i*batch_size/N > 0.15:
                break
        if running > 0:
            print(f'[{run + 1}, {epoch + 1}] loss: {running_loss / running:.3f}')
    torch.save(net.state_dict(), weight_path)

print('Finished Training')

100%|██████████| 40/40 [01:04<00:00,  1.62s/it]


rule_based_agent_0    467
rule_based_agent_1    471
rule_based_agent_2    485
rule_based_agent_3    483
Loaded 529282 actions.
529282 actions in the training set:
UP: 19.57% (103561)
DOWN: 25.01% (132377)
LEFT: 19.84% (104996)
RIGHT: 25.32% (134024)
BOMB: 8.66% (45840)
WAIT: 1.60% (8484)
[1, 1,   200] loss: 318.752
[1, 1,   400] loss: 153.964
[1, 1,   600] loss: 138.011
[1, 1] loss: 117.551


100%|██████████| 40/40 [00:59<00:00,  1.48s/it]


rule_based_agent_0    468
rule_based_agent_1    472
rule_based_agent_2    488
rule_based_agent_3    481
Loaded 517242 actions.
517242 actions in the training set:
UP: 19.52% (100969)
DOWN: 25.07% (129665)
LEFT: 19.66% (101676)
RIGHT: 25.42% (131482)
BOMB: 8.68% (44871)
WAIT: 1.66% (8579)
[2, 1,   200] loss: 128.109
[2, 1,   400] loss: 116.567
[2, 1,   600] loss: 107.399
[2, 1] loss: 96.754


100%|██████████| 40/40 [01:03<00:00,  1.58s/it]


rule_based_agent_0    484
rule_based_agent_1    487
rule_based_agent_2    515
rule_based_agent_3    466
Loaded 532165 actions.
532165 actions in the training set:
UP: 19.48% (103672)
DOWN: 24.90% (132490)
LEFT: 19.97% (106288)
RIGHT: 25.49% (135629)
BOMB: 8.64% (45984)
WAIT: 1.52% (8102)
[3, 1,   200] loss: 120.316
[3, 1,   400] loss: 111.276
[3, 1,   600] loss: 103.465
[3, 1] loss: 96.600


100%|██████████| 40/40 [01:02<00:00,  1.55s/it]


rule_based_agent_0    511
rule_based_agent_1    514
rule_based_agent_2    440
rule_based_agent_3    463
Loaded 545976 actions.
545976 actions in the training set:
UP: 19.88% (108550)
DOWN: 25.34% (138351)
LEFT: 19.68% (107432)
RIGHT: 24.87% (135770)
BOMB: 8.58% (46830)
WAIT: 1.66% (9043)
[4, 1,   200] loss: 109.717
[4, 1,   400] loss: 102.905
[4, 1,   600] loss: 96.635
[4, 1] loss: 92.557


100%|██████████| 40/40 [01:29<00:00,  2.23s/it]


rule_based_agent_0    494
rule_based_agent_1    512
rule_based_agent_2    516
rule_based_agent_3    449
Loaded 521097 actions.
521097 actions in the training set:
UP: 19.36% (100892)
DOWN: 25.10% (130794)
LEFT: 19.78% (103064)
RIGHT: 25.50% (132905)
BOMB: 8.56% (44620)
WAIT: 1.69% (8822)
[5, 1,   200] loss: 106.748
[5, 1,   400] loss: 96.791
[5, 1,   600] loss: 91.922
[5, 1] loss: 88.345


100%|██████████| 40/40 [01:24<00:00,  2.12s/it]


rule_based_agent_0    473
rule_based_agent_1    498
rule_based_agent_2    502
rule_based_agent_3    480
Loaded 510504 actions.
510504 actions in the training set:
UP: 19.40% (99063)
DOWN: 24.98% (127509)
LEFT: 19.86% (101374)
RIGHT: 25.64% (130870)
BOMB: 8.56% (43702)
WAIT: 1.56% (7986)
[6, 1,   200] loss: 100.977
[6, 1,   400] loss: 91.007
[6, 1,   600] loss: 85.719


100%|██████████| 40/40 [01:25<00:00,  2.14s/it]


rule_based_agent_0    472
rule_based_agent_1    468
rule_based_agent_2    562
rule_based_agent_3    475
Loaded 524460 actions.
524460 actions in the training set:
UP: 19.40% (101723)
DOWN: 24.98% (131019)
LEFT: 19.75% (103564)
RIGHT: 25.50% (133735)
BOMB: 8.74% (45817)
WAIT: 1.64% (8602)
[7, 1,   200] loss: 102.720
[7, 1,   400] loss: 91.762
[7, 1,   600] loss: 83.406
[7, 1] loss: 81.811


100%|██████████| 40/40 [01:47<00:00,  2.68s/it]


rule_based_agent_0    530
rule_based_agent_1    457
rule_based_agent_2    513
rule_based_agent_3    453
Loaded 524907 actions.
524907 actions in the training set:
UP: 19.25% (101040)
DOWN: 25.37% (133179)
LEFT: 19.78% (103842)
RIGHT: 25.50% (133838)
BOMB: 8.54% (44845)
WAIT: 1.56% (8163)
[8, 1,   200] loss: 96.309
[8, 1,   400] loss: 88.023
[8, 1,   600] loss: 83.030
[8, 1] loss: 75.959


100%|██████████| 40/40 [01:45<00:00,  2.65s/it]


rule_based_agent_0    577
rule_based_agent_1    496
rule_based_agent_2    442
rule_based_agent_3    428
Loaded 520692 actions.
520692 actions in the training set:
UP: 20.03% (104279)
DOWN: 25.72% (133910)
LEFT: 19.19% (99937)
RIGHT: 25.01% (130203)
BOMB: 8.49% (44219)
WAIT: 1.56% (8144)
[9, 1,   200] loss: 100.961
[9, 1,   400] loss: 86.638
[9, 1,   600] loss: 79.860
[9, 1] loss: 69.878


100%|██████████| 40/40 [01:07<00:00,  1.69s/it]


rule_based_agent_0    465
rule_based_agent_1    501
rule_based_agent_2    499
rule_based_agent_3    508
Loaded 506860 actions.
506860 actions in the training set:
UP: 19.83% (100530)
DOWN: 25.10% (127240)
LEFT: 19.36% (98129)
RIGHT: 25.36% (128537)
BOMB: 8.70% (44083)
WAIT: 1.65% (8341)
[10, 1,   200] loss: 96.255
[10, 1,   400] loss: 82.778
[10, 1] loss: 76.884
Finished Training


In [25]:
torch.save(net.state_dict(), "models/model_weights_pre")

In [26]:
optimizer = optim.AdamW(net.parameters(), lr=0.0001)
batch_size = 256 

In [27]:
for run in range(10):
    # os.system("rm -rf ../rule_based_agent/data")
    os.system("rm -rf ../data")
    os.system("cd ../..; python main.py play --no-gui --agents rule_based_agent rule_based_agent rule_based_agent rule_based_agent --train 4 --n-rounds 50 --scenario loot-crate")

    trainset = BombermanDataset(data_path)
    N = len(trainset)
    print(f"{N} actions in the training set:")
    for i in range(6):
        n = np.sum([t == i for t in trainset.actions])
        print(f"{ACTION_MAP_INV[i]}: {n/N*100:.2f}% ({n})")
    # sampler = BatchSampler(WeightedRandomSampler(trainset.weights, len(trainset), replacement=True,), batch_size, False)
    # trainloader = torch.utils.data.DataLoader(trainset, batch_sampler=sampler)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    for epoch in range(num_epoch_per_run):
        running_loss = 0.0
        running = 0
        for i, batch in enumerate(trainloader, 0):
        # for i, batch in enumerate(sampler, 0):
            print(i, end="\r")
            # get the inputs; data is a list of [inputs, labels]
            (channels, features), actions, rewards = batch

            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = net(channels, features)
            # loss = criterion(outputs, labels)
            log_logits = F.log_softmax(outputs, dim=-1)
            log_probs = log_logits.gather(1, actions[:,np.newaxis])
            
            loss = torch.mean(-log_probs * rewards)
            loss.backward()
            optimizer.step()
    
            # print statistics
            running_loss += loss.item()
            running += 1
            if i % 200 == 199:    # print every 2000 mini-batches
                print(f'[{run + 1}, {epoch + 1}, {i + 1:5d}] loss: {running_loss / running:.3f}')
                torch.save(net.state_dict(), weight_path)
                running_loss = 0.0
                running = 0
                #print(log_probs)
                #print(rewards)
            if i*batch_size/N > 0.25:
                break
        if running > 0:
            print(f'[{run + 1}, {epoch + 1}] loss: {running_loss / running:.3f}')
    torch.save(net.state_dict(), weight_path)

print('Finished Training')

100%|██████████| 50/50 [01:15<00:00,  1.51s/it]


rule_based_agent_0    638
rule_based_agent_1    580
rule_based_agent_2    561
rule_based_agent_3    633
Loaded 822416 actions.
822416 actions in the training set:
UP: 19.91% (163734)
DOWN: 25.51% (209824)
LEFT: 19.23% (158179)
RIGHT: 24.89% (204718)
BOMB: 8.64% (71082)
WAIT: 1.81% (14879)
[1, 1,   200] loss: 102.332
[1, 1,   400] loss: 93.699
[1, 1,   600] loss: 90.158
[1, 1,   800] loss: 84.334
[1, 1] loss: 84.447


100%|██████████| 50/50 [01:16<00:00,  1.53s/it]


rule_based_agent_0    645
rule_based_agent_1    602
rule_based_agent_2    603
rule_based_agent_3    579
Loaded 826954 actions.
826954 actions in the training set:
UP: 19.63% (162332)
DOWN: 25.31% (209305)
LEFT: 19.67% (162673)
RIGHT: 25.15% (208009)
BOMB: 8.66% (71576)
WAIT: 1.58% (13059)
[2, 1,   200] loss: 95.647
[2, 1,   400] loss: 88.308
[2, 1,   600] loss: 84.899
[2, 1,   800] loss: 81.897
[2, 1] loss: 76.017


100%|██████████| 50/50 [01:14<00:00,  1.49s/it]


rule_based_agent_0    597
rule_based_agent_1    621
rule_based_agent_2    571
rule_based_agent_3    590
Loaded 820584 actions.
820584 actions in the training set:
UP: 19.63% (161094)
DOWN: 25.33% (207830)
LEFT: 19.33% (158648)
RIGHT: 25.48% (209124)
BOMB: 8.66% (71065)
WAIT: 1.56% (12823)
[3, 1,   200] loss: 95.899
[3, 1,   400] loss: 87.939
[3, 1,   600] loss: 84.908
[3, 1,   800] loss: 79.018
[3, 1] loss: 87.280


100%|██████████| 50/50 [01:17<00:00,  1.54s/it]


rule_based_agent_0    571
rule_based_agent_1    617
rule_based_agent_2    597
rule_based_agent_3    605
Loaded 831811 actions.
831811 actions in the training set:
UP: 19.77% (164475)
DOWN: 25.00% (207934)
LEFT: 19.53% (162474)
RIGHT: 25.45% (211686)
BOMB: 8.64% (71857)
WAIT: 1.61% (13385)
[4, 1,   200] loss: 93.526
[4, 1,   400] loss: 88.204
[4, 1,   600] loss: 82.349
[4, 1,   800] loss: 78.860
[4, 1] loss: 74.927


100%|██████████| 50/50 [01:25<00:00,  1.70s/it]


rule_based_agent_0    576
rule_based_agent_1    592
rule_based_agent_2    584
rule_based_agent_3    645
Loaded 770813 actions.
770813 actions in the training set:
UP: 19.51% (150406)
DOWN: 25.19% (194175)
LEFT: 19.45% (149928)
RIGHT: 25.36% (195470)
BOMB: 8.71% (67151)
WAIT: 1.78% (13683)
[5, 1,   200] loss: 85.886
[5, 1,   400] loss: 78.865
[5, 1,   600] loss: 75.416
[5, 1] loss: 72.070


100%|██████████| 50/50 [02:26<00:00,  2.93s/it]


rule_based_agent_0    617
rule_based_agent_1    563
rule_based_agent_2    572
rule_based_agent_3    634
Loaded 817075 actions.
817075 actions in the training set:
UP: 19.83% (162007)
DOWN: 25.41% (207615)
LEFT: 19.41% (158622)
RIGHT: 24.98% (204074)
BOMB: 8.67% (70805)
WAIT: 1.71% (13952)
[6, 1,   200] loss: 93.683
[6, 1,   400] loss: 87.006
[6, 1,   600] loss: 82.005
[6, 1] loss: 76.662


100%|██████████| 50/50 [02:45<00:00,  3.32s/it]


rule_based_agent_0    636
rule_based_agent_1    586
rule_based_agent_2    527
rule_based_agent_3    680
Loaded 800842 actions.
800842 actions in the training set:
UP: 19.87% (159141)
DOWN: 25.66% (205524)
LEFT: 19.07% (152717)
RIGHT: 25.04% (200542)
BOMB: 8.60% (68852)
WAIT: 1.76% (14066)
[7, 1,   200] loss: 86.340
[7, 1,   400] loss: 80.279
[7, 1,   600] loss: 75.939
[7, 1] loss: 71.134


100%|██████████| 50/50 [02:58<00:00,  3.57s/it]


rule_based_agent_0    641
rule_based_agent_1    632
rule_based_agent_2    584
rule_based_agent_3    613
Loaded 825463 actions.
825463 actions in the training set:
UP: 20.06% (165581)
DOWN: 25.75% (212520)
LEFT: 19.14% (157967)
RIGHT: 24.73% (204165)
BOMB: 8.70% (71808)
WAIT: 1.63% (13422)
[8, 1,   200] loss: 89.819
[8, 1,   400] loss: 84.729
[8, 1,   600] loss: 78.933
[8, 1,   800] loss: 74.801
[8, 1] loss: 73.780


100%|██████████| 50/50 [03:00<00:00,  3.61s/it]


rule_based_agent_0    618
rule_based_agent_1    606
rule_based_agent_2    581
rule_based_agent_3    597
Loaded 811923 actions.
811923 actions in the training set:
UP: 19.70% (159965)
DOWN: 25.45% (206639)
LEFT: 19.38% (157357)
RIGHT: 25.13% (204045)
BOMB: 8.65% (70239)
WAIT: 1.68% (13678)
[9, 1,   200] loss: 88.642
[9, 1,   400] loss: 81.953
[9, 1,   600] loss: 77.387
[9, 1] loss: 73.486


100%|██████████| 50/50 [02:34<00:00,  3.09s/it]


rule_based_agent_0    625
rule_based_agent_1    614
rule_based_agent_2    557
rule_based_agent_3    618
Loaded 837987 actions.
837987 actions in the training set:
UP: 19.93% (166986)
DOWN: 25.69% (215248)
LEFT: 19.37% (162302)
RIGHT: 24.72% (207122)
BOMB: 8.67% (72661)
WAIT: 1.63% (13668)
[10, 1,   200] loss: 90.694
[10, 1,   400] loss: 83.622
[10, 1,   600] loss: 78.916
[10, 1,   800] loss: 75.312
[10, 1] loss: 72.207
Finished Training


In [28]:
torch.save(net.state_dict(), "models/model_weights_fine")

In [4]:
batch_size = 256
data_path = f"../data/"
weight_path = "models/model_weights.pth"
num_epoch_per_run = 1

# optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=0.001)
net = AgentNet()
#torch.save(net.state_dict(), weight_path)
#net.load_state_dict(torch.load("models/model_weights.pth"))

optimizer = optim.AdamW(net.parameters(), lr=0.00001)

In [5]:
for run in range(4):
    # os.system("rm -rf ../rule_based_agent/data")
    os.system("rm -rf ../data")
    os.system("cd ../..; python main.py play --no-gui --agents basic_agent rule_based_agent rule_based_agent rule_based_agent --train 4 --n-rounds 60 --scenario loot-crate --save-winner-game True")

    trainset = BombermanDataset(data_path)
    N = len(trainset)
    print(f"{N} actions in the training set:")
    for i in range(6):
        n = np.sum([t == i for t in trainset.actions])
        print(f"{ACTION_MAP_INV[i]}: {n/N*100:.2f}% ({n})")
    # sampler = BatchSampler(WeightedRandomSampler(trainset.weights, len(trainset), replacement=True,), batch_size, False)
    # trainloader = torch.utils.data.DataLoader(trainset, batch_sampler=sampler)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    for epoch in range(num_epoch_per_run):
        running_loss = 0.0
        running = 0
        for i, batch in enumerate(trainloader, 0):
        # for i, batch in enumerate(sampler, 0):
            print(i, end="\r")
            # get the inputs; data is a list of [inputs, labels]
            (channels, features), actions, rewards = batch

            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = net(channels, features)
            # loss = criterion(outputs, labels)
            log_logits = F.log_softmax(outputs, dim=-1)
            log_probs = log_logits.gather(1, actions[:,np.newaxis])
            
            loss = torch.mean(-log_probs * rewards)
            loss.backward()
            optimizer.step()
    
            # print statistics
            running_loss += loss.item()
            running += 1
            if i % 200 == 199:    # print every 2000 mini-batches
                print(f'[{run + 1}, {epoch + 1}, {i + 1:5d}] loss: {running_loss / running:.3f}')
                torch.save(net.state_dict(), weight_path)
                running_loss = 0.0
                running = 0
                #print(log_probs)
                #print(rewards)
            if i*batch_size/N > 0.5:
                break
        if running > 0:
            print(f'[{run + 1}, {epoch + 1}] loss: {running_loss / running:.3f}')
    torch.save(net.state_dict(), weight_path)

print('Finished Training')

100%|██████████| 60/60 [01:19<00:00,  1.32s/it]


basic_agent    609
rule_based_agent_0    794
rule_based_agent_1    752
rule_based_agent_2    735
Loaded 294534 actions.
294534 actions in the training set:
UP: 19.95% (58755)
DOWN: 25.74% (75803)
LEFT: 19.36% (57014)
RIGHT: 25.08% (73856)
BOMB: 8.48% (24976)
WAIT: 1.40% (4130)
[1, 1,   200] loss: 79.986
[1, 1,   400] loss: 79.855
[1, 1] loss: 78.929


100%|██████████| 60/60 [01:22<00:00,  1.38s/it]


basic_agent    639
rule_based_agent_0    725
rule_based_agent_1    779
rule_based_agent_2    754
Loaded 284585 actions.
284585 actions in the training set:
UP: 19.54% (55620)
DOWN: 26.09% (74247)
LEFT: 19.28% (54855)
RIGHT: 25.29% (71978)
BOMB: 8.41% (23930)
WAIT: 1.39% (3955)
[2, 1,   200] loss: 86.051
[2, 1,   400] loss: 86.843
[2, 1] loss: 85.522


100%|██████████| 60/60 [01:18<00:00,  1.31s/it]


basic_agent    584
rule_based_agent_0    825
rule_based_agent_1    767
rule_based_agent_2    770
Loaded 293086 actions.
293086 actions in the training set:
UP: 19.31% (56609)
DOWN: 25.94% (76036)
LEFT: 19.32% (56626)
RIGHT: 25.49% (74721)
BOMB: 8.58% (25154)
WAIT: 1.34% (3940)
[3, 1,   200] loss: 82.996
[3, 1,   400] loss: 81.969
[3, 1] loss: 81.688


100%|██████████| 60/60 [01:26<00:00,  1.44s/it]


basic_agent    630
rule_based_agent_0    677
rule_based_agent_1    851
rule_based_agent_2    742
Loaded 295559 actions.
295559 actions in the training set:
UP: 19.08% (56406)
DOWN: 25.40% (75073)
LEFT: 19.87% (58738)
RIGHT: 25.71% (76001)
BOMB: 8.58% (25357)
WAIT: 1.35% (3984)
[4, 1,   200] loss: 82.551
[4, 1,   400] loss: 83.323
[4, 1] loss: 83.069
Finished Training


In [6]:
torch.save(net.state_dict(), "models/model_weights_winner.pth")


#weight_path_ref = "models/model_weights_pretrained_ref.pth"
#torch.save(net.state_dict(),weight_path_ref)

## batch_size = 64
data_path = f"../data/"
weight_path = "models/model_weights.pth"
pretrained_weight_path = "models/model_weights_pretrained.pth"
num_epoch_per_run = 1

In [73]:
# pretrained_net = AgentNet()
# pretrained_net.load_state_dict(torch.load(pretrained_weight_path))
# torch.save(pretrained_net.state_dict(), weigth_path)

In [74]:
net = AgentNet()
pretrained_net = AgentNet()
pretrained_net.load_state_dict(torch.load(pretrained_weight_path))
net.cnn.load_state_dict(pretrained_net.cnn.state_dict())
torch.save(net.state_dict(), weight_path)
optimizer = optim.AdamW(net.mlp.parameters(), lr=0.001)

In [75]:
optimizer = optim.AdamW(net.mlp.parameters(), lr=0.001)

EVENT_HORIZON = {
    MOVED_LEFT: 1,
    MOVED_RIGHT: 1,
    MOVED_UP: 1,
    MOVED_DOWN: 1,
    WAITED: 1,
    INVALID_ACTION: 1,
    BOMB_DROPPED: 1,
    BOMB_EXPLODED: 1,
    CRATE_DESTROYED: 5,
    COIN_FOUND: 3,
    COIN_COLLECTED: 8,
    KILLED_OPPONENT: 8,
    KILLED_SELF: 4,
    GOT_KILLED: 3,
    OPPONENT_ELIMINATED: 7,
    SURVIVED_ROUND: 5
}


REWARDS = {
    MOVED_LEFT: -0.5,  #-0.2
    MOVED_RIGHT: -0.5, #-0.2
    MOVED_UP: -0.5,  #-0.2
    MOVED_DOWN: -0.5,  #-0.2    
    WAITED: 4,
    INVALID_ACTION: -2,
    BOMB_DROPPED: 0,
    BOMB_EXPLODED: 1,
    CRATE_DESTROYED: 2,
    COIN_FOUND: 0,
    COIN_COLLECTED: 6,
    KILLED_OPPONENT: 1,
    KILLED_SELF: -10,
    GOT_KILLED: -1,
    OPPONENT_ELIMINATED: 2,
    SURVIVED_ROUND: 9
}

In [76]:
action_probs = np.zeros((6,500))


In [79]:
for run in range(10):
    # os.system("rm -rf ../rule_based_agent/data")
    os.system("rm -rf ../data")
    #os.system("cd ../..; python main.py play --no-gui --agents basic_agent basic_agent basic_agent basic_agent --train 4 --n-rounds 30 --scenario loot-crate")
    os.system("cd ../..; python main.py play --no-gui --agents basic_agent basic_agent basic_agent basic_agent  --train 4 --n-rounds 30 --scenario loot-crate")

    # os.system("cd ../..; python main.py play --no-gui --agents rule_based_agent rule_based_agent rule_based_agent rule_based_agent --train 4 --n-rounds 10 --scenario classic")

    trainset = BombermanDataset(data_path)
    N = len(trainset)
    print(f"{N} actions in the training set:")
    for i in range(6):
        n = np.sum([t == i for t in trainset.actions])
        print(f"{ACTION_MAP_INV[i]}: {n/N*100:.2f}% ({n})")
        action_probs[i,run] = n/N
    
    
    # sampler = BatchSampler(WeightedRandomSampler(trainset.weights, len(trainset), replacement=True,), batch_size, False)
    # trainloader = torch.utils.data.DataLoader(trainset, batch_sampler=sampler)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    for epoch in range(num_epoch_per_run):
        running_loss = 0.0
        running = 0
        for i, batch in enumerate(trainloader, 0):
        # for i, batch in enumerate(sampler, 0):
            print(i, end="\r")
            # get the inputs; data is a list of [inputs, labels]
            (channels, features), actions, rewards = batch

            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = net(channels, features)
            # loss = criterion(outputs, labels)
            log_logits = F.log_softmax(outputs, dim=-1)
            log_probs = log_logits.gather(1, actions[:,np.newaxis])
            """
            if i == 0:
                print(actions, rewards, log_probs)
                print(-log_probs * rewards)
            """
            loss = torch.mean(-log_probs * rewards)
            loss.backward()
            optimizer.step()
    
            # print statistics
            running_loss += loss.item()
            running += 1
            if i % 200 == 199:    # print every 2000 mini-batches
                print(f'[{run + 1}, {epoch + 1}, {i + 1:5d}] loss: {running_loss / running:.3f}')
                torch.save(net.state_dict(), weight_path)
                running_loss = 0.0
                running = 0
        if running > 0:
            print(f'[{run + 1}, {epoch + 1}] loss: {running_loss / running:.3f}')
    torch.save(net.state_dict(), weight_path)

print('Finished Training')

100%|██████████| 30/30 [00:58<00:00,  1.94s/it]


Loaded 298683 actions.
298683 actions in the training set:
UP: 15.22% (45452)
DOWN: 34.32% (102520)
LEFT: 16.40% (48982)
RIGHT: 26.48% (79098)
BOMB: 6.45% (19274)
WAIT: 1.12% (3357)
[1, 1,   200] loss: 20.545
[1, 1,   400] loss: 17.363
[1, 1,   600] loss: 15.881
[1, 1,   800] loss: 15.734
[1, 1,  1000] loss: 14.239
[1, 1,  1200] loss: 13.750
[1, 1,  1400] loss: 13.503
[1, 1,  1600] loss: 13.106
[1, 1,  1800] loss: 12.162
[1, 1,  2000] loss: 12.178
[1, 1,  2200] loss: 11.820
[1, 1,  2400] loss: 11.852
[1, 1,  2600] loss: 11.281
[1, 1,  2800] loss: 11.148
[1, 1,  3000] loss: 10.654
[1, 1,  3200] loss: 11.071
[1, 1,  3400] loss: 10.446
[1, 1,  3600] loss: 10.402
[1, 1,  3800] loss: 9.904
[1, 1,  4000] loss: 9.635
[1, 1,  4200] loss: 9.458
[1, 1,  4400] loss: 9.511
[1, 1,  4600] loss: 9.639
[1, 1] loss: 9.649


100%|██████████| 30/30 [00:58<00:00,  1.94s/it]


Loaded 275019 actions.
275019 actions in the training set:
UP: 14.84% (40818)
DOWN: 35.45% (97483)
LEFT: 16.15% (44411)
RIGHT: 26.31% (72357)
BOMB: 6.03% (16597)
WAIT: 1.22% (3353)
[2, 1,   200] loss: 15.642
[2, 1,   400] loss: 12.522
[2, 1,   600] loss: 10.814
[2, 1,   800] loss: 10.197
[2, 1,  1000] loss: 9.779
[2, 1,  1200] loss: 8.961
[2, 1,  1400] loss: 8.457
[2, 1,  1600] loss: 8.409
[2, 1,  1800] loss: 8.007
[2, 1,  2000] loss: 7.834
[2, 1,  2200] loss: 7.641
[2, 1,  2400] loss: 7.661
[2, 1,  2600] loss: 7.253
[2, 1,  2800] loss: 6.938
[2, 1,  3000] loss: 6.859
[2, 1,  3200] loss: 7.091
[2, 1,  3400] loss: 7.851
[2, 1,  3600] loss: 6.596
[2, 1,  3800] loss: 7.056
[2, 1,  4000] loss: 6.786
[2, 1,  4200] loss: 6.336
[2, 1] loss: 5.841


100%|██████████| 30/30 [01:02<00:00,  2.07s/it]


Loaded 296943 actions.
296943 actions in the training set:
UP: 14.74% (43777)
DOWN: 34.50% (102442)
LEFT: 15.75% (46779)
RIGHT: 27.90% (82858)
BOMB: 5.79% (17200)
WAIT: 1.31% (3887)
[3, 1,   200] loss: 13.117
[3, 1,   400] loss: 10.529
[3, 1,   600] loss: 8.514
[3, 1,   800] loss: 7.821
[3, 1,  1000] loss: 7.589
[3, 1,  1200] loss: 7.152
[3, 1,  1400] loss: 6.555
[3, 1,  1600] loss: 6.769
[3, 1,  1800] loss: 6.426
[3, 1,  2000] loss: 6.444
[3, 1,  2200] loss: 6.472
[3, 1,  2400] loss: 6.159
[3, 1,  2600] loss: 5.700
[3, 1,  2800] loss: 6.195
[3, 1,  3000] loss: 6.146
[3, 1,  3200] loss: 6.141
[3, 1,  3400] loss: 5.716
[3, 1,  3600] loss: 5.535
[3, 1,  3800] loss: 5.579
[3, 1,  4000] loss: 5.579
[3, 1,  4200] loss: 6.154
[3, 1,  4400] loss: 5.372
[3, 1,  4600] loss: 5.222
[3, 1] loss: 6.239


100%|██████████| 30/30 [00:53<00:00,  1.80s/it]


Loaded 288673 actions.
288673 actions in the training set:
UP: 14.39% (41553)
DOWN: 41.43% (119600)
LEFT: 15.19% (43854)
RIGHT: 22.47% (64863)
BOMB: 5.37% (15504)
WAIT: 1.14% (3299)
[4, 1,   200] loss: 10.355
[4, 1,   400] loss: 8.255
[4, 1,   600] loss: 7.762
[4, 1,   800] loss: 6.706
[4, 1,  1000] loss: 6.530
[4, 1,  1200] loss: 6.148
[4, 1,  1400] loss: 5.509
[4, 1,  1600] loss: 5.585
[4, 1,  1800] loss: 5.632
[4, 1,  2000] loss: 5.073
[4, 1,  2200] loss: 4.864
[4, 1,  2400] loss: 5.536
[4, 1,  2600] loss: 5.197
[4, 1,  2800] loss: 4.650
[4, 1,  3000] loss: 5.106
[4, 1,  3200] loss: 5.531
[4, 1,  3400] loss: 4.926
[4, 1,  3600] loss: 4.624
[4, 1,  3800] loss: 4.378
[4, 1,  4000] loss: 4.987
[4, 1,  4200] loss: 4.764
[4, 1,  4400] loss: 4.636
[4, 1] loss: 4.888


100%|██████████| 30/30 [00:52<00:00,  1.76s/it]


Loaded 278948 actions.
278948 actions in the training set:
UP: 13.04% (36378)
DOWN: 39.66% (110624)
LEFT: 14.77% (41196)
RIGHT: 26.79% (74744)
BOMB: 4.77% (13313)
WAIT: 0.97% (2693)
[5, 1,   200] loss: 11.025
[5, 1,   400] loss: 8.949
[5, 1,   600] loss: 7.630
[5, 1,   800] loss: 7.161
[5, 1,  1000] loss: 6.631
[5, 1,  1200] loss: 6.287
[5, 1,  1400] loss: 6.397
[5, 1,  1600] loss: 5.681
[5, 1,  1800] loss: 6.000
[5, 1,  2000] loss: 6.066
[5, 1,  2200] loss: 5.989
[5, 1,  2400] loss: 5.975
[5, 1,  2600] loss: 5.652
[5, 1,  2800] loss: 5.766
[5, 1,  3000] loss: 5.277
[5, 1,  3200] loss: 5.607
[5, 1,  3400] loss: 5.522
[5, 1,  3600] loss: 5.495
[5, 1,  3800] loss: 5.514
[5, 1,  4000] loss: 5.363
[5, 1,  4200] loss: 5.473
[5, 1] loss: 5.809


100%|██████████| 30/30 [00:51<00:00,  1.71s/it]


Loaded 261623 actions.
261623 actions in the training set:
UP: 15.45% (40432)
DOWN: 39.18% (102493)
LEFT: 14.66% (38343)
RIGHT: 23.93% (62594)
BOMB: 5.39% (14107)
WAIT: 1.40% (3654)
[6, 1,   200] loss: 10.541
[6, 1,   400] loss: 8.960
[6, 1,   600] loss: 7.116
[6, 1,   800] loss: 6.318
[6, 1,  1000] loss: 6.269
[6, 1,  1200] loss: 5.886
[6, 1,  1400] loss: 6.083
[6, 1,  1600] loss: 5.624
[6, 1,  1800] loss: 5.623
[6, 1,  2000] loss: 5.538
[6, 1,  2200] loss: 5.588
[6, 1,  2400] loss: 5.382
[6, 1,  2600] loss: 5.600
[6, 1,  2800] loss: 5.469
[6, 1,  3000] loss: 5.517
[6, 1,  3200] loss: 5.241
[6, 1,  3400] loss: 5.374
[6, 1,  3600] loss: 5.240
[6, 1,  3800] loss: 5.201
[6, 1,  4000] loss: 5.039
[6, 1] loss: 5.375


100%|██████████| 30/30 [00:55<00:00,  1.83s/it]


Loaded 266770 actions.
266770 actions in the training set:
UP: 15.18% (40502)
DOWN: 41.63% (111067)
LEFT: 13.93% (37161)
RIGHT: 22.58% (60247)
BOMB: 5.43% (14481)
WAIT: 1.24% (3312)
[7, 1,   200] loss: 10.172
[7, 1,   400] loss: 8.596
[7, 1,   600] loss: 7.097
[7, 1,   800] loss: 6.640
[7, 1,  1000] loss: 6.194
[7, 1,  1200] loss: 6.160
[7, 1,  1400] loss: 5.763
[7, 1,  1600] loss: 5.387
[7, 1,  1800] loss: 5.189
[7, 1,  2000] loss: 5.195
[7, 1,  2200] loss: 5.018
[7, 1,  2400] loss: 5.369
[7, 1,  2600] loss: 4.919
[7, 1,  2800] loss: 5.100
[7, 1,  3000] loss: 5.002
[7, 1,  3200] loss: 5.368
[7, 1,  3400] loss: 4.856
[7, 1,  3600] loss: 5.708
[7, 1,  3800] loss: 5.186
[7, 1,  4000] loss: 4.962
[7, 1] loss: 4.603


100%|██████████| 30/30 [00:52<00:00,  1.74s/it]


Loaded 272686 actions.
272686 actions in the training set:
UP: 14.62% (39861)
DOWN: 41.77% (113900)
LEFT: 11.12% (30333)
RIGHT: 26.36% (71867)
BOMB: 4.95% (13500)
WAIT: 1.18% (3225)
[8, 1,   200] loss: 8.618
[8, 1,   400] loss: 7.104
[8, 1,   600] loss: 6.021
[8, 1,   800] loss: 5.576
[8, 1,  1000] loss: 5.072
[8, 1,  1200] loss: 4.652
[8, 1,  1400] loss: 4.711
[8, 1,  1600] loss: 4.287
[8, 1,  1800] loss: 4.299
[8, 1,  2000] loss: 4.500
[8, 1,  2200] loss: 4.388
[8, 1,  2400] loss: 4.180
[8, 1,  2600] loss: 4.070
2624

KeyboardInterrupt: 

In [39]:
a = np.zeros((1,7))
print(a)
print(a[0])
print(a[0,0])


[[0. 0. 0. 0. 0. 0. 0.]]
[0. 0. 0. 0. 0. 0. 0.]
0.0


In [51]:
a = np.array([1,2,3,4,5,6,7])-1
print(a)
print(a[1:7])

[0 1 2 3 4 5 6]
[1 2 3 4 5 6]



    extensions = np.zeros(4)
    if pos[0]+4 <= 16:             #field, box_map, explosion_map, others_map, coin_map
        if field[pos[0]+4, pos[1]] == 1:
            extensions[0] = 1
        if box_map[pos[0]+4, pos[1]] ==1:
            extensions[0] = 2
        if explosion_map[pos[0]+4, pos[1]] ==1:
            extensions[0] = 3
        if others_map[pos[0]+4, pos[1]] == 1:
            extensions[0] = 4
        if coin_map[pos[0]+4, pos[1]] ==1:
            extensions[0] =5

    if pos[0]-4 >= 0:             #field, box_map, explosion_map, others_map, coin_map
        if field[pos[0]-4, pos[1]] == 1:
            extensions[1] = 1
        if box_map[pos[0]-4, pos[1]] ==1:
            extensions[1] = 2
        if explosion_map[pos[0]+4, pos[1]] ==1:
            extensions[1] = 3
        if others_map[pos[0]-4, pos[1]] == 1:
            extensions[1] = 4
        if coin_map[pos[0]-4, pos[1]] ==1:
            extensions[1] =5

    if pos[1]+4 <= 16:             #field, box_map, explosion_map, others_map, coin_map
        if field[pos[0], pos[1]+4] == 1:
            extensions[2] = 1
        if box_map[pos[0], pos[1]+4] ==1:
            extensions[2] = 2
        if explosion_map[pos[0], pos[1]+4] ==1:
            extensions[2] = 3
        if others_map[pos[0], pos[1]+4] == 1:
            extensions[2] = 4
        if coin_map[pos[0], pos[1]+4] ==1:
            extensions[2] = 5

    if pos[1]-4 >= 0:             #field, box_map, explosion_map, others_map, coin_map
        if field[pos[0], pos[1]-4] == 1:
            extensions[3] = 1
        if box_map[pos[0], pos[1]-4] ==1:
            extensions[3] = 2
        if explosion_map[pos[0], pos[1]-4] ==1:
            extensions[3] = 3
        if others_map[pos[0], pos[1]-4] == 1:
            extensions[3] = 4
        if coin_map[pos[0], pos[1]-4] ==1:
            extensions[3] = 5

