In [1]:
!nvidia-smi

Tue Oct 28 21:15:40 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   38C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
import jax
import jax.numpy as jnp

print("JAX version:", jax.__version__)
print("Backend:", jax.default_backend())
print("Devices:", jax.devices())

JAX version: 0.7.2
Backend: gpu
Devices: [CudaDevice(id=0)]


In [5]:
from jax import grad, jit

x = jnp.linspace(-2.0, 2.0, 5)
tau = 0.5

def tempered_softmax(x, temperature):
    shifted = x - jnp.max(x)
    weights = jnp.exp(shifted / temperature)
    return weights / jnp.sum(weights)

softmax_val = tempered_softmax(x, tau)
dsoftmax_dtau = grad(lambda t: tempered_softmax(x, t).sum())(tau)
print("Softmax:", softmax_val)
print("d/dtau sum =", dsoftmax_dtau)

compiled_softmax = jit(tempered_softmax)
_ = compiled_softmax(x, tau)

Softmax: [2.9007587e-04 2.1433870e-03 1.5837606e-02 1.1702495e-01 8.6470395e-01]
d/dtau sum = -4.3092886e-08


In [6]:
def make_dataset(key, n_samples=512, scale=1.0):
    key_x, key_noise = jax.random.split(key)
    X = scale * jax.random.normal(key_x, (n_samples, 2))
    true_w = jnp.array([2.0, -1.0])
    true_b = -0.8
    logits = X @ true_w + true_b
    probs = jax.nn.sigmoid(logits)
    y = jax.random.bernoulli(key_noise, probs).astype(jnp.float32)
    return X, y

key = jax.random.PRNGKey(0)
X_train, y_train = make_dataset(key)
print("Shapes:", X_train.shape, y_train.shape)

Shapes: (512, 2) (512,)


In [7]:
import optax

def model(params, X):
    return X @ params["w"] + params["b"]

def loss_fn(params, X, y):
    logits = model(params, X)
    losses = optax.sigmoid_binary_cross_entropy(logits, y)
    return losses.mean()

params = {"w": jnp.zeros((2,)), "b": 0.0}
optimizer = optax.adam(learning_rate=3e-2)
opt_state = optimizer.init(params)

@jit
def update(params, opt_state, X, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, X, y)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    probs = jax.nn.sigmoid(model(params, X))
    accuracy = jnp.mean((probs > 0.5) == y)
    return params, opt_state, loss, accuracy

for step in range(1, 501):
    params, opt_state, loss, acc = update(params, opt_state, X_train, y_train)
    if step % 100 == 0:
        print(f"step {step:03d} loss={loss:.4f} acc={acc:.3f}")

step 100 loss=0.4161 acc=0.805
step 200 loss=0.4136 acc=0.805
step 300 loss=0.4136 acc=0.805
step 400 loss=0.4136 acc=0.805
step 500 loss=0.4136 acc=0.805


In [8]:
from functools import partial

def predict_single(params, x):
    return jax.nn.sigmoid(jnp.dot(x, params["w"]) + params["b"])

batched_predict = jax.vmap(partial(predict_single, params), in_axes=(0,))
probs = batched_predict(X_train)
roc_points = jnp.stack([y_train, probs], axis=1)
print("First five probability pairs:\n", roc_points[:5])

First five probability pairs:
 [[1.         0.9177734 ]
 [1.         0.32183945]
 [0.         0.04654315]
 [1.         0.88042223]
 [1.         0.992222  ]]


In [9]:
@jax.jit
def evaluate(params, X, y):
    logits = model(params, X)
    preds = (jax.nn.sigmoid(logits) > 0.5).astype(jnp.float32)
    accuracy = jnp.mean(preds == y)
    return accuracy

acc = evaluate(params, X_train, y_train)
jax.debug.print("Training accuracy: {acc:.3f}", acc=acc)

Training accuracy: 0.805
