In [1]:
!pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --upgrade

Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html
Collecting jax[tpu]
  Downloading jax-0.4.20-py3-none-any.whl (1.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting libtpu-nightly==0.1.dev20231102
  Downloading https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20231102-py3-none-any.whl (122.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.6/122.6 MB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting jaxlib==0.4.20
  Downloading jaxlib-0.4.20-cp310-cp310-manylinux2014_x86_64.whl (85.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.8/85.8 MB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
Installing collected packages: libtpu-nightly, jaxlib, jax
  Attempting uninstall: libtpu-nightly
    Found existing installation: libtp

In [2]:
import jax
print(jax.__version__)
print(jax.devices())

0.4.20
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]


In [3]:
!pip install brainpy

Collecting brainpy
  Downloading brainpy-2.4.6-py3-none-any.whl (678 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m678.5/678.5 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: brainpy
Successfully installed brainpy-2.4.6
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [4]:
import time
from pprint import pprint  
import numpy as np
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
bp.check.turn_off()

In [6]:
class Tool:
  def __init__(self, pre_stimulus_period=100., stimulus_period=1000., delay_period=500.):
    self.pre_stimulus_period = pre_stimulus_period
    self.stimulus_period = stimulus_period
    self.delay_period = delay_period
    self.freq_variance = 10.
    self.freq_interval = 50.
    self.total_period = pre_stimulus_period + stimulus_period + delay_period

  def generate_freqs(self, mean, num=1):
    # stimulus period
    n_stim = int(self.stimulus_period / self.freq_interval)
    n_interval = int(self.freq_interval / bm.get_dt())
    freqs_stim = np.random.normal(mean, self.freq_variance, (n_stim, 1, num))
    freqs_stim = np.tile(freqs_stim, (1, n_interval, 1)).reshape(n_stim * n_interval, num)
    # pre stimulus period
    freqs_pre = np.zeros([int(self.pre_stimulus_period / bm.get_dt()), num])
    # post stimulus period
    freqs_delay = np.zeros([int(self.delay_period / bm.get_dt()), num])
    all_freqs = np.concatenate([freqs_pre, freqs_stim, freqs_delay], axis=0)
    return bm.asarray(all_freqs)

  def visualize_results(self, mon, IA_freqs, IB_freqs, t_start=0., title=None):
    fig, gs = bp.visualize.get_figure(4, 1, 3, 10)
    axes = [fig.add_subplot(gs[i, 0]) for i in range(4)]

    ax = axes[0]
    bp.visualize.raster_plot(mon['ts'], mon['A.spike'], markersize=1, ax=ax)
    if title: ax.set_title(title)
    ax.set_ylabel("Group A")
    ax.set_xlim(t_start, self.total_period + 1)
    ax.axvline(self.pre_stimulus_period, linestyle='dashed')
    ax.axvline(self.pre_stimulus_period + self.stimulus_period, linestyle='dashed')
    ax.axvline(self.pre_stimulus_period + self.stimulus_period + self.delay_period, linestyle='dashed')

    ax = axes[1]
    bp.visualize.raster_plot(mon['ts'], mon['B.spike'], markersize=1, ax=ax)
    ax.set_ylabel("Group B")
    ax.set_xlim(t_start, self.total_period + 1)
    ax.axvline(self.pre_stimulus_period, linestyle='dashed')
    ax.axvline(self.pre_stimulus_period + self.stimulus_period, linestyle='dashed')
    ax.axvline(self.pre_stimulus_period + self.stimulus_period + self.delay_period, linestyle='dashed')

    ax = axes[2]
    rateA = bp.measure.firing_rate(mon['A.spike'], width=10.)
    rateB = bp.measure.firing_rate(mon['B.spike'], width=10.)
    ax.plot(mon['ts'], rateA, label="Group A")
    ax.plot(mon['ts'], rateB, label="Group B")
    ax.set_ylabel('Population activity [Hz]')
    ax.set_xlim(t_start, self.total_period + 1)
    ax.axvline(self.pre_stimulus_period, linestyle='dashed')
    ax.axvline(self.pre_stimulus_period + self.stimulus_period, linestyle='dashed')
    ax.axvline(self.pre_stimulus_period + self.stimulus_period + self.delay_period, linestyle='dashed')
    ax.legend()

    ax = axes[3]
    ax.plot(mon['ts'], IA_freqs, label="group A")
    ax.plot(mon['ts'], IB_freqs, label="group B")
    ax.set_ylabel("Input activity [Hz]")
    ax.set_xlim(t_start, self.total_period + 1)
    ax.axvline(self.pre_stimulus_period, linestyle='dashed')
    ax.axvline(self.pre_stimulus_period + self.stimulus_period, linestyle='dashed')
    ax.axvline(self.pre_stimulus_period + self.stimulus_period + self.delay_period, linestyle='dashed')
    ax.legend()
    ax.set_xlabel("Time [ms]")

    plt.show()

In [7]:
class ExpSyn(bp.Projection):
  def __init__(self, pre, post, conn, delay, g_max, tau, E):
    super().__init__()
    if conn == 'all2all':
      comm = bp.dnn.AllToAll(pre.num, post.num, g_max)
    elif conn == 'one2one':
      comm = bp.dnn.OneToOne(pre.num, g_max)
    else:
      raise ValueError
    syn = bp.dyn.Expon.desc(post.num, tau=tau, sharding=[bm.sharding.NEU_AXIS])
    out = bp.dyn.COBA.desc(E=E)
    self.proj = bp.dyn.ProjAlignPostMg2(
      pre=pre, delay=delay, comm=comm,
      syn=syn, out=out, post=post
    )

In [8]:
class NMDA(bp.Projection):
  def __init__(self, pre, post, conn, delay, g_max):
    super().__init__()
    if conn == 'all2all':
      comm = bp.dnn.AllToAll(pre.num, post.num, g_max)
    elif conn == 'one2one':
      comm = bp.dnn.OneToOne(pre.num, g_max)
    else:
      raise ValueError
    syn = bp.dyn.NMDA.desc(pre.num, a=0.5, tau_decay=100., tau_rise=2., sharding=[bm.sharding.NEU_AXIS])
    out = bp.dyn.MgBlock(E=0., cc_Mg=1.0)
    self.proj = bp.dyn.ProjAlignPreMg2(
      pre=pre, delay=delay, syn=syn,
      comm=comm, out=out, post=post
    )

In [9]:
class DecisionMakingNet(bp.DynSysGroup):
  def __init__(self, scale=1., f=0.15):
    super().__init__()
    # 网络中各组神经元的数目
    num_exc = int(1600 * scale)
    num_I, num_A, num_B = int(400 * scale), int(f * num_exc), int(f * num_exc)
    num_N = num_exc - num_A - num_B
    self.num_A, self.num_B, self.num_N, self.num_I = num_A, num_B, num_N, num_I
    self.num = num_A + num_B + num_N + num_I

    poisson_freq = 2400.  # Hz
    w_pos = 1.7
    w_neg = 1. - f * (w_pos - 1.) / (1. - f)
    g_ext2E_AMPA = 2.1  # nS
    g_ext2I_AMPA = 1.62  # nS
    g_E2E_AMPA = 0.05 / scale  # nS
    g_E2I_AMPA = 0.04 / scale  # nS
    g_E2E_NMDA = 0.165 / scale  # nS
    g_E2I_NMDA = 0.13 / scale  # nS
    g_I2E_GABAa = 1.3 / scale  # nS
    g_I2I_GABAa = 1.0 / scale  # nS

    neu_par = dict(V_rest=-70., V_reset=-55., V_th=-50., V_initializer=bp.init.OneInit(-70.),
                   sharding=[bm.sharding.NEU_AXIS])

    # E neurons/pyramid neurons
    self.A = bp.dyn.LifRef(num_A, tau=20., R=0.04, tau_ref=2., **neu_par)
    self.B = bp.dyn.LifRef(num_B, tau=20., R=0.04, tau_ref=2., **neu_par)
    self.N = bp.dyn.LifRef(num_N, tau=20., R=0.04, tau_ref=2., **neu_par)

    # I neurons/interneurons
    self.I = bp.dyn.LifRef(num_I, tau=10., R=0.05, tau_ref=1., **neu_par)

    # poisson stimulus  # 'freqs' as bm.Variable
    self.IA = bp.dyn.PoissonGroup(num_A, freqs=bm.Variable(bm.zeros(1)), sharding=[bm.sharding.NEU_AXIS])
    self.IB = bp.dyn.PoissonGroup(num_B, freqs=bm.Variable(bm.zeros(1)), sharding=[bm.sharding.NEU_AXIS])

    # noise neurons
    self.noise_B = bp.dyn.PoissonGroup(num_B, freqs=poisson_freq, sharding=[bm.sharding.NEU_AXIS])
    self.noise_A = bp.dyn.PoissonGroup(num_A, freqs=poisson_freq, sharding=[bm.sharding.NEU_AXIS])
    self.noise_N = bp.dyn.PoissonGroup(num_N, freqs=poisson_freq, sharding=[bm.sharding.NEU_AXIS])
    self.noise_I = bp.dyn.PoissonGroup(num_I, freqs=poisson_freq, sharding=[bm.sharding.NEU_AXIS])

    # define external inputs
    self.IA2A = ExpSyn(self.IA, self.A, 'one2one', None, g_ext2E_AMPA, tau=2., E=0.)
    self.IB2B = ExpSyn(self.IB, self.B, 'one2one', None, g_ext2E_AMPA, tau=2., E=0.)

    # define AMPA projections from N
    self.N2B_AMPA = ExpSyn(self.N, self.B, 'all2all', 0.5, g_E2E_AMPA * w_neg, tau=2., E=0.)
    self.N2A_AMPA = ExpSyn(self.N, self.A, 'all2all', 0.5, g_E2E_AMPA * w_neg, tau=2., E=0.)
    self.N2N_AMPA = ExpSyn(self.N, self.N, 'all2all', 0.5, g_E2E_AMPA, tau=2., E=0.)
    self.N2I_AMPA = ExpSyn(self.N, self.I, 'all2all', 0.5, g_E2I_AMPA, tau=2., E=0.)

    # define NMDA projections from N
    self.N2B_NMDA = NMDA(self.N, self.B, 'all2all', 0.5, g_E2E_NMDA * w_neg)
    self.N2A_NMDA = NMDA(self.N, self.A, 'all2all', 0.5, g_E2E_NMDA * w_neg)
    self.N2N_NMDA = NMDA(self.N, self.N, 'all2all', 0.5, g_E2E_NMDA)
    self.N2I_NMDA = NMDA(self.N, self.I, 'all2all', 0.5, g_E2I_NMDA)

    # define AMPA projections from B
    self.B2B_AMPA = ExpSyn(self.B, self.B, 'all2all', 0.5, g_E2E_AMPA * w_pos, tau=2., E=0.)
    self.B2A_AMPA = ExpSyn(self.B, self.A, 'all2all', 0.5, g_E2E_AMPA * w_neg, tau=2., E=0.)
    self.B2N_AMPA = ExpSyn(self.B, self.N, 'all2all', 0.5, g_E2E_AMPA, tau=2., E=0.)
    self.B2I_AMPA = ExpSyn(self.B, self.I, 'all2all', 0.5, g_E2I_AMPA, tau=2., E=0.)

    # define NMDA projections from B
    self.B2B_NMDA = NMDA(self.B, self.B, 'all2all', 0.5, g_E2E_NMDA * w_pos)
    self.B2A_NMDA = NMDA(self.B, self.A, 'all2all', 0.5, g_E2E_NMDA * w_neg)
    self.B2N_NMDA = NMDA(self.B, self.N, 'all2all', 0.5, g_E2E_NMDA)
    self.B2I_NMDA = NMDA(self.B, self.I, 'all2all', 0.5, g_E2I_NMDA)

    # define AMPA projections from A
    self.A2B_AMPA = ExpSyn(self.A, self.B, 'all2all', 0.5, g_E2E_AMPA * w_neg, tau=2., E=0.)
    self.A2A_AMPA = ExpSyn(self.A, self.A, 'all2all', 0.5, g_E2E_AMPA * w_pos, tau=2., E=0.)
    self.A2N_AMPA = ExpSyn(self.A, self.N, 'all2all', 0.5, g_E2E_AMPA, tau=2., E=0.)
    self.A2I_AMPA = ExpSyn(self.A, self.I, 'all2all', 0.5, g_E2I_AMPA, tau=2., E=0.)

    # define NMDA projections from A
    self.A2B_NMDA = NMDA(self.A, self.B, 'all2all', 0.5, g_E2E_NMDA * w_neg)
    self.A2A_NMDA = NMDA(self.A, self.A, 'all2all', 0.5, g_E2E_NMDA * w_pos)
    self.A2N_NMDA = NMDA(self.A, self.N, 'all2all', 0.5, g_E2E_NMDA)
    self.A2I_NMDA = NMDA(self.A, self.I, 'all2all', 0.5, g_E2I_NMDA)

    # define I->E/I conn
    self.I2B = ExpSyn(self.I, self.B, 'all2all', 0.5, g_I2E_GABAa, tau=5., E=-70.)
    self.I2A = ExpSyn(self.I, self.A, 'all2all', 0.5, g_I2E_GABAa, tau=5., E=-70.)
    self.I2N = ExpSyn(self.I, self.N, 'all2all', 0.5, g_I2E_GABAa, tau=5., E=-70.)
    self.I2I = ExpSyn(self.I, self.I, 'all2all', 0.5, g_I2I_GABAa, tau=5., E=-70.)

    # define external projections
    self.noise2B = ExpSyn(self.noise_B, self.B, 'one2one', None, g_ext2E_AMPA, tau=2., E=0.)
    self.noise2A = ExpSyn(self.noise_A, self.A, 'one2one', None, g_ext2E_AMPA, tau=2., E=0.)
    self.noise2N = ExpSyn(self.noise_N, self.N, 'one2one', None, g_ext2E_AMPA, tau=2., E=0.)
    self.noise2I = ExpSyn(self.noise_I, self.I, 'one2one', None, g_ext2I_AMPA, tau=2., E=0.)

In [10]:
def run_a_simulation(scale):
  # simulation tools
  pre_stimulus_period = 100.  # time before the external simuli are given
  stimulus_period = 1000.  # time within which the external simuli are given
  delay_period = 500.  # time after the external simuli are removed
  tool = Tool(pre_stimulus_period, stimulus_period, delay_period)

  # stimulus
  mu0 = 40.
  coherence = 25.6
  IA_freqs = tool.generate_freqs(mu0 + mu0 / 100. * coherence)
  IB_freqs = tool.generate_freqs(mu0 - mu0 / 100. * coherence)

  # network
  net = DecisionMakingNet(scale)

  def run(i):
    bp.share.save(i=i, t=bm.get_dt() * i)
    net.IA.freqs.value = IA_freqs[i]
    net.IB.freqs.value = IB_freqs[i]
    net.update()
    return {'A.spike': net.A.spike.value, 'B.spike': net.B.spike.value}

  # running
  indices = np.arange(int(tool.total_period / bm.get_dt()))
  mon = bm.for_loop(run, indices)
  mon['ts'] = indices * bm.get_dt()
  tool.visualize_results(mon, IA_freqs, IB_freqs)

In [11]:
def simulate_a_trial(scale, platform='cpu', x64=False, monitor=True):
  bm.set_platform(platform)
  bm.random.seed()
  if x64:
    bm.enable_x64()

  # simulation tools
  pre_stimulus_period = 100.  # time before the external simuli are given
  stimulus_period = 1000.  # time within which the external simuli are given
  delay_period = 500.  # time after the external simuli are removed
  tool = Tool(pre_stimulus_period, stimulus_period, delay_period)

  # stimulus
  mu0 = 40.
  coherence = 25.6
  IA_freqs = tool.generate_freqs(mu0 + mu0 / 100. * coherence)
  IB_freqs = tool.generate_freqs(mu0 - mu0 / 100. * coherence)

  # network
  net = DecisionMakingNet(scale)

  def run(i):
    bp.share.save(i=i, t=bm.get_dt() * i)
    net.IA.freqs.value = IA_freqs[i]
    net.IB.freqs.value = IB_freqs[i]
    net.update()
    if monitor:
      return {'A.spike': net.A.spike.value, 'B.spike': net.B.spike.value}

  @bm.jit
  def jit_run(indices):
    return bm.for_loop(run, indices)

  # first running
  n_step = int(tool.total_period / bm.get_dt())
  indices = np.arange(n_step)
#   net.reset_state()
  t0 = time.time()
  mon = jax.block_until_ready(jit_run(indices))
  t1 = time.time()

#   # second running
#   net.reset_state()
#   t2 = time.time()
#   mon = jax.block_until_ready(jit_run(indices))
#   t3 = time.time()

  # mon['ts'] = indices * bm.get_dt()
  # tool.visualize_results(mon, IA_freqs, IB_freqs)

  print(f'platform = {platform}, x64 = {x64}, scale = {scale}, '
        f'first run = {t1 - t0} s')

  # post
  bm.disable_x64()
  bm.clear_buffer_memory(platform)
  return {'num': net.num,
          'exe_time': t1 - t0,
          'run_time': t1 - t0,
          # 'fr': rate
          }

In [12]:
def benchmark(devices, platform='cpu', x64=True, scales=(1, 4, 8, 10, 20, 40, 60, 80, 100)):
    final_results = dict()
    for scale in scales:  # 
        res = {'exetime': [], 'runtime': []}
        for _ in range(2):
            with bm.sharding.device_mesh(devices, [bm.sharding.NEU_AXIS]):
                r = simulate_a_trial(scale=scale, platform=platform, x64=x64, monitor=True)
            res['exetime'].append(r['exe_time'])
            res['runtime'].append(r['run_time'])
        print(f'Scale = {scale}:')
        pprint(res)

# 8 devices

In [14]:
benchmark(
    jax.devices(), 
    platform='tpu', 
    x64=False, 
    scales=(1, 4, 8, 10, 20, 40, 60, 80, 100, 200, 400, 800, 1000, 2000, 4000, 8000, 10000), 
)

platform = tpu, x64 = False, scale = 1, first run = 8.095625162124634 s
platform = tpu, x64 = False, scale = 1, first run = 7.82029128074646 s
Scale = 1:
{'exetime': [8.095625162124634, 7.82029128074646],
 'runtime': [8.095625162124634, 7.82029128074646]}
platform = tpu, x64 = False, scale = 4, first run = 8.117016315460205 s
platform = tpu, x64 = False, scale = 4, first run = 8.010966062545776 s
Scale = 4:
{'exetime': [8.117016315460205, 8.010966062545776],
 'runtime': [8.117016315460205, 8.010966062545776]}
platform = tpu, x64 = False, scale = 8, first run = 8.053727149963379 s
platform = tpu, x64 = False, scale = 8, first run = 8.151676654815674 s
Scale = 8:
{'exetime': [8.053727149963379, 8.151676654815674],
 'runtime': [8.053727149963379, 8.151676654815674]}
platform = tpu, x64 = False, scale = 10, first run = 8.062843561172485 s
platform = tpu, x64 = False, scale = 10, first run = 8.424638509750366 s
Scale = 10:
{'exetime': [8.062843561172485, 8.424638509750366],
 'runtime': [8.0

XlaRuntimeError: RESOURCE_EXHAUSTED: Error loading program: Attempting to reserve 8.96G at the bottom of memory. That was not possible. There are 6.39G free, 0B reserved, and 6.39G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

In [15]:
benchmark(
    jax.devices()[:4], 
    platform='tpu', 
    x64=False, 
    scales=(1, 4, 8, 10, 20, 40, 60, 80, 100, 200, 400, 800, 1000, 2000, 4000, 8000), 
)

platform = tpu, x64 = False, scale = 1, first run = 7.646655559539795 s
platform = tpu, x64 = False, scale = 1, first run = 8.274417400360107 s
Scale = 1:
{'exetime': [7.646655559539795, 8.274417400360107],
 'runtime': [7.646655559539795, 8.274417400360107]}
platform = tpu, x64 = False, scale = 4, first run = 7.800139665603638 s
platform = tpu, x64 = False, scale = 4, first run = 7.973190546035767 s
Scale = 4:
{'exetime': [7.800139665603638, 7.973190546035767],
 'runtime': [7.800139665603638, 7.973190546035767]}
platform = tpu, x64 = False, scale = 8, first run = 7.961609363555908 s
platform = tpu, x64 = False, scale = 8, first run = 8.504684448242188 s
Scale = 8:
{'exetime': [7.961609363555908, 8.504684448242188],
 'runtime': [7.961609363555908, 8.504684448242188]}
platform = tpu, x64 = False, scale = 10, first run = 8.145019292831421 s
platform = tpu, x64 = False, scale = 10, first run = 8.31432294845581 s
Scale = 10:
{'exetime': [8.145019292831421, 8.31432294845581],
 'runtime': [8.

In [16]:
benchmark(
    jax.devices()[:2], 
    platform='tpu', 
    x64=False, 
    scales=(1, 4, 8, 10, 20, 40, 60, 80, 100, 200, 400, 800, 1000, 2000, 4000, 8000), 
)

platform = tpu, x64 = False, scale = 1, first run = 6.603653192520142 s
platform = tpu, x64 = False, scale = 1, first run = 6.030699014663696 s
Scale = 1:
{'exetime': [6.603653192520142, 6.030699014663696],
 'runtime': [6.603653192520142, 6.030699014663696]}
platform = tpu, x64 = False, scale = 4, first run = 6.159245014190674 s
platform = tpu, x64 = False, scale = 4, first run = 6.222467660903931 s
Scale = 4:
{'exetime': [6.159245014190674, 6.222467660903931],
 'runtime': [6.159245014190674, 6.222467660903931]}
platform = tpu, x64 = False, scale = 8, first run = 6.183906078338623 s
platform = tpu, x64 = False, scale = 8, first run = 6.308734178543091 s
Scale = 8:
{'exetime': [6.183906078338623, 6.308734178543091],
 'runtime': [6.183906078338623, 6.308734178543091]}
platform = tpu, x64 = False, scale = 10, first run = 6.830740451812744 s
platform = tpu, x64 = False, scale = 10, first run = 6.270754814147949 s
Scale = 10:
{'exetime': [6.830740451812744, 6.270754814147949],
 'runtime': [

XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 14.30G. That was not possible. There are 706.58M free.; (0x0x0_HBM0): while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

In [18]:
benchmark(
    jax.devices()[:1], 
    platform='tpu', 
    x64=False, 
    scales=(1, 4, 8, 10, 20, 40, 60, 80, 100, 200, 400, 800, 1000, 2000, 4000, 8000), 
)

platform = tpu, x64 = False, scale = 1, first run = 4.9059929847717285 s
platform = tpu, x64 = False, scale = 1, first run = 4.815359354019165 s
Scale = 1:
{'exetime': [4.9059929847717285, 4.815359354019165],
 'runtime': [4.9059929847717285, 4.815359354019165]}
platform = tpu, x64 = False, scale = 4, first run = 5.15300726890564 s
platform = tpu, x64 = False, scale = 4, first run = 5.128577709197998 s
Scale = 4:
{'exetime': [5.15300726890564, 5.128577709197998],
 'runtime': [5.15300726890564, 5.128577709197998]}
platform = tpu, x64 = False, scale = 8, first run = 5.963993787765503 s
platform = tpu, x64 = False, scale = 8, first run = 5.169392108917236 s
Scale = 8:
{'exetime': [5.963993787765503, 5.169392108917236],
 'runtime': [5.963993787765503, 5.169392108917236]}
platform = tpu, x64 = False, scale = 10, first run = 5.428404092788696 s
platform = tpu, x64 = False, scale = 10, first run = 5.571969270706177 s
Scale = 10:
{'exetime': [5.428404092788696, 5.571969270706177],
 'runtime': [

XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 14.30G. That was not possible. There are 706.58M free.; (0x0x0_HBM0)