In [1]:
import numpy as np
import torch
from torch import nn
from pettingzoo.classic import tictactoe_v3
import random
from tqdm.auto import tqdm
from collections import namedtuple, deque
import matplotlib.pyplot as plt

In [2]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Input layer size: 3x3x2 (18)
        # Hidden layers: 27x27x27
        # Output layer: 9
        
        self.model = nn.Sequential(
            nn.Linear(9, 64),
            nn.LeakyReLU(inplace=True),
            nn.Linear(64, 64),
            nn.LeakyReLU(inplace=True),
            nn.Linear(64, 64),
            nn.LeakyReLU(inplace=True),
            nn.Linear(64, 9),
            nn.LeakyReLU(inplace=True),
        )
        
    def forward(self, obs, state = None, info={}):
        assert obs.shape[-1] == 9
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float)
        batch = obs.shape[0]
        logits = self.model(obs.view(batch,-1))
        return logits

In [3]:
Transition = namedtuple("Transition", ("state", "action", "next_state", "reward"))

In [4]:
def get_state(obs, player = 'player_1'):
    state = obs['observation'][:,:,0] - obs['observation'][:,:,1]
    if player == 'player_1':
        return state.reshape(1,-1)
    else:
        return -state.reshape(1,-1)

In [5]:
class ReplayMemory(object):
    
    def __init__(self, capacity=10e5):
        self.capacity = int(capacity)
        self.memory = deque([], maxlen=self.capacity)
    
    def push(self, *args):
        self.memory.append(Transition(*args))
        
    def sample(self, batch_size):
        assert batch_size <= self.capacity, "Batch size larger than capacity"
        return random.sample(self.memory, batch_size)

    def initialize(self, env, player = 'player_1'):
        # Fill with random transitions
        players = env.agents
        initial_state = np.ones(9)*-1
        previous_state = initial_state
        previous_action = -1
        previous_reward = 0
        while len(self.memory) < self.capacity:
            env.reset()
            obs, reward, done, truncation, info = env.last()
            while done == False:
                p = env.agent_selection
                if p != player:
                    a = random.sample(range(9), 1)[0]
                    env.step(a)
                    obs, reward, done, truncation, info = env.last()
                    continue
                    
                    
                state = get_state(obs, p)
                if previous_action != -1:
                    self.push(previous_state, previous_action, state, r)
                
                a = random.sample(range(9), 1)[0]
                if obs['action_mask'][a] == 1:
                    env.step(a)
                    obs, reward, done, truncation, info = env.last()
                    r = env.rewards[p]
                else:
                    done = True
                    self.push(state, a, np.ones(state.shape), -10)
                    
                previous_state = state
                previous_action = a
                previous_reward = r
                
    
    def __len__(self):
        return len(self.memory)

In [12]:
# hyperparameters
episodes = 500000
batch_size = 8
eps_start = 1
eps_end = 0.1
rate = 1/100000 * np.log(eps_end/eps_start)
capacity = 200000
lr = 0.00001
epsilon = 0.1
gamma = 1
training_freq = 1
target_update_freq = 100
eval_episodes = 2000
eval_freq = 2000
save_freq = 2000

In [13]:
# Initializations
env = tictactoe_v3.env()
env.reset()
obs, reward, termination, truncation, info = env.last()
state_shape = obs['observation'].shape
action_shape = obs['action_mask'].shape

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
X_p_net = Net()
X_p_net.to(device)
O_p_net = Net()
O_p_net.to(device)
policy_nets = {'player_1': X_p_net, 'player_2': O_p_net}
X_target = Net()
X_target.to(device)
X_target.load_state_dict(X_p_net.state_dict())
O_target = Net()
O_target.to(device)
O_target.load_state_dict(O_p_net.state_dict())
target_nets = {'player_1': X_target, 'player_2': O_target}

X_memory = ReplayMemory(capacity)
O_memory = ReplayMemory(capacity)

memories = {'player_1': X_memory, 'player_2': O_memory}

optimizer_X = torch.optim.AdamW(policy_nets['player_1'].parameters(), lr=lr, amsgrad=True)
optimizer_O = torch.optim.AdamW(policy_nets['player_2'].parameters(), lr=lr, amsgrad=True)

optimizers = {'player_1': optimizer_X, 'player_2': optimizer_O}

criterion_1 = nn.MSELoss()
criterion_2 = nn.MSELoss()
crits = {'player_1':criterion_1, 'player_2': criterion_2}


In [14]:
def evaluate(env, n, policy, player = 'player_1'):
    # We evaluate against a random opponent
    r = 0
    for _ in range(n):
        env.reset()
        obs, reward, done, truncation, info = env.last()
        while done == False:
            p = env.agent_selection
            
            # Random
            if p != player:
                a = int(np.random.choice(np.nonzero(obs['action_mask'])[0], 1))
                env.step(a)
                obs, reward, done, truncation, info = env.last()
                continue
            
            # Correct Player
            state = torch.tensor(get_state(obs, player), device = device, dtype=torch.float32)
            output = policy(torch.Tensor(state))
            a = int(torch.argmax(output))
            if obs['action_mask'][a] == 1:
                env.step(a)
                r += env.rewards[p]
            else:
                done = True
                continue
                
            obs, reward, done, truncation, info = env.last()
    r /= n
    return r

def evaluate_random(env, n, player = 'player_1'):
    r = 0
    for _ in range(n):
        env.reset()
        obs, reward, done, truncation, info = env.last()
        while done == False:
            p = env.agent_selection
            a = int(np.random.choice(np.nonzero(obs['action_mask'])[0], 1))
            env.step(a)
            obs, reward, done, truncation, info = env.last()
            print(p, reward)
            print(env.rewards)
            
    return r/n
    
def select_action(state, policy, eps_tresh, mask, greedy=False):
    assert state.shape[0] == 1
    assert state.shape[1] == 9
    sample = random.random()
    if greedy == True or sample>eps_tresh:
        state = torch.tensor(state, device = device, dtype = torch.float32)
        legal_actions = np.nonzero(mask)[0]
        action = int(torch.argmax(policy(state)[0,legal_actions]))
        return action
    else:
        return int(np.random.choice(np.nonzero(mask)[0], 1))

In [15]:
time_steps = 0
evals = {'player_1':[], 'player_2':[]}
players = env.agents
previous_state = {}
previous_action = {}
previous_reward = {}
progress_bar = tqdm(range(episodes), desc = "Episodes", position = 0, leave = True)
epsilon_bar = tqdm(bar_format = '{desc}', position = 1)
p1_bar = tqdm(bar_format = '{desc}', position = 2)
p2_bar = tqdm(bar_format = '{desc}', position = 3)
bars = {'player_1': p1_bar, 'player_2': p2_bar}

for steps_done in progress_bar:
    
    for p in players:
        previous_action[p] = -1
        previous_reward[p] = 0
    
    states = []
        
    env.reset()
    obs, reward, done, truncation, info = env.last()
    state = get_state(obs, 'player_1')
    states.append(state)
    a = select_action(state, policy_nets['player_1'], 1, obs['action_mask'])
    previous_action['player_1'] = a
    previous_reward['player_1'] = 0
    env.step(a)
    obs, reward, done, truncation, info = env.last()

    i = 1
    ended = True
    while done == False:
        
        p = env.agent_selection
        state = get_state(obs, p)
        states.append(state)
        
        if i >= 2:
            memories[p].push(states[i-2], previous_action[p], states[i], previous_reward[p])
        
        eps = max(eps_start * np.exp(rate*steps_done), eps_end)
        a = select_action(state, policy_nets[p], eps, obs['action_mask'])
        
        # Legal action
        if obs['action_mask'][a] == 1:
            env.step(a)
            r = env.rewards[p]
            
        # Illegal action
        else:
            done = True
            memories[p].push(states[i], a, states[i], -1)
            ended = False
            continue
            
        obs, reward, done, truncation, info = env.last()
        
        previous_action[p] = a
        previous_reward[p] = r
        
        time_steps += 1
        i += 1       
        
        if time_steps%training_freq==0:
            if len(memories[p]) >= batch_size:
                batch = memories[p].sample(batch_size)
            else:
                continue
            target_input = torch.empty(batch_size, 9, device=device)
            policy_input = torch.empty(batch_size, 9, device = device)
            rewards = torch.empty(batch_size,device = device)
            actions = np.zeros(batch_size)
            for ind, t in enumerate(batch):
                policy_input[ind,:] = torch.tensor(t.state)
                target_input[ind,:] = torch.tensor(t.next_state)
                rewards[ind] = t.reward
                actions[ind] = t.action
            
            q_values = policy_nets[p](policy_input)[np.arange(batch_size), actions]
            next_state_q_values = torch.max(target_nets[p](target_input), 1)[0]*gamma + rewards[:]
            loss = crits[p](q_values, next_state_q_values)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy_nets[p].parameters(), max_norm = 1, norm_type = 2)
            optimizers[p].step()
            optimizers[p].zero_grad()
            
        if time_steps % target_update_freq == 0:
            for p, net in target_nets.items():
                net.load_state_dict(policy_nets[p].state_dict())
        
    # Update terminal states:
    if ended:
        _ = env.agent_selection
        terminal = get_state(obs, _)
        # update player_1
        if p == "player_1": # final move made by player_1
            memories["player_1"].push(states[-1], previous_action["player_1"], terminal, env.rewards["player_1"])
            memories["player_2"].push(states[-2], previous_action["player_2"], terminal, env.rewards["player_2"])
        else: # final move made by player_2
            memories["player_1"].push(states[-2], previous_action["player_1"], terminal, env.rewards["player_1"])
            memories["player_2"].push(states[-1], previous_action["player_2"], terminal, env.rewards["player_2"])
        
    if steps_done%eval_freq == 0:
        for p in evals.keys():
            e = evaluate(env, eval_episodes, policy_nets[p], p)
            bars[p].set_description_str(f'{p} win rate = {e*100}%')
            evals[p].append(e)
        epsilon_bar.set_description_str(f'epsilon = {eps}')

    if steps_done%save_freq == 0:
        for p in evals.keys():
            torch.save(policy_nets[p].state_dict(), f"nn_params/TTT_{p}_policy_network_state_dict.pt")
            np.save(f"nn_params/TTT_{p}_evaluations.py", evals[p])

Episodes:   0%|          | 0/500000 [00:00<?, ?it/s]







KeyboardInterrupt: 

In [16]:
policy_nets['player_2'](np.array([[0,0,0,0,0,1,0,0,0]]))

tensor([[269.9980, 257.8204, 257.2065, 231.2742, 276.9026, 240.5143, 231.4798,
         214.4390, 245.5986]], grad_fn=<LeakyReluBackward1>)

In [11]:
print(policy_nets['player_2'].model[0].weight.grad)

tensor([[-3.1734e-05, -1.0977e-05, -5.9294e-05, -1.8750e-05,  1.7679e-05,
          6.3044e-05, -6.3339e-05,  2.2786e-05, -1.7977e-05],
        [ 6.3012e-05, -1.0142e-05,  1.0522e-05, -1.0185e-05,  1.0203e-05,
          9.5310e-06,  5.9029e-07, -8.2633e-05,  7.2588e-05],
        [ 8.1377e-04,  8.1429e-04, -7.8897e-04,  8.2483e-04, -8.2482e-04,
         -8.1852e-04,  4.2058e-06,  7.8406e-04,  1.9604e-05],
        [-1.9794e-03, -1.9452e-03,  1.9585e-03, -1.9318e-03,  1.9318e-03,
          1.9586e-03, -1.3434e-05, -1.9590e-03,  4.3454e-07],
        [ 3.6718e-05, -3.9206e-06,  7.8065e-05,  3.4602e-05, -3.3649e-05,
          1.0265e-05,  3.7683e-06, -1.1349e-04,  6.2840e-05],
        [ 1.9009e-04,  1.9046e-04, -1.9007e-04,  1.9034e-04, -1.9033e-04,
         -1.9095e-04,  7.2903e-07,  1.9034e-04,  4.7874e-08],
        [-3.0350e-03, -3.0365e-03,  3.0363e-03, -3.0365e-03,  3.0365e-03,
          3.0364e-03, -2.9951e-07, -3.0372e-03,  9.8304e-07],
        [ 1.2937e-05,  5.6012e-05, -1.4292e-04, 

In [None]:
np