In [2]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append(r"../..")


from torch import nn
from torch.nn import functional as F
import torch.optim as optim
import torch
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt

from callbacks import state_to_features, ACTION_MAP, ACTION_MAP_INV
from networks import AgentNet

from events import *

In [6]:
EVENT_HORIZON = {
    MOVED_LEFT: 1,
    MOVED_RIGHT: 1,
    MOVED_UP: 1,
    MOVED_DOWN: 1,
    WAITED: 2,
    INVALID_ACTION: 1,
    BOMB_DROPPED: 2,
    BOMB_EXPLODED: 2,
    CRATE_DESTROYED: 8,
    COIN_FOUND: 3,
    COIN_COLLECTED: 5,
    KILLED_OPPONENT: 8,
    KILLED_SELF: 3,
    GOT_KILLED: 3,
    OPPONENT_ELIMINATED: 0,
    SURVIVED_ROUND: 8
}
"""
REWARDS = {
    MOVED_LEFT: 1,
    MOVED_RIGHT: 1,
    MOVED_UP: 1,
    MOVED_DOWN: 1,
    WAITED: 0,
    INVALID_ACTION: -5,
    BOMB_DROPPED: 1,
    BOMB_EXPLODED: 0,
    CRATE_DESTROYED: 4,
    COIN_FOUND: 0,
    COIN_COLLECTED: 5,
    KILLED_OPPONENT: 5,
    KILLED_SELF: -8,
    GOT_KILLED: -8,
    OPPONENT_ELIMINATED: 0,
    SURVIVED_ROUND: 0
}
"""
REWARDS = {
    MOVED_LEFT: 100,
    MOVED_RIGHT: 100,
    MOVED_UP: 100,
    MOVED_DOWN: 100,
    WAITED: -100,
    INVALID_ACTION: -10,
    BOMB_DROPPED: 0,
    BOMB_EXPLODED: 0,
    CRATE_DESTROYED: 0,
    COIN_FOUND: 0,
    COIN_COLLECTED: 0,
    KILLED_OPPONENT: 0,
    KILLED_SELF: 0,
    GOT_KILLED: 0,
    OPPONENT_ELIMINATED: 0,
    SURVIVED_ROUND: 0
}

# REWARDS = {k: v + 10 for k, v in REWARDS.items()}



In [3]:
# Specify directory with training data
data_path = "../data/"

In [None]:
print(np.random.normal())
ns = np.zeros(1000)
for i in range(1000):
    ns[i] = np.random.normal()

plt.hist(ns)

In [None]:
 if i > 12:
                    if data['events'][i-4] == data['events'][i] and data['events'][i-4] == data['events'][i-8]:
                        if data['events'][i-3] == data['events'][i-7]:
                            if data['events'][i] in ["MOVED_LEFT", "MOVED_RIGHT", "MOVED_UP", "MOVED_DOWN"]:
                                if data['events'][i-3] in ["MOVED_LEFT", "MOVED_RIGHT", "MOVED_UP", "MOVED_DOWN"]:
                                    if data['events'][i-8] == data['events'][i-12]:
                                
                                        rewards[0] -= 1/50*(REWARDS["MOVED_LEFT"]+REWARDS["MOVED_RIGHT"]+REWARDS["MOVED_UP"]+REWARDS["MOVED_DOWN"])
                        
                if i>10:
                    if data['events'][i-2] == data['events'][i] and data['events'][i-4] == data['events'][i-2]:
                        if data['events'][i-4] == data['events'][i-6] and data['events'][i-8] == data['events'][i-6]:
                            if data['events'][i-3] == data['events'][i-5]:
                                if data['events'][i-3] in ["MOVED_LEFT", "MOVED_RIGHT", "MOVED_UP", "MOVED_DOWN"]:
                                    if data['events'][i-8] == data['events'][i-10]:
                                        if data['events'][i] in ["MOVED_LEFT", "MOVED_RIGHT", "MOVED_UP", "MOVED_DOWN"]:
                                            rewards[0] -= 1/50*(REWARDS["MOVED_LEFT"]+REWARDS["MOVED_RIGHT"]+REWARDS["MOVED_UP"]+REWARDS["MOVED_DOWN"])

In [3]:
from torch.utils.data import Dataset


class BombermanDataset(Dataset):
    def __init__(self, states_dir, max_game_step=200, min_reward=1):
        self.features = []
        self.actions = []
        self.cum_rewards = []
        self.total_number = 0
        for states_file in os.listdir(states_dir):

            if not states_file.endswith('.pickle'):
                continue
            with open(os.path.join(states_dir, states_file), "rb") as f:
                data = pickle.load(f)
            self.total_number += len(data['game_state'])
            last_round = -1
            running_horizons = []
            running_rewards = []
            for i in range(len(data['game_state'])-1,-1,-1):
                game_state = data['game_state'][i]               

                    
                if last_round != game_state['round']:
                    last_round = game_state['round']
                    running_horizons = []
                    running_rewards = []
                if game_state['step'] > max_game_step:
                    continue
                    
                # determine if action should be included in training set
                
                horizons = [EVENT_HORIZON[e] for e in data['events'][i]]
                rewards = [REWARDS[e] for e in data['events'][i]]
                cr = np.where(data['events'] == "CRATE_DESTROYED")
                #print(cr)
                #rewards[np.array(cr[-1:], dtype = "int")] == 0
                #for j in cr[1:]:
                #    rewards[j] = 0             

                running_rewards = [r for j,r in enumerate(running_rewards) if running_horizons[j]>1]
                running_horizons = [h-1 for h in running_horizons if h > 1]
                running_rewards.extend(rewards)
                running_horizons.extend(horizons)
                #if "INVALID_ACTION" in data['events'][i]:
                #    continue   
                self.cum_rewards.append(np.sum(running_rewards))

                last_action = None if game_state['step'] == 1 else data['action'][i-1]
                #print("LAST ACTION:")
                #print(last_action)
                
                channels, features, act_map = state_to_features(game_state, last_action=last_action)
                action = ACTION_MAP[act_map[data['action'][i]]]
                self.features.append((
                    torch.tensor(channels[np.newaxis], dtype=torch.float),
                    torch.tensor(features, dtype=torch.float)
                ))
                self.actions.append(action)
                
                    
        self.cum_rewards = torch.tensor(self.cum_rewards, dtype=torch.float)
        self.n = np.array([np.sum([t == i for t in self.actions]) for i in range(6)])
        self.weights = np.array([1/self.n[t] for t in self.actions])
        self.proportion_selected = len(self.actions)/self.total_number
        print(f"Loaded {len(self.actions)} actions.")

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.actions[idx], self.cum_rewards[idx]

In [4]:
os.system("rm -rf ../data")
os.system("cd ../..; python main.py play --no-gui --agents rule_based_agent rule_based_agent rule_based_agent rule_based_agent --train 4 --n-rounds 100 --scenario loot-crate")

100%|██████████| 100/100 [04:04<00:00,  2.44s/it]


0

In [None]:
trainset = BombermanDataset(data_path)

N = len(trainset)
print(f"{N} actions in the training set:")
for i in range(6):
    n = np.sum([t == i for t in trainset.actions])
    print(f"{ACTION_MAP_INV[i]}: {n/N*100:.2f}% ({n})")

In [None]:
batch_size = 32
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)

In [None]:
net = AgentNet()
optimizer = optim.Adam(net.parameters(), lr=0.001)

In [None]:
os.makedirs("models", exist_ok=True)
for epoch in range(5):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, batch in enumerate(trainloader, 0):
        print(i, end="\r")
        # get the inputs; data is a list of [inputs, labels]
        (channels, features), actions, rewards = batch

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(channels, features)
        # loss = criterion(outputs, labels)
        log_logits = F.log_softmax(outputs, dim=-1)
        log_probs = log_logits.gather(1, actions[:,np.newaxis])
        
        loss = torch.mean(-log_probs * rewards)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 200 == 199:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            torch.save(net.state_dict(), 'models/model_weights.pth')
            running_loss = 0.0

torch.save(net.state_dict(), 'models/model_weights.pth')
print('Finished Training')

In [None]:
a.gather(1, torch.tensor([0,1,0])[:,np.newaxis])

In [14]:
expnet = AgentNet()
expnet.load_state_dict(torch.load("models/model_weights.pth"))
#(channels, features), actions, rewards = enumerate(trainloader,0)
#print(expnet.cnn(trainloader).size())
for i, batch in enumerate(trainloader,0):
    if i%100==0:
        (channels, features), actions, rewards = batch    
        print(expnet.cnn(channels).size())

torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32, 128])


In [4]:
batch_size = 64
data_path = f"../data/"
weight_path = "models/model_weights_pretrained.pth"
num_epoch_per_run = 1

# optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=0.001)
net = AgentNet()
net.load_state_dict(torch.load(weight_path))
optimizer = optim.AdamW(net.parameters(), lr=0.001)

In [11]:
optimizer = optim.AdamW(net.parameters(), lr=0.0001)
batch_size = 128


In [12]:
EVENT_HORIZON = {
    MOVED_LEFT: 1,
    MOVED_RIGHT: 1,
    MOVED_UP: 1,
    MOVED_DOWN: 1,
    WAITED: 1,
    INVALID_ACTION: 1,
    BOMB_DROPPED: 1,
    BOMB_EXPLODED: 1,
    CRATE_DESTROYED: 6,
    COIN_FOUND: 1,
    COIN_COLLECTED: 5,
    KILLED_OPPONENT: 6,
    KILLED_SELF: 5,
    GOT_KILLED: 6,
    OPPONENT_ELIMINATED: 4,
    SURVIVED_ROUND: 100
}


REWARDS = {
    MOVED_LEFT: 5, #1
    MOVED_RIGHT: 5, #1
    MOVED_UP: 5,   #1
    MOVED_DOWN: 5,  #1
    WAITED: 5,  #1
    INVALID_ACTION: -120, #-7
    BOMB_DROPPED: 5,  #1
    BOMB_EXPLODED: 0,
    CRATE_DESTROYED: 120,
    COIN_FOUND: 0,
    COIN_COLLECTED: 120,
    KILLED_OPPONENT: 60,
    KILLED_SELF: -120, #-12
    GOT_KILLED: -120,
    OPPONENT_ELIMINATED: 5,
    SURVIVED_ROUND: 200 #20
}

In [7]:
#os.system("rm -rf ../data")
#os.system("cd ../..; python main.py play --no-gui --agents basic_agent rule_based_agent rule_based_agent rule_based_agent --train 4 --n-rounds 100 --scenario loot-crate")


In [13]:
for run in range(20):
    # os.system("rm -rf ../rule_based_agent/data")
    os.system("rm -rf ../data")
    os.system("cd ../..; python main.py play --no-gui --agents rule_based_agent rule_based_agent rule_based_agent rule_based_agent --train 4 --n-rounds 40 --scenario loot-crate")

    trainset = BombermanDataset(data_path)
    N = len(trainset)
    print(f"{N} actions in the training set:")
    for i in range(6):
        n = np.sum([t == i for t in trainset.actions])
        print(f"{ACTION_MAP_INV[i]}: {n/N*100:.2f}% ({n})")
    # sampler = BatchSampler(WeightedRandomSampler(trainset.weights, len(trainset), replacement=True,), batch_size, False)
    # trainloader = torch.utils.data.DataLoader(trainset, batch_sampler=sampler)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    for epoch in range(num_epoch_per_run):
        running_loss = 0.0
        running = 0
        for i, batch in enumerate(trainloader, 0):
        # for i, batch in enumerate(sampler, 0):
            print(i, end="\r")
            # get the inputs; data is a list of [inputs, labels]
            (channels, features), actions, rewards = batch

            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = net(channels, features)
            # loss = criterion(outputs, labels)
            log_logits = F.log_softmax(outputs, dim=-1)
            log_probs = log_logits.gather(1, actions[:,np.newaxis])
            
            loss = torch.mean(-log_probs * rewards)
            loss.backward()
            optimizer.step()
    
            # print statistics
            running_loss += loss.item()
            running += 1
            if i % 200 == 199:    # print every 2000 mini-batches
                print(f'[{run + 1}, {epoch + 1}, {i + 1:5d}] loss: {running_loss / running:.3f}')
                torch.save(net.state_dict(), weight_path)
                running_loss = 0.0
                running = 0
                #print(log_probs)
                #print(rewards)
                print(i*batch_size/N)
            if i*batch_size/N > 0.2:
                break
        if running > 0:
            print(f'[{run + 1}, {epoch + 1}] loss: {running_loss / running:.3f}')
            
    torch.save(net.state_dict(), weight_path)

print('Finished Training') 

100%|██████████| 40/40 [00:59<00:00,  1.48s/it]


Loaded 525770 actions.
525770 actions in the training set:
UP: 19.63% (103232)
DOWN: 25.37% (133403)
LEFT: 19.51% (102601)
RIGHT: 25.06% (131746)
BOMB: 8.75% (45992)
WAIT: 1.67% (8796)
[1, 1,   200] loss: 92.586
0.0484470395800445
[1, 1,   400] loss: 84.784
0.09713753162029024
[1, 1,   600] loss: 77.504
0.145828023660536
[1, 1,   800] loss: 70.732
0.1945185157007817
[1, 1] loss: 69.262


100%|██████████| 40/40 [00:53<00:00,  1.33s/it]


Loaded 523498 actions.
523498 actions in the training set:
UP: 19.50% (102087)
DOWN: 25.31% (132492)
LEFT: 19.55% (102332)
RIGHT: 25.13% (131548)
BOMB: 8.74% (45752)
WAIT: 1.77% (9287)
[2, 1,   200] loss: 93.142
0.04865730146055954
[2, 1,   400] loss: 81.633
0.09755911197368472
[2, 1,   600] loss: 72.792
0.1464609224868099
[2, 1,   800] loss: 68.250
0.19536273299993506
[2, 1] loss: 68.489


100%|██████████| 40/40 [00:52<00:00,  1.32s/it]


Loaded 517882 actions.
517882 actions in the training set:
UP: 19.57% (101357)
DOWN: 25.10% (130011)
LEFT: 19.58% (101424)
RIGHT: 25.28% (130923)
BOMB: 8.75% (45309)
WAIT: 1.71% (8858)
[3, 1,   200] loss: 89.229
0.04918494946725316
[3, 1,   400] loss: 78.171
0.09861705948459301
[3, 1,   600] loss: 74.151
0.14804916950193286
[3, 1,   800] loss: 68.119
0.19748127951927272
[3, 1] loss: 68.285


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


Loaded 513999 actions.
513999 actions in the training set:
UP: 19.91% (102319)
DOWN: 25.71% (132139)
LEFT: 19.26% (99021)
RIGHT: 24.82% (127598)
BOMB: 8.63% (44382)
WAIT: 1.66% (8540)
[4, 1,   200] loss: 86.765
0.04955651664691955
[4, 1,   400] loss: 75.953
0.09936206101568291
[4, 1,   600] loss: 72.083
0.14916760538444626
[4, 1,   800] loss: 66.310
0.19897314975320965
[4, 1] loss: 67.990


100%|██████████| 40/40 [00:53<00:00,  1.34s/it]


Loaded 515581 actions.
515581 actions in the training set:
UP: 19.65% (101306)
DOWN: 25.11% (129479)
LEFT: 19.78% (101967)
RIGHT: 25.28% (130356)
BOMB: 8.50% (43850)
WAIT: 1.67% (8623)
[5, 1,   200] loss: 87.889
0.04940445827134825
[5, 1,   400] loss: 78.042
0.09905718015210026
[5, 1,   600] loss: 71.587
0.14870990203285225
[5, 1,   800] loss: 65.966
0.19836262391360426
[5, 1] loss: 64.794


100%|██████████| 40/40 [00:53<00:00,  1.33s/it]


Loaded 514361 actions.
514361 actions in the training set:
UP: 19.47% (100146)
DOWN: 24.97% (128450)
LEFT: 19.95% (102617)
RIGHT: 25.59% (131609)
BOMB: 8.60% (44249)
WAIT: 1.42% (7290)
[6, 1,   200] loss: 84.822
0.049521639471110754
[6, 1,   400] loss: 76.082
0.09929213140187533
[6, 1,   600] loss: 70.022
0.14906262333263992
[6, 1,   800] loss: 65.121
0.1988331152634045
[6, 1] loss: 61.227


100%|██████████| 40/40 [00:56<00:00,  1.41s/it]


Loaded 534933 actions.
534933 actions in the training set:
UP: 19.71% (105458)
DOWN: 25.73% (137663)
LEFT: 19.30% (103218)
RIGHT: 25.18% (134684)
BOMB: 8.59% (45932)
WAIT: 1.49% (7978)
[7, 1,   200] loss: 90.046
0.04761717822605822
[7, 1,   400] loss: 79.443
0.09547363875475995
[7, 1,   600] loss: 75.195
0.14333009928346166
[7, 1,   800] loss: 69.407
0.1911865598121634
[7, 1] loss: 64.381


100%|██████████| 40/40 [00:52<00:00,  1.31s/it]


Loaded 509123 actions.
509123 actions in the training set:
UP: 19.54% (99499)
DOWN: 25.63% (130498)
LEFT: 19.40% (98775)
RIGHT: 25.07% (127624)
BOMB: 8.63% (43961)
WAIT: 1.72% (8766)
[8, 1,   200] loss: 85.123
0.050031131966145705
[8, 1,   400] loss: 75.986
0.10031367665573938
[8, 1,   600] loss: 67.984
0.15059622134533304
[8, 1] loss: 65.614


100%|██████████| 40/40 [00:58<00:00,  1.46s/it]


Loaded 526928 actions.
526928 actions in the training set:
UP: 19.54% (102955)
DOWN: 25.01% (131791)
LEFT: 19.69% (103753)
RIGHT: 25.36% (133635)
BOMB: 8.57% (45138)
WAIT: 1.83% (9656)
[9, 1,   200] loss: 82.919
0.04834057024868673
[9, 1,   400] loss: 72.998
0.09692405793580906
[9, 1,   600] loss: 67.673
0.14550754562293142
[9, 1,   800] loss: 62.809
0.19409103331005376
[9, 1] loss: 56.704


100%|██████████| 40/40 [00:55<00:00,  1.38s/it]


Loaded 507710 actions.
507710 actions in the training set:
UP: 19.72% (100115)
DOWN: 25.86% (131309)
LEFT: 19.31% (98055)
RIGHT: 24.96% (126716)
BOMB: 8.63% (43819)
WAIT: 1.52% (7696)
[10, 1,   200] loss: 87.356
0.05017037285064308
[10, 1,   400] loss: 78.948
0.10059285812767131
[10, 1,   600] loss: 70.859
0.15101534340469952
[10, 1] loss: 65.657


100%|██████████| 40/40 [00:52<00:00,  1.31s/it]


Loaded 507399 actions.
507399 actions in the training set:
UP: 19.81% (100505)
DOWN: 25.42% (128960)
LEFT: 19.26% (97724)
RIGHT: 25.28% (128291)
BOMB: 8.63% (43795)
WAIT: 1.60% (8124)
[11, 1,   200] loss: 80.711
0.05020112377044496
[11, 1,   400] loss: 73.888
0.10065451449451024
[11, 1,   600] loss: 66.147
0.1511079052185755
[11, 1] loss: 61.612


100%|██████████| 40/40 [00:57<00:00,  1.44s/it]


Loaded 518319 actions.
518319 actions in the training set:
UP: 19.50% (101056)
DOWN: 25.14% (130285)
LEFT: 19.64% (101802)
RIGHT: 25.36% (131453)
BOMB: 8.62% (44700)
WAIT: 1.74% (9023)
[12, 1,   200] loss: 82.968
0.049143481138063626
[12, 1,   400] loss: 75.722
0.09853391444265018
[12, 1,   600] loss: 69.291
0.14792434774723673
[12, 1,   800] loss: 61.987
0.1973147810518233
[12, 1] loss: 64.201


100%|██████████| 40/40 [00:55<00:00,  1.38s/it]


Loaded 519883 actions.
519883 actions in the training set:
UP: 19.59% (101871)
DOWN: 25.48% (132470)
LEFT: 19.55% (101649)
RIGHT: 25.06% (130291)
BOMB: 8.72% (45322)
WAIT: 1.59% (8280)
[13, 1,   200] loss: 83.972
0.048995639403481166
[13, 1,   400] loss: 75.229
0.09823748805019591
[13, 1,   600] loss: 68.606
0.14747933669691066
[13, 1,   800] loss: 64.323
0.1967211853436254
[13, 1] loss: 62.519


100%|██████████| 40/40 [00:57<00:00,  1.45s/it]


Loaded 542123 actions.
542123 actions in the training set:
UP: 19.83% (107515)
DOWN: 25.19% (136554)
LEFT: 19.41% (105217)
RIGHT: 24.90% (134998)
BOMB: 8.71% (47213)
WAIT: 1.96% (10626)
[14, 1,   200] loss: 79.830
0.046985647168631475
[14, 1,   400] loss: 70.121
0.09420740311700482
[14, 1,   600] loss: 65.491
0.14142915906537815
[14, 1,   800] loss: 62.763
0.1886509150137515
[14, 1] loss: 58.825


100%|██████████| 40/40 [00:54<00:00,  1.37s/it]


Loaded 533802 actions.
533802 actions in the training set:
UP: 19.59% (104557)
DOWN: 24.93% (133059)
LEFT: 19.75% (105440)
RIGHT: 25.54% (136327)
BOMB: 8.61% (45968)
WAIT: 1.58% (8451)
[15, 1,   200] loss: 85.547
0.047718067747966476
[15, 1,   400] loss: 76.726
0.09567592478109861
458

KeyboardInterrupt: 

In [None]:
torch.save(net.state_dict(), "models/model_weights.pth")


#weight_path_ref = "models/model_weights_pretrained_ref.pth"
#torch.save(net.state_dict(),weight_path_ref)

In [72]:
batch_size = 64
data_path = f"../data/"
weight_path = "models/model_weights.pth"
pretrained_weight_path = "models/model_weights_pretrained.pth"
num_epoch_per_run = 1

In [73]:
# pretrained_net = AgentNet()
# pretrained_net.load_state_dict(torch.load(pretrained_weight_path))
# torch.save(pretrained_net.state_dict(), weigth_path)

In [74]:
net = AgentNet()
pretrained_net = AgentNet()
pretrained_net.load_state_dict(torch.load(pretrained_weight_path))
net.cnn.load_state_dict(pretrained_net.cnn.state_dict())
torch.save(net.state_dict(), weight_path)
optimizer = optim.AdamW(net.mlp.parameters(), lr=0.001)

In [75]:
optimizer = optim.AdamW(net.mlp.parameters(), lr=0.001)

EVENT_HORIZON = {
    MOVED_LEFT: 1,
    MOVED_RIGHT: 1,
    MOVED_UP: 1,
    MOVED_DOWN: 1,
    WAITED: 1,
    INVALID_ACTION: 1,
    BOMB_DROPPED: 1,
    BOMB_EXPLODED: 1,
    CRATE_DESTROYED: 5,
    COIN_FOUND: 3,
    COIN_COLLECTED: 8,
    KILLED_OPPONENT: 8,
    KILLED_SELF: 4,
    GOT_KILLED: 3,
    OPPONENT_ELIMINATED: 7,
    SURVIVED_ROUND: 5
}


REWARDS = {
    MOVED_LEFT: -0.5,  #-0.2
    MOVED_RIGHT: -0.5, #-0.2
    MOVED_UP: -0.5,  #-0.2
    MOVED_DOWN: -0.5,  #-0.2    
    WAITED: 4,
    INVALID_ACTION: -2,
    BOMB_DROPPED: 0,
    BOMB_EXPLODED: 1,
    CRATE_DESTROYED: 2,
    COIN_FOUND: 0,
    COIN_COLLECTED: 6,
    KILLED_OPPONENT: 1,
    KILLED_SELF: -10,
    GOT_KILLED: -1,
    OPPONENT_ELIMINATED: 2,
    SURVIVED_ROUND: 9
}

In [76]:
action_probs = np.zeros((6,500))


In [79]:
for run in range(10):
    # os.system("rm -rf ../rule_based_agent/data")
    os.system("rm -rf ../data")
    #os.system("cd ../..; python main.py play --no-gui --agents basic_agent basic_agent basic_agent basic_agent --train 4 --n-rounds 30 --scenario loot-crate")
    os.system("cd ../..; python main.py play --no-gui --agents basic_agent basic_agent basic_agent basic_agent  --train 4 --n-rounds 30 --scenario loot-crate")

    # os.system("cd ../..; python main.py play --no-gui --agents rule_based_agent rule_based_agent rule_based_agent rule_based_agent --train 4 --n-rounds 10 --scenario classic")

    trainset = BombermanDataset(data_path)
    N = len(trainset)
    print(f"{N} actions in the training set:")
    for i in range(6):
        n = np.sum([t == i for t in trainset.actions])
        print(f"{ACTION_MAP_INV[i]}: {n/N*100:.2f}% ({n})")
        action_probs[i,run] = n/N
    
    
    # sampler = BatchSampler(WeightedRandomSampler(trainset.weights, len(trainset), replacement=True,), batch_size, False)
    # trainloader = torch.utils.data.DataLoader(trainset, batch_sampler=sampler)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    for epoch in range(num_epoch_per_run):
        running_loss = 0.0
        running = 0
        for i, batch in enumerate(trainloader, 0):
        # for i, batch in enumerate(sampler, 0):
            print(i, end="\r")
            # get the inputs; data is a list of [inputs, labels]
            (channels, features), actions, rewards = batch

            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = net(channels, features)
            # loss = criterion(outputs, labels)
            log_logits = F.log_softmax(outputs, dim=-1)
            log_probs = log_logits.gather(1, actions[:,np.newaxis])
            """
            if i == 0:
                print(actions, rewards, log_probs)
                print(-log_probs * rewards)
            """
            loss = torch.mean(-log_probs * rewards)
            loss.backward()
            optimizer.step()
    
            # print statistics
            running_loss += loss.item()
            running += 1
            if i % 200 == 199:    # print every 2000 mini-batches
                print(f'[{run + 1}, {epoch + 1}, {i + 1:5d}] loss: {running_loss / running:.3f}')
                torch.save(net.state_dict(), weight_path)
                running_loss = 0.0
                running = 0
        if running > 0:
            print(f'[{run + 1}, {epoch + 1}] loss: {running_loss / running:.3f}')
    torch.save(net.state_dict(), weight_path)

print('Finished Training')

100%|██████████| 30/30 [00:58<00:00,  1.94s/it]


Loaded 298683 actions.
298683 actions in the training set:
UP: 15.22% (45452)
DOWN: 34.32% (102520)
LEFT: 16.40% (48982)
RIGHT: 26.48% (79098)
BOMB: 6.45% (19274)
WAIT: 1.12% (3357)
[1, 1,   200] loss: 20.545
[1, 1,   400] loss: 17.363
[1, 1,   600] loss: 15.881
[1, 1,   800] loss: 15.734
[1, 1,  1000] loss: 14.239
[1, 1,  1200] loss: 13.750
[1, 1,  1400] loss: 13.503
[1, 1,  1600] loss: 13.106
[1, 1,  1800] loss: 12.162
[1, 1,  2000] loss: 12.178
[1, 1,  2200] loss: 11.820
[1, 1,  2400] loss: 11.852
[1, 1,  2600] loss: 11.281
[1, 1,  2800] loss: 11.148
[1, 1,  3000] loss: 10.654
[1, 1,  3200] loss: 11.071
[1, 1,  3400] loss: 10.446
[1, 1,  3600] loss: 10.402
[1, 1,  3800] loss: 9.904
[1, 1,  4000] loss: 9.635
[1, 1,  4200] loss: 9.458
[1, 1,  4400] loss: 9.511
[1, 1,  4600] loss: 9.639
[1, 1] loss: 9.649


100%|██████████| 30/30 [00:58<00:00,  1.94s/it]


Loaded 275019 actions.
275019 actions in the training set:
UP: 14.84% (40818)
DOWN: 35.45% (97483)
LEFT: 16.15% (44411)
RIGHT: 26.31% (72357)
BOMB: 6.03% (16597)
WAIT: 1.22% (3353)
[2, 1,   200] loss: 15.642
[2, 1,   400] loss: 12.522
[2, 1,   600] loss: 10.814
[2, 1,   800] loss: 10.197
[2, 1,  1000] loss: 9.779
[2, 1,  1200] loss: 8.961
[2, 1,  1400] loss: 8.457
[2, 1,  1600] loss: 8.409
[2, 1,  1800] loss: 8.007
[2, 1,  2000] loss: 7.834
[2, 1,  2200] loss: 7.641
[2, 1,  2400] loss: 7.661
[2, 1,  2600] loss: 7.253
[2, 1,  2800] loss: 6.938
[2, 1,  3000] loss: 6.859
[2, 1,  3200] loss: 7.091
[2, 1,  3400] loss: 7.851
[2, 1,  3600] loss: 6.596
[2, 1,  3800] loss: 7.056
[2, 1,  4000] loss: 6.786
[2, 1,  4200] loss: 6.336
[2, 1] loss: 5.841


100%|██████████| 30/30 [01:02<00:00,  2.07s/it]


Loaded 296943 actions.
296943 actions in the training set:
UP: 14.74% (43777)
DOWN: 34.50% (102442)
LEFT: 15.75% (46779)
RIGHT: 27.90% (82858)
BOMB: 5.79% (17200)
WAIT: 1.31% (3887)
[3, 1,   200] loss: 13.117
[3, 1,   400] loss: 10.529
[3, 1,   600] loss: 8.514
[3, 1,   800] loss: 7.821
[3, 1,  1000] loss: 7.589
[3, 1,  1200] loss: 7.152
[3, 1,  1400] loss: 6.555
[3, 1,  1600] loss: 6.769
[3, 1,  1800] loss: 6.426
[3, 1,  2000] loss: 6.444
[3, 1,  2200] loss: 6.472
[3, 1,  2400] loss: 6.159
[3, 1,  2600] loss: 5.700
[3, 1,  2800] loss: 6.195
[3, 1,  3000] loss: 6.146
[3, 1,  3200] loss: 6.141
[3, 1,  3400] loss: 5.716
[3, 1,  3600] loss: 5.535
[3, 1,  3800] loss: 5.579
[3, 1,  4000] loss: 5.579
[3, 1,  4200] loss: 6.154
[3, 1,  4400] loss: 5.372
[3, 1,  4600] loss: 5.222
[3, 1] loss: 6.239


100%|██████████| 30/30 [00:53<00:00,  1.80s/it]


Loaded 288673 actions.
288673 actions in the training set:
UP: 14.39% (41553)
DOWN: 41.43% (119600)
LEFT: 15.19% (43854)
RIGHT: 22.47% (64863)
BOMB: 5.37% (15504)
WAIT: 1.14% (3299)
[4, 1,   200] loss: 10.355
[4, 1,   400] loss: 8.255
[4, 1,   600] loss: 7.762
[4, 1,   800] loss: 6.706
[4, 1,  1000] loss: 6.530
[4, 1,  1200] loss: 6.148
[4, 1,  1400] loss: 5.509
[4, 1,  1600] loss: 5.585
[4, 1,  1800] loss: 5.632
[4, 1,  2000] loss: 5.073
[4, 1,  2200] loss: 4.864
[4, 1,  2400] loss: 5.536
[4, 1,  2600] loss: 5.197
[4, 1,  2800] loss: 4.650
[4, 1,  3000] loss: 5.106
[4, 1,  3200] loss: 5.531
[4, 1,  3400] loss: 4.926
[4, 1,  3600] loss: 4.624
[4, 1,  3800] loss: 4.378
[4, 1,  4000] loss: 4.987
[4, 1,  4200] loss: 4.764
[4, 1,  4400] loss: 4.636
[4, 1] loss: 4.888


100%|██████████| 30/30 [00:52<00:00,  1.76s/it]


Loaded 278948 actions.
278948 actions in the training set:
UP: 13.04% (36378)
DOWN: 39.66% (110624)
LEFT: 14.77% (41196)
RIGHT: 26.79% (74744)
BOMB: 4.77% (13313)
WAIT: 0.97% (2693)
[5, 1,   200] loss: 11.025
[5, 1,   400] loss: 8.949
[5, 1,   600] loss: 7.630
[5, 1,   800] loss: 7.161
[5, 1,  1000] loss: 6.631
[5, 1,  1200] loss: 6.287
[5, 1,  1400] loss: 6.397
[5, 1,  1600] loss: 5.681
[5, 1,  1800] loss: 6.000
[5, 1,  2000] loss: 6.066
[5, 1,  2200] loss: 5.989
[5, 1,  2400] loss: 5.975
[5, 1,  2600] loss: 5.652
[5, 1,  2800] loss: 5.766
[5, 1,  3000] loss: 5.277
[5, 1,  3200] loss: 5.607
[5, 1,  3400] loss: 5.522
[5, 1,  3600] loss: 5.495
[5, 1,  3800] loss: 5.514
[5, 1,  4000] loss: 5.363
[5, 1,  4200] loss: 5.473
[5, 1] loss: 5.809


100%|██████████| 30/30 [00:51<00:00,  1.71s/it]


Loaded 261623 actions.
261623 actions in the training set:
UP: 15.45% (40432)
DOWN: 39.18% (102493)
LEFT: 14.66% (38343)
RIGHT: 23.93% (62594)
BOMB: 5.39% (14107)
WAIT: 1.40% (3654)
[6, 1,   200] loss: 10.541
[6, 1,   400] loss: 8.960
[6, 1,   600] loss: 7.116
[6, 1,   800] loss: 6.318
[6, 1,  1000] loss: 6.269
[6, 1,  1200] loss: 5.886
[6, 1,  1400] loss: 6.083
[6, 1,  1600] loss: 5.624
[6, 1,  1800] loss: 5.623
[6, 1,  2000] loss: 5.538
[6, 1,  2200] loss: 5.588
[6, 1,  2400] loss: 5.382
[6, 1,  2600] loss: 5.600
[6, 1,  2800] loss: 5.469
[6, 1,  3000] loss: 5.517
[6, 1,  3200] loss: 5.241
[6, 1,  3400] loss: 5.374
[6, 1,  3600] loss: 5.240
[6, 1,  3800] loss: 5.201
[6, 1,  4000] loss: 5.039
[6, 1] loss: 5.375


100%|██████████| 30/30 [00:55<00:00,  1.83s/it]


Loaded 266770 actions.
266770 actions in the training set:
UP: 15.18% (40502)
DOWN: 41.63% (111067)
LEFT: 13.93% (37161)
RIGHT: 22.58% (60247)
BOMB: 5.43% (14481)
WAIT: 1.24% (3312)
[7, 1,   200] loss: 10.172
[7, 1,   400] loss: 8.596
[7, 1,   600] loss: 7.097
[7, 1,   800] loss: 6.640
[7, 1,  1000] loss: 6.194
[7, 1,  1200] loss: 6.160
[7, 1,  1400] loss: 5.763
[7, 1,  1600] loss: 5.387
[7, 1,  1800] loss: 5.189
[7, 1,  2000] loss: 5.195
[7, 1,  2200] loss: 5.018
[7, 1,  2400] loss: 5.369
[7, 1,  2600] loss: 4.919
[7, 1,  2800] loss: 5.100
[7, 1,  3000] loss: 5.002
[7, 1,  3200] loss: 5.368
[7, 1,  3400] loss: 4.856
[7, 1,  3600] loss: 5.708
[7, 1,  3800] loss: 5.186
[7, 1,  4000] loss: 4.962
[7, 1] loss: 4.603


100%|██████████| 30/30 [00:52<00:00,  1.74s/it]


Loaded 272686 actions.
272686 actions in the training set:
UP: 14.62% (39861)
DOWN: 41.77% (113900)
LEFT: 11.12% (30333)
RIGHT: 26.36% (71867)
BOMB: 4.95% (13500)
WAIT: 1.18% (3225)
[8, 1,   200] loss: 8.618
[8, 1,   400] loss: 7.104
[8, 1,   600] loss: 6.021
[8, 1,   800] loss: 5.576
[8, 1,  1000] loss: 5.072
[8, 1,  1200] loss: 4.652
[8, 1,  1400] loss: 4.711
[8, 1,  1600] loss: 4.287
[8, 1,  1800] loss: 4.299
[8, 1,  2000] loss: 4.500
[8, 1,  2200] loss: 4.388
[8, 1,  2400] loss: 4.180
[8, 1,  2600] loss: 4.070
2624

KeyboardInterrupt: 

In [39]:
a = np.zeros((1,7))
print(a)
print(a[0])
print(a[0,0])


[[0. 0. 0. 0. 0. 0. 0.]]
[0. 0. 0. 0. 0. 0. 0.]
0.0


In [51]:
a = np.array([1,2,3,4,5,6,7])-1
print(a)
print(a[1:7])

[0 1 2 3 4 5 6]
[1 2 3 4 5 6]



    extensions = np.zeros(4)
    if pos[0]+4 <= 16:             #field, box_map, explosion_map, others_map, coin_map
        if field[pos[0]+4, pos[1]] == 1:
            extensions[0] = 1
        if box_map[pos[0]+4, pos[1]] ==1:
            extensions[0] = 2
        if explosion_map[pos[0]+4, pos[1]] ==1:
            extensions[0] = 3
        if others_map[pos[0]+4, pos[1]] == 1:
            extensions[0] = 4
        if coin_map[pos[0]+4, pos[1]] ==1:
            extensions[0] =5

    if pos[0]-4 >= 0:             #field, box_map, explosion_map, others_map, coin_map
        if field[pos[0]-4, pos[1]] == 1:
            extensions[1] = 1
        if box_map[pos[0]-4, pos[1]] ==1:
            extensions[1] = 2
        if explosion_map[pos[0]+4, pos[1]] ==1:
            extensions[1] = 3
        if others_map[pos[0]-4, pos[1]] == 1:
            extensions[1] = 4
        if coin_map[pos[0]-4, pos[1]] ==1:
            extensions[1] =5

    if pos[1]+4 <= 16:             #field, box_map, explosion_map, others_map, coin_map
        if field[pos[0], pos[1]+4] == 1:
            extensions[2] = 1
        if box_map[pos[0], pos[1]+4] ==1:
            extensions[2] = 2
        if explosion_map[pos[0], pos[1]+4] ==1:
            extensions[2] = 3
        if others_map[pos[0], pos[1]+4] == 1:
            extensions[2] = 4
        if coin_map[pos[0], pos[1]+4] ==1:
            extensions[2] = 5

    if pos[1]-4 >= 0:             #field, box_map, explosion_map, others_map, coin_map
        if field[pos[0], pos[1]-4] == 1:
            extensions[3] = 1
        if box_map[pos[0], pos[1]-4] ==1:
            extensions[3] = 2
        if explosion_map[pos[0], pos[1]-4] ==1:
            extensions[3] = 3
        if others_map[pos[0], pos[1]-4] == 1:
            extensions[3] = 4
        if coin_map[pos[0], pos[1]-4] ==1:
            extensions[3] = 5

