##### Copyright 2019 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");



In [0]:
#@title Default title text
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# High Performance Monte Carlo Simulation of Ising Model on TPU Clusters

This notebook is a companion webpage for the paper: *High Performance Monte Carlo Simulation of Ising Model on TPU Clusters (Yang et al., 2019)*. See the [README.md](https://github.com/google-research/google-research/blob/master/simulation_research/ising_model/README.md) for details on how to simulate Ising model on Cloud TPU.

In [0]:
"""Ising Model MCMC Simulation on TPU.

This is the implementation of Algorithm 2: UpdateOptim in the paper.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import itertools
import os
import time

import numpy as np
import tensorflow as tf
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.cluster_resolver import TPUClusterResolver

# Constants
_NANOS_PER_SECOND = 1e9

# Whether bfloat16 is used in the simulation.
USE_BFLOAT16 = True
# Lattice sub dimension, where each sub lattice is a square grid of spin values.
LATTICE_SUB_DIM = 256
# Lattice block dimensions of size 2, where each block is a grid of sub lattices
# with shape lattice_block_dims.
LATTICE_BLOCK_DIMS = [96, 48]
# Lattice super block dimension, where each super block is a square grid of
# lattice blocks.
LATTICE_SUPER_BLOCK_DIM = 2
# Cores topology of size 2, each core is one replica of the computation.
CORES_TOPOLOGY = [2, 4]
# Burn in steps in MCMC.
NUMBER_OF_BURN_IN_UPDATE = 1
# Simulation steps in MCMC.
NUMBER_OF_WHOLE_LATTICE_UPDATE = 1000
# Inverse temperature, the default is the critical inverse temperature.
INVERSE_TEMPERATURE = 1.0 / 2.26918531421

In [0]:
def get_dtype():
  return tf.bfloat16 if USE_BFLOAT16 else tf.float32


def is_single_core(cores_topology):
  return cores_topology == [1, 1]


def create_iterator(shape):
  """Create an iterator with a given shape."""
  dims = [range(dim) for dim in shape]
  return itertools.product(*dims)


def create_list(shape):
  """Create a list with a given shape and default value None."""
  if shape:
    return [create_list(shape[1:]) for _ in range(shape[0])]

In [0]:
class NearestNeighborCalculatorOptim(object):
  """Calculate the sum of nearest neighbor spin values."""

  def __init__(self):
    # Constant matrix to compute the sum of nearest neighbor spins. Applying
    # it as a left operand to adding the northern/southern neighbors, as a
    # right operand to adding the western/eastern neighobrs.
    grid_nn = tf.constant(
        np.eye(LATTICE_SUB_DIM // 2, k=0, dtype=np.float16) +
        np.eye(LATTICE_SUB_DIM // 2, k=1, dtype=np.float16),
        dtype=get_dtype())
    self._grid_nn = tf.broadcast_to(grid_nn, [
        LATTICE_BLOCK_DIMS[0],
        LATTICE_BLOCK_DIMS[1],
        LATTICE_SUB_DIM // 2,
        LATTICE_SUB_DIM // 2,
    ])
    # Given the spins on northern/southern or western/eastern boundaries of a
    # grid of sub-lattices, transform them into another grid of sub-lattices
    # that are added to compensate the nearest neighbor sums for spins on the
    # boundaries.
    self._grid_expand_s = tf.broadcast_to(
        tf.one_hot(0, LATTICE_SUB_DIM // 2, dtype=get_dtype()), [
            LATTICE_BLOCK_DIMS[0],
            LATTICE_BLOCK_DIMS[1],
            LATTICE_SUB_DIM // 2,
        ])
    self._grid_expand_e = tf.broadcast_to(
        tf.one_hot(
            LATTICE_SUB_DIM // 2 - 1, LATTICE_SUB_DIM // 2, dtype=get_dtype()),
        [
            LATTICE_BLOCK_DIMS[0],
            LATTICE_BLOCK_DIMS[1],
            LATTICE_SUB_DIM // 2,
        ])

    # If the lattice is distributed in multiple replicas, then define that
    # permutation pairs that permute the boundaries of sub-lattices across
    # replicas in all 4 directions. Those boundary spins are used to compute the
    # nearest neighor sums of boundary spins on sub-lattices. They are ignored
    # if the lattice is updated on single core.
    tpu_x, tpu_y = CORES_TOPOLOGY
    core_ids = np.arange(tpu_x * tpu_y).reshape((tpu_x, tpu_y))
    core_ids_n = np.roll(core_ids, 1, axis=0)
    core_ids_s = np.roll(core_ids, -1, axis=0)
    core_ids_w = np.roll(core_ids, 1, axis=1)
    core_ids_e = np.roll(core_ids, -1, axis=1)

    self._permute_n = []
    self._permute_s = []
    self._permute_w = []
    self._permute_e = []
    for i, j in itertools.product(range(tpu_x), range(tpu_y)):
      self._permute_n.append([core_ids_n[i, j], core_ids[i, j]])
      self._permute_s.append([core_ids_s[i, j], core_ids[i, j]])
      self._permute_w.append([core_ids_w[i, j], core_ids[i, j]])
      self._permute_e.append([core_ids_e[i, j], core_ids[i, j]])

  def get_boundary_n(self, super_grids, i, j, single_core, black):
    """The boundary on the northern direction."""
    if single_core:
      if black:
        boundary_n = super_grids[1][0][
            (i - 1) % LATTICE_SUPER_BLOCK_DIM][j][-1:, :, -1, :]
        boundary_n_rest = super_grids[1][0][i][j][:-1, :, -1, :]
      else:
        boundary_n = super_grids[1][1][
            (i - 1) % LATTICE_SUPER_BLOCK_DIM][j][-1:, :, -1, :]
        boundary_n_rest = super_grids[1][1][i][j][:-1, :, -1, :]
    else:
      if black:
        if i == 0:
          boundary_n = tpu_ops.collective_permute(
              super_grids[1][0][(i - 1) %
                                LATTICE_SUPER_BLOCK_DIM][j][-1:, :, -1, :],
              self._permute_n)
        else:
          boundary_n = super_grids[1][0][
              (i - 1) % LATTICE_SUPER_BLOCK_DIM][j][-1:, :, -1, :]
        boundary_n_rest = super_grids[1][0][i][j][:-1, :, -1, :]
      else:
        if i == 0:
          # In this case, its northern boundary is the southern boundary of the
          # sub-lattice in the replica above, assuming periodic boundary
          # condition on core topology.
          boundary_n = tpu_ops.collective_permute(
              super_grids[1][1][(i - 1) %
                                LATTICE_SUPER_BLOCK_DIM][j][-1:, :, -1, :],
              self._permute_n)
        else:
          boundary_n = super_grids[1][1][
              (i - 1) % LATTICE_SUPER_BLOCK_DIM][j][-1:, :, -1, :]
        boundary_n_rest = super_grids[1][1][i][j][:-1, :, -1, :]
    grid_boundary_n_ij = tf.concat([boundary_n, boundary_n_rest], axis=0)
    return grid_boundary_n_ij

  def get_boundary_s(self, super_grids, i, j, single_core, black):
    """The boundary on the southern direction."""
    if single_core:
      if black:
        boundary_s = super_grids[0][1][(i + 1) %
                                       LATTICE_SUPER_BLOCK_DIM][j][:1, :, 0, :]
        boundary_s_rest = super_grids[0][1][i][j][1:, :, 0, :]
      else:
        boundary_s = super_grids[0][0][(i + 1) %
                                       LATTICE_SUPER_BLOCK_DIM][j][:1, :, 0, :]
        boundary_s_rest = super_grids[0][0][i][j][1:, :, 0, :]
    else:
      if black:
        if i == LATTICE_SUPER_BLOCK_DIM - 1:
          boundary_s = tpu_ops.collective_permute(
              super_grids[0][1][(i + 1) %
                                LATTICE_SUPER_BLOCK_DIM][j][:1, :, 0, :],
              self._permute_s)
        else:
          boundary_s = super_grids[0][1][
              (i + 1) % LATTICE_SUPER_BLOCK_DIM][j][:1, :, 0, :]
        boundary_s_rest = super_grids[0][1][i][j][1:, :, 0, :]
      else:
        if i == LATTICE_SUPER_BLOCK_DIM - 1:
          # In this case, its southern boundary is the northern boundary of the
          # sub-lattice in the replica below, assuming periodic boundary
          # condition on core topology.
          boundary_s = tpu_ops.collective_permute(
              super_grids[0][0][(i + 1) %
                                LATTICE_SUPER_BLOCK_DIM][j][:1, :, 0, :],
              self._permute_s)
        else:
          boundary_s = super_grids[0][0][
              (i + 1) % LATTICE_SUPER_BLOCK_DIM][j][:1, :, 0, :]
        boundary_s_rest = super_grids[0][0][i][j][1:, :, 0, :]
    grid_boundary_s_ij = tf.concat([boundary_s_rest, boundary_s], axis=0)
    return grid_boundary_s_ij

  def get_boundary_w(self, super_grids, i, j, single_core, black):
    """The boundary on the western direction."""
    if single_core:
      if black:
        boundary_w = super_grids[0][1][i][
            (j - 1) % LATTICE_SUPER_BLOCK_DIM][:, -1:, :, -1]
        boundary_w_rest = super_grids[0][1][i][j][:, :-1, :, -1]
      else:
        boundary_w = super_grids[1][1][i][
            (j - 1) % LATTICE_SUPER_BLOCK_DIM][:, -1:, :, -1]
        boundary_w_rest = super_grids[1][1][i][j][:, :-1, :, -1]
    else:
      if black:
        if j == 0:
          boundary_w = tpu_ops.collective_permute(
              super_grids[0][1][i][(j - 1) %
                                   LATTICE_SUPER_BLOCK_DIM][:, -1:, :, -1],
              self._permute_w)
        else:
          boundary_w = super_grids[0][1][i][
              (j - 1) % LATTICE_SUPER_BLOCK_DIM][:, -1:, :, -1]
        boundary_w_rest = super_grids[0][1][i][j][:, :-1, :, -1]
      else:
        if j == 0:
          # In this case, its western boundary is the eastern boundary of the
          # sub-lattice in the replica on the left, assuming periodic boundary
          # condition on core topology.
          boundary_w = tpu_ops.collective_permute(
              super_grids[1][1][i][(j - 1) %
                                   LATTICE_SUPER_BLOCK_DIM][:, -1:, :, -1],
              self._permute_w)
        else:
          boundary_w = super_grids[1][1][i][
              (j - 1) % LATTICE_SUPER_BLOCK_DIM][:, -1:, :, -1]
        boundary_w_rest = super_grids[1][1][i][j][:, :-1, :, -1]
    grid_boundary_w_ij = tf.concat([boundary_w, boundary_w_rest], axis=1)
    return grid_boundary_w_ij

  def get_boundary_e(self, super_grids, i, j, single_core, black):
    """The boundary on the eastern direction."""
    if single_core:
      if black:
        boundary_e = super_grids[1][0][i][(j + 1) %
                                          LATTICE_SUPER_BLOCK_DIM][:, :1, :, 0]
        boundary_e_rest = super_grids[1][0][i][j][:, 1:, :, 0]
      else:
        boundary_e = super_grids[0][0][i][(j + 1) %
                                          LATTICE_SUPER_BLOCK_DIM][:, :1, :, 0]
        boundary_e_rest = super_grids[0][0][i][j][:, 1:, :, 0]
    else:
      if black:
        if j == LATTICE_SUPER_BLOCK_DIM - 1:
          boundary_e = tpu_ops.collective_permute(
              super_grids[1][0][i][(j + 1) %
                                   LATTICE_SUPER_BLOCK_DIM][:, :1, :, 0],
              self._permute_e)
        else:
          boundary_e = super_grids[1][0][i][
              (j + 1) % LATTICE_SUPER_BLOCK_DIM][:, :1, :, 0]
        boundary_e_rest = super_grids[1][0][i][j][:, 1:, :, 0]
      else:
        if j == LATTICE_SUPER_BLOCK_DIM - 1:
          # In this case, its eastern boundary is the western boundary of the
          # sub-lattice in the replica on the right, assuming periodic boundary
          # condition on core topology.
          boundary_e = tpu_ops.collective_permute(
              super_grids[0][0][i][(j + 1) %
                                   LATTICE_SUPER_BLOCK_DIM][:, :1, :, 0],
              self._permute_e)
        else:
          boundary_e = super_grids[0][0][i][
              (j + 1) % LATTICE_SUPER_BLOCK_DIM][:, :1, :, 0]
        boundary_e_rest = super_grids[0][0][i][j][:, 1:, :, 0]
    grid_boundary_e_ij = tf.concat([boundary_e_rest, boundary_e], axis=1)
    return grid_boundary_e_ij

  def sum_of_nearest_neighbors_black(self, super_grids, single_core):
    """The sum of nearest neighbor in each site on periodic boundries."""
    sum_nn_00 = create_list([LATTICE_SUPER_BLOCK_DIM] * 2)
    for i, j in create_iterator([LATTICE_SUPER_BLOCK_DIM] * 2):
      sum_nn_00[i][j] = (
          tf.matmul(super_grids[0][1][i][j], self._grid_nn) +
          tf.matmul(self._grid_nn, super_grids[1][0][i][j], transpose_a=True))
    sum_nn_11 = create_list([LATTICE_SUPER_BLOCK_DIM] * 2)
    for i, j in create_iterator([LATTICE_SUPER_BLOCK_DIM] * 2):
      sum_nn_11[i][j] = (
          tf.matmul(self._grid_nn, super_grids[0][1][i][j]) +
          tf.matmul(super_grids[1][0][i][j], self._grid_nn, transpose_b=True))
    for i, j in create_iterator([LATTICE_SUPER_BLOCK_DIM] * 2):
      # Handle the northern/western boundary.
      grid_boundary_n_ij = self.get_boundary_n(super_grids, i, j, single_core,
                                               True)
      grid_boundary_w_ij = self.get_boundary_w(super_grids, i, j, single_core,
                                               True)

      sum_nn_00[i][j] += (
          tf.einsum('mni,mnj->mnij', self._grid_expand_s, grid_boundary_n_ij) +
          tf.einsum('mni,mnj->mnij', grid_boundary_w_ij, self._grid_expand_s))

      # Handle the southern/eastern boundary.
      grid_boundary_s_ij = self.get_boundary_s(super_grids, i, j, single_core,
                                               True)
      grid_boundary_e_ij = self.get_boundary_e(super_grids, i, j, single_core,
                                               True)
      sum_nn_11[i][j] += (
          tf.einsum('mni,mnj->mnij', self._grid_expand_e, grid_boundary_s_ij) +
          tf.einsum('mni,mnj->mnij', grid_boundary_e_ij, self._grid_expand_e))

    return sum_nn_00, sum_nn_11

  def sum_of_nearest_neighbors_white(self, super_grids, single_core):
    """The sum of nearest neighbor in each site on periodic boundries."""
    sum_nn_01 = create_list([LATTICE_SUPER_BLOCK_DIM] * 2)
    for i, j in create_iterator([LATTICE_SUPER_BLOCK_DIM] * 2):
      sum_nn_01[i][j] = (
          tf.matmul(super_grids[0][0][i][j], self._grid_nn, transpose_b=True) +
          tf.matmul(self._grid_nn, super_grids[1][1][i][j], transpose_a=True))
    sum_nn_10 = create_list([LATTICE_SUPER_BLOCK_DIM] * 2)
    for i, j in create_iterator([LATTICE_SUPER_BLOCK_DIM] * 2):
      sum_nn_10[i][j] = (
          tf.matmul(self._grid_nn, super_grids[0][0][i][j]) +
          tf.matmul(super_grids[1][1][i][j], self._grid_nn))
    for i, j in create_iterator([LATTICE_SUPER_BLOCK_DIM] * 2):
      # Handle the northern/estern boundary.
      grid_boundary_n_ij = self.get_boundary_n(super_grids, i, j, single_core,
                                               False)
      grid_boundary_e_ij = self.get_boundary_e(super_grids, i, j, single_core,
                                               False)
      sum_nn_01[i][j] += (
          tf.einsum('mni,mnj->mnij', self._grid_expand_s, grid_boundary_n_ij) +
          tf.einsum('mni,mnj->mnij', grid_boundary_e_ij, self._grid_expand_e))
      # Handle the southern/western boundary.
      grid_boundary_s_ij = self.get_boundary_s(super_grids, i, j, single_core,
                                               False)
      grid_boundary_w_ij = self.get_boundary_w(super_grids, i, j, single_core,
                                               False)
      sum_nn_10[i][j] += (
          tf.einsum('mni,mnj->mnij', self._grid_expand_e, grid_boundary_s_ij) +
          tf.einsum('mni,mnj->mnij', grid_boundary_w_ij, self._grid_expand_s))

    return sum_nn_01, sum_nn_10

In [0]:
def _validate_params():
  """Validate parameters before using them."""
  assert LATTICE_SUB_DIM > 0

  assert LATTICE_BLOCK_DIMS is not None and len(LATTICE_BLOCK_DIMS) == 2
  assert LATTICE_BLOCK_DIMS[0] > 0 and LATTICE_BLOCK_DIMS[1] > 0
  assert LATTICE_SUPER_BLOCK_DIM > 0
  assert INVERSE_TEMPERATURE > 0.0
  assert NUMBER_OF_BURN_IN_UPDATE > 0
  assert NUMBER_OF_WHOLE_LATTICE_UPDATE > 0

  assert CORES_TOPOLOGY is not None and len(CORES_TOPOLOGY) == 2
  assert CORES_TOPOLOGY[0] > 0 and CORES_TOPOLOGY[1] > 0


def compute_nanoseconds_per_flip(step_time):
  """Compute the nanoseconds per flip given step_time in seconds."""
  nanos_per_flip = step_time * _NANOS_PER_SECOND / (
      (LATTICE_SUB_DIM * LATTICE_SUPER_BLOCK_DIM)**2 *
      np.prod(LATTICE_BLOCK_DIMS) * NUMBER_OF_WHOLE_LATTICE_UPDATE *
      np.prod(CORES_TOPOLOGY))
  return nanos_per_flip


# pylint: disable=unused-argument
def _grid_initializer(shape, dtype, partition_info):
  grid_value_init = 2.0 * tf.cast(
      tf.random_uniform(shape) > 0.5, dtype=dtype) - 1.0
  return grid_value_init

In [0]:
def update_optim(sweeps, fn, single_core):
  """Simulation in each replica using 'compact' representation.

  Boundary value communications are handled by collective permute.

  Ref: Algorithm UpdateOptim in the paper.

  Args:
    sweeps: the number of whole lattice update.
    fn: the function on a given configuration of the lattice, it takes a 2-D
      list of variables, i.e., v_ij and returns a float32.
    single_core: a bool that specify whether the lattice is updated on single or
      multiple cores.

  Returns:
    The estimated expectation of fn, i.e., <fn>
  """
  _validate_params()

  temperature_muliplier = -2 * INVERSE_TEMPERATURE

  lattice_sub_shape = [
      LATTICE_BLOCK_DIMS[0],
      LATTICE_BLOCK_DIMS[1],
      LATTICE_SUB_DIM // 2,
      LATTICE_SUB_DIM // 2,
  ]

  nn_calculator = NearestNeighborCalculatorOptim()

  def checkerboard(k, estimation):
    """Checkerboard algorithm update."""
    with tf.variable_scope('tpu', reuse=tf.AUTO_REUSE, use_resource=True):
      super_grids = create_list([2] * 2 + [LATTICE_SUPER_BLOCK_DIM] * 2)
      for l, m in create_iterator([2] * 2):
        for i, j in create_iterator([LATTICE_SUPER_BLOCK_DIM] * 2):
          super_grids[l][m][i][j] = tf.get_variable(
              'grids_%d%d_%d%d' % (l, m, i, j),
              initializer=_grid_initializer,
              shape=lattice_sub_shape,
              dtype=tf.bfloat16)

    def update(probs, black):
      """Checkerboard algorithm update for a given color."""
      if black:
        idx = [[0, 0], [1, 1]]
        sum_nn_color = \
            nn_calculator.sum_of_nearest_neighbors_black(
                super_grids, single_core)
      else:
        idx = [[0, 1], [1, 0]]
        sum_nn_color = \
            nn_calculator.sum_of_nearest_neighbors_white(
                super_grids, single_core)
      assign_ops = []
      for i, j in create_iterator([LATTICE_SUPER_BLOCK_DIM] * 2):
        for [idx0, idx1], sum_nn in zip(idx, sum_nn_color):
          acceptance_ratio_ij = (
              temperature_muliplier * sum_nn[i][j] *
              super_grids[idx0][idx1][i][j])
          flips_ij = tf.cast(
              probs[idx0][idx1][i][j] < acceptance_ratio_ij, dtype=get_dtype())
          assign_ops.append(super_grids[idx0][idx1][i][j].assign_sub(
              flips_ij * super_grids[idx0][idx1][i][j] *
              tf.constant(2.0, dtype=get_dtype())))
      return assign_ops

    probs = create_list([2] * 2 + [LATTICE_SUPER_BLOCK_DIM] * 2)
    for l, m in create_iterator([2] * 2):
      for i, j in create_iterator([LATTICE_SUPER_BLOCK_DIM] * 2):
        probs[l][m][i][j] = (
            tf.log(tf.random_uniform(lattice_sub_shape, dtype=get_dtype())))
    grid_black_update = update(probs, black=True)
    with tf.control_dependencies(grid_black_update):
      grid_white_update = update(probs, black=False)
    with tf.control_dependencies(grid_white_update):
      return k + 1, (
          estimation * tf.cast(k, dtype=tf.float32) /
          tf.cast(k + 1, dtype=tf.float32) +
          tf.cast(fn(super_grids), dtype=tf.float32) /
          tf.cast(k + 1, dtype=tf.float32))

  def while_loop(sweeps):
    _, estimation = tf.while_loop(lambda i, _: i < sweeps, checkerboard, [
        0,
        tf.constant(0.0, dtype=tf.float32),
    ])
    return estimation

  return while_loop(sweeps)

In [0]:
def create_ising_mcmc_simulator(fn, single_core):
  """Ising model MCMC simulation on single or multiple TPUs.

  The whole lattice is distributed equally among TPU cores, and each TPU core
  runs one replica given cores topology `[l1, l2]`. In each core, because of
  the hard limit of protobuf (2GB), we split the sub-lattice into multi-scale
  sub-lattices as follows:

  The sub-lattice in each core is divided into a 'kxk' grid of sub-lattices,
  each sub-lattice is a tensor variable, where `k' is flag:
  lattice_super_block_dim,
  `[v_00, ........... v_{0,k-1}]`
  `[v_10, ........... v_{1,k-1}]`
  `.............................`
  `[v_{k-1,0}, ..., v_{k-1,k-1}]`

  Each v_ij is again a `m1xm2` grid of smaller sub-lattice, where `m1` and
  `m2` are provided by flag: lattice_block_dims,
  `[g_00, ............, g_{0,m2-1}]`
  `[g_10, ............, g_{1,m2-1}]`
  `................................`
  `[g_{m1-1,0}, ..., g_{m1-1,m2-1}]`

  and each g_ij is a `nxn` sub-lattice, where `n` is flag: lattice_sub_dim.

  Each g_ij is furthure split into 4 compact sub-lattices, i.e.,
  `g_ij_00 = g_ij[0::2, 0::2]`
  `g_ij_01 = g_ij[0::2, 1::2]`
  `g_ij_10 = g_ij[1::2, 0::2]`
  `g_ij_11 = g_ij[1::2, 1::2]`

  Where g_ij_00 and g_ij_11 are all black spins, and g_ij_01 and g_ij_10 are
  all white spins.

  Thus the whole lattice has dimensions [l1*k*m1*n, l2*k*m2*n].

  The boundaries of each sub-lattice are collectively permuted among replicas
  to calculate nearest neighbor sums, which are used to compute
  Metropolis-Hastings updates.

  Args:
    fn: the function on a given configuration of the sub_lattice on the
      replica. It takes a 2-D list of variables, i.e., v_ij, and returns a
      float32. fn must be additive in order to insure the correctness of the
      estimation, i.e., `fn([sub_lattice_1, ..., sub_lattice_n]) = sum_{i=1}^n
      fn([sub_lattice_i])`.
    single_core: a bool that specify whether the lattice is updated on single
      or multiple cores.

  Returns:
    The tuple of 2 operators for the burn_in and the estimation.
  """
  if single_core:
    burn_in = tf.contrib.tpu.rewrite(
        lambda sweeps: update_optim(sweeps, fn, True),
        [NUMBER_OF_BURN_IN_UPDATE])
    estimation = tf.contrib.tpu.rewrite(
        lambda sweeps: update_optim(sweeps, fn, True),
        [NUMBER_OF_WHOLE_LATTICE_UPDATE])
  else:
    burn_in = tf.contrib.tpu.replicate(
        lambda sweeps: update_optim(sweeps, fn, False),
        inputs=[
            [NUMBER_OF_BURN_IN_UPDATE] for _ in range(np.prod(CORES_TOPOLOGY))
        ])
    estimation = tf.contrib.tpu.replicate(
        lambda sweeps: update_optim(sweeps, fn, False),
        inputs=[[NUMBER_OF_WHOLE_LATTICE_UPDATE]
                for _ in range(np.prod(CORES_TOPOLOGY))])
  return burn_in, estimation

In [0]:
def ising_mcmc_simulation(fn, session):
  """Ising mcmc simulation.

  Args:
    fn: the function on a given configuration of the lattice. It takes a 2-D
      list of variables, i.e., v_ij and returns a float32. Refer to
      create_ising_mcmc_on_single_core for details on the variables.
    session: a function returns a tensorflow session.

  Returns:
    The estimation of the expectation of fn.
  """
  single_core = is_single_core(CORES_TOPOLOGY)
  burn_in, estimation = create_ising_mcmc_simulator(fn, single_core)
  with session() as sess:
    print(sess.list_devices())
    sess.run(tf.contrib.tpu.initialize_system())
    sess.run(tf.global_variables_initializer())
    print('--- start burning in ---')
    sess.run(burn_in)
    print('--- finish burning in ---')
    start_time = time.time()
    estimation_val = sess.run(estimation)
    step_time = time.time() - start_time
    nanos_per_flip = compute_nanoseconds_per_flip(step_time)
    print('--- %s seconds ---' % step_time)
    print('--- %s nanoseconds per flip ---' % nanos_per_flip)
    sess.run(tf.contrib.tpu.shutdown_system())
    print('--- Done ---')
  return estimation_val

In [0]:
def get_session():
  def _get_tpu_setup():
    tpu_cluster_resolver = TPUClusterResolver(tpu=os.environ('TPU_NAME'))
    cluster_def = tpu_cluster_resolver.cluster_spec().as_cluster_def()
    tpu_master_grpc_path = tpu_cluster_resolver.get_master()
    return cluster_def, tpu_master_grpc_path

  cluster_def, tpu_master_grpc_path = _get_tpu_setup()
  config = tf.ConfigProto(
      allow_soft_placement=True,
      isolate_session_state=True,
      cluster_def=cluster_def)
  return tf.Session(tpu_master_grpc_path, config=config)


def reduce_mean(super_grids):
  avg_mag = tf.constant(0.0, dtype=tf.float32)
  for l, m in create_iterator([2] * 2):
    for i, j in create_iterator([LATTICE_SUPER_BLOCK_DIM] * 2):
      avg_mag += tf.cast(
          tf.reduce_mean(super_grids[l][m][i][j]), dtype=tf.float32)
  return avg_mag / np.prod([
      2, 2, LATTICE_SUPER_BLOCK_DIM, LATTICE_SUPER_BLOCK_DIM
  ]).astype(np.float32)


with tf.Graph().as_default():
  ising_mcmc_simulation(reduce_mean, get_session)