In [28]:
from numba import njit, prange
import numpy as np
import time

x = np.arange(100).reshape(10, 10)

@njit()
def go_fast_n(a): # Function is compiled and runs in machine code
    trace = 0.0
    for i in range(a.shape[0]):
        if i % 2 == 0:
            trace += np.tanh(a[i, i])
        else:
            trace += 1 / max(1, np.power(2, a[i, i]))
    return a + trace

def go_fast_i(a): # Function is compiled and runs in machine code
    trace = 0.0
    for i in range(a.shape[0]):
        if i % 2 == 0:
            trace += np.tanh(a[i, i])
        else:
            trace += 1 / max(1, np.power(2, a[i, i]))
    return a + trace

# trigger jit
print(go_fast_n(x))
print(go_fast_i(x))

[[  6.00048828   7.00048828   8.00048828   9.00048828  10.00048828
   11.00048828  12.00048828  13.00048828  14.00048828  15.00048828]
 [ 16.00048828  17.00048828  18.00048828  19.00048828  20.00048828
   21.00048828  22.00048828  23.00048828  24.00048828  25.00048828]
 [ 26.00048828  27.00048828  28.00048828  29.00048828  30.00048828
   31.00048828  32.00048828  33.00048828  34.00048828  35.00048828]
 [ 36.00048828  37.00048828  38.00048828  39.00048828  40.00048828
   41.00048828  42.00048828  43.00048828  44.00048828  45.00048828]
 [ 46.00048828  47.00048828  48.00048828  49.00048828  50.00048828
   51.00048828  52.00048828  53.00048828  54.00048828  55.00048828]
 [ 56.00048828  57.00048828  58.00048828  59.00048828  60.00048828
   61.00048828  62.00048828  63.00048828  64.00048828  65.00048828]
 [ 66.00048828  67.00048828  68.00048828  69.00048828  70.00048828
   71.00048828  72.00048828  73.00048828  74.00048828  75.00048828]
 [ 76.00048828  77.00048828  78.00048828  79.00048828  

In [29]:
%timeit go_fast_n(x)

1.98 µs ± 21.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [30]:
%timeit go_fast_i(x)

40.7 µs ± 3.65 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [84]:
def softmax_sample(counts, actions, temp):
    if temp == 0:
        i = np.argmax(counts)
        return counts[i], actions[i]
     
    p = 1/temp
    powers = counts**p
    total_sum = np.sum(powers)
    probabilities = powers / total_sum
    i = np.argmax(np.random.multinomial(1, probabilities, 1))
    return probabilities[i], actions[i]        

@njit()
def softmax_sample_c(counts, actions, temp):
    if temp == 0:
        i = np.argmax(counts)
        return counts[i], actions[i]
     
    p = 1/temp
    powers = counts**p
    total_sum = np.sum(powers)
    probabilities = powers / total_sum
    i = np.argmax(np.random.multinomial(1, probabilities, 1))
    return probabilities[i], actions[i]   

In [89]:
counts = np.arange(10)*100
actions = np.arange(10)

[softmax_sample_c(counts, actions, 1) for _ in range(10)]

[(0.17777777777777778, 8),
 (0.08888888888888889, 4),
 (0.044444444444444446, 2),
 (0.08888888888888889, 4),
 (0.1111111111111111, 5),
 (0.13333333333333333, 6),
 (0.17777777777777778, 8),
 (0.17777777777777778, 8),
 (0.1111111111111111, 5),
 (0.15555555555555556, 7)]

In [90]:
%timeit softmax_sample(counts, actions, 1)

48 µs ± 9.06 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [91]:
%timeit softmax_sample_c(counts, actions, 1)

1.58 µs ± 5.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [92]:
%timeit softmax_sample_c(counts, actions, 0.5)

12.9 µs ± 4.04 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [93]:
%timeit softmax_sample_c(counts, actions, 0)

856 ns ± 72.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
