In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../../')

%matplotlib inline
from matplotlib import pyplot as plt
from IPython.display import clear_output

import numpy as np
from pandas import ewma

def plot_stats(data, name):
    plt.grid()
    plt.plot(data, label=name, alpha=0.2)
    plt.plot(ewma(np.array(data), span=10), label='{} ewma@10'.format(name), alpha=0.5)
    plt.plot(ewma(np.array(data), span=100), label='{} ewma@100'.format(name))
    plt.title('{} survivors'.format(name))
    plt.legend()

In [3]:
from torch import nn
import torch.nn.functional as F
import numpy as np


class A3CModel(nn.Module):
    def __init__(self, in_channels, n_actions):
        super(A3CModel, self).__init__()
        
        self.fc1 = nn.Linear(27, 256)
        
        self.logits = nn.Linear(256, n_actions)
        self.state_value = nn.Linear(256, 1)

    def forward(self, obs):
        middle = tuple(ax // 2 for ax in obs[0].shape)

        state = np.stack(
            (obs[0] == obs[0][middle], obs[1], obs[2]),
            axis=-1
        ).flatten().astype(float)
        x = Variable(torch.FloatTensor(state).unsqueeze(0))
        
        x = F.relu(self.fc1(x))

        logits = self.logits(x)
        state_value = self.state_value(x)
        
        return logits, state_value

In [4]:
from azkaban.agent import RandomAgent, A3CAgent, A3CParams
from azkaban.env import TeamsEnv, TeamsEnvConf
from azkaban.optim import SharedAdam

import torch
from torch.autograd import Variable
from torch.optim import Adam
from torch.multiprocessing import Lock

conf = TeamsEnvConf(
    world_shape=(7, 7),
    comm_shape=(0,),
    team_names=[
        'a3c',
        'random'
    ]
)

def model_generator():
    return A3CModel(
        in_channels=3,
        n_actions=conf.action_space.shape()[0]
    )

params = A3CParams()

shared_model = model_generator()
shared_optimizer = SharedAdam(shared_model.parameters(), lr=params.lr)
shared_lock = Lock()

a3c_team = 5
random_team = 5

def env_generator():
    return TeamsEnv(
        teams=[
            tuple(
                A3CAgent(
                    conf=conf,
                    params=params,
                    model=model_generator(),
                    shared_model=shared_model,
                    shared_optimizer=shared_optimizer,
                    trainable=True,
                    lock=shared_lock
                ) for i in range(a3c_team)
            ),
            tuple(
                RandomAgent(
                    conf=conf
                ) for i in range(random_team)
            )
        ],
        conf=conf
    )

In [5]:
from copy import copy

stats_lock = Lock()
stats = []

max_ticks = 1000

def worker():
    env = env_generator()
    
    while True:
        env.reset()
        
        for i in range(max_ticks):
            done = env.step(interrupt=(i == max_ticks - 1))
            
            if done:
                break
        print('!')
        with stats_lock:
            print('?')
            stats.append(tuple(copy(env.members)))

In [10]:
from torch.multiprocessing import Process
import time

n_workers = 1
workers = []

for _ in range(n_workers):
    p = Process(target=worker)
    p.start()
    workers.append(p)

print('Press Ctrl+C to exit')
try:
    time.sleep(1)
    
    clear_output(True)
    
    if len(stats) > 0:
        with stats_lock:
            a3c, random = zip(*stats)

        plt.figure(figsize=(16, 5))
        plt.subplot(1, 2, 1)
        plot_stats(a3c, name='a3c')

        plt.subplot(1, 2, 2)
        plot_stats(random, name='random')
        plt.savefig('results.png')
        plt.show()
except KeyboardInterrupt:
    print('Killing workers...')
    
    for worker in workers:
        worker.terminate()
    
    shared_lock = Lock()

Process Process-3:
Traceback (most recent call last):
  File "/Users/laplab/.pyenv/versions/3.6.1/lib/python3.6/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/Users/laplab/.pyenv/versions/3.6.1/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
TypeError: 'Process' object is not callable


Press Ctrl+C to exit


In [11]:
for worker in workers:
    worker.terminate()
shared_lock = Lock()

In [12]:
stats

[]