In [1]:
import sys, os
import jax
import jax.numpy as jnp
import chex
from dog import *
import numpy as np
import random as rnd

### speed test all validation functions

In [2]:
env = env_reset(0, distance=jnp.int32(10),
                enable_circular_board=True,
                enable_jump_in_goal_area=True,
                enable_start_blocking=False)
env.pins = jnp.array([[38,40,41,0],
                      [6, 45, 44, -1]])
env.board = set_pins_on_board(env.board, env.pins)
env.current_player = 0

In [3]:
def test_swap():
    val_swap(env)

def test_normal_move():
    i = rnd.randint(1, 13)
    val_action_normal_move(env, i)

def test_seven_move():
    i = rnd.randint(0, len(DISTS_7_4)-1)
    val_action_7(env, DISTS_7_4[i])

In [9]:
%timeit -n 1000 test_swap()


20.7 µs ± 1.49 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
%timeit -n 1000 test_normal_move()


26.5 µs ± 1.56 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [11]:
%timeit -n 1000 test_seven_move()

127 µs ± 4.81 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [3]:
def test_swap_step():
    pin = rnd.randint(0,3)
    pos = rnd.randint(0, len(env.board)-1)
    step_swap(env, pin, pos)

def test_normal_move_step():
    i = rnd.randint(1, 13)
    pin = rnd.randint(0,3)
    step_normal_move(env, pin, i)

@jax.jit
def test_seven_move_step():
    i = rnd.randint(0, len(DISTS_7_4)-1)
    step_hot_7(env, DISTS_7_4[i])

def test_neg_move_step():
    pin = rnd.randint(0,3)
    step_neg_move(env, pin, -4)

In [5]:
%timeit -n 1000 test_swap_step()

31.4 µs ± 1.26 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [7]:
%timeit -n 1000 test_normal_move_step()

32.7 µs ± 793 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [9]:
%timeit -n 1000 test_neg_move_step()

31.8 µs ± 1.66 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [9]:
%timeit -n 100 test_seven_move_step()

136 ms ± 46.4 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [24]:
# Für echte Performance-Tests
key = jax.random.PRNGKey(0)

# Option 1: Pre-generierte Arrays mit JAX random indexing
pins = jnp.array([rnd.randint(0, 3) for _ in range(1000)])
moves = jnp.array([rnd.randint(1, 13) for _ in range(1000)])

def test_batch():
    # Teste ALLE auf einmal mit vmap
    return jax.vmap(lambda p, m: step_normal_move(env, p, m))(pins, moves)

%timeit -n 100 test_batch()  # Durchschnitt pro Call = total_time / 1000

The slowest run took 6.52 times longer than the fastest. This could mean that an intermediate result is being cached.
2.01 ms ± 2.14 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [25]:
# Fixe Inputs für reproduzierbare Messung
pin = jnp.array(2)
move = jnp.array(5)

@jax.jit
def test_single():
    return step_normal_move(env, pin, move)

# Erste Ausführung kompiliert
test_single()

# Jetzt messen
%timeit -n 10000 test_single()

6.83 µs ± 380 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
