In [None]:
import os
os.environ['XLA_FLAGS'] = '--xla_dump_hlo_as_text --xla_dump_to=/home/karl/tmp/hlo'
import tensorflow as tf
tf.config.optimizer.set_jit(True)

In [None]:
from contextlib import contextmanager
from datetime import datetime
from timeit import timeit
import itertools

In [None]:
def prepare(dtype):
    cast = lambda x: tf.cast(tf.floor(x), dtype)
    a = cast(tf.random.uniform((1000, 1000)) * 100)
    b = cast(tf.random.uniform((1000,)))
    return a, b


def py_func(a, b):
    c = tf.zeros_like(a)
    for d in b:
        c += d * a
    return c


@tf.function()
def tf_func(a, b):
    return py_func(a, b)

In [None]:
@contextmanager
def benchmark(key, stats={}):
    @contextmanager
    def trial():
        start = datetime.now()
        yield stat
        end = datetime.now()
        duration = (end - start).total_seconds()
        stat.append(duration)
        print('  trial {}: {:.6f}'.format(len(stat), stat[-1]))

    stat = []
    print(f'{key}:')
    stats[key] = stat
    yield trial
    if not stat:
        print('  (no trials)')
    else:
        mean = sum(stat) / len(stat)
        deviation = (sum([(x - mean) ** 2 for x in stat]) / (len(stat) + 1)) ** 0.5
        print('  mean: {:.6f}'.format(mean))
        print('  dev: {:.6f}'.format(deviation))
        print('  min: {:.6f}'.format(min(stat)))
        print('  max: {:.6f}'.format(max(stat)))


def do_benchmark(func, dtype, iterations=30):
    a, b = prepare(dtype)
    key = f'{func.__name__} {dtype.name}'
    with benchmark(key, stats) as trial:
        for i in range(iterations):
            with trial() as stat:
                func(a, b)
            if stat[-1] > 3 and i > 5:
                break


stats = {}

In [None]:
dtypes = [
    tf.float32,
    tf.float64,
    tf.complex64,
    tf.complex128,
    tf.int8,
    tf.int16,
    tf.int32,
    tf.uint8,
]
funcs = [py_func, tf_func]
for dtype, func in itertools.product(dtypes, funcs):
    do_benchmark(func, dtype)

In [None]:
import json
with open('tf-dtype-benchmark-xla.json', 'w') as f:
    f.write(json.dumps(stats, indent=2, sort_keys=True))