In [1]:
import sys, os
import jax
import jax.numpy as jnp
import chex
from dog import *
import numpy as np
import random as rnd

### speed test all validation functions

In [2]:
env = env_reset(0, 2, distance=jnp.int32(10),
                enable_circular_board=True,
                enable_jump_in_goal_area=True,
                enable_start_blocking=False)
pins = jnp.array([[38,25,10,0],
                      [6, 45, 44, -1]])
board = set_pins_on_board(env.board, pins)
hands = jnp.ones((2,14), dtype=jnp.int8)
env = env.replace(pins=pins, board=board, hands=hands)

In [3]:
def test_swap():
    val_swap(env)

def test_normal_move():
    i = rnd.randint(1, 13)
    val_action_normal_move(env, i)

def test_seven_move():
    i = rnd.randint(0, len(DISTS_7_4)-1)
    val_action_7(env, DISTS_7_4[i])

In [9]:
%timeit -n 1000 test_swap()


20.7 µs ± 1.49 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
%timeit -n 1000 test_normal_move()


26.5 µs ± 1.56 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [11]:
%timeit -n 1000 test_seven_move()

127 µs ± 4.81 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [20]:
def test_swap_step():
    pin = rnd.randint(0,3)
    pos = rnd.randint(0, len(env.board)-1)
    step_swap(env, pin, pos)

def test_normal_move_step():
    i = rnd.randint(1, 13)
    pin = rnd.randint(0,3)
    step_normal_move(env, pin, i)

def test_seven_move_step():
    i = rnd.randint(0, len(DISTS_7_4)-1)
    step_hot_7(env, DISTS_7_4[i])

def test_neg_move_step():
    pin = rnd.randint(0,3) 
    step_neg_move(env, pin, -4)

In [21]:
# Warm-up (Kompilierung)
test_swap_step() 

# Messung (nur Ausführung)
%timeit -n 1000 test_swap_step()

23.1 µs ± 611 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [22]:
# Warm-up (Kompilierung)
test_normal_move_step()

%timeit -n 1000 test_normal_move_step()

22.8 µs ± 1.38 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [23]:
# Warm-up (Kompilierung)
test_neg_move_step()

%timeit -n 1000 test_neg_move_step()

20.9 µs ± 493 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [24]:
# Warm-up (Kompilierung)
test_seven_move_step()

%timeit -n 100 test_seven_move_step()

118 µs ± 4.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [18]:
# Für echte Performance-Tests
key = jax.random.PRNGKey(0)

# Option 1: Pre-generierte Arrays mit JAX random indexing
pins = jnp.array([rnd.randint(0, 3) for _ in range(1000)])
moves = jnp.array([rnd.randint(1, 13) for _ in range(1000)])

def test_batch():
    # Teste ALLE auf einmal mit vmap
    return jax.vmap(lambda p, m: step_normal_move(env, p, m))(pins, moves)

%timeit -n 100 test_batch()  # Durchschnitt pro Call = total_time / 1000

927 µs ± 52.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [19]:
# Fixe Inputs für reproduzierbare Messung
pin = jnp.array(2)
move = jnp.array(5)

@jax.jit
def test_single():
    return step_normal_move(env, pin, move)

# Erste Ausführung kompiliert
test_single()

# Jetzt messen
%timeit -n 10000 test_single()

10.8 µs ± 188 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [12]:
# funktionen die die haupt funktion aufrufen
def test_swap_step():
    # pin = rnd.randint(0,3)
    # pos = rnd.randint(0, len(env.board)-1)
    # action = jnp.zeros(6, dtype=jnp.int32)
    # action = action.at[1].set(1)  # Swap action
    # action = action.at[pin+2].set(pos)
    # act_idx = map_move_to_action(env, action)
    act_idx = rnd.randint(364, 556)  # Precomputed swap action indices
    env_step(env, act_idx)

def test_normal_move_step():
    # i = rnd.randint(1, 13)
    # pin = rnd.randint(0,3)
    # action = jnp.zeros(6, dtype=jnp.int32)
    # action = action.at[1].set(1)  # Swap action
    # action = action.at[pin+2].set(i)
    # act_idx = map_move_to_action(env, action)
    act_idx = rnd.randint(676, 724)  # Precomputed normal move action indices
    env_step(env, act_idx)

def test_seven_move_step():
    # i = rnd.randint(0, len(DISTS_7_4)-1)
    # action = jnp.zeros(6, dtype=jnp.int32)
    # action = action.at[2:].set(DISTS_7_4[i])
    # act_idx = map_move_to_action(env, action)
    act_idx = rnd.randint(556, 676)  # Precomputed seven move action indices
    env_step(env, act_idx)

def test_neg_move_step():
    # pin = rnd.randint(0,3) 
    # action = jnp.zeros(6, dtype=jnp.int32)
    # action = action.at[1].set(1)  # Swap action
    # action = action.at[pin+2].set(-4)
    # act_idx = map_move_to_action(env, action)
    act_idx = rnd.randint(724, 728)  # Precomputed negative move action indices
    env_step(env, act_idx)

In [15]:
env_step(env, 0)  # Dummy call to compile env_step

(DOG(board=Array([ 0, -1, -1, -1, -1, -1,  1, -1, -1, -1,  0, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1,  0, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1,  0, -1, -1, -1, -1, -1,  1,  1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1], dtype=int8), current_player=Array(1, dtype=int8), pins=Array([[38, 25, 10,  0],
        [ 6, 45, 44, -1]], dtype=int32), reward=Array(-1, dtype=int8), done=Array(False, dtype=bool), deck=Array([6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8], dtype=int8), hands=Array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int8), num_players=Array(2, dtype=int8), start=Array([ 0, 10], dtype=int32), target=Array([39,  9], dtype=int32), goal=Array([[40, 41, 42, 43],
        [44, 45, 46, 47]], dtype=int32), swap_choices=Array([-1, -1, -1, -1], dtype=int8), round_starter=Array(0, dtype=int8), phase=Array(0, dtype=int8), board_size=40, total_board_size=56, rules={'enable_teams': False, 'e

In [13]:
map_move_to_action(env, jnp.zeros(6, dtype=jnp.int32))  # Dummy call to compile map_move_to_action

Array(739, dtype=int32)

In [16]:
# Messung (nur Ausführung)
%timeit -n 100 test_swap_step()

57.1 µs ± 3.59 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [17]:
%timeit -n 1000 test_normal_move_step()

54.8 µs ± 1.17 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [18]:
%timeit -n 1000 test_neg_move_step()

55.8 µs ± 1.47 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [19]:
%timeit -n 100 test_seven_move_step()

58.4 µs ± 4.76 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
