<a href="https://colab.research.google.com/github/lmcanavals/acomplex/blob/main/0905_ds_hex.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title
import graphviz as gv

In [None]:
#@title
def drawDS(ds):
  graph = gv.Digraph("DisjointSet")
  graph.graph_attr['rankdir'] = "BT"
  for e, p in enumerate(ds):
    graph.node(str(e))
    if p >= 0:
      graph.edge(str(e), str(p))
  return graph

In [14]:
import numpy as np

EMPTY = 0
BLACK = 1
WHITE = 2

class Hex:
  def __init__(self, n):
    self.n = n
    self.board = np.zeros((n, n), dtype=int)
    self.blackds = [-1]*(n**2 + 2)
    for i in range(n):
      self.__union(self.blackds, i, n**2)
      self.__union(self.blackds, (n - 1) * n + i, n**2 + 1)
    self.whiteds = [-1]*(n**2 + 2)
    for i in range(n):
      self.__union(self.whiteds, i * n, n**2)
      self.__union(self.whiteds, i * n + n - 1, n**2 + 1)

  def __find(self, ds, e):
    if ds[e] < 0:
      return e
    else:
      ancestor = self.__find(ds, ds[e])
      ds[e] = ancestor
      return ancestor
      
  def __union(self, ds, a, b):
    a = self.__find(ds,  a)
    b = self.__find(ds, b)

    if a == b: return

    if ds[a] <= ds[b]:
      ds[a] += ds[b]
      ds[b] = a
    else:
      ds[b] += ds[a]
      ds[a] = b

  def move(self, i, j, piece):
    if self.board[i, j] != EMPTY: return -1
    self.board[i, j] = piece
    ds = self.whiteds if piece == WHITE else self.blackds
    I = i * self.n + j
    adj = [(i-1, j), (i-1, j+1), (i, j-1), (i, j+1), (i+1, j-1), (i+1, j)]
    for ii, jj in adj:
      if ii < 0 or ii >= self.n or jj < 0 or jj >= self.n: continue
      if self.board[ii, jj] == piece:
        II = ii * self.n + jj
        self.__union(ds, I, II)
    N = self.n
    N = N*N
    return piece if self.__find(ds, N) == self.__find(ds, N + 1) else EMPTY

In [35]:
hex = Hex(5)
print(hex.board)

[[0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]]


In [36]:
hex.move(0, 2, WHITE)

0

In [37]:
hex.move(1, 0, WHITE)
hex.move(0, 1, WHITE)
hex.move(0, 3, WHITE)
hex.move(0, 4, WHITE)

2

In [38]:
print(hex.board)

[[0 2 2 2 2]
 [2 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]]
