<a href="https://colab.research.google.com/github/hublun/Bayesian_Aggregation_Average_Data/blob/master/Jax_2024_1_18.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Using the OpenAI Library to Programmatically Access GPT-3.5-turbo!

This notebook was authored by [DRC Lab](http://www.dulun.com/)

In [None]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Aug_15_22:02:13_PDT_2023
Cuda compilation tools, release 12.2, V12.2.140
Build cuda_12.2.r12.2/compiler.33191640_0


In [None]:
!nvidia-smi

Thu Jan 18 20:22:24 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   49C    P0              28W /  70W |  11449MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
import jax

In [None]:
jax.devices()

[cuda(id=0)]

In [1]:
from jax import grad
import jax.numpy as jnp

def tanh(x):  # Define a function
  y = jnp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh)  # Obtain its gradient function
print(grad_tanh(1.0))   # Evaluate it at x = 1.0

0.4199743




---



In [4]:
x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)
x
x.shape

(3,)

In [7]:
x.at[1].set(4)

Array([1., 4., 3.], dtype=float32)

This is because sometimes we might want to run code on a vector processor like a GPU or TPU that we can use JAX for, or we might prefer to run some code on a CPU in NumPy.

```
The other point to note is that JAX tensors have shape. This is usually a tuple, so (3,) means a three-dimensional

vector along the first axis. A matrix has two axes, and a tensor has three or more axes.
```



Now we come to places where JAX differs from NumPy. It is really important to pay attention to “[JAX—The Sharp Bits](https://oreil.ly/qqcFM)” to understand these differences. JAX’s philosophy is about speed and purity. By making functions pure (without side effects) and by making data immutable, JAX is able to make some guarantees to the underlying accelerated linear algebra (XLA) library that it uses to talk to GPUs. JAX guarantees that these functions applied to data can be run in parallel and have deterministic results without side effects, and thus XLA is able to compile these functions and make them run much faster than if they were run just on NumPy.

In [9]:
x = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=jnp.int32)
x.shape


(3, 3)



---

