In [None]:
import numpy as np
import sys
import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import matplotlib.colors as cols
import matplotlib.cm as cm

#### the same as above but with eager execution enabled

In [None]:
import numpy as np
import sys
import tensorflow as tf
tf.enable_eager_execution()
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import matplotlib.colors as cols
import matplotlib.cm as cm

In [None]:
CODE_MSB_VALUE = 1<<3
def codeBoard(board):
    return np.array(list(map(
        lambda L:list(map(lambda l: [int(x) for x in bin((1<<l)+CODE_MSB_VALUE)[3:]], L)),
        board
    )))

In [None]:
def getRot3(board):
    rotations = [board]
    for i in range(3):
        rotations.append(np.rot90(rotations[-1]))
    return rotations

def augment(board):
    flip = np.transpose(board)
    return [codeBoard(v) for v in getRot3(board)+getRot3(flip)]

In [None]:
board = [[0,0,1,2],
         [1,1,2,0],
         [0,0,2,1],
         [1,2,0,0]
        ]

code = np.array(augment(board), dtype=np.float32)
segment_hight = code.shape[1]
segment_width = code.shape[2]
code_depth = code.shape[3]

In [None]:
l = getActionsIterator(board)
list(l)

#### actions iterator

In [None]:
def getVacationsIterator(board):
    return zip(*np.nonzero(np.array(board) == 0))

orig_dirs = [(-1,0),(-1,-1),(0,-1),(1,-1)]
SingleDim = 4
MIN_LINE_SIZE = 3
isInRange = lambda loc: np.all(np.array(loc)>=0) and np.all(np.array(loc)<SingleDim)

def isStraitConnection(board, location, player):
    connection = []
    for shift in orig_dirs:
        line = []
        
        loc = location
        while True:
            loc = (loc[0]+shift[0], loc[1]+shift[1])
            if not isInRange(loc) or board[loc] != player: break
            line.append(loc)
        #print(line)
            
        loc = location
        while True:
            loc = (loc[0]-shift[0], loc[1]-shift[1])
            if not isInRange(loc) or board[loc] != player: break
            line.append(loc)
        #print(line)
        
        if len(line) >= MIN_LINE_SIZE-1:
            return True
    
    return False

def getOneMoveTransition(board, move):
    player, location = move
    is_terminal = isStraitConnection(board, location, player)
    next_board = board.copy()
    next_board[location] = player
    return (is_terminal, next_board)
    

#### test the environment

In [None]:
[(getOneMoveTransition(np.array(board),(2,loc))[0], loc) for loc in getVacationsIterator(board)]

#### test Conv2D

In [None]:
k2D = np.array([[[2, 1, 0],[6, 0, 1],[2, 1, 0],[0,-1, 1]],
               [[2, 1, 3],[9,-7, 1],[2, 1,-9],[5, 0, 0]],
               [[0, 1, 3],[0, 0, 1],[2, 1,-3],[5,-1, 1]],
               [[2, 1,-3],[0,-2, 1],[2, 1, 3],[4, 0,11]]
              ], dtype=np.float32)
kernel2D = tf.reshape(k2D, k2D.shape+(1,), name='kernel2D')

#reshape for 2D convolution
code2D = code.reshape((1,code.shape[0]*code.shape[1],)+code.shape[2:])
print(tf.squeeze(tf.nn.conv2d(code2D, filters=kernel, strides=4, padding='VALID')).numpy())

#### test Conv1D

In [None]:
k1D = np.array([[2, 1, 3],[9, 0, 1],
                [0,-1, 4],[5, 0,-3]], dtype=np.float32)
kernel1D = tf.reshape(k1D, k1D.shape+(1,), name='kernel1D')
#reshape for 1D convolution
code1D = code.reshape((1,code.shape[0]*code.shape[1]*code.shape[2],code.shape[3]))
print(tf.squeeze(tf.nn.conv1d(code1D, filters=kernel1D, stride=4, padding='VALID')).numpy())

#### build the model

In [None]:
class vModel(tf.keras.Model):
    def __init__(self):
        super(vModel, self).__init__()
        self.dense1 = tf.keras.layers.Dense(units=44, name='dense1')
        self.dense2 = tf.keras.layers.Dense(units=1, name='dense2')
    
    def call(self, input):
        x = self.dense1(input)
        x = self.dense2(x)
        return tf.math.reduce_max(x, axis=0)

model = vModel()

#### get the gradient

In [None]:
def grad(position, estimate):
    with tf.GradientTape() as t:
        y = model(position)
        loss = tf.square(y, estimate)
        grad = t.gradient(loss, model.trainable_variables)

### Proof number search
#### the class ProofNumberNode

In [None]:
class ProofNumberNode:
    def __init__(self, is_and, pn=1, dn=1):
        self.is_and = is_and
        self.expanded = False
        self.proof_num = pn
        self.disproof_num = dn
        self.children = []

    def getNumber(self):
        if self.is_and:
            return self.proof_num
        else:
            return self.disproof_num
        
    def isAnd(self):
        return self.is_and
    
    def isExpanded(self):
        return self.expanded
    
    def update(self):
        if len(self.children) == 0: return
        if not self.is_and:
            self.disproof_num = 0
            self.proof_num = INF
            for child in self.children:
                self.disproof_num += child.disproof_num
                self.proof_num = min(child.proof_num, self.proof_num)
        else:
            self.proof_num = 0
            self.disproof_num = INF
            for child in self.children:
                self.proof_num += child.proof_num
                self.disproof_num = min(child.disproof_num, self.disproof_num)

    def expand(self):
        if self.expanded: return
        self.expanded = True
        it = self.getExpandIterator()
        for child in it:
            shortcut = child.proof_num == 0 and not self.is_and or \
               child.disproof_num == 0 and self.is_and
            self.children.append(child)
            if shortcut: break
                
    def __repr__(self):
        return "<Type: "+["OR","AND"][self.is_and]+"; pn:"+str(self.proof_num)+ \
            "; dn:"+str(self.disproof_num)+"; "+["","expd "][self.expanded]+ \
            str(len(self.children))+" ch>"

#### Search implementation

In [None]:
def descendToMPN(node):
    if not node.isExpanded():
        node.expand()
        node.update()
        return True
    
    select = []
    if node.isAnd():
        dn = INF
        for child in node.children:
            if child.disproof_num < dn:
                dn = child.disproof_num
                select = [child]
            elif child.disproof_num == dn:
                select.append(child)
    else:
        pn = INF
        for child in node.children:
            if child.proof_num < pn:
                pn = child.proof_num
                select = [child]
            elif child.proof_num == pn:
                select.append(child)
    expanded = False
    for child in select:
        expanded = descendToMPN(child)
        if expanded: break

    node.update()
    return expanded

def iteratePNSearch(root, max_nodes=100):
    n=0
    while root.count < max_nodes:
        if not descendToMPN(root):
            print("No expansion. Search terminated.")
            break
        if root.proof_num == 0:
            print("prooved")
            break
        elif root.disproof_num == 0:
            print("disprooved")
            break
        else:
            n += 1
            print("Iteration {:3}: nodes count is {:4}".format(n,root.count))

### Test PNS

In [None]:
WIN_NODE = 1
LOOSE_NODE = -1
DRAW_NODE = 2
UNKNOWN_NODE = 0
INF = sys.maxsize

In [None]:
class TestPNS(ProofNumberNode):
    def __init__(self, is_and, min_children=0):
        self.ch_num = np.random.randint(min_children,5)
        if np.random.rand() < 0.15:
            pn = 0 
            dn = INF
        elif np.random.rand() < 0.25:
            pn = INF 
            dn = 0
        elif self.ch_num == 0:
            pn = INF//2
            dn = INF//2
        else:
            pn = 1
            dn = 1
        ProofNumberNode.__init__(self, is_and, pn, dn)
        TestPNS.count += 1
        
    def getExpandIterator(self):
        for i in range(self.ch_num):
            yield TestPNS(not self.is_and)

class rootPNS(TestPNS):
    def __init__(self, is_and):
        TestPNS.__init__(self, is_and, min_children=2)
        self.proof_num = 1
        self.disproof_num = 1


#### Run test search with maximum 100 expanded nodes

In [None]:
TestPNS.count = 0
root = rootPNS(False)
iteratePNSearch(root)
print(root.count)

In [None]:
root.count, root, root.children

In [None]:
root.children[1].children[1].children[0].children[0].children

In [None]:
print(descendToMPN(root.children[0]))
#root.children[2].update()
root.update()

#### Continue test search, if previosely stopped by exceeding maximum number of nodes

In [None]:
iteratePNSearch(root, max_nodes=1000)

#### PNS stuff

In [None]:
def rootExpandIterator(inst):
    for i in range(np.random.randint(2,5)):
        yield TestPNS(not inst.is_and)

In [None]:
'''                    
    def expand(self):
        if self.expanded: return
        self.expanded = True
        it = self.getChildrenIterator()
        shortcut = False
        for child in it:
            e = child.evaluate()
            if e == WIN_NODE:
                new_node = newNode(child_is_and, 0, INF)
                shortcut = not node.isAnd():
            elif e == LOOSE_NODE:
                new_node = newNode(child_is_and, INF, 0)
                shortcut = node.isAnd()
            else new_node = newNode(child_is_and, 1, 1)
            node.children.append(new_node)
            if shortcut: break
'''
    
def expandNode(node, getChildrenIterator, evaluateNode, newNode):
    it = getChildrenIterator(node)
        
    shortcut = False
    for child in it:
        e = evaluateNode(child)
        if e == WIN_NODE:
            new_node = newNode(child_is_and, 0, INF)
            shortcut = not node.isAnd():
        elif e == LOOSE_NODE:
            new_node = newNode(child_is_and, INF, 0)
            shortcut = node.isAnd()
        else new_node = newNode(child_is_and, 1, 1)
        node.children.append(new_node)
        if shortcut: break
            
def updateNode(node):
    for child in node:
        value += child.getValue()
        if child.getNumber() < num:
            num = child.getNumber()
    node.setNumber(value)
    node.setValue(num)
        
def selectMostProvingNode(node):
    while (node.isExpanded()):
        int value = INF
        for child in node:
            if (child.getNumber() < value):
                best = child;
                value = child.getNumber()
        node = best;

    return node;

def updateProofNumbers(node):
    value = 0
    num = INF
    for child in node:
        value += child.getValue()
        if child.getNumber() < num:
            num = child.getNumber()
    node.setNumber(value)
    node.setValue(num)
    
def updateNumbers(node):
    all_number = 0
    min_number = INF
    for child in node:
        all_number += child.getMinNumber()
        if child.getAllNumber() < min_number:
            min_number = child.getAllNumber()
    node.setAllNumber(all_number)
    node.setMinNumber(min_number)

def expandNode2(node, getChildrenIterator, evaluateNode, newNode):
    it = getChildrenIterator(node)
    child_is_and = not node.isAnd()
    if child_is_and:
        all_number = 0
        min_number = INF
    else:
        all_number = INF
        min_number = 0
        
    for child in it:
        e = evaluateNode(child)
        if e == WIN_NODE:
            new_node = newNode(child_is_and, 0, INF)
            if not node.isAnd():
                # shortcut the search
                break
        elif e == LOOSE_NODE:
            new_node = newNode(child_is_and, INF, 0)
            if node.isAnd():
                # shortcut the search
                break
        else new_node = newNode(child_is_and, 1, 1)
        
        setProofAndDisproofNumbers( c )
        if n.type == AND:
            if ( c.disproof == 0 ) break;
    } else {  /* OR node */
      if ( c.proof == 0 ) break;
    }
  }
  n.expanded = true;
}

def doPNS(root):
    mpn = selectMostProvingNode(root)
    

### TicTacToe PNS implementation

In [None]:
class TicTacToe(ProofNumberNode):
    def __init__(self, position):
        
        self.
        ProofNumberNode.__init__(self, )
        
def 