In [1]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()

True

In [2]:
from IPython.display import clear_output

In [3]:
import os, json
from src.state_generator import generate_states

if not os.path.exists("data/states.json"):
    generate_states()
with open("data/states.json") as file:
    states_json = json.load(file)
    states_dict = states_json['states']

In [4]:
from collections import namedtuple, deque
import random

Transition = namedtuple('Transition', ('state', 'action','next_state', 'reward'))

class ReplayMemory():

  def __init__(self, capacity):
    self.memory = deque([], maxlen=capacity)
  
  def push(self, *args):
    self.memory.append(Transition(*args))
  
  def sample(self, batch_size):
    return random.sample(self.memory, batch_size)

  def __len__(self):
    return len(self.memory)

In [5]:
import torch.nn as nn
import torch.nn.functional as F

class DQN(nn.Module):
  def __init__(self, n_observations, n_actions):
    super(DQN, self).__init__()
    self.layer_input = nn.Linear(n_observations, 2048)
    self.layer_h_1 = nn.Linear(2048, 2048)
    #self.layer_h_2 = nn.Linear(512, 512)
    self.layer_v = nn.Linear(2048, 1)
    self.layer_a = nn.Linear(2048, n_actions)

  def forward(self, x):
    x = F.relu(self.layer_input(x))
    # x = self.dropout0(x)
    x = F.relu(self.layer_h_1(x))
    # x = self.dropout1(x)
    # x = F.relu(self.layer_h_2(x))
    # x = self.dropout2(x)
    
    v = self.layer_v(x)
    a = self.layer_a(x)
    
    q = v + a - a.mean()
    
    return q

In [6]:
from src.tictactoe import decode as state_decode

encoded_states = [s['encoded'] for s in states_dict]
decoded_states = [state_decode(es) for es in encoded_states]


In [7]:
turn_marks = {
    'x': {
        'x': 1,
        'o': -1,
        '-': 0,
    },
    'o': {
        'x': -1,
        'o': 1,
        '-': 0,
    }
}

def get_game_obs(state_dict: dict) -> list:
    turn_mark = 'x' if state_dict['turn'] % 2 == 0 else 'o'
    return [turn_marks[turn_mark][e] for e in state_dict['encoded']]

all_games_obs = [get_game_obs(sd) for sd in states_dict]
all_games_actions = [sd['actions'] for sd in states_dict]
all_games_isdone = [sd['done'] for sd in states_dict]
all_games_winner = [sd['winner'] for sd in states_dict]
all_games_possible_wins = [sd['possible_wins'] for sd in states_dict]

In [8]:
n_observations = 9
n_actions = 9

REPLAY_SIZE = len(all_games_obs) * n_actions**2

In [9]:
memory = ReplayMemory(REPLAY_SIZE)
n = len(states_dict)
for sd_id, sd in enumerate(states_dict):
    if sd_id % 500 == 499:
        print(f"{100 * sd_id/n}%")
    obs = torch.tensor(get_game_obs(sd), dtype=torch.float32).to(device).unsqueeze(0)
    for i in range(n_actions):
        
        # if done, then cannot take an action. So, continue
        if sd['done']:
            continue

        action = torch.tensor([[i]], dtype=torch.long).to(device)
        oponent_sd = states_dict[sd['actions'][i]]
        turn_mark = 'x' if sd['turn'] % 2 == 0 else 'o'
        if oponent_sd['done']:
            reward = 0
            if oponent_sd['winner'] != '-':
                # win or lose
                reward += 10 if oponent_sd['winner'] == turn_mark else 0
            # draw
            reward += 0
            
            reward = torch.tensor([reward], dtype=torch.float32).to(device)
            memory.push(obs, action, None, reward)
        else:
            if sd == oponent_sd:
                #if random.random() > 0.7:
                #    #ignore some invalid moves
                #    continue
                # invalid move
                reward = -10
                reward = torch.tensor([reward], dtype=torch.float32).to(device)
                memory.push(obs, action, None, reward)
            else:
                for j in(range(n_actions)):
                    
                    # next_states for each oponent action based on the "player" action.
                    next_sd = states_dict[oponent_sd['actions'][j]]
                    if next_sd == oponent_sd:
                        # invalid oponent action
                        continue
                    next_obs = torch.tensor(get_game_obs(next_sd), dtype=torch.float32).to(device).unsqueeze(0)

                    reward = 0
                    #else:
                    #   reward += next_sd['possible_wins']

                    if next_sd['done'] and next_sd['winner'] != '-':
                        reward += -10 if next_sd['winner'] != turn_mark else 0
                        

                    reward = torch.tensor([reward], dtype=torch.float32).to(device)

                    # ('state', 'action','next_state', 'reward')
                    memory.push(obs, action, next_obs, reward)

9.109163928441037%
18.236582694414018%
27.364001460387%
36.491420226359985%
45.61883899233297%
54.74625775830595%
63.873676524278935%
73.00109529025191%
82.1285140562249%
91.25593282219788%


In [10]:
len(memory)

71019

In [52]:
memory.sample(50)

[Transition(state=tensor([[ 0.,  0.,  1.,  0.,  0.,  0.,  0.,  0., -1.]], device='cuda:0'), action=tensor([[0]], device='cuda:0'), next_state=tensor([[ 1.,  0.,  1.,  0.,  0.,  0., -1.,  0., -1.]], device='cuda:0'), reward=tensor([0.], device='cuda:0')),
 Transition(state=tensor([[ 1.,  0.,  0., -1.,  0.,  0.,  0., -1.,  1.]], device='cuda:0'), action=tensor([[5]], device='cuda:0'), next_state=tensor([[ 1.,  0.,  0., -1., -1.,  1.,  0., -1.,  1.]], device='cuda:0'), reward=tensor([0.], device='cuda:0')),
 Transition(state=tensor([[ 0., -1.,  1.,  0., -1.,  0.,  0.,  0.,  0.]], device='cuda:0'), action=tensor([[6]], device='cuda:0'), next_state=tensor([[-1., -1.,  1.,  0., -1.,  0.,  1.,  0.,  0.]], device='cuda:0'), reward=tensor([0.], device='cuda:0')),
 Transition(state=tensor([[ 0.,  1., -1.,  1., -1.,  0.,  0.,  0., -1.]], device='cuda:0'), action=tensor([[4]], device='cuda:0'), next_state=None, reward=tensor([-10.], device='cuda:0')),
 Transition(state=tensor([[ 1., -1., -1.,  1.,

In [64]:
BATCH_SIZE = 71019 // 500 #343224 // 100 #1024
GAMMA = 0.95
TAU = 0.00005
LR = 0.0001

In [65]:
import torch.optim as optim

policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.RMSprop(policy_net.parameters(), lr=LR) # amsgrad? r:
global_step = 0

episode_durations = []

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    
    optimizer.zero_grad()

    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions)) # print after

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    state_action_values = policy_net(state_batch) # print after
    state_action_values = state_action_values.gather(1, action_batch)

    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
    expected_next_action_values = reward_batch + GAMMA * next_state_values

    criterion = nn.MSELoss()
    loss = criterion(state_action_values, expected_next_action_values.unsqueeze(1))
    
    loss.backward()

    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()
    
    return loss.item()

In [66]:
from tensorboardX import SummaryWriter
import time
import numpy as np

timestr = time.strftime("%Y_%m_%d_%H_%M_%S")

In [67]:
if torch.cuda.is_available():
  max_epoch = 20000
else:
  max_epoch = 50

h_params = {
    'REPLAY_SIZE': REPLAY_SIZE,
    'BATCH_SIZE': BATCH_SIZE,
    'GAMMA': GAMMA,
    'TAU': TAU,
    'LR': LR,
}

In [68]:
from itertools import count
import numpy as np

ep_losses = []

with SummaryWriter(log_dir=f'duel_runs/{timestr}') as writer:
    
    while global_step < max_epoch:


        loss_scalar = optimize_model()
        ep_losses.append(loss_scalar)

        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()

        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)
            

        loss_mean = np.mean(ep_losses)
        
        writer.add_hparams(
            h_params,
            {
                'i_episode': global_step,
                'Memory_len': len(memory),
                'Loss': loss_mean, # loss_scalar,
            }, name='.', global_step=global_step,
        )
        ep_losses = []
        
        global_step += 1 
        writer.flush()

In [69]:
import os

save_model_dir = './duel_saved_models'

if not os.path.exists(f'{save_model_dir}'):
    os.mkdir(f'{save_model_dir}')
if not os.path.exists(f'{save_model_dir}/{timestr}'):
    os.mkdir(f'{save_model_dir}/{timestr}')

torch.save(policy_net.state_dict(), f'{save_model_dir}/{timestr}/policy_net')
torch.save(target_net.state_dict(), f'{save_model_dir}/{timestr}/target_net')

In [70]:
"""_MODEL_DATE_NAME = '2024_04_25_12_56_11'

model = DQN(n_observations, n_actions).to(device)
load_dict = torch.load(f'./duel_saved_models/{_MODEL_DATE_NAME}/policy_net')
model.load_state_dict(load_dict)"""

"_MODEL_DATE_NAME = '2024_04_25_12_56_11'\n\nmodel = DQN(n_observations, n_actions).to(device)\nload_dict = torch.load(f'./duel_saved_models/{_MODEL_DATE_NAME}/policy_net')\nmodel.load_state_dict(load_dict)"

In [71]:
#import matplotlib.pyplot as plt
#from IPython import display

#_, ax = plt.subplots(1, 1)

#img = ax.imshow(env.render())


while True:
  # policy_net.eval()
  state = states_dict[0]

  for t in count():
    obs = torch.tensor(get_game_obs(state), dtype=torch.float32, device=device).unsqueeze(0)
    with torch.no_grad():
      action = policy_net(obs)
      action = action.max(1)[1].view(1,1)
    
    next_state = states_dict[state['actions'][action.item()]]
    done = next_state['done']
    
    state = next_state
    
    for line in state_decode(state['encoded']):
      print(line)
    print(state)

    if done:
      print(obs)
      input()
      break
    
    input()

['-', '-', '-']
['-', 'x', '-']
['-', '-', '-']
{'id': 4763, 'encoded': '----x----', 'actions': [4764, 4915, 5015, 5076, 4763, 5111, 5124, 5134, 5141], 'done': False, 'turn': 1, 'winner': '-', 'possible_wins': 0}
['-', '-', '-']
['-', 'x', '-']
['-', '-', 'o']
{'id': 5141, 'encoded': '----x---o', 'actions': [1867, 3254, 4163, 4759, 5141, 5142, 5143, 5144, 5141], 'done': False, 'turn': 2, 'winner': '-', 'possible_wins': 0}
['-', 'x', '-']
['-', 'x', '-']
['-', '-', 'o']
{'id': 3254, 'encoded': '-x--x---o', 'actions': [2385, 3254, 2687, 2909, 3254, 3171, 3217, 3242, 3254], 'done': False, 'turn': 3, 'winner': '-', 'possible_wins': 0}
['-', 'x', '-']
['-', 'x', '-']
['-', 'o', 'o']
{'id': 3242, 'encoded': '-x--x--oo', 'actions': [1839, 3242, 3233, 3238, 3242, 3243, 3244, 3242, 3242], 'done': False, 'turn': 4, 'winner': '-', 'possible_wins': 0}
['-', 'x', '-']
['-', 'x', '-']
['x', 'o', 'o']
{'id': 3244, 'encoded': '-x--x-xoo', 'actions': [2383, 3244, 2685, 2907, 3244, 3169, 3244, 3244, 324

KeyboardInterrupt: Interrupted by user

In [82]:
#import matplotlib.pyplot as plt
#from IPython import display

#_, ax = plt.subplots(1, 1)

#img = ax.imshow(env.render())


while True:
  # policy_net.eval()
  state = states_dict[0]
  model_turn = True if random.random() > 0.5 else False
  for t in count():
    clear_output()
    if model_turn:
      obs = torch.tensor(get_game_obs(state), dtype=torch.float32, device=device).unsqueeze(0)
      with torch.no_grad():
        action = policy_net(obs)
        action = action.max(1)[1].view(1,1)
      
      next_state = states_dict[state['actions'][action.item()]]
      
      
    else:
      for line in state_decode(state['encoded']):
        print(line, flush = True)
      action = int(input("player move: "))
      next_state = states_dict[state['actions'][action]]
    done = next_state['done']
    state = next_state
    model_turn = not model_turn
    if done:
      for line in state_decode(state['encoded']):
        print(line, flush = True)
      input("press to continue...")
      break

['-', '-', '-']


['-', '-', '-']
['-', '-', '-']
