In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append(r"../..")


from torch import nn
from torch.nn import functional as F
import torch.optim as optim
import torch
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt

from callbacks import state_to_features, ACTION_MAP, ACTION_MAP_INV
from networks import AgentNet

from events import *

In [2]:
from torch.utils.data import Dataset


class BombermanDataset(Dataset):
    def __init__(self, states_dir, max_game_step=200, min_reward=1):
        self.features = []
        self.actions = []
        self.cum_rewards = []
        self.total_number = 0
        for num, states_file in enumerate(os.listdir(states_dir)):
            print(f"{num}", end="\r")

            if not states_file.endswith('.pickle'):
                continue
            with open(os.path.join(states_dir, states_file), "rb") as f:
                data = pickle.load(f)
            self.total_number += len(data['game_state'])
            last_round = -1
            running_horizons = []
            running_rewards = []
            for i in range(len(data['game_state'])-1,-1,-1):
                game_state = data['game_state'][i]               

                    
                if last_round != game_state['round']:
                    last_round = game_state['round']
                    running_horizons = []
                    running_rewards = []
                if game_state['step'] > max_game_step:
                    continue
                    
                # determine if action should be included in training set
                
                horizons = [EVENT_HORIZON[e] for e in data['events'][i]]
                rewards = [REWARDS[e] for e in data['events'][i]]
                cr = np.where(data['events'] == "CRATE_DESTROYED")         

                running_rewards = [r for j,r in enumerate(running_rewards) if running_horizons[j]>1]
                running_horizons = [h-1 for h in running_horizons if h > 1]
                running_rewards.extend(rewards)
                running_horizons.extend(horizons)

                self.cum_rewards.append(np.sum(running_rewards))

                last_action = None if game_state['step'] == 1 else data['action'][i-1]
                
                features, act_map = state_to_features(game_state, r=4, last_action=last_action)
                action = ACTION_MAP[act_map[data['action'][i]]]
                self.features.append({key: torch.tensor(value, dtype=torch.float) for key, value in features.items()})
                self.actions.append(action)

        self.actions = torch.tensor(self.actions)
        self.cum_rewards = torch.tensor(self.cum_rewards, dtype=torch.float)

        print(f"Loaded {len(self.actions)} actions.")

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.actions[idx], self.cum_rewards[idx]

In [4]:
data_path = f"../data/"
weight_path = "models/model_weights_pre.pth"
num_epoch_per_run = 1

# optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=0.001)
net = AgentNet()

In [5]:
optimizer = optim.AdamW(net.parameters(), lr=0.001)
batch_size = 128

In [6]:
EVENT_HORIZON = {
    MOVED_LEFT: 1,
    MOVED_RIGHT: 1,
    MOVED_UP: 1,
    MOVED_DOWN: 1,
    WAITED: 1,
    INVALID_ACTION: 1,
    BOMB_DROPPED: 1,
    BOMB_EXPLODED: 0,
    CRATE_DESTROYED: 5,
    COIN_FOUND: 0,
    COIN_COLLECTED: 5,
    KILLED_OPPONENT: 8,
    KILLED_SELF: 5,
    GOT_KILLED: 5,
    OPPONENT_ELIMINATED: 0,
    SURVIVED_ROUND: 0,
}

MOVE_REWARD = 1
REWARDS = {
    MOVED_LEFT: MOVE_REWARD,
    MOVED_RIGHT: MOVE_REWARD,
    MOVED_UP: MOVE_REWARD,
    MOVED_DOWN: MOVE_REWARD,
    WAITED: MOVE_REWARD,
    INVALID_ACTION: -1,
    BOMB_DROPPED: -2,
    BOMB_EXPLODED: 0,
    CRATE_DESTROYED: 6,
    COIN_FOUND: 0,
    COIN_COLLECTED: 10,
    KILLED_OPPONENT: 1,
    KILLED_SELF: 0,
    GOT_KILLED: -10,
    OPPONENT_ELIMINATED: 0,
    SURVIVED_ROUND: 0,
}

In [None]:
losses = []
for run in range(20):
    # os.system("rm -rf ../rule_based_agent/data")
    os.system("rm -rf ../data")
    os.system("cd ../..; python main.py play --no-gui --agents rule_based_agent rule_based_agent rule_based_agent rule_based_agent --train 4 --n-rounds 10 --scenario loot-crate")

    trainset = BombermanDataset(data_path)
    N = len(trainset)
    print(f"{N} actions in the training set:")
    for i in range(6):
        n = np.sum(trainset.actions.numpy() == i)
        print(f"{ACTION_MAP_INV[i]}: {n/N*100:.2f}% ({n})")

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    for epoch in range(num_epoch_per_run):
        running_loss = 0.0
        running = 0
        for i, batch in enumerate(trainloader, 0):
        # for i, batch in enumerate(sampler, 0):
            print(i, end="\r")
            # get the inputs; data is a list of [inputs, labels]
            features, actions, rewards = batch

            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = net(features['coin_view'], features['local_view'], features['features'])
            # loss = criterion(outputs, labels)
            log_logits = F.log_softmax(outputs, dim=-1)
            log_probs = log_logits.gather(1, actions[:,np.newaxis])
            
            loss = torch.mean(-log_probs * rewards)
            loss.backward()
            optimizer.step()
    
            # print statistics
            running_loss += loss.item()
            running += 1
            if i % 200 == 199:    # print every 2000 mini-batches
                print(f'[{run + 1}, {epoch + 1}, {i + 1:5d}] loss: {running_loss / running:.3f}')
                torch.save(net.state_dict(), weight_path)
                if i == 199:
                    losses.append(running_loss / running)
                running_loss = 0.0
                running = 0

        if running > 0:
            print(f'[{run + 1}, {epoch + 1}] loss: {running_loss / running:.3f}')
            
    torch.save(net.state_dict(), weight_path)

np.save("models/losses_pre.npy", np.array(losses))
print('Finished Training') 

In [7]:
weight_path = "models/model_weights_fine.pth"
optimizer = optim.AdamW(net.parameters(), lr=0.0001)
batch_size = 256

In [8]:
losses = []
for run in range(20):
    # os.system("rm -rf ../rule_based_agent/data")
    os.system("rm -rf ../data")
    os.system("cd ../..; python main.py play --no-gui --agents rule_based_agent rule_based_agent rule_based_agent rule_based_agent --train 4 --n-rounds 10 --scenario loot-crate")

    trainset = BombermanDataset(data_path)
    N = len(trainset)
    print(f"{N} actions in the training set:")
    for i in range(6):
        n = np.sum(trainset.actions.numpy() == i)
        print(f"{ACTION_MAP_INV[i]}: {n/N*100:.2f}% ({n})")
    # sampler = BatchSampler(WeightedRandomSampler(trainset.weights, len(trainset), replacement=True,), batch_size, False)
    # trainloader = torch.utils.data.DataLoader(trainset, batch_sampler=sampler)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    for epoch in range(num_epoch_per_run):
        running_loss = 0.0
        running = 0
        for i, batch in enumerate(trainloader, 0):
        # for i, batch in enumerate(sampler, 0):
            print(i, end="\r")
            # get the inputs; data is a list of [inputs, labels]
            features, actions, rewards = batch

            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = net(features['coin_view'], features['local_view'], features['features'])
            # loss = criterion(outputs, labels)
            log_logits = F.log_softmax(outputs, dim=-1)
            log_probs = log_logits.gather(1, actions[:,np.newaxis])
            
            loss = torch.mean(-log_probs * rewards)
            loss.backward()
            optimizer.step()
    
            # print statistics
            running_loss += loss.item()
            running += 1
            if i % 200 == 199:    # print every 2000 mini-batches
                print(f'[{run + 1}, {epoch + 1}, {i + 1:5d}] loss: {running_loss / running:.3f}')
                torch.save(net.state_dict(), weight_path)
                if i == 199:
                    losses.append(running_loss / running)
                running_loss = 0.0
                running = 0
        
        if i < 199:
            losses.append(running_loss / running)
        if running > 0:
            print(f'[{run + 1}, {epoch + 1}] loss: {running_loss / running:.3f}')
            
    torch.save(net.state_dict(), weight_path)

np.save("models/losses_fine.npy", np.array(losses))
print('Finished Training') 

100%|██████████| 10/10 [00:04<00:00,  2.03it/s]


Loaded 33887 actions.
33887 actions in the training set:
UP: 19.24% (6519)
DOWN: 25.09% (8501)
LEFT: 19.51% (6611)
RIGHT: 26.07% (8833)
BOMB: 8.36% (2832)
WAIT: 1.74% (591)
[1, 1] loss: 4.133


100%|██████████| 10/10 [00:05<00:00,  1.69it/s]


Loaded 35939 actions.
35939 actions in the training set:
UP: 19.27% (6925)
DOWN: 24.95% (8967)
LEFT: 19.61% (7047)
RIGHT: 25.45% (9148)
BOMB: 8.87% (3189)
WAIT: 1.84% (663)
[2, 1] loss: 4.144


100%|██████████| 10/10 [00:05<00:00,  1.77it/s]


Loaded 34620 actions.
34620 actions in the training set:
UP: 19.38% (6710)
DOWN: 25.17% (8714)
LEFT: 19.99% (6922)
RIGHT: 25.01% (8660)
BOMB: 8.56% (2965)
WAIT: 1.87% (649)
[3, 1] loss: 4.254


100%|██████████| 10/10 [00:06<00:00,  1.66it/s]


Loaded 35016 actions.
35016 actions in the training set:
UP: 19.99% (7001)
DOWN: 25.48% (8923)
LEFT: 19.36% (6780)
RIGHT: 24.64% (8627)
BOMB: 8.72% (3053)
WAIT: 1.80% (632)
[4, 1] loss: 4.149


100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Loaded 36326 actions.
36326 actions in the training set:
UP: 19.10% (6940)
DOWN: 24.49% (8896)
LEFT: 19.96% (7250)
RIGHT: 26.06% (9467)
BOMB: 8.78% (3191)
WAIT: 1.60% (582)
[5, 1] loss: 3.778


100%|██████████| 10/10 [00:05<00:00,  1.70it/s]


Loaded 34910 actions.
34910 actions in the training set:
UP: 19.18% (6696)
DOWN: 25.75% (8989)
LEFT: 19.57% (6833)
RIGHT: 25.29% (8829)
BOMB: 8.75% (3054)
WAIT: 1.46% (509)
[6, 1] loss: 4.124


100%|██████████| 10/10 [00:05<00:00,  1.80it/s]


Loaded 34690 actions.
34690 actions in the training set:
UP: 19.31% (6697)
DOWN: 25.58% (8874)
LEFT: 19.61% (6802)
RIGHT: 24.94% (8653)
BOMB: 8.64% (2996)
WAIT: 1.93% (668)
[7, 1] loss: 4.072


100%|██████████| 10/10 [00:05<00:00,  1.79it/s]


Loaded 35455 actions.
35455 actions in the training set:
UP: 19.06% (6759)
DOWN: 24.96% (8851)
LEFT: 19.58% (6943)
RIGHT: 25.70% (9111)
BOMB: 8.73% (3095)
WAIT: 1.96% (696)
[8, 1] loss: 4.057


100%|██████████| 10/10 [00:05<00:00,  1.69it/s]


Loaded 34956 actions.
34956 actions in the training set:
UP: 19.98% (6983)
DOWN: 26.05% (9106)
LEFT: 19.42% (6787)
RIGHT: 24.33% (8506)
BOMB: 8.65% (3023)
WAIT: 1.58% (551)
[9, 1] loss: 4.148


100%|██████████| 10/10 [00:06<00:00,  1.66it/s]


Loaded 36593 actions.
36593 actions in the training set:
UP: 20.54% (7516)
DOWN: 25.76% (9426)
LEFT: 18.78% (6873)
RIGHT: 24.66% (9024)
BOMB: 8.63% (3158)
WAIT: 1.63% (596)
[10, 1] loss: 3.892


100%|██████████| 10/10 [00:05<00:00,  1.77it/s]


Loaded 34257 actions.
34257 actions in the training set:
UP: 19.76% (6768)
DOWN: 25.17% (8623)
LEFT: 19.53% (6690)
RIGHT: 25.15% (8617)
BOMB: 8.56% (2933)
WAIT: 1.83% (626)
[11, 1] loss: 4.009


100%|██████████| 10/10 [00:06<00:00,  1.65it/s]


Loaded 36185 actions.
36185 actions in the training set:
UP: 19.13% (6922)
DOWN: 25.39% (9188)
LEFT: 20.11% (7278)
RIGHT: 25.41% (9196)
BOMB: 8.41% (3044)
WAIT: 1.54% (557)
[12, 1] loss: 3.980


100%|██████████| 10/10 [00:05<00:00,  1.84it/s]


Loaded 32361 actions.
32361 actions in the training set:
UP: 18.93% (6126)
DOWN: 25.28% (8180)
LEFT: 19.78% (6401)
RIGHT: 25.68% (8311)
BOMB: 8.58% (2775)
WAIT: 1.76% (568)
[13, 1] loss: 3.929


100%|██████████| 10/10 [00:06<00:00,  1.64it/s]


Loaded 37711 actions.
37711 actions in the training set:
UP: 19.39% (7312)
DOWN: 24.31% (9166)
LEFT: 20.33% (7666)
RIGHT: 26.05% (9823)
BOMB: 8.29% (3128)
WAIT: 1.63% (616)
[14, 1] loss: 3.888


100%|██████████| 10/10 [00:05<00:00,  1.74it/s]


Loaded 34481 actions.
34481 actions in the training set:
UP: 19.40% (6688)
DOWN: 25.46% (8778)
LEFT: 19.27% (6646)
RIGHT: 25.59% (8822)
BOMB: 8.40% (2896)
WAIT: 1.89% (651)
[15, 1] loss: 4.096


100%|██████████| 10/10 [00:05<00:00,  1.75it/s]


Loaded 35046 actions.
35046 actions in the training set:
UP: 18.36% (6435)
DOWN: 23.40% (8201)
LEFT: 21.21% (7434)
RIGHT: 27.18% (9524)
BOMB: 8.55% (2996)
WAIT: 1.30% (456)
[16, 1] loss: 3.940


100%|██████████| 10/10 [00:05<00:00,  1.71it/s]


Loaded 35091 actions.
35091 actions in the training set:
UP: 19.88% (6976)
DOWN: 25.25% (8862)
LEFT: 19.31% (6775)
RIGHT: 24.78% (8697)
BOMB: 8.75% (3069)
WAIT: 2.03% (712)
[17, 1] loss: 3.867


100%|██████████| 10/10 [00:05<00:00,  1.73it/s]


Loaded 34784 actions.
34784 actions in the training set:
UP: 19.59% (6813)
DOWN: 25.01% (8699)
LEFT: 18.95% (6591)
RIGHT: 25.89% (9006)
BOMB: 8.74% (3040)
WAIT: 1.83% (635)
[18, 1] loss: 3.959


100%|██████████| 10/10 [00:06<00:00,  1.64it/s]


Loaded 36455 actions.
36455 actions in the training set:
UP: 20.26% (7384)
DOWN: 24.32% (8867)
LEFT: 19.55% (7127)
RIGHT: 25.56% (9319)
BOMB: 8.93% (3255)
WAIT: 1.38% (503)
[19, 1] loss: 3.950


100%|██████████| 10/10 [00:05<00:00,  1.79it/s]


Loaded 35063 actions.
35063 actions in the training set:
UP: 20.03% (7024)
DOWN: 25.28% (8863)
LEFT: 19.49% (6835)
RIGHT: 24.32% (8528)
BOMB: 8.91% (3125)
WAIT: 1.96% (688)
[20, 1] loss: 3.756
Finished Training
