In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import network

import gc
import copy
import random

import numpy as np
import axelrod as axl
from pprint import pprint
from itertools import permutations
from collections import namedtuple, deque

from network.replay import ReplayMemory

In [3]:
Config = {}
GAME_LEN = 20
C = axl.Action.C
D = axl.Action.D
GAME = axl.Game(r=4, s=0, t=5, p=1)

def Match(players, turns=GAME_LEN, reset=True):
    return axl.Match(players, turns=turns, reset=reset, game=GAME)

In [4]:
players = (axl.Alternator(), axl.Random())
game = Match(players)
game.play()

[(C, C),
 (D, D),
 (C, C),
 (D, D),
 (C, D),
 (D, D),
 (C, C),
 (D, D),
 (C, C),
 (D, C),
 (C, C),
 (D, D),
 (C, D),
 (D, C),
 (C, C),
 (D, C),
 (C, D),
 (D, C),
 (C, D),
 (D, D)]

In [5]:
# possible to change the way this class behaves to redefine input structure
class State():
    def __init__(self, depth):
        self.depth = depth
        self.reset()
        
    def reset(self):
        self.state = [deque([0 for _ in range(self.depth)], maxlen=self.depth) for _ in range(2)]
    
    def __repr__(self):
        return str(s.state).replace("),", "),\n")
    
    def values(self):
        return np.array(self.state, ndmin=3)
    
    def push(self, *args):
        play, coplay = map(self.encode, args)
        self.state[0].append(play)
        self.state[1].append(coplay)
        return self.values()
    
    @staticmethod
    def encode(play):
        if play == axl.Action.C:
            return 1
        else:
            return -1
    
    
s = State(20)
print(s.values())

s.push(C, D)
print(s)

[[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]]
[deque([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], maxlen=20),
 deque([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1], maxlen=20)]


In [6]:
def extract(game, memory, depth=20):
    """
    extract transitions,
    game = axl.Match object, with a finished game,
    memory = ReplayMemory
    """
    actions = game.result
    rewards = game.scores()
    state = State(depth)
    
    s = state.values()
    iterator = iter(zip(actions, rewards))
    while True:
        try:
            a_, r_ = next(iterator)
            s_ = state.push(*a_)
            
            memory.push(s, a_[0], s_, r_[0])
            s, a, r = (s_, a_, r_)
            
        except StopIteration:
            break

In [7]:
memory = ReplayMemory(1000)
print(len(memory))
extract(game, memory, GAME_LEN)
print(len(memory))

0
20


In [8]:
memory.sample(1)

[Transition(state=array([[[ 0,  0,  0,  0,  0,  0,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,
           1, -1,  1, -1],
         [ 0,  0,  0,  0,  0,  0,  1, -1,  1, -1, -1, -1,  1, -1,  1,  1,
           1, -1, -1,  1]]]), action=C, next_state=array([[[ 0,  0,  0,  0,  0,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1,
          -1,  1, -1,  1],
         [ 0,  0,  0,  0,  0,  1, -1,  1, -1, -1, -1,  1, -1,  1,  1,  1,
          -1, -1,  1,  1]]]), reward=4)]

In [9]:
def collect_exp(players, memory):
    old = len(memory)
    for pair in players:
        game = Match(pair, turns=GAME_LEN)
        game.play()
        extract(game, memory, GAME_LEN)
    new = len(memory)
    print(f"Collected {new-old} experience.")

players = permutations([axl.TitForTat(), axl.TitForTat(), axl.Random(), axl.Alternator()], 2)
collect_exp(players, memory)

Collected 240 experience.


In [10]:
class NNplayer(axl.Player):
    """
    """
    
    # These are various properties for the strategy
    name = 'NNplayer'
    classifier = {
        'memory_depth': 1,
        'stochastic': False,
        'inspects_source': False,
        'manipulates_source': False,
        'manipulates_state': False
    }
    
    decision = (axl.Action.C, axl.Action.D)
    
    def __init__(self, network, memory, greedy=0.2, gamma=0.999):
        super().__init__()
        
        self.network = network
        self.memory  = memory
        self.state   = State(GAME_LEN)
        
        self.greedy  = greedy
        self.gamma   = gamma
        
    def reset(self):
        self.state.reset()
        
    def strategy(self, opponent):
        """Make decision"""
        
        # make random choice to explore
        if random.random() < self.greedy:
            return random.choice(self.decision)
        
        # or query the network to exploit
        else:
            d = self.network.query(self.state.values())[0]
            return self.decision[d]
    
    # overwrite update_history to update self state
    def update_history(self, *args):
        self.history.append(*args)
        self.update_state(*args)
        
    def update_state(self, play, coplay):
        """update current game state & record transition into replay memory"""
        s  = self.state.values()
        s_ = self.state.push(play, coplay)
        r  = axl.interaction_utils.compute_scores([(play, coplay)])[0][0]
        self.memory.push(s, play, s_, r)
        
    def train(self, epoch, param):
        for _ in range(epoch):
            self.network.learn(self.memory, param, self.gamma)
        self.network.update_target()

In [33]:
from copy import deepcopy
from network import NeuralNetwork
from collections import namedtuple, deque

Transition = namedtuple('Transition', 
                        ('state', 'action', 'next_state', 'reward'))

class DQN():
    
    def __init__(self, layers):
        
        # define networks
        self.policy_net = NeuralNetwork(layers)
        self.target_net = deepcopy(self.policy_net)
        
        self.loss = None
    
    def query(self, state):
        """make decision from given state"""
        #return self.policy_net(state, mode='classification')
        d = self.policy_net(state, mode='rgr')
        print(d)
        return np.argmax(d, axis=1)
    
    def update_target(self):
        self.target_net = deepcopy(self.policy_net)
        
        
        
    def learn(self, memory, param, gamma):
        
        
        length = len(memory)
        batch_size = param['batch']
        sections = length // batch_size
        
        param['epoch'] += 1
        param['mode'] = 'train'
        self.policy_net.set_loss_func('mse')
        
        # get batches
        ts = Transition(*zip(*memory.sample(length)))
        ss  = np.vstack(ts.state)
        ss_ = np.vstack(ts.next_state)
        ats = np.array([[True, False] if a==C else [False, True] for a in ts.action])
        rs  = np.array(ts.reward, ndmin=2).T

#         print(ss)
#         print(ss_)
#         print(ats)
#         print(rs)
        
        # split into batches
        ss, ss_, ats, rs = map(lambda x: np.array_split(x, sections), (ss, ss_, ats, rs))
        
        # train
        for s, s_, at, r in zip(ss, ss_, ats, rs):
            
            # value of current state
            Q_values = self.policy_net(s, param=param, mode='rg') * at
            
            # value of next state
            Q_values_ = np.max(self.target_net(s_, mode='rg'), axis=1, keepdims=True)
            
            # expected value of current state
            E_values = gamma*Q_values_ + r
            
            # feedback
            loss, _ = self.policy_net.loss_fn(E_values, Q_values)
            loss = loss * at  # relocate loss to action taken
            self.policy_net.backprop(loss, param)
            
            # track training loss
            if not self.loss:
                self.loss = np.mean(np.max(np.abs(loss),axis=1))
            else:
                self.loss = 0.9*self.loss + 0.1*np.mean(np.max(np.abs(loss),axis=1))

In [81]:
dqn = DQN([
#                     network.Flatten_layer(),
#                     network.Maxout_layer(GAME_LEN*2, 100),
#                     network.BatchNorm_layer(100),
#                     network.Maxout_layer(100, 60),
#                     network.Maxout_layer(60, 40),
#                     network.BatchNorm_layer(40),
#                     network.Maxout_layer(40, 20),
#                     network.Maxout_layer(20, 2),
    
                    network.Flatten_layer(),
                    network.Linear_layer(GAME_LEN*2, 100, bias=0.01),
                    network.Activation_layer('ReLU'),
                    network.Linear_layer(100, 200),
                    #network.BatchNorm_layer(200),
                    network.Activation_layer('ReLU'),
                    network.Linear_layer(200, 40),
                    network.Activation_layer('ReLU'),
                    network.Linear_layer(40, 2),
                    ])
p1 = NNplayer(dqn, ReplayMemory(1000), gamma=0.95)
del dqn
gc.collect()

1096

In [85]:
param = {"lr": 5e-6, 'batch': 8, "momentum": 0.9, "mode": "train", "eps": 1e-9, "beta":(0.9, 0.999), 
         "epoch": 0, 'optimizer': 'adam', 't': 1, 'clip': 1.0, 'decay': 0.0}

In [86]:
p1.greedy=0.4
for i in range(50):
    players = (p1, axl.TitForTat())
    game = Match(players)
    game.play()

[[1.79695166 1.49711955]]
[[2.26127406 2.43677847]]
[[2.2219914 2.4267914]]
[[3.07908386 4.39382773]]
[[2.90605139 3.09221148]]
[[3.54704793 3.27885892]]
[[3.70140505 4.28501626]]
[[3.54973471 3.94766085]]
[[5.19766886 3.13086387]]
[[3.88159838 3.90593843]]
[[4.77738787 4.40220985]]
[[3.38175231 6.11037265]]
[[5.37844521 6.10622328]]
[[1.2245196 1.5523293]]
[[2.50779952 2.23504199]]
[[2.57313295 2.99546457]]
[[2.35265798 2.83555153]]
[[3.47230317 4.23844543]]
[[3.84458933 4.81023885]]
[[2.68733123 5.07078031]]
[[3.28752929 3.57101975]]
[[3.23827996 6.46705102]]
[[4.4688553  3.34514714]]
[[2.73639084 4.83369059]]
[[5.8462996  6.68023944]]
[[0.96192236 0.84744182]]
[[1.95431987 2.2860138 ]]
[[2.31712225 2.66559291]]
[[2.84645402 3.02161588]]
[[2.83266201 2.94775148]]
[[2.81185003 3.99815105]]
[[2.34292844 3.11898296]]
[[5.0005134  8.03489519]]
[[5.6262069  4.74471337]]
[[3.85770472 4.24161239]]
[[3.94807999 6.16253637]]
[[4.62282078 4.85873255]]
[[5.72435738 7.24415528]]
[[1.28432275 1.5

In [87]:
for _ in range(100):
    p1.train(100, param)
    print(p1.network.loss)

1.0404306954428502
0.704123739116273
0.5591478736705868
0.9043273667970992
0.7737533060198768
1.1261093649939178
0.7569202533279128
0.8382366858352143
0.6426984392028872
0.6242767641192025
0.41997848561819634
0.3919831514972496
0.27305803608738427
0.24228702622328974
0.19661657080766295
0.19483550267755004
0.10697761167834988
0.16383323577834275
0.11992759116773812
0.11064939575631141


KeyboardInterrupt: 

In [88]:
p1.greedy=0.0
players = (p1, axl.TitForTat())
game = Match(players)
actions = game.play()
scores = game.scores()
scores

[[73.86892394 66.39919678]]
[[83.17215829 69.12339757]]
[[89.81793779 76.41083988]]
[[92.74789261 77.10714116]]
[[100.46323416  83.22528053]]
[[101.88848377  83.65256574]]
[[108.92331135  89.53133053]]
[[107.28415626  87.64398574]]
[[111.61498452  91.64488501]]
[[112.81914344  95.51472776]]
[[102.69321404  87.0306835 ]]
[[102.28900745  85.64276806]]
[[101.24422622  85.10027345]]
[[98.85096368 82.34685671]]
[[98.18741719 83.3916201 ]]
[[93.61082008 77.84463345]]
[[85.99185844 70.43336263]]
[[78.56692975 64.4232145 ]]
[[73.22571813 59.15781452]]
[[76.53038789 61.85024125]]


[(4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4),
 (4, 4)]

In [66]:
p1.network.policy_net.print_parameters()

--0--
Printing flatten layer:
{'shape': (1, 2, 20), 'type': 'flatten'}
--1--
Printing linear layer:
{'bias': 0.01,
 'input': array([[0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1.]]),
 'input_nodes': 40,
 'm1': array([[ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [-0.00020977, -0.00095223, -0.0009181 , ..., -0.00029098,
        -0.00020432, -0.00078971],
       [ 0.00219831,  0.00349761, -0.00507549, ...,  0.00112909,
         0.00121084, -0.00195639],
       ...,
       [-0.00540627, -0.03318098, -0.01051075, ...,  0.00178815,
        -0.00703176, -0.01861019],
       [ 0.00212579,  0.01869414, -0.01959067, ..., -0.0105582 ,
        -0.01002088,  0.00296873],
       [-0.03286595, -0.06181669, -0.06969266, ..., -0.04258941,
        -0.02395079, -0.05628469]]),
 'm2': array([[0.00000000e+00, 0.00000000e+

In [58]:
p1.target_net.print_parameters()

--0--
Printing linear layer:
{'bias': 0,
 'input': array([[-1, -1, -1, -1, -1, -1, -1, -1,  3,  3,  3,  3,  3,  3,  3,  3,
         3,  3,  3,  3],
       [-1, -1,  0,  1,  3,  1,  0,  1,  0,  5,  0,  5,  3,  1,  0,  5,
         0,  5,  0,  1],
       [-1, -1, -1, -1, -1, -1,  3,  5,  0,  5,  0,  5,  0,  5,  0,  5,
         0,  5,  0,  5],
       [-1, -1, -1, -1, -1, -1, -1, -1,  0,  1,  1,  1,  1,  1,  5,  3,
         3,  0,  5,  0],
       [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  0,
        50, 30, 50, 30],
       [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1,  3],
       [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1,  3,  3,  3],
       [-1, -1, -1, -1, -1, -1,  0,  1,  3,  1,  0,  1,  0,  5,  0,  5,
         3,  1,  0,  5],
       [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  3,  3,  3,
         3,  3,  3,  3],
       [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
    

In [59]:
# TODO: is the last turn really learned?

In [50]:
p1.greedy=0
players = (p1, axl.Alternator())
game = Match(players, turns=GAME_LEN, reset=False)
actions = game.play()
scores = game.scores()
scores

[[139.26483157  80.73600473]]
[[139.26483157  80.73600473]]
[[131.32996287  76.44282417]]
[[125.04856293  72.07261553]]
[[121.46701121  70.68969251]]
[[118.51196138  68.50633181]]
[[112.41266457  65.51117411]]
[[107.3983868   62.03491765]]
[[97.65106035 56.60011832]]
[[107.03143318  62.10103404]]
[[96.55901029 55.61153098]]
[[92.38216847 53.59353024]]
[[103.09273132  59.65961577]]
[[91.30520482 53.02587687]]
[[102.8436791   59.18046654]]
[[87.48615107 51.51756402]]
[[103.7733925   59.26855639]]
[[87.85346375 51.58897998]]
[[104.3980876   59.53855703]]
[[91.72162096 53.20190031]]


[(4, 4),
 (0, 5),
 (4, 4),
 (0, 5),
 (4, 4),
 (0, 5),
 (4, 4),
 (0, 5),
 (4, 4),
 (0, 5),
 (4, 4),
 (0, 5),
 (4, 4),
 (0, 5),
 (4, 4),
 (0, 5),
 (4, 4),
 (0, 5),
 (4, 4),
 (0, 5)]