# Multi Agent Reinforcement Learning

In [None]:
from IPython.lib.display import YouTubeVideo

### Markov Games

In [None]:
YouTubeVideo('Y9qq4Jqnwls')

### Approaches to Multi Agent Reinforcement Learning (MARL)

In [None]:
YouTubeVideo('uKV9AJykin0')

### Cooperation, Competition, Mixed Environments

In [None]:
YouTubeVideo('vx6PIH5_oFg')

# Case Study: Physical Deception

A paper on MARL, called **[Multi Agent Actor Critic for Mixed Cooperative Competitive environments](https://papers.nips.cc/paper/7217-multi-agent-actor-critic-for-mixed-cooperative-competitive-environments.pdf)** by *OpenAI* has used a team of 5 neural networks, called [OpenAI Five](https://openai.com/projects/five/) to defeat amateur [Dota 2](https://www.dota2.com/play/) players.

We'll implement a part of this paper to train an agent to solve the **Physical Deception** problem.

In [None]:
YouTubeVideo('nRKrQamUISs')

In [None]:
YouTubeVideo('Ks9-TeCg3Fs')

In [None]:
YouTubeVideo('4hFAhtLJR5U')

#### Objective of the Environment
`Blue` dots are the **good agents**, and the `red` dot is an **adversary**. All of the agents' goals are to go near the `green` target. The blue agents know which one is green, but the red agent is color-blind and does not know which target is green/black! The optimal solution is for the red agent to chase one of the blue agent, and for the blue agents to split up and go toward each of the target.

#### Running on Multi-core CPU the workspace
Use of GPU wouldn't impact the training time for this program, Instead, Multicore environments would be a better choice to increase the training speed.

### Import the Packages

In [None]:
import os
import time
import random
import numpy as np
from copy import copy
from math import sqrt
from collections import deque

import torch
import torch.nn as nn
from torch.optim import Adam
import torch.nn.functional as F

import matplotlib.pyplot as plt
from matplotlib import animation
%matplotlib inline

from utils.ParallelEnvironments import MultiAgentParallelEnv

from IPython.display import display as Display
from IPython.display import HTML
from pyvirtualdisplay import Display as display
display = display(visible=0, size=(1400, 900))
display.start()

is_ipython = 'inline' in plt.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = 'cpu'

In [None]:
MULTI_AGENT_ENVIRONMENT = 'simple_adversary.py'  # https://github.com/openai/multiagent-particle-envs
MODEL_FILE              = './models/maddpg-adversary-episode-{:4d}.pt'
N_ENVIRONMENTS          = 4
RANDOM_SEED             = 1

In [None]:
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.set_num_threads(N_ENVIRONMENTS)

### Explore the Multi Agent Environment

In [None]:
envs = MultiAgentParallelEnv(MULTI_AGENT_ENVIRONMENT, N_ENVIRONMENTS)

print('Number of agents: {}'.format(envs.n_agents))
print('Types of agents: {}'.format(envs.agent_types))
print('Observation space of each agent: {}'.format(envs.states))
print('Action space of each agent: {}'.format(envs.actions))

envs.close()

### The Actor-Critic Network

In [None]:
def hidden_init(layer):
    fan_in = layer.weight.data.size()[0]
    lim = 1.0 / np.sqrt(fan_in)
    return -lim, lim


class Actor(nn.Module):
    '''Actor (Policy) Model'''
    def __init__(self, state_size, action_size, fc1_units=16, fc2_units=8):
        '''Initialize parameters and build model.
        Params
        ======
            state_size (int): Dimension of each state
            action_size (int): Dimension of each action
            fc1_units (int): Number of nodes in first hidden layer
            fc2_units (int): Number of nodes in second hidden layer
        '''
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_size, fc1_units)
        self.fc2 = nn.Linear(fc1_units, fc2_units)
        self.fc3 = nn.Linear(fc2_units, action_size)
        # self.input_norm = nn.BatchNorm1d(input_dim)
        # self.input_norm.weight.data.fill_(1)
        # self.input_norm.bias.data.fill_(0)        
        # self.reset_parameters()

    def reset_parameters(self):
        self.fc1.weight.data.uniform_(*hidden_init(self.fc1))
        self.fc2.weight.data.uniform_(*hidden_init(self.fc2))
        self.fc3.weight.data.uniform_(-1e-3, 1e-3)

    def forward(self, state):
        '''
        Build an actor (policy) network that maps states -> actions
        Return a vector of the force
        ''' 
        x = F.relu(self.fc1(state)) # F.leaky_relu
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        norm = torch.norm(x)
        # x is a 2D vector (a force that is applied to the agent)
        # we bound the norm of the vector to be between 0 and 10
        return 10.0 * (torch.tanh(norm)) * x / norm if norm > 0 else 10 * x


class Critic(nn.Module):
    '''Critic (Value) Model'''
    def __init__(self, state_size, action_size, fc1_units=32, fc2_units=16):
        '''Initialize parameters and build model.
        Params
        ======
            state_size (int): Dimension of each state for individual agents
            action_size (int): Dimension of ALL actions for ALL the agents
            fc1_units (int): Number of nodes in the first hidden layer
            fc2_units (int): Number of nodes in the second hidden layer
        '''
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_size + action_size, fc1_units)
        self.fc2 = nn.Linear(fc1_units, fc2_units)
        self.fc3 = nn.Linear(fc2_units, 1)
        # self.reset_parameters()

    def reset_parameters(self):
        self.fc1.weight.data.uniform_(*hidden_init(self.fc1))
        self.fc2.weight.data.uniform_(*hidden_init(self.fc2))
        self.fc3.weight.data.uniform_(-1e-3, 1e-3)

    def forward(self, state_action):
        '''Build a critic (value) network that maps (state, action) pairs -> Q-values'''
        x = F.relu(self.fc1(state_action)) # F.leaky_relu
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

### Exploration: Adding random noise to the continuous actions

In [None]:
class OUNoise:
    def __init__(self, action_size, scale=0.1, mu=0, theta=0.15, sigma=0.2):
        self.action_size = action_size
        self.scale = scale
        self.mu = mu
        self.theta = theta
        self.sigma = sigma
        self.state = np.ones(self.action_size) * self.mu
        self.reset()

    def reset(self):
        self.state = np.ones(self.action_size) * self.mu

    def sample(self):
        x = self.state
        dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(len(x))
        self.state = x + dx
        return torch.tensor(self.state * self.scale).float()

### Experience Replay Buffer

In [None]:
class ReplayBuffer:
    def __init__(self, size):
        self.size = size
        self.deque = deque(maxlen=self.size)

    def push(self,transition):
        'push into the buffer'
        for item in list(map(list, zip(*transition))):
            self.deque.append(item)

    def sample(self, batchsize):
        'sample from the buffer'
        samples = random.sample(self.deque, batchsize)
        return list(map(list, zip(*samples)))

    def __len__(self):
        return len(self.deque)

### Hyperparameters for Model

In [None]:
BATCH_SIZE = 1000                # minibatch size
GAMMA = 0.95                     # discount factor
TAU = 0.02                       # for soft update of target parameters

LR_ACTOR = 1e-2                  # learning rate of the actor 
LR_CRITIC = LR_ACTOR             # learning rate of the critic [1, 10] times of LR_ACTOR
WEIGHT_DECAY = 1e-5              # L2 weight decay [0, 0.0001]

TRAIN_EVERY = 2 * N_ENVIRONMENTS # number of episodes before update the model

### The DDPG Agent

In [None]:
class DDPGAgent:
    def __init__(self, state_size, action_size, n_agents):
        '''
        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            n_agents (int): number of agents in the environment
        '''
        super(DDPGAgent, self).__init__()
        self.actor = Actor(state_size, action_size).to(device)
        self.target_actor = Actor(state_size, action_size).to(device)
        self.critic = Critic(state_size, action_size * n_agents).to(device)
        self.target_critic = Critic(state_size, action_size * n_agents).to(device)
        
        self.noise = OUNoise(action_size, scale=1.0)

        # initialize targets same as original networks
        self.hard_update(self.target_actor, self.actor)
        self.hard_update(self.target_critic, self.critic)

        self.actor_optimizer = Adam(self.actor.parameters(), lr=LR_ACTOR)
        self.critic_optimizer = Adam(self.critic.parameters(), lr=LR_CRITIC, weight_decay=WEIGHT_DECAY)

    def act(self, obs, noise=0.0):
        obs = obs.to(device)
        action = self.actor(obs) + noise * self.noise.sample()
        return action

    def target_act(self, obs, noise=0.0):
        obs = obs.to(device)
        action = self.target_actor(obs) + noise * self.noise.sample()
        return action
    
    def hard_update(self, target_model, source_model):
        for target_param, source_param in zip(target_model.parameters(), source_model.parameters()):
            target_param.data.copy_(source_param.data)

### Multi Agent DDPG

In [None]:
def transpose_to_tensor(input_list):
    make_tensor = lambda x: torch.tensor(x, dtype=torch.float)
    return list(map(make_tensor, zip(*input_list)))

class MADDPG:
    def __init__(self, state_size, action_size, n_agents):
        super(MADDPG, self).__init__()
        self.maddpg_agent = [DDPGAgent(state_size, action_size, n_agents) for _ in range(n_agents)]        
        self.iter = 0

    def get_actors(self):
        '''get actors of all the agents in the MADDPG object'''
        return [ddpg_agent.actor for ddpg_agent in self.maddpg_agent]

    def get_target_actors(self):
        '''get target_actors of all the agents in the MADDPG object'''
        return [ddpg_agent.target_actor for ddpg_agent in self.maddpg_agent]

    def act(self, obs_all_agents, noise=0.0):
        '''get actions from all agents in the MADDPG object'''
        actions = [agent.act(obs, noise) for agent, obs in zip(self.maddpg_agent, obs_all_agents)]
        return actions

    def target_act(self, obs_all_agents, noise=0.0):
        '''get target network actions from all the agents in the MADDPG object'''
        target_actions = [ddpg_agent.target_act(obs, noise) for ddpg_agent, obs in zip(self.maddpg_agent, obs_all_agents)]
        return target_actions

    def update(self, samples, agent_number):
        '''
        update the critic and actor of the agent referred by 'agent_number'
        '''
        
        # need to transpose each element of the samples
        # to flip obs[parallel_agent][agent_number] to
        # obs[agent_number][parallel_agent]
        obs, obs_full, action, reward, next_obs, next_obs_full, done = map(transpose_to_tensor, samples)

        obs_full = torch.stack(obs_full)
        next_obs_full = torch.stack(next_obs_full)
        
        agent = self.maddpg_agent[agent_number]
        
        # -------------- update critic network -------------
        agent.critic_optimizer.zero_grad()

        # critic loss = batch mean of (y- Q(s,a) from target network)^2
        # y = reward of this timestep + discount * Q(s', a') from target network
        target_actions = self.target_act(next_obs)
        target_actions = torch.cat(target_actions, dim=1)
        
        target_critic_input = torch.cat((next_obs_full.t(), target_actions), dim=1).to(device)
        
        with torch.no_grad():
            q_next = agent.target_critic(target_critic_input)
        
        y = reward[agent_number].view(-1, 1) + GAMMA * q_next * (1 - done[agent_number].view(-1, 1))
        action = torch.cat(action, dim=1)
        critic_input = torch.cat((obs_full.t(), action), dim=1).to(device)
        q = agent.critic(critic_input)

        huber_loss = torch.nn.SmoothL1Loss()
        critic_loss = huber_loss(q, y.detach())
        critic_loss.backward()
        # torch.nn.utils.clip_grad_norm_(agent.critic.parameters(), 0.5)
        agent.critic_optimizer.step()

        # ------------- update actor network using policy gradient -------------
        agent.actor_optimizer.zero_grad()
        # make input to agent
        # detach the other agents to save computation
        # saves some time for computing derivative
        q_input = [self.maddpg_agent[i].actor(ob) if i == agent_number
                   else self.maddpg_agent[i].actor(ob).detach()
                   for i, ob in enumerate(obs)]
                
        q_input = torch.cat(q_input, dim=1)
        
        # combine all the actions and observations for input to critic
        # many of the obs are redundant, and obs[1] contains all useful information already
        q_input = torch.cat((obs_full.t(), q_input), dim=1)
        
        # get the policy gradient
        actor_loss = - agent.critic(q_input).mean()
        actor_loss.backward()
        # torch.nn.utils.clip_grad_norm_(agent.actor.parameters(),0.5)
        agent.actor_optimizer.step()

    def update_targets(self):
        '''soft update targets'''
        self.iter += 1
        for ddpg_agent in self.maddpg_agent:
            self.soft_update(ddpg_agent.target_actor, ddpg_agent.actor)
            self.soft_update(ddpg_agent.target_critic, ddpg_agent.critic)
    
    def soft_update(self, target, source):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - TAU) + param.data * TAU)

### The Training Loop

In [None]:
def train(n_episodes=1000, planning_horizon=80, print_every=100, save_every=10):
    
    # amplitude of OU noise, this slowly decreases to 0
    noise = 2
    noise_reduction = 0.9999
    
    envs = MultiAgentParallelEnv(MULTI_AGENT_ENVIRONMENT, N_ENVIRONMENTS)
    agent_types = envs.agent_types
    n_agents = envs.n_agents
    state_size = envs.states[0].shape[0]
    action_size = envs.actions[0].shape[0]
    
    buffer = ReplayBuffer(5000 * planning_horizon) # keep 5000 episodes worth of replay
    maddpg = MADDPG(state_size, action_size, n_agents) # initialize policy and critic
    agent_reward = [[] for _ in range(n_agents)]

    for episode in range(0, n_episodes, N_ENVIRONMENTS):
        reward_this_episode = np.zeros((N_ENVIRONMENTS, 3))
        obs, obs_full = envs.reset()
        for episode_t in range(planning_horizon):
            actions = maddpg.act(transpose_to_tensor(obs), noise=noise)
            noise *= noise_reduction
            actions_array = torch.stack(actions).detach().numpy()

            # transpose the list of list
            # flip the first two indices
            # input to step requires the first index to correspond to number of parallel agents
            actions_for_env = np.rollaxis(actions_array, 1)
            
            # step forward one frame
            next_obs, next_obs_full, rewards, dones, info = envs.step(actions_for_env)
            
            # add data to buffer
            transition = (obs, obs_full, actions_for_env, rewards, next_obs, next_obs_full, dones)
            buffer.push(transition)
            reward_this_episode += rewards
            obs, obs_full = next_obs, next_obs_full

        for env_idx in range(N_ENVIRONMENTS):
            for agent_idx in range(n_agents):
                agent_reward[agent_idx].append(reward_this_episode[env_idx, agent_idx])
        
        # update once after every TRAIN_EVERY episodes
        if len(buffer) > BATCH_SIZE and episode % TRAIN_EVERY < N_ENVIRONMENTS:
            for agent_idx in range(n_agents):
                samples = buffer.sample(BATCH_SIZE)
                maddpg.update(samples, agent_idx)
            maddpg.update_targets() # soft update the target network towards the actual networks

        # display progress after every episode        
        print('\r' + ' ' * 120, end='')
        print('\rEpisode: {:04d}/{:04d}\tAvg. Rewards for {} Agents: {}'.format(episode, n_episodes, n_agents, [np.round(np.mean(agent_reward[i]), 4) for i in range(n_agents)]), end='')
        if (episode + 1) % print_every < N_ENVIRONMENTS:
            print('\r' + ' ' * 120, end='')
            print('\rEpisode: {:04d}/{:04d}\tAvg. Rewards for {} Agents: {}'.format(episode, n_episodes, n_agents, [np.round(np.mean(agent_reward[i]), 4) for i in range(n_agents)]))
        if episode % save_every < N_ENVIRONMENTS or episode == n_episodes - N_ENVIRONMENTS:
            torch.save([
                {
                    'actor_params' : maddpg.maddpg_agent[i].actor.state_dict(),
                    'actor_optim_params': maddpg.maddpg_agent[i].actor_optimizer.state_dict(),
                    'critic_params' : maddpg.maddpg_agent[i].critic.state_dict(),
                    'critic_optim_params' : maddpg.maddpg_agent[i].critic_optimizer.state_dict()
                } for i in range(n_agents)
            ], MODEL_FILE.format(episode))

    envs.close()
    return agent_reward, agent_types

In [None]:
agent_rewards, agent_types = train()

### Plot the Scores

In [None]:
fig, axs = plt.subplots(3, 1, sharex=True, figsize=(10, 10))
axs = axs.flatten()
for i, (title, scores) in enumerate(zip(agent_types, agent_rewards)):
    axs[i].plot(np.arange(1, len(scores)+1), scores)
    axs[i].set_title(title)
    axs[i].set_ylabel('Score')
    axs[i].set_xlabel('' if i != len(agent_types) - 1 else 'Episode #')
plt.show()

### Watch the Smart Agents in Action

In [None]:
LOAD_EPISODE = 996
LOAD_N_ENVIRONMENTS = 4
SHOW_CONTROL = True

In [None]:
def animate_frames(frames):
    'function to animate a list of frames'
    def display_animation(anim):
        plt.close(anim._fig)
        return HTML(anim.to_jshtml())
    plt.axis('off')
    cmap = None if len(frames[0].shape) == 3 else 'Greys' # color option for plotting, use Greys for greyscale
    patch = plt.imshow(frames[0], cmap=cmap)  
    fanim = animation.FuncAnimation(plt.gcf(), lambda x: patch.set_data(frames[x]), frames = len(frames), interval=30)
    Display(display_animation(fanim))

envs = MultiAgentParallelEnv(MULTI_AGENT_ENVIRONMENT, LOAD_N_ENVIRONMENTS) # seed=int(time.time())
n_agents = envs.n_agents
maddpg = MADDPG(envs.states[0].shape[0], envs.actions[0].shape[0], n_agents)  

# load the weights from file
map_location = (lambda storage, loc: storage.cuda()) if torch.cuda.is_available() else 'cpu'
checkpoint = torch.load(MODEL_FILE.format(LOAD_EPISODE), map_location=map_location)        
for i in range(n_agents):
    maddpg.maddpg_agent[i].actor.load_state_dict(checkpoint[i]['actor_params'])
    maddpg.maddpg_agent[i].actor.eval()

for e in range(1):
    frames = []
    obs, _ = envs.reset()
    if SHOW_CONTROL:
        frames.append(envs.render(mode='rgb_array'))
    else:
        img = plt.imshow(envs.render(mode='rgb_array'))
    for t in range(200):
        actions = maddpg.act(transpose_to_tensor(obs), noise=0)
        actions_array = torch.stack(actions).detach().numpy()
        actions_for_env = np.rollaxis(actions_array, 1)
        if SHOW_CONTROL:
            frames.append(envs.render(mode='rgb_array'))
        else:
            img.set_data(envs.render(mode='rgb_array')) 
            plt.axis('off')
            display.display(plt.gcf())
            display.clear_output(wait=True)
        obs, _, _, dones, _ = envs.step(actions_for_env)
        if dones.any():
            break
    if SHOW_CONTROL:
        animate_frames(frames)

# Case Study: AlphaZero

The following materials are derived from the original papers, [alphago zero](https://deepmind.com/documents/119/agz_unformatted_nature.pdf), and [alphazero](https://arxiv.org/abs/1712.01815) by the DeepMind.

In [None]:
YouTubeVideo('Zzc1XJ1aJ-4')

## Zero-Sum Game

In [None]:
YouTubeVideo('uPw1dHVqdXQ')

## Monte Carlo Tree Search

### Random Sampling

In [None]:
YouTubeVideo('wn2B3j_Qz6E')

### Expansion and Back-propagation

In [None]:
YouTubeVideo('H34Wtk1iNDY')

## AlphaZero

### Guided Tree Search

In [None]:
YouTubeVideo('LinuRy47xbw')

### Self-Play Training

In [None]:
YouTubeVideo('wl1qfPXqRuQ')

---

# TicTacToe using AlphaZero

We'll now implement an advanced version of `TicTacToe` using AlphaZero.

For code walkthroughs, you can watch these [notebook walkthrough](https://www.youtube.com/watch?v=uUFuBscf98I) and [python classes walkthrough](https://www.youtube.com/watch?v=hKnBQvtJ_zQ) videos.

### Implementation of the Monte Carlo Tree Search (MCTS) Algorithm

In [None]:
c = 1.0

# transformations
t0= lambda x: x
t1= lambda x: x[:,::-1].copy()
t2= lambda x: x[::-1,:].copy()
t3= lambda x: x[::-1,::-1].copy()
t4= lambda x: x.T
t5= lambda x: x[:,::-1].T.copy()
t6= lambda x: x[::-1,:].T.copy()
t7= lambda x: x[::-1,::-1].T.copy()

tlist=[t0, t1,t2,t3,t4,t5,t6,t7]
tlist_half=[t0,t1,t2,t3]

def flip(x, dim):
    indices = [slice(None)] * x.dim()
    indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
    return x[tuple(indices)]


t0inv= lambda x: x
t1inv= lambda x: flip(x,1)
t2inv= lambda x: flip(x,0)
t3inv= lambda x: flip(flip(x,0),1)
t4inv= lambda x: x.t()
t5inv= lambda x: flip(x,0).t()
t6inv= lambda x: flip(x,1).t()
t7inv= lambda x: flip(flip(x,0),1).t()

tinvlist = [t0inv, t1inv, t2inv, t3inv, t4inv, t5inv, t6inv, t7inv]
tinvlist_half=[t0inv, t1inv, t2inv, t3inv]

transformation_list = list(zip(tlist, tinvlist))
transformation_list_half = list(zip(tlist_half, tinvlist_half))


def process_policy(policy, game):

    # for square board, add rotations as well
    if game.size[0]==game.size[1]:
        t, tinv = random.choice(transformation_list)

    # otherwise only add reflections
    else:
        t, tinv = random.choice(transformation_list_half)
     
    frame=torch.tensor(t(game.state*game.player), dtype=torch.float, device=device)
    input=frame.unsqueeze(0).unsqueeze(0)
    prob, v = policy(input)
    mask = torch.tensor(game.available_mask(), dtype=torch.bool)
    
    # we add a negative sign because when deciding next move,
    # the current player is the previous player making the move
    return game.available_moves(), tinv(prob)[mask].view(-1), v.squeeze().squeeze()

class MCTSNode:
    def __init__(self, game, mother=None, prob=torch.tensor(0., dtype=torch.float)):
        self.game = game
          
        # child nodes
        self.child = {}
        # numbers for determining which actions to take next
        self.U = 0

        # V from neural net output
        # it's a torch.tensor object
        # has require_grad enabled
        self.prob = prob
        # the predicted expectation from neural net
        self.nn_v = torch.tensor(0., dtype=torch.float)
        
        # visit count
        self.N = 0

        # expected V from MCTS
        self.V = 0

        # keeps track of the guaranteed outcome
        # initialized to None
        # this is for speeding the tree-search up
        # but stopping exploration when the outcome is certain
        # and there is a known perfect play
        self.outcome = self.game.score


        # if game is won/loss/draw
        if self.game.score is not None:
            self.V = self.game.score*self.game.player
            self.U = 0 if self.game.score is 0 else self.V*float('inf')

        # link to previous node
        self.mother = mother

    def create_child(self, actions, probs):
        # create a dictionary of children
        games = [ copy(self.game) for a in actions ]

        for action, game in zip(actions, games):
            game.move(action)

        child = { tuple(a): MCTSNode(g, self, p) for a,g,p in zip(actions, games, probs) }
        self.child = child
        
    def explore(self, policy):

        if self.game.score is not None:
            raise ValueError("game has ended with score {0:d}".format(self.game.score))

        current = self

        
        # explore children of the node
        # to speed things up 
        while current.child and current.outcome is None:

            child = current.child
            max_U = max(c.U for c in child.values())
            # print("current max_U ", max_U) 
            actions = [ a for a,c in child.items() if c.U == max_U ]
            if len(actions) == 0:
                print("error zero length ", max_U)
                print(current.game.state)
                      
            action = random.choice(actions)            

            if max_U == -float("inf"):
                current.U = float("inf")
                current.V = 1.0
                break
            
            elif max_U == float("inf"):
                current.U = -float("inf")
                current.V = -1.0
                break
                
            current = child[action]
        
        # if node hasn't been expanded
        if not current.child and current.outcome is None:
            # policy outputs results from the perspective of the next player
            # thus extra - sign is needed
            next_actions, probs, v = process_policy(policy, current.game)
            current.nn_v = -v
            current.create_child(next_actions, probs)
            current.V = -float(v)

        
        current.N += 1

        # now update U and back-prop
        while current.mother:
            mother = current.mother
            mother.N += 1
            # beteen mother and child, the player is switched, extra - sign
            mother.V += (-current.V - mother.V)/mother.N

            #update U for all sibling nodes
            for sibling in mother.child.values():
                if sibling.U is not float("inf") and sibling.U is not -float("inf"):
                    sibling.U = sibling.V + c*float(sibling.prob)* sqrt(mother.N)/(1+sibling.N)

            current = current.mother


               
    def next(self, temperature=1.0):

        if self.game.score is not None:
            raise ValueError('game has ended with score {0:d}'.format(self.game.score))

        if not self.child:
            print(self.game.state)
            raise ValueError('no children found and game hasn\'t ended')
        
        child=self.child

        
        # if there are winning moves, just output those
        max_U = max(c.U for c in child.values())

        if max_U == float("inf"):
            prob = torch.tensor([ 1.0 if c.U == float("inf") else 0 for c in child.values()], device=device)
            
        else:
            # divide things by maxN for numerical stability
            maxN = max(node.N for node in child.values())+1
            prob = torch.tensor([ (node.N/maxN)**(1/temperature) for node in child.values() ], device=device)

        # normalize the probability
        if torch.sum(prob) > 0:
            prob /= torch.sum(prob)
            
        # if sum is zero, just make things random
        else:
            prob = torch.tensor(1.0/len(child), device=device).repeat(len(child))

        nn_prob = torch.stack([ node.prob for node in child.values() ]).to(device)

        nextstate = random.choices(list(child.values()), weights=prob)[0]
        
        # V was for the previous player making a move
        # to convert to the current player we add - sign
        return nextstate, (-self.V, -self.nn_v, prob, nn_prob)

    def detach_mother(self):
        del self.mother
        self.mother = None

### Utility Methods and Classes to Set Up a Game

In [None]:
# output the index of when v has a continuous string of i
# get_runs([0,0,1,1,1,0,0],1) gives [2],[5],[3]
def get_runs(v, i):
     bounded = np.hstack(([0], (v==i).astype(int), [0]))
     difs = np.diff(bounded)
     starts, = np.where(difs > 0)
     ends, = np.where(difs < 0)
     return starts, ends, ends-starts

# see if vector contains N of certain number in a row
def in_a_row(v, N, i):
     if len(v) < N:
          return False
     else:
          _, _, total = get_runs(v,i)
          return np.any(total >= N)
        
def get_lines(matrix, loc):

     i,j=loc
     flat = matrix.reshape(-1,*matrix.shape[2:])
    
     w = matrix.shape[0]
     h = matrix.shape[1]
     def flat_pos(pos):
          return pos[0]*h+pos[1]

     pos = flat_pos((i,j))

     # index for flipping matrix across different axis
     ic = w-1-i
     jc = h-1-j

     # top left
     tl = (i-j,0) if i>j else (0, j-i)
     tl = flat_pos(tl)

     # bottom left
     bl = (w-1-(ic-j),0) if ic>j else (w-1, j-ic)
     bl = flat_pos(bl)

     # top right
     tr = (i-jc,h-1) if i>jc else (0, h-1-(jc-i))
     tr = flat_pos(tr)

     # bottom right
     br = (w-1-(ic-jc),h-1) if ic>jc else (w-1, h-1-(jc-ic))
     br = flat_pos(br)

     hor = matrix[:,j]
     ver = matrix[i,:]
     diag_right = np.concatenate([flat[tl:pos:h+1],flat[pos:br+1:h+1]])
     diag_left = np.concatenate([flat[tr:pos:h-1],flat[pos:bl+1:h-1]])

     return hor, ver, diag_right, diag_left


class ConnectN:

     def __init__(self, size, N, pie_rule=False):
          self.size = size
          self.w, self.h = size
          self.N = N

          # make sure game is well defined
          if self.w<0 or self.h<0 or self.N<2 or \
             (self.N > self.w and self.N > self.h):
               raise ValueError('Game cannot initialize with a {0:d}x{1:d} grid, and winning condition {2:d} in a row'.format(self.w, self.h, self.N))

          
          self.score = None
          self.state=np.zeros(size, dtype=np.float)
          self.player=1
          self.last_move=None
          self.n_moves=0
          self.pie_rule=pie_rule
          self.switched_side=False

     # fast deepcopy
     def __copy__(self):
          cls = self.__class__
          new_game = cls.__new__(cls)
          new_game.__dict__.update(self.__dict__)

          new_game.N = self.N
          new_game.pie_rule = self.pie_rule
          new_game.state = self.state.copy()
          new_game.switched_side = self.switched_side
          new_game.n_moves = self.n_moves
          new_game.last_move = self.last_move
          new_game.player = self.player
          new_game.score = self.score
          return new_game
    
     # check victory condition
     # fast version
     def get_score(self):

          # game cannot end beca
          if self.n_moves<2*self.N-1:
               return None

          i,j = self.last_move
          hor, ver, diag_right, diag_left = get_lines(self.state, (i,j))

          # loop over each possibility
          for line in [ver, hor, diag_right, diag_left]:
               if in_a_row(line, self.N, self.player):
                    return self.player
                    
          # no more moves
          if np.all(self.state!=0):
               return 0

          return None

     # for rendering
     # output a list of location for the winning line
     def get_winning_loc(self):
        
          if self.n_moves<2*self.N-1:
               return []

          
          loc = self.last_move
          hor, ver, diag_right, diag_left = get_lines(self.state, loc)
          ind = np.indices(self.state.shape)
          ind = np.moveaxis(ind, 0, -1)
          hor_ind, ver_ind, diag_right_ind, diag_left_ind = get_lines(ind, loc)
          # loop over each possibility
        
          pieces = [hor, ver, diag_right, diag_left]
          indices = [hor_ind, ver_ind, diag_right_ind, diag_left_ind]
        
          #winning_loc = np.full(self.state.shape, False, dtype=bool)
        
          for line, index in zip(pieces, indices):
               starts, ends, runs = get_runs(line, self.player)

               # get the start and end location
               winning = (runs >= self.N)
               print(winning)
               if not np.any(winning):
                    continue
            
               starts_ind = starts[winning][0]
               ends_ind = ends[winning][0]
               indices = index[starts_ind:ends_ind]
               #winning_loc[indices[:,0], indices[:,1]] = True
               return indices
            
          return []
    
    
     def move(self, loc):
          i,j=loc
          success = False
          if self.w>i>=0 and self.h>j>=0:
               if self.state[i,j]==0:

                    # make a move
                    self.state[i,j]=self.player

                    # if pie rule is enabled
                    if self.pie_rule:
                         if self.n_moves==1:
                              self.state[tuple(self.last_move)]=-self.player
                              self.switched_side=False
                    
                         elif self.n_moves==0:
                              # pie rule, make first move 0.5
                              # this is to let the neural net know
                              self.state[i,j]=self.player/2.0
                              self.switched_side=False
                         
                    success = True

               # switching side
               elif self.pie_rule and self.state[i,j] == -self.player/2.0:

                    # make a move
                    self.state[i,j]=self.player
                    self.switched_side=True

                    success = True

                         
               

          if success:
               self.n_moves += 1
               self.last_move = tuple((i,j))
               self.score = self.get_score()

               # if game is not over, switch player
               if self.score is None:
                    self.player *= -1
               
               return True

          return False
    
    
     def available_moves(self):
          indices = np.moveaxis(np.indices(self.state.shape), 0, -1)
          return indices[np.abs(self.state) != 1]

     def available_mask(self):
          return (np.abs(self.state) != 1).astype(np.uint8)

### The Game Play Class

In [None]:
class Play:
    
    def __init__(self, game, player1=None, player2=None, name='game'):
        self.original_game=game
        self.game=copy(game)
        self.player1=player1
        self.player2=player2
        self.player=self.game.player
        self.end=False
        self.play()

    def reset(self):
        self.game=copy(self.original_game)
        self.click_cid=None
        self.end=False
        
    def play(self, name='Game'):
        
        self.reset()
        
        if self.game.w * self.game.h <25:
            figsize=(self.game.w/1.6, self.game.h/1.6)
        else:
            figsize=(self.game.w/2.1, self.game.h/2.1)
        
        self.fig=plt.figure(name, figsize=figsize)
        if self.game.w * self.game.h <25:
            self.fig.subplots_adjust(.2,.2,1,1)
        else:
            self.fig.subplots_adjust(.1,.1,1,1)
            
        self.fig.show()
        w,h=self.game.size
        self.ax=self.fig.gca()
        self.ax.grid()
        # remove hovering coordinate tooltips
        self.ax.format_coord = lambda x, y: ''
        self.ax.set_xlim([-.5,w-.5])
        self.ax.set_ylim([-.5,h-.5])
        self.ax.set_xticks(np.arange(0, w, 1))
        self.ax.set_yticks(np.arange(0, h, 1))
        self.ax.set_aspect('equal')
    
        for loc in ['top', 'right', 'bottom', 'left']:
            self.ax.spines[loc].set_visible(False)


        # fully AI game
        if self.player1 is not None and self.player2 is not None:
            self.anim = animation.FuncAnimation(self.fig, self.draw_move, frames=self.move_generator, interval=500, repeat=False)
            return
        
        # at least one human
        if self.player1 is not None:
            # first move from AI first
            succeed = False
            while not succeed:
                loc = self.player1(self.game)
                succeed = self.game.move(loc)

            self.draw_move(loc)
            
        self.click_cid=self.fig.canvas.mpl_connect('button_press_event', self.click)

            
    def move_generator(self):
        score = None
        # game not concluded yet
        while score is None:
            self.player = self.game.player
            if self.game.player == 1:
                loc = self.player1(self.game)
            else:
                loc = self.player2(self.game)
                
            success = self.game.move(loc)

            # see if game is done
            if success:
                score=self.game.score
                yield loc
                
        
    def draw_move(self, move=None):
        if self.end:
            return
        
        i,j=self.game.last_move if move is None else move
        c='salmon' if self.player==1 else 'lightskyblue'
        self.ax.scatter(i,j,s=500,marker='o',zorder=3, c=c)
        score = self.game.score
        self.draw_winner(score)
        self.fig.canvas.draw()


    def draw_winner(self, score):
        if score is None:
            return
        
        if score == -1 or score == 1:
            locs = self.game.get_winning_loc()
            c='darkred' if score==1 else 'darkblue'
            self.ax.scatter(locs[:,0],locs[:,1], s=300, marker='*',c=c,zorder=4)

        # try to disconnect if game is over
        if hasattr(self, 'click_cid'):
            self.fig.canvas.mpl_disconnect(self.click_cid)

        self.end=True
        
    
    def click(self,event):
        
        loc=(int(round(event.xdata)), int(round(event.ydata)))
        self.player = self.game.player
        succeed=self.game.move(loc)

        if succeed:
            self.draw_move()

        else:
            return
        
        if self.player1 is not None or self.player2 is not None:

            succeed = False
            self.player = self.game.player
            while not succeed:
                if self.game.player == 1:
                    loc = self.player1(self.game)
                else:
                    loc = self.player2(self.game)
                succeed = self.game.move(loc)
               
            self.draw_move()

### Setup a game

In [None]:
game_setting = {'size':(3,3), 'N':3}

game = ConnectN(**game_setting)

### Display the Game State, Current Player and Score after a Move

In [None]:
game.move((0,1))
print(game.state)
print(game.player)
print(game.score)

### Display the Game State after a Sequence of Moves

In [None]:
game.move((0,0)) # player -1 move
game.move((1,1)) # player +1 move
game.move((1,0)) # player -1 move
game.move((2,1)) # player +1 move

print(game.state)
print(game.player)
print(game.score)

### Play a Game Interactively

In [None]:
%matplotlib notebook

gameplay=Play(ConnectN(**game_setting),  player1=None, player2=None)

### Initialize an Agent to Play the Game
We need to define a policy for tic-tac-toe, that takes the game state as input, and outputs a policy and a critic

## Tentative Exercise
Code up your own policy for training

In [None]:
class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        # solution
        self.conv = nn.Conv2d(1, 16, kernel_size=2, stride=1, bias=False)
        self.size = 2*2*16
        self.fc = nn.Linear(self.size,32)

        # layers for the policy
        self.fc_action1 = nn.Linear(32, 16)
        self.fc_action2 = nn.Linear(16, 9)
        
        # layers for the critic
        self.fc_value1 = nn.Linear(32, 8)
        self.fc_value2 = nn.Linear(8, 1)
        self.tanh_value = nn.Tanh()
        
    def forward(self, x):
        x = x.to(device)
        
        # solution
        y = F.relu(self.conv(x))
        y = y.view(-1, self.size)
        y = F.relu(self.fc(y))
        
        
        # the action head
        a = F.relu(self.fc_action1(y))
        a = self.fc_action2(a)
        # availability of moves
        avail = (torch.abs(x.squeeze())!=1).float()
        avail = avail.reshape(-1, 9)
        
        # locations where actions are not possible, we set the prob to zero
        maxa = torch.max(a)
        # subtract off max for numerical stability (avoids blowing up at infinity)
        exp = avail*torch.exp(a-maxa)
        prob = exp/torch.sum(exp)
        
        
        # the value head
        value = F.relu(self.fc_value1(y))
        value = self.tanh_value(self.fc_value2(value))
        return prob.view(3,3), value

In [None]:
policy = Policy()

### Define a Player that Uses MCTS and the Expert Policy + Critic to Play a Game
Here we introduced a new parameter
```
T = temperature
```
This tells us how to choose the next move based on the **MCTS** results

$$p_a = \frac{N_a^{\frac{1}{T}}}{\sum_a N_a^{\frac{1}{T}}}$$

As $T\rightarrow0$, we choose action with largest $N_a$. 

In [None]:
def Policy_Player_MCTS(game):
    mytree = MCTSNode(copy(game))
    for _ in range(50):
        mytree.explore(policy)
   
    mytreenext, (v, nn_v, p, nn_p) = mytree.next(temperature=0.1)
    
    return mytreenext.game.last_move

def Random_Player(game):
    return random.choice(game.available_moves())

In [None]:
game = ConnectN(**game_setting)

print(game.state)

Policy_Player_MCTS(game);

### Play a Game against the Policy

In [None]:
%matplotlib notebook

gameplay=Play(ConnectN(**game_setting), player1=None, player2=Policy_Player_MCTS)

## Train the Agent

Initialize our **AlphaZero** agent and optimizer

In [None]:
game=ConnectN(**game_setting)
policy = Policy()
optimizer = Adam(policy.parameters(), lr=0.01, weight_decay=1e-4)

### Tenative exercise
code up the alphazero loss function, defined to be
$$L = \sum_t \left\{ \left(v^{(t)}_\theta - z\right)^2  - \sum_a p^{(t)}_a \log \pi_\theta(a|s_t) \right\} + \textrm{constant}$$ 
I added a constant term $\sum_t \sum_a p^{(t)}\log p^{(t)}$ so that when $v_\theta^{(t)} = z$ and $p^{(t)}_a = \pi_\theta(a|s_t)$, $L=0$, this way we can have some metric of progress.

In [None]:
episodes = 400
print_every = 50
outcomes = []
losses = []

for e in range(episodes):

    mytree = MCTSNode(ConnectN(**game_setting))
    vterm = []
    logterm = []
    
    while mytree.outcome is None:
        for _ in range(50):
            mytree.explore(policy)

        current_player = mytree.game.player
        mytree, (v, nn_v, p, nn_p) = mytree.next()        
        mytree.detach_mother()
        
        # ------------- solution --------------
        # compute prob* log pi 
        loglist = torch.log(nn_p)*p
        
        # constant term to make sure if policy result = MCTS result, loss = 0
        constant = torch.where(p>0, p*torch.log(p),torch.tensor(0.))
        logterm.append(-torch.sum(loglist-constant))
        
        vterm.append(nn_v*current_player)
        # -------------------------------------
        
    # we compute the "policy_loss" for computing gradient
    outcome = mytree.outcome
    outcomes.append(outcome)
    
    # ------------- solution --------------
    loss = torch.sum( (torch.stack(vterm)-outcome)**2 + torch.stack(logterm) )
    # -------------------------------------
    
    optimizer.zero_grad()
    loss.backward()
    losses.append(float(loss))
    optimizer.step()
    
    print('\r' + ' ' * 120, end='')
    print('\rTraining Loop:: {:04d}/{:04d}    Avg. Loss: {:3.2f}    Recent Outcomes: {}'.format(e + 1, episodes, np.mean(losses[-20:]), outcomes[-10:]), end='')
    if (e + 1) % print_every == 0:
        print('\r' + ' ' * 120, end='')
        print('\rTraining Loop:: {:04d}/{:04d}    Avg. Loss: {:3.2f}    Recent Outcomes: {}'.format(e + 1, episodes, np.mean(losses[-20:]), outcomes[-10:]))
        
    del loss

### Plot the Loss

In [None]:
%matplotlib inline
plt.plot(losses)
plt.show()

### Play a Game against the Trained AlphaZero Agent

#### As First Player

In [None]:
%matplotlib notebook

gameplay=Play(ConnectN(**game_setting), player1=None, player2=Policy_Player_MCTS)

#### As Second Player

In [None]:
%matplotlib notebook

gameplay=Play(ConnectN(**game_setting), player2=None, player1=Policy_Player_MCTS)

---

# Advanced TicTacToe with AlphaZero

We'll now extend the previous idea to an advanced version of `TicTacToe` using AlphaZero.

For code walkthrough, you can watch this [video](https://www.youtube.com/watch?v=MOIk_BbCjRw).

In [None]:
MODEL_FILE = './models/mcts-tictactoe-6-6-4.pt'

### Initialize a game

In [None]:
game_setting = {'size':(6,6), 'N':4, 'pie_rule':True}

game = ConnectN(**game_setting)

In [None]:
%matplotlib notebook

gameplay=Play(ConnectN(**game_setting), player1=None, player2=None)

### Define the Policy

See if you can train it under `1000` games and with only `1000` steps of exploration in each move.

In [None]:
class Policy(nn.Module):

    def __init__(self, game):
        super(Policy, self).__init__()

        # input = 6x6 board
        # convert to 5x5x8
        self.conv1 = nn.Conv2d(1, 16, kernel_size=2, stride=1, bias=False)
        # 5x5x16 to 3x3x32
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, bias=False)

        self.size=3*3*32
        
        # the part for actions
        self.fc_action1 = nn.Linear(self.size, self.size//4)
        self.fc_action2 = nn.Linear(self.size//4, 36)
        
        # the part for the value function
        self.fc_value1 = nn.Linear(self.size, self.size//6)
        self.fc_value2 = nn.Linear(self.size//6, 1)
        self.tanh_value = nn.Tanh()
        
    def forward(self, x):
        x = x.to(device)
        
        y = F.leaky_relu(self.conv1(x))
        y = F.leaky_relu(self.conv2(y))
        y = y.view(-1, self.size)
        
        # action head
        a = self.fc_action2(F.leaky_relu(self.fc_action1(y)))
        
        avail = (torch.abs(x.squeeze())!=1).float()
        avail = avail.reshape(-1, 36)
        maxa = torch.max(a)
        exp = avail*torch.exp(a-maxa)
        prob = exp/torch.sum(exp)
        
        # value head
        value = self.tanh_value(self.fc_value2(F.leaky_relu( self.fc_value1(y) )))
        return prob.view(6,6), value

In [None]:
game_setting = {'size':(6,6), 'N':4}

game = ConnectN(**game_setting)

policy = Policy(game)

### Define a MCTS Player for Play

In [None]:
def Policy_Player_MCTS(game):
    mytree = MCTSNode(copy(game))
    for _ in range(1000):
        mytree.explore(policy)
       
    mytreenext, (v, nn_v, p, nn_p) = mytree.next(temperature=0.1)
    
    return mytreenext.game.last_move

def Random_Player(game):
    return random.choice(game.available_moves())

### Play a Game against a Random Policy

In [None]:
%matplotlib notebook

gameplay=Play(ConnectN(**game_setting), player1=Policy_Player_MCTS, player2=None)

### Training

Note: training is **VERY VERY** slow!!

In [None]:
game=ConnectN(**game_setting)
policy = Policy(game)
optimizer = Adam(policy.parameters(), lr=0.01, weight_decay=1e-5)

In [None]:
episodes = 2000
print_every= 10
outcomes = []
policy_loss = []
Nmax = 1000
save_every = 500

for e in range(episodes):

    mytree = MCTSNode(game)
    logterm = []
    vterm = []
    
    while mytree.outcome is None:
        for _ in range(Nmax):
            mytree.explore(policy)
            if mytree.N >= Nmax:
                break
            
        current_player = mytree.game.player
        mytree, (v, nn_v, p, nn_p) = mytree.next()
        mytree.detach_mother()
        
        loglist = torch.log(nn_p)*p
        constant = torch.where(p>0, p*torch.log(p),torch.tensor(0.))
        logterm.append(-torch.sum(loglist-constant))

        vterm.append(nn_v*current_player)
        
    # we compute the "policy_loss" for computing gradient
    outcome = mytree.outcome
    outcomes.append(outcome)
    
    loss = torch.sum( (torch.stack(vterm)-outcome)**2 + torch.stack(logterm) )
    optimizer.zero_grad()
    loss.backward()
    policy_loss.append(float(loss))
    optimizer.step()
    
    print('\r' + ' ' * 120, end='')
    print('\rTraining Loop:: {:04d}/{:04d}    Avg. Loss: {:3.2f}    Recent Outcomes: {}'.format(e + 1, episodes, np.mean(policy_loss[-20:]), outcomes[-10:]), end='')
    if (e + 1) % print_every == 0:
        print('\r' + ' ' * 120, end='')
        print('\rTraining Loop:: {:04d}/{:04d}    Avg. Loss: {:3.2f}    Recent Outcomes: {}'.format(e + 1, episodes, np.mean(policy_loss[-20:]), outcomes[-10:]))
    if (e + 1) % save_every == 0:
        torch.save(policy, MODEL_FILE)
    
    del loss

### Watch the Trained Agent Play against the Random Agent

In [None]:
%matplotlib notebook

challenge_policy = torch.load(MODEL_FILE)

def Challenge_Player_MCTS(game):
    mytree = MCTSNode(copy(game))
    for _ in range(1000):
        mytree.explore(challenge_policy)
    mytreenext, (v, nn_v, p, nn_p) = mytree.next(temperature=0.1)
    return mytreenext.game.last_move

gameplay=Play(ConnectN(**game_setting), player2=Policy_Player_MCTS, player1=Challenge_Player_MCTS)

---

Next: [Multi Player Tennis](./Tennis.ipynb)