In [1]:
import numpy as np
import torch
from torch import nn
from pettingzoo.classic import tictactoe_v3
import random
from tqdm.auto import tqdm
from collections import namedtuple, deque
import matplotlib.pyplot as plt

In [2]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Input layer size: 3x3x2 (18)
        # Hidden layers: 27x27x27
        # Output layer: 9
        
        self.model = nn.Sequential(
            nn.Linear(9, 64),
            nn.LeakyReLU(inplace=True),
            nn.Linear(64, 64),
            nn.LeakyReLU(inplace=True),
            nn.Linear(64, 64),
            nn.LeakyReLU(inplace=True),
            nn.Linear(64, 9),
            nn.LeakyReLU(inplace=True),
        )
        
    def forward(self, obs, state = None, info={}):
        assert obs.shape[-1] == 9
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float)
        batch = obs.shape[0]
        logits = self.model(obs.view(batch,-1))
        return logits

In [3]:
Transition = namedtuple("Transition", ("state", "action", "next_state", "reward"))

In [4]:
def get_state(obs, player = 'player_1'):
    state = obs['observation'][:,:,0] - obs['observation'][:,:,1]
    if player == 'player_1':
        return state.reshape(1,-1)
    else:
        return -state.reshape(1,-1)

In [8]:
class ReplayMemory(object):
    
    def __init__(self, capacity=10e5):
        self.capacity = int(capacity)
        self.memory = deque([], maxlen=self.capacity)
    
    def push(self, *args):
        self.memory.append(Transition(*args))
        
    def sample(self, batch_size):
        assert batch_size <= self.capacity, "Batch size larger than capacity"
        return random.sample(self.memory, batch_size)

    def initialize(self, env, player = 'player_1'):
        # Fill with random transitions
        players = env.agents
        initial_state = np.ones(9)*-1
        previous_state = initial_state
        previous_action = -1
        previous_reward = 0
        while len(self.memory) < self.capacity:
            env.reset()
            obs, reward, done, truncation, info = env.last()
            while done == False:
                p = env.agent_selection
                if p != player:
                    a = random.sample(range(9), 1)[0]
                    env.step(a)
                    obs, reward, done, truncation, info = env.last()
                    continue
                    
                    
                state = get_state(obs, p)
                if previous_action != -1:
                    self.push(previous_state, previous_action, state, r)
                
                a = random.sample(range(9), 1)[0]
                if obs['action_mask'][a] == 1:
                    env.step(a)
                    obs, reward, done, truncation, info = env.last()
                    r = env.rewards[p]
                else:
                    done = True
                    self.push(state, a, np.ones(state.shape), -10)
                    
                previous_state = state
                previous_action = a
                previous_reward = r
                
    
    def __len__(self):
        return len(self.memory)

In [13]:
# hyperparameters
episodes = 500000
batch_size = 2
eps_start = 0.9
eps_end = 0.1
rate = 1/episodes * np.log(eps_end/eps_start)
capacity = 200000
lr = 0.00001
epsilon = 0.1
gamma = 1
training_freq = 1
target_update_freq = 10000
eval_episodes = 2000
eval_freq = 5000
save_freq = 5000

In [10]:
# Initializations
env = tictactoe_v3.env()
env.reset()
obs, reward, termination, truncation, info = env.last()
state_shape = obs['observation'].shape
action_shape = obs['action_mask'].shape
X_p_net = Net()
O_p_net = Net()
policy_nets = {'player_1': X_p_net, 'player_2': O_p_net}
X_target = Net()
X_target.load_state_dict(X_p_net.state_dict())
O_target = Net()
O_target.load_state_dict(O_p_net.state_dict())
target_nets = {'player_1': X_target, 'player_2': O_target}

X_memory = ReplayMemory(capacity)
O_memory = ReplayMemory(capacity)

memories = {'player_1': X_memory, 'player_2': O_memory}

optimizer_X = torch.optim.AdamW(policy_nets['player_1'].parameters(), lr=lr, amsgrad=True)
optimizer_O = torch.optim.AdamW(policy_nets['player_2'].parameters(), lr=lr, amsgrad=True)

optimizers = {'player_1': optimizer_X, 'player_2': optimizer_O}

criterion = nn.SmoothL1Loss()

In [11]:
def evaluate(env, n, policy, player = 'player_1'):
    # We evaluate against a random opponent
    r = 0
    for _ in range(n):
        env.reset()
        obs, reward, done, truncation, info = env.last()
        while done == False:
            p = env.agent_selection
            
            # Random
            if p != player:
                a = int(np.random.choice(np.nonzero(obs['action_mask'])[0], 1))
                env.step(a)
                obs, reward, done, truncation, info = env.last()
                continue
            
            # Correct Player
            state = get_state(obs, player)
            output = policy(torch.Tensor(state))
            a = int(torch.argmax(output))
            if obs['action_mask'][a] == 1:
                env.step(a)
                r += env.rewards[p]
            else:
                done = True
                continue
                
            obs, reward, done, truncation, info = env.last()
    r /= n
    return r

def evaluate_random(env, n, player = 'player_1'):
    r = 0
    for _ in range(n):
        env.reset()
        obs, reward, done, truncation, info = env.last()
        while done == False:
            p = env.agent_selection
            a = int(np.random.choice(np.nonzero(obs['action_mask'])[0], 1))
            env.step(a)
            obs, reward, done, truncation, info = env.last()
            print(p, reward)
            print(env.rewards)
            
    return r/n
    
def select_action(state, policy, eps_tresh, greedy=False):
    assert state.shape[0] == 1
    assert state.shape[1] == 9
    sample = random.random()
    if greedy == True or sample>eps_tresh:
        action = int(torch.argmax(policy(state)))
        return action
    else:
        return np.random.choice(range(9))
        

In [14]:
time_steps = 0
evals = {'player_1':[], 'player_2':[]}
players = env.agents
previous_state = {}
previous_action = {}
previous_reward = {}
initial_state = np.zeros(9)

for steps_done in tqdm(range(episodes), desc = "Episodes", position = 0, leave = True):
    
    for p in players:
        previous_action[p] = -1
        previous_reward[p] = 0
        
    env.reset()
    obs, reward, done, truncation, info = env.last()
    states = []
    i = 0
    ended = True
    last_player = "player_1"
    while done == False:
        
        p = env.agent_selection
        last_player=p
        state = get_state(obs, p)
        states.append(state)

        
        if i >= 2:
            memories[p].push(states[i-2], previous_action[p], states[i], previous_reward[p])
        
        eps = eps_start * np.exp(rate*steps_done)
        a = select_action(state, policy_nets[p], eps)
        
        # Legal action
        if obs['action_mask'][a] == 1:
            env.step(a)
            r = env.rewards[p]
            
        # Illegal action
        else:
            done = True
            memories[p].push(states[i], a, states[i], -1)
            ended = False
            continue
            
        obs, reward, done, truncation, info = env.last()
        
        previous_action[p] = a
        previous_reward[p] = r
        
        time_steps += 1
        i += 1
        
        
        if time_steps%training_freq==0:
            optimizers[p].zero_grad()
            if len(memories[p]) >= batch_size:
                batch = memories[p].sample(batch_size)
            else:
                continue
            target_input = torch.empty(batch_size, 9)
            policy_input = torch.empty(batch_size, 9)
            rewards = torch.empty(batch_size)
            actions = np.zeros(batch_size)
            for ind, t in enumerate(batch):
                policy_input[ind,:] = torch.tensor(t.state)
                target_input[ind,:] = torch.tensor(t.next_state)
                rewards[ind] = t.reward
                actions[ind] = t.action
            
            q_values = policy_nets[p](policy_input)[np.arange(batch_size), actions]
            next_state_q_values = torch.max(target_nets[p](target_input), 1)[0]*gamma + rewards[:]
            loss = criterion(q_values, next_state_q_values)
            loss.backward()
            torch.nn.utils.clip_grad_value_(policy_nets[p].parameters(), 100)
            optimizers[p].step()
            
        if time_steps % target_update_freq == 0:
            for p, net in target_nets.items():
                net.load_state_dict(policy_nets[p].state_dict())
        
    # Update terminal states:
    if ended:
        _ = env.agent_selection
        terminal = get_state(obs, _)
        # update player_1
        if p == "player_1": # final move made by player_1
            memories["player_1"].push(states[-1], previous_action["player_1"], terminal, env.rewards["player_1"])
            memories["player_2"].push(states[-2], previous_action["player_2"], terminal, env.rewards["player_2"])
        else: # final move made by player_2
            memories["player_1"].push(states[-2], previous_action["player_1"], terminal, env.rewards["player_1"])
            memories["player_2"].push(states[-1], previous_action["player_2"], terminal, env.rewards["player_2"])
        
    if steps_done%eval_freq == 0:
        for p in evals.keys():
            e = evaluate(env, eval_episodes, policy_nets[p], p)
            #status[p].set_description_str(f'{p} win rate = {e*100}%')
            tqdm.write(f'{p} win rate = {e*100}%')
            evals[p].append(e)

    if steps_done%save_freq == 0:
        for p in evals.keys():
            torch.save(policy_nets[p].state_dict(), f"nn_params/TTT_{p}_policy_network_state_dict.pt")
            np.save(f"nn_params/TTT_{p}_evaluations.py", evals[p])


Episodes:   0%|          | 0/500000 [00:00<?, ?it/s]

player_1 win rate = 0.0%
player_2 win rate = 0.0%


KeyboardInterrupt: 