# 00 - Environment Sanity Check

Verify that JAX, Flax, and our `tunix_hack` package are working correctly.

In [2]:
# Check JAX installation and GPU
import jax
import jax.numpy as jnp

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Default backend: {jax.default_backend()}")

JAX version: 0.8.1
Devices: [CudaDevice(id=0)]
Default backend: gpu


W1130 11:44:57.123799   21008 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1130 11:44:57.125975   20876 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.


In [3]:
# Quick GPU test
x = jnp.ones((1000, 1000))
y = jnp.dot(x, x)
print(f"Matrix multiplication result shape: {y.shape}")
print(f"Sum: {y.sum()}")

Matrix multiplication result shape: (1000, 1000)
Sum: 1000000000.0


In [4]:
# Check Flax
import flax
print(f"Flax version: {flax.__version__}")

Flax version: 0.12.1


In [5]:
# Check transformers
import transformers
print(f"Transformers version: {transformers.__version__}")

Transformers version: 4.57.3


In [None]:
# Check our tunix_hack package
from tunix_hack.utils.xml_parsing import extract_tag, has_valid_format
from tunix_hack.rewards import math_reward, creative_reward

# Test XML parsing
test_output = "<reasoning>Step 1: Add 2+2=4</reasoning><answer>4</answer>"
print(f"Reasoning: {extract_tag(test_output, 'reasoning')}")
print(f"Answer: {extract_tag(test_output, 'answer')}")
print(f"Valid format: {has_valid_format(test_output)}")

In [7]:
# Test math reward
output = "<reasoning>To find 2+2, I add the numbers together. 2+2=4.</reasoning><answer>4</answer>"
reward = math_reward(output, "4")
print(f"Math reward (correct): {reward}")

wrong_output = "<reasoning>Let me think...</reasoning><answer>5</answer>"
wrong_reward = math_reward(wrong_output, "4")
print(f"Math reward (wrong): {wrong_reward}")

Math reward (correct): 1.0
Math reward (wrong): 0.0


In [8]:
print("Environment sanity check passed!")

Environment sanity check passed!
