## Acceleration with GPU


https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

In [1]:
import jax
import jax.numpy as jnp

import numpy as np
import time

import matplotlib.pyplot as plt

In [2]:
jax.devices()

[CudaDevice(id=0)]

In [3]:
!nvidia-smi

Sat Jul 19 11:41:35 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.169                Driver Version: 570.169        CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce GTX 1660 ...    Off |   00000000:01:00.0  On |                  N/A |
| N/A   53C    P0             22W /   60W |     123MiB /   6144MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

  pid, fd = os.forkpty()


## Define function (Numpy)

In [4]:
def myfunc_np(x, alpha=1.67, lmbda=1.05):
  
  return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

## Define function (JAX)

In [5]:
def myfunc_jnp(x, alpha=1.67, lmbda=1.05):
  
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

## Measure Time (Numpy)

In [6]:
x = np.random.normal(size=(5000000,)).astype(np.float32)

In [7]:
for k in range(10):

  time_bgn = time.time()

  myfunc_np(x)

  time_end = time.time()

  elapsed_time_ms = (time_end - time_bgn) * 1000

  print('[ms] : {:.4f}'.format(elapsed_time_ms))

[ms] : 46.0935
[ms] : 39.0298
[ms] : 38.7335
[ms] : 38.9185
[ms] : 38.5909
[ms] : 38.8465
[ms] : 39.0699
[ms] : 38.7752
[ms] : 38.7580
[ms] : 38.9898


## Measure Time (JAX)

In [8]:
myfunc_jit = jax.jit(myfunc_jnp) # JIT

In [9]:
x_dev = jax.device_put(x) # Transfer data

In [10]:
for k in range(10):

  time_bgn = time.time()

  myfunc_jit(x_dev).block_until_ready()

  time_end = time.time() 

  elapsed_time_ms = (time_end - time_bgn) * 1000

  print('[ms] : {:.4f}'.format(elapsed_time_ms))

[ms] : 67.3943
[ms] : 0.2489
[ms] : 0.2549
[ms] : 0.2534
[ms] : 0.2422
[ms] : 0.3369
[ms] : 0.2983
[ms] : 0.2255
[ms] : 0.3812
[ms] : 0.2563
