# Conectar cuaderno a la GPU.

In [None]:
import numpy as np
import jax.numpy as jnp
import jax
import time
import numba
from numba import cuda
import math

In [None]:
# NumPy
print("NumPy:")
arr_np = np.random.rand(10_000_000)
np_time = %timeit -o np.square(arr_np)

# JAX
print("\nJAX:")
arr_jax = jnp.array(arr_np)
square_fn = jit(jnp.square)
jax_time = %timeit -o square_fn(arr_jax).block_until_ready()

# Numba
print("\nNumba:")

@numba_jit
def square_numba(arr):
    return arr * arr
numba_time = %timeit -o square_numba(arr_np)



print("\nNumPy Square Time:", np_time.best)
print("JAX Square Time:", jax_time.best)
print("Numba Square Time:", numba_time.best)

NumPy:
21 ms ± 173 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

JAX:
512 µs ± 3.59 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Numba:
57.1 ms ± 1.43 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

NumPy Square Time: 0.020709080799997538
JAX Square Time: 0.0005079865359998621
Numba Square Time: 0.055021133000082045


In [None]:
# NumPy
print("NumPy:")
mat_np = np.random.rand(1000, 1000)
mat_np_result = np.matmul(mat_np, mat_np)
np_time = %timeit -o np.matmul(mat_np, mat_np)

# JAX
print("\nJAX:")
mat_jax = jnp.array(mat_np)
mat_jax_result = jnp.matmul(mat_jax, mat_jax)
jax_time = %timeit -o jnp.matmul(mat_jax, mat_jax).block_until_ready()

# Numba
@cuda.jit
def matmul(a, b, out):
    x, y = cuda.grid(2)
    if x < out.shape[0] and y < out.shape[1]:
        out[x,y] = a[x,y]*b[y,x]

TPB = (16,16)
BPG = (math.ceil(mat_np.shape[0]/TPB[0]),
       math.ceil(mat_np.shape[0]/TPB[0]))

out = cuda.device_array((1000,1000))
print("\nNumba:")

numba_time = %timeit -o matmul[BPG, TPB](mat_np, mat_np, out)

print("\nNumPy Matrix Multiplication Time:", np_time.best)
print("JAX Matrix Multiplication Time:", jax_time.best)
print("Numba Matrix Multiplication Time:", numba_time.best)


NumPy:
57.6 ms ± 5.29 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

JAX:
600 µs ± 7.33 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Numba:




8.19 ms ± 350 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

NumPy Matrix Multiplication Time: 0.053225178299999246
JAX Matrix Multiplication Time: 0.0005902321309999934
Numba Matrix Multiplication Time: 0.008006213300000127


In [None]:
MAX = 255.0

@jax.jit
def stencil_iter(grid):
    grid = grid.at[1:-1, 1:-1, 1:-1].set((grid[1:-1, 1:-1, 1:-1] + grid[0:-2, 1:-1, 1:-1] +
                                            grid[2:, 1:-1, 1:-1] + grid[1:-1, 0:-2, 1:-1] +
                                            grid[1:-1, 2:, 1:-1] + grid[1:-1, 1:-1, 0:-2] +
                                            grid[1:-1, 1:-1, 2:]) / 7.0)
    return grid

def stencil(grid, max_iter):
    # Creating initial matrix
    # Setting first column and first row to maximum value
    grid = grid.at[:,0,:].add(MAX)
    grid = grid.at[:,-1,:].add(MAX)

    #jax.ops.index_update(grid, jax.ops.index[:, 0, :], grid[:, 0, :] + MAX)
    #grid = jax.ops.index_update(grid, jax.ops.index[:, -1, :], grid[:, -1, :] + MAX)

    # Applying the stencil computation to the whole multidimensional array
    for _ in range(max_iter):
        grid = stencil_iter(grid)

    # jax.ops.index_update(grid, jax.ops.index[1:-1, 1:-1, 1:-1],
    return grid


In [None]:
    print(jax.devices()[0])
    size = 250
    max_iter = 1000
    arr = jnp.zeros(shape=(size, size, size))
    gpu_arr = jax.device_put(arr)
    print(gpu_arr.device_buffer.device())
    a = time.perf_counter()
    stencil(gpu_arr, max_iter)
    b = time.perf_counter()
    print(b-a)

gpu:0
gpu:0
6.026596316999985
