[View in Colaboratory](https://colab.research.google.com/github/lukasheinrich/pyhf-benchmarks/blob/master/colab/TPU_standalone.ipynb)

# TPU pyhf interpolation

In [0]:
import os
import pprint
import tensorflow as tf

if 'COLAB_TPU_ADDR' not in os.environ:
  print('ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!')
else:
  tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR']
  print ('TPU address is', tpu_address)

  with tf.Session(tpu_address) as session:
    devices = session.list_devices()
    
  print('TPU devices:')
  pprint.pprint(devices)

If the cell above reports an error, make sure that you have enabled TPU support in the notebook settings. (Edit menu → Notebook settings)

Now, let's try a simple computation.

In [26]:
import numpy as np
import timeit

def setup(N,float_t):
    def _hfinterp_code1(histogramssets, alphasets):
        allset_allhisto_deltas_up = tf.divide(histogramssets[:,:,2], histogramssets[:,:,1])
        allset_allhisto_deltas_dn = tf.divide(histogramssets[:,:,0], histogramssets[:,:,1])

        def ones(shape):
            return tf.ones(shape, dtype = float_t)
        def zeros(shape):
            return tf.zeros(shape, dtype = float_t)
        def where(mask,a,b):
            mask = tf.cast(mask, float_t)
            inv_mask = tf.cast(1-mask, float_t)
            return mask * a + inv_mask * b
        
        allsets_allhistos_masks = where(alphasets > 0, ones(alphasets.shape), zeros(alphasets.shape))
        bases_up = tf.einsum('sa,shb->shab', ones(alphasets.shape), allset_allhisto_deltas_up)
        bases_dn = tf.einsum('sa,shb->shab', ones(alphasets.shape), allset_allhisto_deltas_dn)
        exponents = tf.einsum('sa,shb->shab', tf.abs(alphasets), ones(allset_allhisto_deltas_up.shape))
        masks = tf.einsum('sa,shb->shab', allsets_allhistos_masks, ones(allset_allhisto_deltas_up.shape))

        bases = where(masks, bases_up, bases_dn)
        return tf.pow(bases, exponents)

    
    a_shape = (100,1)
    h_shape = (100,100,3,N)

    a = tf.placeholder(float_t, a_shape)
    h = tf.placeholder(float_t, h_shape)
    ops = tf.contrib.tpu.rewrite(_hfinterp_code1, [h,a])

    return ops, [a,h], [a_shape,h_shape]

def run_it(ops,args,shapes):
    r = session.run(ops, {args[0]: np.random.uniform(-1,1, size = shapes[0]), args[1]: np.random.uniform(-1,1, size = shapes[1])})
    return r

results = []
for n in np.linspace(500,7000,14):
    n = int(n)
    session = tf.Session(tpu_address)
    try:
        print('Initializing...')
        session.run(tf.contrib.tpu.initialize_system())
        print('Running ops')
        exec_time = timeit.timeit('run_it(o,a,s)', number=10, setup="from __main__ import run_it, setup; import tensorflow as tf; o,a,s = setup({},tf.bfloat16)".format(n))
        results.append((n,exec_time))
        print('N: {} time: {}'.format(n,exec_time))
    finally:
        # For now, TPU sessions must be shutdown separately from
        # closing the session.
        session.run(tf.contrib.tpu.shutdown_system())
        session.close()

Initializing...
Running ops
N: 500 time: 20.807447327999398
Initializing...
Running ops
N: 1000 time: 17.89259686100013
Initializing...
Running ops
N: 1500 time: 48.874835416998394
Initializing...
Running ops
N: 2000 time: 25.663443012999778
Initializing...
Running ops
N: 2500 time: 73.5554243229999
Initializing...
Running ops
N: 3000 time: 47.30244636699899
Initializing...
Running ops
N: 3500 time: 97.17557937099991
Initializing...
Running ops
N: 4000 time: 43.82593866299976
Initializing...
Running ops
N: 4500 time: 121.1753972710012
Initializing...
Running ops
N: 5000 time: 74.95264433599914
Initializing...
Running ops
N: 5500 time: 145.3289784690005
Initializing...
Running ops
N: 6000 time: 70.05059990200061
Initializing...
Running ops
N: 6500 time: 168.16438966199894
Initializing...
Running ops
N: 7000 time: 103.03043713499937


In [27]:
results

[(500, 20.807447327999398),
 (1000, 17.89259686100013),
 (1500, 48.874835416998394),
 (2000, 25.663443012999778),
 (2500, 73.5554243229999),
 (3000, 47.30244636699899),
 (3500, 97.17557937099991),
 (4000, 43.82593866299976),
 (4500, 121.1753972710012),
 (5000, 74.95264433599914),
 (5500, 145.3289784690005),
 (6000, 70.05059990200061),
 (6500, 168.16438966199894),
 (7000, 103.03043713499937)]