In [1]:
! pip install ray
import ray
ray.init()
ray.available_resources()

Collecting ray
  Downloading ray-2.10.0-cp310-cp310-manylinux2014_x86_64.whl.metadata (13 kB)
Collecting msgpack<2.0.0,>=1.0.0 (from ray)
  Downloading msgpack-1.0.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting protobuf!=3.19.5,>=3.15.3 (from ray)
  Downloading protobuf-5.26.1-cp37-abi3-manylinux2014_x86_64.whl.metadata (592 bytes)
Collecting aiosignal (from ray)
  Downloading aiosignal-1.3.1-py3-none-any.whl.metadata (4.0 kB)
Collecting frozenlist (from ray)
  Downloading frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading ray-2.10.0-cp310-cp310-manylinux2014_x86_64.whl (65.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.1/65.1 MB[0m [31m50.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading msgpack-1.0.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (385 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

2024-04-10 07:13:41,488	INFO worker.py:1752 -- Started a local Ray instance.


{'node:__internal_head__': 1.0,
 'accelerator_type:TITAN': 1.0,
 'CPU': 15.0,
 'object_store_memory': 10000000000.0,
 'memory': 29781571789.0,
 'GPU': 1.0,
 'node:172.17.0.6': 1.0}

In [2]:
import numpy as np
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)

from tqdm.notebook import trange
##from tqdm import tqdm
import random

class ConnectFour:
    def __init__(self):
        self.row_count = 6
        self.column_count = 7
        self.action_size = self.column_count
        self.in_a_row = 4

    def get_initial_state(self):
        return np.zeros((self.row_count, self.column_count))

    def get_next_state(self, state, action, player):
        ### the top row is index zero,
        ### so the place of action is the highest row of the action column
        row = np.max(np.where(state[:, action] == 0))
        state[row, action] = player
        return state

    def get_valid_moves(self, state):
        ### return all column index where row 0 is empty
        return (state[0] == 0).astype(np.uint8)

    def check_win(self, state, action):
        if action == None:
            return False

        row = np.min(np.where(state[:, action] != 0))
        column = action
        player = state[row][column]

        def count(offset_row, offset_column):
            for i in range(1, self.in_a_row):
                r = row + offset_row * i
                c = action + offset_column * i
                if (
                    r < 0
                    or r >= self.row_count
                    or c < 0
                    or c >= self.column_count
                    or state[r][c] != player
                ):
                    return i - 1
            return self.in_a_row - 1

        return (
            count(1, 0) >= self.in_a_row - 1 # vertical
            or (count(0, 1) + count(0, -1)) >= self.in_a_row - 1 # horizontal
            or (count(1, 1) + count(-1, -1)) >= self.in_a_row - 1 # top left diagonal
            or (count(1, -1) + count(-1, 1)) >= self.in_a_row - 1 # top right diagonal
        )

    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False

    def get_opponent(self, player):
        return -player

    def get_opponent_value(self, value):
        return -value

    def change_perspective(self, state, player):
        return state * player

    def get_encoded_state(self, state):
        encoded_state = np.stack(
            (state == -1, state == 0, state == 1)
        ).astype(np.float32)

        return encoded_state

class ResNet(nn.Module):
    def __init__(self, game, num_resBlocks, num_hidden, device):
        super().__init__()
        self.device = device
        self.startBlock = nn.Sequential(
            nn.Conv2d(3, num_hidden, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_hidden),
            nn.ReLU()
        )

        self.backBone = nn.ModuleList(
            [ResBlock(num_hidden) for i in range(num_resBlocks)]
        )

        self.policyHead = nn.Sequential(
            nn.Conv2d(num_hidden, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * game.row_count * game.column_count, game.action_size)
        )

        self.valueHead = nn.Sequential(
            nn.Conv2d(num_hidden, 3, kernel_size=3, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * game.row_count * game.column_count, 1),
            nn.Tanh()
        )
        self.to(device)

    def forward(self, x):
        x = self.startBlock(x)
        for resBlock in self.backBone:
            x = resBlock(x)
        policy = self.policyHead(x)
        value = self.valueHead(x)
        return policy, value

class ResBlock(nn.Module):
    def __init__(self, num_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_hidden)
        self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_hidden)

    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        x = F.relu(x)
        return x

class Node:
    def __init__(self, game, args, state, parent=None, action_taken=None, prior=0, visit_count=0):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken

        self.children = []
        self.prior = prior

        self.visit_count = visit_count
        self.value_sum = 0

    def is_fully_expanded(self):
        return len(self.children) > 0

    def select(self):
        best_child = None
        best_ucb = -np.inf

        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb

        return best_child

    def get_ucb(self, child):
        if child.visit_count == 0:
            q_value = 0
        else:
            q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
        return q_value + self.args['C'] * (math.sqrt(self.visit_count) / (child.visit_count + 1)) * child.prior

    def expand(self, policy):
        for action, prob in enumerate(policy):
            if prob > 0:
                child_state = self.state.copy()
                child_state = self.game.get_next_state(child_state, action, 1)
                child_state = self.game.change_perspective(child_state, player=-1)

                child = Node(self.game, self.args, child_state, self, action, prob)
                self.children.append(child)

        return child


    def backpropagate(self, value):
        self.value_sum += value
        self.visit_count += 1

        value = self.game.get_opponent_value(value)
        if self.parent is not None:
            self.parent.backpropagate(value)

class MCTS:
    def __init__(self, game, args, model):
        self.game = game
        self.args = args
        self.model = model

    @torch.no_grad()
    def search(self, state):
        root = Node(self.game, self.args, state, visit_count=1)

        policy, _ = self.model(
            torch.tensor(self.game.get_encoded_state(state), device=self.model.device).unsqueeze(0)
        )
        policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
        policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
            * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size)

        valid_moves = self.game.get_valid_moves(state)
        policy *= valid_moves
        policy /= np.sum(policy)
        root.expand(policy)

        for search in range(self.args['num_searches']):
            node = root

            while node.is_fully_expanded():
                node = node.select()

            value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
            value = self.game.get_opponent_value(value)

            if not is_terminal:
                policy, value = self.model(
                    torch.tensor(self.game.get_encoded_state(node.state), device=self.model.device).unsqueeze(0)
                )
                policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
                valid_moves = self.game.get_valid_moves(node.state)
                policy *= valid_moves
                policy /= np.sum(policy)

                value = value.item()

                node.expand(policy)

            node.backpropagate(value)


        action_probs = np.zeros(self.game.action_size)
        for child in root.children:
            action_probs[child.action_taken] = child.visit_count
        action_probs /= np.sum(action_probs)
        return action_probs







class AlphaZero:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTS(game, args, model)




    def selfPlay(self):
      memory = []
      player = 1
      state = self.game.get_initial_state()

      while True:
          neutral_state = self.game.change_perspective(state, player)
          action_probs = self.mcts.search(neutral_state)

          memory.append((neutral_state, action_probs, player))


          temperature_action_probs = action_probs ** (1 / self.args['temperature'])
          temperature_action_probs /= np.sum(temperature_action_probs)
          action = np.random.choice(self.game.action_size, p=temperature_action_probs)

          state = self.game.get_next_state(state, action, player)

          value, is_terminal = self.game.get_value_and_terminated(state, action)

          if is_terminal:
              returnMemory = []
              for hist_neutral_state, hist_action_probs, hist_player in memory:
                  hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                  returnMemory.append((
                      self.game.get_encoded_state(hist_neutral_state),
                      hist_action_probs,
                      hist_outcome
                  ))
              return returnMemory

          player = self.game.get_opponent(player)




    def train(self, memory):
        random.shuffle(memory)
        for batchIdx in range(0, len(memory), self.args['batch_size']):
            sample = memory[batchIdx:min(len(memory) - 1, batchIdx + self.args['batch_size'])] # Change to memory[batchIdx:batchIdx+self.args['batch_size']] in case of an error
            state, policy_targets, value_targets = zip(*sample)

            state, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)

            state = torch.tensor(state, dtype=torch.float32, device=self.model.device)
            policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=self.model.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)

            out_policy, out_value = self.model(state)

            policy_loss = F.cross_entropy(out_policy, policy_targets)
            value_loss = F.mse_loss(out_value, value_targets)
            loss = policy_loss + value_loss

            self.optimizer.zero_grad() # change to self.optimizer
            loss.backward()
            self.optimizer.step() # change to self.optimizer

    def learn(self):
        for iteration in range(self.args['num_iterations']):
            memory = []

            self.model.eval()
            for selfPlay_iteration in trange(self.args['num_selfPlay_iterations']):

                memory += self.selfPlay()

            self.model.train()
            for epoch in trange(self.args['num_epochs']):
                self.train(memory)

            torch.save(self.model.state_dict(), f"model_{iteration}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}.pt")





def main():
    game = ConnectFour()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    player = 1

    args = {
        'C': 2,
        'num_searches': 100,
        'dirichlet_epsilon': 0.,
        'dirichlet_alpha': 0.3
    }


    model = ResNet(game, 9, 128, device=device)
    ### model.load_state_dict(torch.load('model_2.pt', map_location=device))
    model.eval()
    mcts = MCTS(game, args, model)

    state = game.get_initial_state()


    while True:
        print(state)

        if player == 1:
            valid_moves = game.get_valid_moves(state)
            print("valid_moves", [i for i in range(game.action_size) if valid_moves[i] == 1])
            action = int(input(f"{player}:"))

            if valid_moves[action] == 0:
                print("action not valid")
                continue

        else:
            neutral_state = game.change_perspective(state, player)
            mcts_probs = mcts.search(neutral_state)
            action = np.argmax(mcts_probs)

        state = game.get_next_state(state, action, player)

        value, is_terminal = game.get_value_and_terminated(state, action)

        if is_terminal:
            print(state)
            if value == 1:
                print(player, "won")
            else:
                print("draw")
            break

        player = game.get_opponent(player)

def test():

    tictactoe = TicTacToe()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    state = tictactoe.get_initial_state()
    state = tictactoe.get_next_state(state, 2, -1)
    state = tictactoe.get_next_state(state, 4, -1)
    state = tictactoe.get_next_state(state, 6, 1)
    state = tictactoe.get_next_state(state, 8, 1)

    print(state)

    encoded_state = tictactoe.get_encoded_state(state)

    print(encoded_state)

    tensor_state = torch.tensor(encoded_state, device=device).unsqueeze(0)

    model = ResNet(tictactoe, 4, 64, device=device)
    model.load_state_dict(torch.load('model_2.pt', map_location=device))

    model.eval()

    policy, value = model(tensor_state)
    value = value.item()
    policy = torch.softmax(policy, axis=1).squeeze(0).detach().cpu().numpy()

    print(value, policy)


In [3]:




@ray.remote(num_gpus=0.2)
def selfPlay(model, optimizer, game, args):
  memory = []
  player = 1
  state = game.get_initial_state()
  mcts = MCTS(game, args, model)
  while True:
      neutral_state = game.change_perspective(state, player)
      action_probs = mcts.search(neutral_state)

      memory.append((neutral_state, action_probs, player))


      temperature_action_probs = action_probs ** (1 / args['temperature'])
      temperature_action_probs /= np.sum(temperature_action_probs)
      action = np.random.choice(game.action_size, p=temperature_action_probs)

      state = game.get_next_state(state, action, player)

      value, is_terminal = game.get_value_and_terminated(state, action)

      if is_terminal:
          returnMemory = []
          for hist_neutral_state, hist_action_probs, hist_player in memory:
              hist_outcome = value if hist_player == player else game.get_opponent_value(value)
              returnMemory.append((
                  game.get_encoded_state(hist_neutral_state),
                  hist_action_probs,
                  hist_outcome
              ))

          return returnMemory

      player = game.get_opponent(player)




def train(memory, model, optimizer, game, args):
    random.shuffle(memory)
    for batchIdx in range(0, len(memory), args['batch_size']):
        sample = memory[batchIdx:min(len(memory) - 1, batchIdx + args['batch_size'])] # Change to memory[batchIdx:batchIdx+self.args['batch_size']] in case of an error
        state, policy_targets, value_targets = zip(*sample)

        state, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)

        state = torch.tensor(state, dtype=torch.float32, device=model.device)
        policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=model.device)
        value_targets = torch.tensor(value_targets, dtype=torch.float32, device=model.device)

        out_policy, out_value = model(state)

        policy_loss = F.cross_entropy(out_policy, policy_targets)
        value_loss = F.mse_loss(out_value, value_targets)
        loss = policy_loss + value_loss

        optimizer.zero_grad() # change to self.optimizer
        loss.backward()
        optimizer.step() # change to self.optimizer

def learn(model, optimizer, game, args):
    for iteration in range(args['num_iterations']):
        memory = []
        results = []
        model.eval()
        for selfPlay_iteration in trange(args['num_selfPlay_iterations']):
            results.append(selfPlay.remote(model, optimizer, game, args))
        output = ray.get(results)

        for i in range(len(output)):
            memory += output[i]
        ## print(memory);
        ## memory += selfPlay()

        model.train()
        for epoch in trange(args['num_epochs']):
            train(memory, model, optimizer, game, args)

        torch.save(model.state_dict(), f"model_{iteration}.pt")
        torch.save(optimizer.state_dict(), f"optimizer_{iteration}.pt")

In [7]:
    ### Start Training
    import time
    game = ConnectFour()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = ResNet(game, 9, 128, device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

    args = {
        'C': 2,
        'num_searches': 600,
        'num_iterations': 1,
        'num_selfPlay_iterations': 10,
        'num_epochs': 1,
        'batch_size': 128,
        'temperature': 1.25,
        'dirichlet_epsilon': 0.25,
        'dirichlet_alpha': 0.3
    }

    ##alphaZero = AlphaZero.remote(model, optimizer, game, args)
    ##alphaZero.learn.remote()
    start = time.time()
    learn(model, optimizer, game, args)
    runtime = time.time()-start
    print(runtime)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

113.09848499298096


In [None]:
    ### test the trained model, not used in Connect4
    ### test()

[[ 0.  0. -1.]
 [ 0. -1.  0.]
 [ 1.  0.  1.]]
[[[0. 0. 1.]
  [0. 1. 0.]
  [0. 0. 0.]]

 [[1. 1. 0.]
  [1. 0. 1.]
  [0. 1. 0.]]

 [[0. 0. 0.]
  [0. 0. 0.]
  [1. 0. 1.]]]
-0.08826334774494171 [0.07518821 0.08451733 0.11294591 0.09931102 0.0952905  0.17337684
 0.12486862 0.14422536 0.09027616]


In [None]:
  ### play the actual game
  main()

[[0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]]
valid_moves [0, 1, 2, 3, 4, 5, 6]


KeyboardInterrupt: Interrupted by user

In [None]:
a = [[0, 1], [1, 1]]
b = [[2, 1], [2, 2]]
d = [a, b]
c = a + b

print(c)

[[0, 1], [1, 1], [2, 1], [2, 2]]
