In [43]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import math
import subprocess as sub
import random

from tqdm.notebook import trange

In [44]:
def get_minimum_moves(state):
  command = ["./solver.exe", "get-min", str(len(state)), *[str(x) for x in state]]
  run_process = sub.run(command, stdout=sub.PIPE)
  output = int(run_process.stdout.decode())
  return output

def get_index_of_state(state):
  command = ["./solver.exe", "get-index", str(len(state)), *[str(x) for x in state]]
  run_process = sub.run(command, stdout=sub.PIPE)
  output = int(run_process.stdout.decode())
  return output

def get_state_from_index(index, n):
  command = ["./solver.exe", "get-permutation", str(n), str(index)]
  run_process = sub.run(command, stdout=sub.PIPE)
  output = run_process.stdout.decode().split()
  output = list(map(int, output))
  return output

def get_random_state(n):
  index = np.random.randint(0, math.factorial(n))
  state = get_state_from_index(index, n)
  return state


class ShortSwap:
  def __init__(self, state_size):
    self.state_size = state_size
    self.action_size = 2*state_size-3 # 2 moves per position except for last 2 positions

  def get_position_and_move(self, action):
    position = action%(self.state_size-1)
    move = action//(self.state_size-1)
    return position, move+1

  def get_next_state(self, state, action):
    position, move = self.get_position_and_move(action)
    state = state.copy()
    state[position], state[position+move] = state[position+move], state[position]
    return state
  
  def check_win(self, state):
    return get_index_of_state(state) == 0
  
  def get_value_and_terminated(self, state, moves, visited):
    if self.check_win(state):
      return 1+1/(moves+1), True
    
    if sum(self.get_valid_moves(state, visited)) == 0:
      return -1, True
    return 0, False
  
  def get_valid_moves(self, state, visited):
    state = state.copy()
    valid_moves = np.ones(self.action_size)
    for position in range(self.state_size-1):
      for move in range(1,2+1):
        action = (move-1)*(self.state_size-1) + position
        if action >= self.action_size:
          continue

        state[position], state[position+move] = state[position+move], state[position]
        if get_index_of_state(state) in visited:
          valid_moves[action] = 0
        state[position], state[position+move] = state[position+move], state[position]
    return valid_moves

  def get_encoded_state(self, state):
    encoded_state = np.array(state).astype(np.float32).reshape(1, -1)
    return encoded_state

  def get_initial_state(self):
    return get_random_state(self.state_size)

n = 5
ss = ShortSwap(n)
get_random_state(n)

[3, 1, 5, 4, 2]

In [45]:
class ResNet(nn.Module):
  def __init__(self, game, num_resBlocks, num_hidden, device):
    super().__init__()

    self.device = device
    self.startBlock = nn.Sequential(
      nn.Conv1d(1, num_hidden, kernel_size=3, padding=1),
      nn.BatchNorm1d(num_hidden),
      nn.ReLU()
    )

    self.backBone = nn.ModuleList(
      [ResBlock(num_hidden) for i in range(num_resBlocks)]
    )

    self.policyHead = nn.Sequential(
      nn.Conv1d(num_hidden, 32, kernel_size=3, padding=1),
      nn.BatchNorm1d(32),
      nn.ReLU(),
      nn.Flatten(),
      nn.Linear(32*game.state_size, game.action_size)
    )    

    self.valueHead = nn.Sequential(
      nn.Conv1d(num_hidden, 3, kernel_size=3, padding=1),
      nn.BatchNorm1d(3),
      nn.ReLU(),
      nn.Flatten(),
      nn.Linear(3*game.state_size, 1),
      nn.Tanh()
    )

    self.to(device)
  
  def forward(self, x):
    x = self.startBlock(x)
    for resBlock in self.backBone:
      x = resBlock(x)
    policy = self.policyHead(x)
    value = self.valueHead(x)

    return policy, value

class ResBlock(nn.Module):
  def __init__(self, num_hidden):
    super().__init__()
    self.conv1 = nn.Conv1d(num_hidden, num_hidden, kernel_size=3, padding=1)
    self.bn1 = nn.BatchNorm1d(num_hidden)
    self.conv2 = nn.Conv1d(num_hidden, num_hidden, kernel_size=3, padding=1)
    self.bn2 = nn.BatchNorm1d(num_hidden)

  def forward(self, x):
    residual = x
    x = F.relu(self.bn1(self.conv1(x)))
    x = F.relu(self.bn2(self.conv2(x)) + residual)
    return x

In [46]:
class Node:
  def __init__(self, game, args, state, parent=None, action_taken=None, prior=0, visit_count=0):
    self.game = game
    self.args = args
    self.state = state
    self.parent = parent
    self.action_taken = action_taken
    self.prior = prior

    if self.parent is None:
      self.moves = 0
    else:
      self.moves = self.parent.moves+1
    
    self.children = []

    self.visit_count = visit_count
    self.value_sum = 0

  def is_fully_expanded(self):
    return len(self.children) > 0
  
  def select(self):
    best_child = None
    best_ucb = -np.inf
    for child in self.children:
      ucb = self.get_ucb(child)
      if ucb > best_ucb:
        best_ucb = ucb
        best_child = child

    return best_child
  
  def get_ucb(self, child):
    if child.visit_count == 0:
      q_value = 0
    else:
      q_value = 1 - (child.value_sum / child.visit_count + 1)/2
    return q_value + self.args['C'] * child.prior * math.sqrt(self.visit_count) / (1+child.visit_count)
  
  def expand(self, policy):
    for action, prob in enumerate(policy):
      if prob > 0:
        child_state = self.state.copy()
        child_state = self.game.get_next_state(child_state, action)
        
        child = Node(self.game, self.args, child_state, self, action, prob)
        self.children.append(child)
    
  def backpropagate(self, value):
    self.value_sum += value
    self.visit_count += 1

    if self.parent is not None:
      self.parent.backpropagate(value)

class MCTS:
  def __init__(self, game, args, model):
    self.game = game
    self.args = args
    self.model = model

  @torch.no_grad()
  def search(self, state, visited):
    root = Node(self.game, self.args, state, visit_count=1)

    policy, _ = self.model(
      torch.tensor(self.game.get_encoded_state(state), device=self.model.device).unsqueeze(0)
    )
    policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
    
    policy = (1-self.args["dirichlet_epsilon"])*policy + self.args["dirichlet_epsilon"]\
      *np.random.dirichlet([self.args["dirichlet_alpha"]*self.game.action_size])
    valid_moves = self.game.get_valid_moves(state, visited)
    policy *= valid_moves
    policy /= np.sum(policy)
    root.expand(policy)

    for search in range(self.args["num_searches"]):
      node = root
      while node.is_fully_expanded():
        node = node.select()
        
      value, is_terminal = self.game.get_value_and_terminated(node.state, node.moves, visited)

      if not is_terminal:
        policy, value = self.model(
          torch.tensor(self.game.get_encoded_state(node.state), device=self.model.device).unsqueeze(0)
        )
        policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
        valid_moves = self.game.get_valid_moves(node.state, visited)
        policy *= valid_moves
        policy /= np.sum(policy)

        value = value.item()

        node.expand(policy)

      node.backpropagate(value)

    action_probs = np.zeros(self.game.action_size)
    for child in root.children:
      action_probs[child.action_taken] = child.visit_count
    action_probs /= np.sum(action_probs)

    return action_probs

In [47]:
class AlphaZero:
  def __init__(self, model, optimizer, game, args):
    self.model = model
    self.optimizer = optimizer
    self.game = game
    self.args = args
    self.mcts = MCTS(game, args, model)

  def selfPlay(self):
    memory = []
    visited = set()
    state = self.game.get_initial_state()
    visited.add(get_index_of_state(state))
    moves = 0
    while True:
      action_probs = self.mcts.search(state, visited)

      memory.append((state, action_probs))

      temperature_action_probs = action_probs ** (1 / self.args["temperature"])
      temperature_action_probs /= np.sum(temperature_action_probs)
      action = np.random.choice(self.game.action_size, p=temperature_action_probs)
      state = self.game.get_next_state(state, action)
      visited.add(get_index_of_state(state))

      value, is_terminal = self.game.get_value_and_terminated(state, moves, visited)
      print(state)
      moves += 1
      if is_terminal:
        returnMemory = []
        for hist_neutral_state, hist_action_probs in memory:
          hist_outcome = value
          returnMemory.append((
            self.game.get_encoded_state(hist_neutral_state),
            hist_action_probs,
            hist_outcome
          ))
          return returnMemory

  def train(self, memory):
    random.shuffle(memory)
    for batchIdx in range(0, len(memory), self.args["batch_size"]):
      sample = memory[batchIdx:min(len(memory), batchIdx + self.args["batch_size"])]
      state, policy_targets, value_targets = zip(*sample)
      state = torch.tensor(np.array(state), dtype=torch.float32, device=self.model.device)
      policy_targets = torch.tensor(np.array(policy_targets), dtype=torch.float32, device=self.model.device)
      value_targets = torch.tensor(np.array(value_targets).reshape(-1,1), dtype=torch.float32, device=self.model.device)

      out_policy, out_value = self.model(state)

      policy_loss = F.cross_entropy(out_policy, policy_targets)
      value_loss = F.mse_loss(out_value, value_targets)

      loss = policy_loss + value_loss

      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()

  def learn(self):
    for iteration in range(self.args["num_iterations"]):
      memory = []

      self.model.eval()
      for selfPlay_iteration in range(self.args["num_selfPlay_iterations"]):
        print(f"Playing game {selfPlay_iteration}...")
        memory += self.selfPlay()

      self.model.train()
      for epoch in range(self.args["num_epochs"]):
        self.train(memory)

      torch.save(self.model.state_dict(), f"model_{iteration}.pt")
      torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}.pt")

In [48]:
N = 4
shortswap = ShortSwap(N)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet(shortswap, 4, 64, device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
args = {
  "C": 2,
  "num_searches": 10,
  "num_iterations": 3,
  "num_selfPlay_iterations": 1,
  "num_epochs": 4,
  "temperature": 1.25,
  "dirichlet_epsilon": 0.25,
  "dirichlet_alpha": 0.3,
  "batch_size": 32
}

alphaZero = AlphaZero(model, optimizer, shortswap, args)
alphaZero.learn()

Playing game 0...
[1, 2, 3, 4]
Playing game 0...
[3, 1, 2, 4]
[3, 2, 1, 4]
[3, 4, 1, 2]
[1, 4, 3, 2]
[4, 1, 3, 2]
[4, 3, 1, 2]
[1, 3, 4, 2]
[1, 2, 4, 3]
[1, 4, 2, 3]
[2, 4, 1, 3]
[2, 3, 1, 4]
[2, 3, 4, 1]
[2, 4, 3, 1]
[2, 1, 3, 4]
[1, 2, 3, 4]
Playing game 0...
[2, 3, 1, 4]
[2, 3, 4, 1]
[2, 4, 3, 1]
[3, 4, 2, 1]
[3, 4, 1, 2]
[3, 2, 1, 4]
[3, 2, 4, 1]
[4, 2, 3, 1]
[4, 2, 1, 3]
[2, 4, 1, 3]
[1, 4, 2, 3]
[1, 4, 3, 2]
[4, 1, 3, 2]
[4, 3, 1, 2]
[4, 3, 2, 1]
[4, 1, 2, 3]
[2, 1, 4, 3]
[2, 1, 3, 4]
[3, 1, 2, 4]
[3, 1, 4, 2]
[1, 3, 4, 2]
[1, 2, 4, 3]
[1, 2, 3, 4]


In [53]:
def play_game(state):
  mcts = MCTS(shortswap, args, model)
  visited = set()
  visited.add(get_index_of_state(state))
  moves = 0
  value, is_terminal = shortswap.get_value_and_terminated(state, moves, visited)
    print(state)
  
  while True:
    action_probs = mcts.search(state, visited)

    temperature_action_probs = action_probs ** (1 / args["temperature"])
    temperature_action_probs /= np.sum(temperature_action_probs)
    action = np.random.choice(shortswap.action_size, p=temperature_action_probs)
    state = shortswap.get_next_state(state, action)
    visited.add(get_index_of_state(state))

    value, is_terminal = shortswap.get_value_and_terminated(state, moves, visited)
    print(state)
    moves += 1
    if is_terminal:
      return value, moves

state=[1,3,4,2]
play_game([1,2,3,4])
get_minimum_moves

[3, 2, 1, 4]
[2, 3, 1, 4]
[2, 3, 4, 1]
[2, 1, 4, 3]
[4, 1, 2, 3]
[4, 2, 1, 3]
[4, 2, 3, 1]
[4, 3, 2, 1]
[3, 4, 2, 1]
[2, 4, 3, 1]
[2, 1, 3, 4]
[3, 1, 2, 4]
[1, 3, 2, 4]
[1, 4, 2, 3]
[2, 4, 1, 3]


(-1, 15)