In [1]:
# The MIT License
# 
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation 
# files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, 
# modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the 
# Software is furnished to do so, subject to the following conditions:
# 
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE 
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# 
# ============================================================================
"""Fourinarow reinforcement learning environment."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import dm_env
from dm_env import specs
import numpy as np

In [2]:
def makemove(bp,wp,color):
    """
    stand-in for the opponent, currently the opponent moves on a random unoccupied square on the board
    in the future, this function will wrap the C++ agent code
    """
    unoccupied_squares = np.nonzero(np.logical_and([bp==0,wp==0]))
    return np.random.choice(unoccupied_squares)

In [48]:
class Fourinarow(dm_env.Environment):
    """
    DM environment for playing fourinarow against a constant opponent
    """

    def __init__(self, seed=1):
        self._rng = np.random.RandomState(seed)
        self._fourinarows = np.array( [[ 0,  9, 18, 27], #hard-coding all ways in which four-in-a-row can appear
                                       [ 1, 10, 19, 28],
                                       [ 2, 11, 20, 29],
                                       [ 3, 12, 21, 30],
                                       [ 4, 13, 22, 31],
                                       [ 5, 14, 23, 32],
                                       [ 6, 15, 24, 33],
                                       [ 7, 16, 25, 34],
                                       [ 8, 17, 26, 35],
                                       [ 0, 10, 20, 30],
                                       [ 1, 11, 21, 31],
                                       [ 2, 12, 22, 32],
                                       [ 3, 13, 23, 33],
                                       [ 4, 14, 24, 34],
                                       [ 5, 15, 25, 35],
                                       [ 3, 11, 19, 27],
                                       [ 4, 12, 20, 28],
                                       [ 5, 13, 21, 29],
                                       [ 6, 14, 22, 30],
                                       [ 7, 15, 23, 31],
                                       [ 8, 16, 24, 32],
                                       [ 0,  1,  2,  3],
                                       [ 1,  2,  3,  4],
                                       [ 2,  3,  4,  5],
                                       [ 3,  4,  5,  6],
                                       [ 4,  5,  6,  7],
                                       [ 5,  6,  7,  8],
                                       [ 9, 10, 11, 12],
                                       [10, 11, 12, 13],
                                       [11, 12, 13, 14],
                                       [12, 13, 14, 15],
                                       [13, 14, 15, 16],
                                       [14, 15, 16, 17],
                                       [18, 19, 20, 21],
                                       [19, 20, 21, 22],
                                       [20, 21, 22, 23],
                                       [21, 22, 23, 24],
                                       [22, 23, 24, 25],
                                       [23, 24, 25, 26],
                                       [27, 28, 29, 30],
                                       [28, 29, 30, 31],
                                       [29, 30, 31, 32],
                                       [30, 31, 32, 33],
                                       [31, 32, 33, 34],
                                       [32, 33, 34, 35]],dtype=int)
        self._reset_next_step = True
        self.reset()

    def check_fourinarow(self,pieces):
        return np.any(np.sum(pieces[self._fourinarows],axis=1)==4)
        
    def check_draw(self):
        return np.sum(self._bp) + np.sum(self._wp)==36
        
    def reset(self):
        """Returns the first `TimeStep` of a new episode."""
        self._reset_next_step = False
        self._bp = np.zeros(36, dtype=np.float32) #Black pieces
        self._wp = np.zeros(36, dtype=np.float32) #White pieces
        return dm_env.restart(self._observation())

    def step(self, action):
        """Updates the environment according to the action."""
        if self._reset_next_step:
            return self.reset()

        agent_move = action

        if self._bp[agent_move] == 1: #illegal move
            terminate = True
            reward = -1 #making an illegal move incurs a reward of -1
        else:
            terminate = False
        
        if not terminate:    
            # Add a black piece to the board
            self._bp[agent_move] = 1

            #check if the agent has ended the game
            if self.check_fourinarow(self._bp): #check if the agent made a winning move
                reward = 1 #winning incurs a reward of 1 
                terminate = True
            if self.check_draw():
                reward = 0 # drawing is neutral
                terminate = True

        if not terminate:
            opponent_move = makemove(self._bp,self._wp,"White")
            self._wp[opponent_move] = 1
            if self.check_fourinarow(self._wp): #check if the opponent made a winning move
                reward = -1 #losing incurs a penalty of -1
                terminate = True
            if self.check_draw():
                reward = 0 # drawing is neutral
                terminate = True
                
        if not terminate:
            return dm_env.transition(reward=0., observation=self._observation())

        else:            
            self._reset_next_step = True
            return dm_env.termination(reward=reward, observation=self._observation())
            
    def observation_spec(self):
        """Returns the observation spec."""
        return specs.BoundedArray(shape=(12,9), dtype=int, name="observation", minimum=0, maximum=1)

    def action_spec(self):
        """Returns the action spec."""
        return specs.DiscreteArray(dtype=int, num_values=36, name="action")

    def _observation(self):
        """
        An observation consists of the placement of the black pieces, the placement of the white pieces and 
        a mask for unoccupied squares
        """
        return np.vstack([self._bp.copy(),
                          self._wp.copy(),
                          np.logical_and(self._bp==0,self._wp==0)])

IndentationError: expected an indented block (<ipython-input-48-f87a5d06fa54>, line 80)

In [49]:
env = Fourinarow()

In [50]:
env.observation_spec()

BoundedArray(shape=(3, 36), dtype=dtype('int32'), name='observation', minimum=0, maximum=1)

In [47]:
env._observation()

array([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)

In [51]:
from ctypes import *
cdll.LoadLibrary('C:/')