In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()

In [None]:
from IPython.display import clear_output

In [None]:
import os, json
from src.state_generator import generate_states

if not os.path.exists("data/states.json"):
    generate_states()
with open("data/states.json") as file:
    states_json = json.load(file)
    states_dict = states_json['states']

In [None]:
from collections import namedtuple, deque
import random

Transition = namedtuple('Transition', ('state', 'action','next_state', 'reward'))

class ReplayMemory():

  def __init__(self, capacity):
    self.memory = deque([], maxlen=capacity)
  
  def push(self, *args):
    self.memory.append(Transition(*args))
  
  def sample(self, batch_size):
    return random.sample(self.memory, batch_size)

  def __len__(self):
    return len(self.memory)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class DQN(nn.Module):
  def __init__(self, n_observations, n_actions):
    super(DQN, self).__init__()
    n = 1028
    self.layer_input = nn.Linear(n_observations, n)
    self.layer_h_1 = nn.Linear(n, n)
    self.layer_h_2 = nn.Linear(n, n)
    self.layer_v = nn.Linear(n, 1)
    self.layer_a = nn.Linear(n, n_actions)

  def forward(self, x):
    x = F.relu(self.layer_input(x))
    # x = self.dropout0(x)
    x = F.relu(self.layer_h_1(x))
    # x = self.dropout1(x)
    x = F.relu(self.layer_h_2(x))
    # x = self.dropout2(x)
    
    v = self.layer_v(x)
    a = self.layer_a(x)
    
    q = v + a - a.mean()
    
    return q

In [None]:
turn_marks = {
    'x': {
        'x': 1,
        'o': -1,
        '-': 0,
    },
    'o': {
        'x': -1,
        'o': 1,
        '-': 0,
    }
}

def get_game_obs(state_dict: dict) -> list:
    """
    turn_mark = 'x' if state_dict['turn'] % 2 == 0 else 'o'
    return [turn_marks[turn_mark][e] for e in state_dict['encoded']]
    """
    x = [1 if e == 'x' else 0 for e in state_dict['encoded']]
    o = [1 if e == 'o' else 0 for e in state_dict['encoded']]
    turn_mark = [state_dict['turn'] % 2]

    return x + o + turn_mark

# all_games_obs = [get_game_obs(sd) for sd in states_dict]
# get_game_obs(states_dict[2]), states_dict[2]
sample_obs = get_game_obs(states_dict[2])

In [None]:
n_observations = len(sample_obs)
n_actions = 9

REPLAY_SIZE = len(states_dict) * n_actions**2

memory = ReplayMemory(REPLAY_SIZE)

In [None]:
BATCH_SIZE = 20 #343224 // 100 #1024
GAMMA = 0.95
TAU = 0.005
LR = 0.0003
EPS = 0.5
EPS_DECAY = 0.9999
EPS_MIN = 0.1

epsilon = EPS

In [None]:
import torch.optim as optim

policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.RMSprop(policy_net.parameters(), lr=LR) # amsgrad? r:
global_step = 0

episode_durations = []

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return

    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions)) # print after

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
        
    optimizer.zero_grad()

    expected_next_action_values = reward_batch + GAMMA * next_state_values
    expected_next_action_values = expected_next_action_values.unsqueeze(1)

    state_action_values = policy_net(state_batch).gather(1, action_batch)
    
    criterion = nn.MSELoss()
    loss = criterion(state_action_values, expected_next_action_values)
    
    loss.backward()

    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()
    
    return loss.item()

In [None]:
from tensorboardX import SummaryWriter
import time
import numpy as np

timestr = time.strftime("%Y_%m_%d_%H_%M_%S")

In [None]:
if torch.cuda.is_available():
  max_epoch = 100_000
else:
  max_epoch = 50

h_params = {
    'REPLAY_SIZE': REPLAY_SIZE,
    'BATCH_SIZE': BATCH_SIZE,
    'GAMMA': GAMMA,
    'TAU': TAU,
    'LR': LR,
}

In [None]:
WIN_REWARD = 1
LOST_REWARD = -1
DRAW_REWARD = -0.5
INVALID_MOVE_REWARD = -3.0
STEP_REWARD = -0.1

In [None]:
def tensor_reward(reward):
    return torch.tensor([reward], dtype=torch.float32).to(device)

In [21]:
from itertools import count
import numpy as np

with SummaryWriter(log_dir=f'duel_runs/{timestr}') as writer:
    
    while global_step < max_epoch:
        
        ep_losses = []
        ep_rewards = []
        ep_qvalues = []
        ep_epsilon = []

        state = states_dict[0]
        obs = torch.tensor(get_game_obs(state), dtype=torch.float32, device=device).unsqueeze(0)

        ai_turn = random.random() < 0.5
        prev_obs = None
        prev_action = None
        for t in count():
            ai_turn = not ai_turn
            if ai_turn:
                with torch.no_grad():
                    q_values = policy_net(obs)
                    max_q_value = q_values.max(1)
                    qvalue = max_q_value[0].item()
                    action = max_q_value[1].view(1,1).item()
                    
                # epsilon = epsilon * EPS_DECAY
                epsilon = EPS - (global_step / max_epoch) if epsilon > EPS_MIN else EPS_MIN
                
                if random.random() < epsilon:
                    action = random.randint(0, 8)
            else:
                actions = state['actions']
                state_id = state["id"]
                valid_actions = [a for a in actions if a != state_id]
                action = actions.index(random.choice(valid_actions))

            next_state = states_dict[state['actions'][action]]
            next_obs = torch.tensor(get_game_obs(next_state), dtype=torch.float32, device=device).unsqueeze(0)
            done = next_state['done']
            
            turn_mark = 'x' if state ['turn'] % 2 == 0 else 'o'
            """
            if done:
                # print("done")
                if next_state['winner'] != '-':
                    # print("win")
                    # win or lose
                    if next_state['winner'] == turn_mark:
                        reward = WIN_REWARD
                else:
                    reward = DRAW_REWARD
                next_state = None
            elif state == next_state:
                next_state = None
                reward = INVALID_MOVE_REWARD
            """
            action = torch.tensor([[action]], dtype=torch.long).to(device)
            #reward = tensor_reward(reward)

            if state == next_state:
                next_state = None
                reward = INVALID_MOVE_REWARD
                memory.push(obs, action, None, tensor_reward(reward))
            elif done:
                if next_state['winner'] != '-':
                    if next_state['winner'] == turn_mark:
                        reward = WIN_REWARD
                        memory.push(obs, action, None, tensor_reward(reward))
                        memory.push(prev_obs, prev_action, None, tensor_reward(LOST_REWARD))
                    else:
                        reward = LOST_REWARD
                        memory.push(obs, action, None, tensor_reward(reward))
                        memory.push(prev_obs, prev_action, None, tensor_reward(WIN_REWARD))
                else:
                    reward = DRAW_REWARD
                    memory.push(obs, action, None, tensor_reward(reward))
                    memory.push(prev_obs, prev_action, None, tensor_reward(reward))
                next_state = None
            else:
                reward = STEP_REWARD
                if prev_obs != None:
                    memory.push(prev_obs, prev_action, next_obs, tensor_reward(reward))
            
            if ai_turn:
                ep_rewards.append(reward)
                ep_qvalues.append(qvalue)
                ep_epsilon.append(epsilon)

            if BATCH_SIZE < len(memory):
                loss_scalar = optimize_model()

                ep_losses.append(loss_scalar)

                target_net_state_dict = target_net.state_dict()
                policy_net_state_dict = policy_net.state_dict()

                # Soft update of the target network's weights
                # θ′ ← τ θ + (1 −τ )θ′
                for key in policy_net_state_dict:
                    target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
                target_net.load_state_dict(target_net_state_dict)
            
            if next_state == None:
                break

            # memory.push(obs, action, next_obs, reward)
            
            prev_action = action
            prev_obs = obs

            state = next_state
            obs = next_obs
            

        if BATCH_SIZE < len(memory):
            loss_mean = np.mean(ep_losses)
            rewards_mean = np.mean(ep_rewards)
            qvalues_mean = np.mean(ep_qvalues)
            epsilons_mean = np.mean(ep_epsilon)
            # reward_log = ep_rewards[-1] if ep_rewards != [] else 0
            writer.add_hparams(
                h_params,
                {
                    'i_episode': t,
                    'Memory_len': len(memory),
                    'Loss': loss_mean, # loss_scalar,
                    'Reward': ep_rewards[-1],
                    'Qvalue': qvalues_mean,
                    'Epsilon': epsilons_mean
                }, name='.', global_step=global_step,
            )
        
        global_step += 1 
        writer.flush()

        

In [None]:
import os

save_model_dir = './duel_saved_models'

if not os.path.exists(f'{save_model_dir}'):
    os.mkdir(f'{save_model_dir}')
if not os.path.exists(f'{save_model_dir}/{timestr}'):
    os.mkdir(f'{save_model_dir}/{timestr}')

torch.save(policy_net.state_dict(), f'{save_model_dir}/{timestr}/policy_net')
torch.save(target_net.state_dict(), f'{save_model_dir}/{timestr}/target_net')

In [None]:
"""_MODEL_DATE_NAME = '2024_04_25_12_56_11'

model = DQN(n_observations, n_actions).to(device)
load_dict = torch.load(f'./duel_saved_models/{_MODEL_DATE_NAME}/policy_net')
model.load_state_dict(load_dict)"""

In [None]:
#import matplotlib.pyplot as plt
#from IPython import display

#_, ax = plt.subplots(1, 1)

#img = ax.imshow(env.render())

from src.tictactoe import decode as state_decode

while True:
  # policy_net.eval()
  state = states_dict[0]

  for t in count():
    obs = torch.tensor(get_game_obs(state), dtype=torch.float32, device=device).unsqueeze(0)
    with torch.no_grad():
      action = policy_net(obs).max(1)[1].view(1,1)
    
    next_state = states_dict[state['actions'][action.item()]]
    done = next_state['done']
    
    state = next_state

    for line in state_decode(state['encoded']):
      print(line, flush=True)
    print(state, flush=True)

    if done:
      print(obs, flush=True)
      input()
      break
    
    input()