
#### Copyright 2021 Google LLC. Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
#@title License
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# The Game of Life and Graph Neural Networks

Conway's Game of Life is a game on an infinite grid with cells. It follows the following rules (from [Wikipedia](https://en.wikipedia.org/wiki/Conway%27s_Game_of_Life)):
* Any live cell with two or three live neighbours survives.
* Any dead cell with three live neighbours becomes a live cell.
* All other live cells die in the next generation. Similarly, all other dead cells stay dead.

For the purposes of this tutorial, we will look at the Game of Life on finite-sized grids.

We use a GCN model to predict future grid states, given the current grid state!


# Installation and Imports

In [None]:
# Install Sonnet.
!pip install dm-sonnet==2.0.0

# Remove all TensorBoard packages, and install TensorBoard again.
!pip list --format=freeze | grep tensorboard | xargs pip uninstall -y
!pip install -q tensorboard
%load_ext tensorboard

In [None]:
from enum import IntEnum
from itertools import product
from typing import Iterable, Tuple, List, Callable

from IPython.display import HTML
import matplotlib.pyplot as plt
import numpy as np
import networkx as nx
import tensorflow as tf
import sonnet as snt
import seaborn as sns
import scipy
import random
import tqdm
import os
import datetime as dt
from google.colab import drive

# Experimental Parameters

In [None]:
data_generation_seed = 7
model_seed = 2
num_update_steps = 1
overcompleteness = m = 2
batch_size = 50
num_epochs = 1000
early_stopping_lim = 300
grid_dims = (8, 8)
grid_size = np.prod(grid_dims)
num_grids_per_num_alive_cells = 5
num_grids = num_grids_per_num_alive_cells * (grid_size//2)
model_type = 'gcn-no-norm'

# Directories for Output

In [None]:
# Save to Google Drive.
GDRIVE_DIR = '/content/gdrive'
MYDRIVE_DIR = os.path.join(GDRIVE_DIR, 'MyDrive')
MAIN_DIR = os.path.join(MYDRIVE_DIR, 'game-of-life')

LOG_DIR = os.path.join(MAIN_DIR, 'logs')
CHECKPOINT_DIR = os.path.join(MAIN_DIR, 'checkpoints')
SAVED_DIR = os.path.join(MAIN_DIR, 'saved')

drive.mount(GDRIVE_DIR)

In [None]:
# Uncomment to clear logs.
!rm -rf {LOG_DIR}

# GNN Layer Definitions

* For the GCN, normalization refers to a scaling factor applied to the neighbours embeddings:
    * 'degree-i': The GCN variant with degree-normalization:
$$
    \text{out}_v = W \cdot \sum\limits_{u \in \mathcal{N}(v)}\frac{h_u}{|\mathcal{N}(v)|} + B \cdot h_v + C
$$
    * 'degree-ij': Original GCN variant from Kipf, et al:
$$
    \text{out}_v = W \cdot \sum\limits_{u \in \mathcal{N}(v)}\frac{h_u}{\sqrt{|\mathcal{N}(u)||\mathcal{N}(v)|}} + B \cdot h_v + C
$$
    * 'none': The GCN variant with no normalization: 
$$
    \text{out}_v = W \cdot \sum\limits_{u \in \mathcal{N}(v)} h_u + B \cdot h_v + C
$$


In [None]:
class GCNLayer(snt.Module):
  def __init__(self, input_dims, output_dims, normalization='none', name=None):
    super(GCNLayer, self).__init__(name=name)
    self.w = tf.Variable(tf.random.normal((input_dims, output_dims), dtype=tf.float32), name='w')
    self.b = tf.Variable(tf.random.normal((input_dims, output_dims), dtype=tf.float32), name='b')
    self.c = tf.Variable(tf.random.normal((output_dims,), dtype=tf.float32), name='c')
    self.normalization = normalization

    if self.normalization == 'degree-i':
      degree_matrix_inv = np.diag(1/np.sum(adjacency_matrix, axis=1))
      adjacency = np.matmul(degree_matrix_inv, adjacency_matrix)
    elif self.normalization == 'degree-ij':
      degree_matrix_inv_sqrt = np.sqrt(np.diag(1/np.sum(adjacency_matrix, axis=1)))
      adjacency = np.matmul(np.matmul(degree_matrix_inv_sqrt, adjacency_matrix), degree_matrix_inv_sqrt)
    elif self.normalization == 'none':
      adjacency = adjacency_matrix
    else:
      raise ValueError()

    self.adjacency = tf.cast(adjacency, tf.float32)

  def __call__(self, features):
    features = tf.cast(features, tf.float32)
    return tf.matmul(tf.matmul(self.adjacency, features), self.w) + tf.matmul(features, self.b) + self.c

# Game of Life Logic

In [None]:
class CellState(IntEnum):
  ALIVE = 1
  DEAD = 0

In [None]:
class GameOfLifeGrid:
  def __init__(self,
               grid_dims: Tuple[int, int]):
    if len(grid_dims) != 2:
      raise ValueError('Grid must be a 2-tuple of (height, width).')

    self.grid = np.zeros(grid_dims, dtype=int)

  @staticmethod
  def _set_grid_state(grid: np.ndarray,
                      state: CellState,
                      positions: Iterable[Tuple[int, int]]):

    positions = np.array(positions)
    if len(positions.shape) == 1:
      positions = np.array([positions])

    for position in positions:
      grid[position[0], position[1]] = state


  @staticmethod
  def _set_as_alive(grid: np.ndarray,
                    positions: Iterable[Tuple[int, int]]):
    GameOfLifeGrid._set_grid_state(grid, CellState.ALIVE, positions)


  @staticmethod
  def _set_as_dead(grid: np.ndarray,
                   positions: Iterable[Tuple[int, int]]):
    GameOfLifeGrid._set_grid_state(grid, CellState.DEAD, positions)


  def set_as_alive(self,
                   positions: Iterable[Tuple[int, int]]):
    GameOfLifeGrid._set_as_alive(self.grid, positions)


  def set_as_dead(self,
                  positions: Iterable[Tuple[int, int]]):
    GameOfLifeGrid._set_as_dead(self.grid, positions)


  def initialize_randomly(self,
                          num_alive_cells: int):
    self.grid = np.zeros(self.grid.shape, dtype=int)
    all_positions = list(iter(self))
    np.random.shuffle(all_positions)
    alive_positions = all_positions[:num_alive_cells]
    self.set_as_alive(alive_positions)


  def __iter__(self):
    indices = np.indices(self.grid.shape)
    index_pairs = zip(indices[0].flatten(), indices[1].flatten())
    return iter(index_pairs)


  def get_state(self,
                position: Tuple[int, int]) -> int:
    return self.grid[position[0]][position[1]]


  def get_neighbours(self,
                     position: Tuple[int, int]) -> List[Tuple[int, int]]:
    xmax, ymax = self.grid.shape
    x, y = position
    neighbours_x_pos = [max(0, x - 1), x, min(x + 1, xmax - 1)]
    neighbours_y_pos = [max(0, y - 1), y, min(y + 1, ymax - 1)]
    neighbours = product(neighbours_x_pos, neighbours_y_pos)
    neighbours = set(neighbours)
    neighbours = [neighbour for neighbour in neighbours if np.any(neighbour != position)]
    return neighbours


  def num_live_neighbours(self,
                          position: Tuple[int, int]) -> int:
    neighbours = self.get_neighbours(position)
    neighbour_states = np.array([self.get_state(neighbour) for neighbour in neighbours])
    return np.sum(neighbour_states == CellState.ALIVE.value)


  def update(self, num_steps, inplace=False) -> List['GameOfLifeGrid']:
    curr_grid = self
    grids = [curr_grid]
    for step in range(num_steps):
      curr_grid = curr_grid.update_one_step(inplace=inplace)
      grids.append(curr_grid)
    return grids


  def update_one_step(self, inplace=False) -> 'GameOfLifeGrid':     
    # Make a copy of the current grid.
    new_grid = self.grid.copy()

    # Get alive and dead cell positions.
    live_positions = np.argwhere(self.grid == CellState.ALIVE)
    dead_positions = np.argwhere(self.grid == CellState.DEAD)

    # Apply Game of Life rules here.
    # Live cells can die.
    for live_position in live_positions:
      if self.num_live_neighbours(live_position) not in [2, 3]:
        GameOfLifeGrid._set_as_dead(new_grid, live_position)

    # Dead cells can be resurrected.
    for dead_position in dead_positions:
      if self.num_live_neighbours(dead_position) == 3:
        GameOfLifeGrid._set_as_alive(new_grid, dead_position)
    
    # Update inplace, or make a new grid?
    if inplace:
      self.grid = new_grid
      return self

    else:
      new_grid_instance = type(self)(grid_dims=self.grid.shape)
      new_grid_instance.grid = new_grid
      return new_grid_instance

  # Returns an adjacency matrix and node features from the current grid.
  def get_features(self):
    # Map cell positions to indices.
    index_map = {pos: index for index, pos in enumerate(self)}
    
    # Compute adjacency matrix.
    num_nodes = self.grid.size
    adjacency_matrix = np.zeros((num_nodes, num_nodes))
    for pos, index in index_map.items():
      neighbours_pos = self.get_neighbours(pos)
      neighbours_indexes = [index_map[neigh_pos] for neigh_pos in neighbours_pos] 
      adjacency_matrix[index][neighbours_indexes] = 1
  
    # List node features.
    node_features = np.array([self.get_state(pos) for pos in index_map]).reshape((num_nodes, 1)).astype('double')

    return adjacency_matrix, node_features

  def plot(self, fig, ax, show_ticks=False):
    ax.imshow(self.grid.T, cmap='Greys', origin='lower')

    if show_ticks:
      xticks = np.arange(self.grid.shape[0])
      yticks = np.arange(self.grid.shape[1])
    else:
      xticks = []
      yticks = []

    ax.set_xticks(xticks)
    ax.set_yticks(yticks)

In [None]:
# Generates input and output node features, from a sequence of grids.
def generate_training_data():

  # A fixed size grid.
  grid = GameOfLifeGrid(grid_dims=grid_dims)
  adjacency_matrix = grid.get_features()[0]

  all_inp_features = np.zeros((num_grids, grid_size, 1), dtype=np.float32)
  all_out_features = np.zeros((num_grids, grid_size, 1), dtype=np.float32)
  all_train_masks = np.zeros((num_grids, grid_size, 1), dtype=np.int32)
  all_val_masks = np.zeros((num_grids, grid_size, 1), dtype=np.int32)
  all_test_masks = np.zeros((num_grids, grid_size, 1), dtype=np.int32)

  index = 0
  for _, num_alive_cells in enumerate(range(1, grid_size//2 + 1)):
    for _ in range(num_grids_per_num_alive_cells):
  
      # Choose some cells to keep alive, and update the grid.
      grid.initialize_randomly(num_alive_cells)
      final_grid = grid.update(num_update_steps)[-1]

      # Save node features.
      _, inp_features = grid.get_features()
      _, out_features = final_grid.get_features()

      # Train and test masks.
      mask = np.zeros(grid_size, dtype=np.int32)
      mask[:grid_size//4] = 1
      mask[grid_size//4:grid_size//2] = 2
      np.random.shuffle(mask)

      train_mask = (mask == 1)
      val_mask = (mask == 2)
      test_mask = (mask == 0)

      train_mask = np.expand_dims(train_mask, axis=1)
      val_mask = np.expand_dims(val_mask, axis=1)
      test_mask = np.expand_dims(test_mask, axis=1)

      # Save.
      all_inp_features[index] = inp_features
      all_out_features[index] = out_features
      all_train_masks[index] = train_mask
      all_val_masks[index] = val_mask
      all_test_masks[index] = test_mask
      index += 1

  return adjacency_matrix, all_inp_features, all_out_features, all_train_masks, all_val_masks, all_test_masks

In [None]:
# Seed PRNG.
np.random.seed(data_generation_seed)
tf.random.set_seed(data_generation_seed)

adjacency_matrix, all_inp_features, all_out_features, all_train_masks, all_val_masks, all_test_masks = generate_training_data()

In [None]:
# Convert to tf.Datasets, to use batching.
dataset = tf.data.Dataset.from_tensor_slices((all_inp_features, all_out_features, all_train_masks, all_val_masks))
dataset = dataset.repeat(num_epochs).shuffle(num_epochs * num_grids, reshuffle_each_iteration=True).batch(batch_size)

# Index into LOG_DIR using the current date and time.
current_time = dt.datetime.now().strftime("%Y-%m-%d/%H:%M:%S")
train_log_dir = LOG_DIR + '/' + current_time + '/train'
val_log_dir = LOG_DIR + '/' + current_time + '/val'
train_saved_dir = SAVED_DIR + '/' + model_type

# Create separate writers for the training and test sets.
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
val_summary_writer = tf.summary.create_file_writer(val_log_dir)

# Define our metrics.
train_loss_metr = tf.keras.metrics.Mean('train_loss', dtype=tf.float64)
val_loss_metr = tf.keras.metrics.Mean('val_loss', dtype=tf.float64)
train_accuracy_metr = tf.keras.metrics.CategoricalAccuracy('train_accuracy')
val_accuracy_metr = tf.keras.metrics.CategoricalAccuracy('val_accuracy')

In [None]:
%tensorboard --logdir {LOG_DIR}

In [None]:
# Define the model.
def build_model(model_type, m):
  if model_type == 'gcn-degree-i':
    return snt.Sequential([
        GCNLayer(input_dims=1, output_dims=2*m, normalization='degree-i', name='gcn_layer_1'),
        tf.nn.relu,
        snt.Linear(output_size=m, name='point_update'),
        tf.nn.relu,
        snt.Linear(output_size=2, name='predict'),
    ])
  elif model_type == 'gcn-no-norm':
    return snt.Sequential([
        GCNLayer(input_dims=1, output_dims=2*m, normalization='none', name='gcn_layer_1'),
        tf.nn.relu,
        snt.Linear(output_size=m, name='point_update'),
        tf.nn.relu,
        snt.Linear(output_size=2, name='predict'),
    ])
  elif model_type == 'cnn':
    return snt.Sequential([
        snt.Conv2D(kernel_shape=3, output_channels=2*m, name='cnn_layer_1'),
        tf.nn.relu,
        snt.Conv2D(kernel_shape=1, output_channels=m, name='point_update'),
        tf.nn.relu,
        snt.Linear(output_size=2, name='predict'),
    ])
  else:
    raise ValueError('Invalid model_type.')

# Seed before building the model.
np.random.seed(model_seed)
tf.random.set_seed(model_seed)

# Construct model.
print('Chosen model: %s' % model_type)
print('Saved model will be at %s' % train_saved_dir)
model = build_model(model_type, m)

# Reshape for the CNNs.
if model_type == 'cnn':
  inp_shape = (8, 8, 1)
else:
  inp_shape = (64, 1)

# Call model before evaluating its size.
model(all_inp_features[0].reshape(-1, *inp_shape))
model_size = np.sum([np.prod(variable.shape) for variable in model.variables])
print('Model Size: %d parameters.' % model_size)
  
# Optimizer.
learning_rate = 5e-2
opt = snt.optimizers.SGD(learning_rate)


# Save model.
def save_with_prefix(prefix):

  @tf.function(input_signature=[tf.TensorSpec([None, *inp_shape])])
  def predict(x):
    return tf.argmax(model(x), axis=-1)

  to_save = snt.Module()
  to_save.predict = predict
  to_save.all_variables = list(model.variables)
  tf.saved_model.save(to_save, train_saved_dir + '/' + str(prefix))


# Update after one batch of training data.
def train_step(step, inp_features, out_features, train_mask, val_mask):

  # Reshape.
  train_mask = tf.reshape(train_mask, -1)
  val_mask = tf.reshape(val_mask, -1)
  out_features = tf.reshape(out_features, -1)
  inp_features = tf.reshape(inp_features, (-1, *inp_shape))

  # Apply masks.
  train_indices = tf.reshape(tf.where(train_mask), -1)
  val_indices = tf.reshape(tf.where(val_mask), -1)
  train_labels = tf.cast(tf.gather(out_features, train_indices), tf.int32)
  val_labels = tf.cast(tf.gather(out_features, val_indices), tf.int32)

  # Watch gradients for these computations!
  with tf.GradientTape() as tape:

    # Get predictions from model.
    logits = model(inp_features)
    logits = tf.reshape(logits, (-1, 2))

    # Compute training loss.
    train_logits = tf.gather(logits, train_indices)
    train_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_logits, labels=train_labels)
    train_loss = tf.math.reduce_mean(train_loss)
    train_loss_metr(train_loss)

  # Update parameters.
  params = model.trainable_variables
  grads = tape.gradient(train_loss, params)
  opt.apply(grads, params)

  # Compute validation loss.
  val_logits = tf.gather(logits, val_indices)
  val_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=val_logits, labels=val_labels)
  val_loss = tf.math.reduce_mean(val_loss)
  val_loss_metr(val_loss)

  # Compute training accuracy.
  train_labels_one_hot = tf.one_hot(train_labels, depth=2)
  train_preds_one_hot = tf.one_hot(tf.math.argmax(train_logits, axis=1), depth=2)
  train_accuracy_metr(train_labels_one_hot, train_preds_one_hot)

  # Compute validation accuracy.
  val_labels_one_hot = tf.one_hot(val_labels, depth=2)
  val_preds_one_hot = tf.one_hot(tf.math.argmax(val_logits, axis=1), depth=2)
  val_accuracy_metr(val_labels_one_hot, val_preds_one_hot)
  
  # Write to logs.
  with train_summary_writer.as_default():
    tf.summary.scalar('loss', train_loss_metr.result(), step=step)
  with val_summary_writer.as_default():
    tf.summary.scalar('loss', val_loss_metr.result(), step=step)

  with train_summary_writer.as_default():
    tf.summary.scalar('accuracy', train_accuracy_metr.result(), step=step)
  with val_summary_writer.as_default():
    tf.summary.scalar('accuracy', val_accuracy_metr.result(), step=step)

  # Reset metrics' states, otherwise we would be averaging over them.
  train_loss_metr.reset_states()
  val_loss_metr.reset_states()
  train_accuracy_metr.reset_states()
  val_accuracy_metr.reset_states()

  return train_loss, val_loss


val_losses = []
for step, (inp_features, out_features, train_mask, val_mask) in enumerate(dataset):
  train_loss, val_loss = train_step(step, inp_features, out_features, train_mask, val_mask)
  val_losses.append(val_loss.numpy())

  if step % 1000 == 0:
    print('Step %d completed.' % step)

  if step > early_stopping_lim and np.min(val_losses[-early_stopping_lim:]) > val_losses[-early_stopping_lim - 1]:
    print(val_losses, np.min(val_losses[-early_stopping_lim:]), val_losses[-early_stopping_lim - 1])
    print('Early-stopping at step %d.' % step)
    break

save_with_prefix(model_size) 