<a href="https://colab.research.google.com/github/hunsii/jax-practice/blob/main/p2_jax_with_tpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Load library




```
from jax.experimental import optimizers
```
에 문제가 생기면 
```
from jax.example_libraries import optimizers
```
로 변경하여 사용하자.

In [1]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
# from jax.experimental import optimizers
from jax.example_libraries import optimizers
from jax.nn import relu, softmax, log_softmax
from jax.config import config


from tensorflow.keras.datasets import mnist

# Set up JAX to use TPU

위의 코드에서는 MNIST 데이터셋을 분류하기 위해 2개의 은닉층을 가지는 간단한 신경망을 정의합니다. JAX 최적화 알고리즘인 Adam을 사용하여 신경망을 학습시킵니다. jax.device_put_sharded 함수를 사용하여 학습 데이터를 TPU로 이동시킵니다. 마지막으로, 여러 epoch에 대해 학습 데이터를 반복하고, 학습 및 테스트 데이터셋에서 신경망의 정확도를 계산합니다.

위의 코드를 실행하기 위해서는 <TPU_IP_ADDRESS> 부분을 TPU 인스턴스의 IP 주소로 교체해야 합니다. 또한, batch_size 및 num_epochs 매개변수를 사용자의 상황에 맞게 수정해야 할 수도 있습니다.





In [None]:
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://<TPU_IP_ADDRESS>:8470"

# Set up TPU device
tpu_device = jax.devices()[0]

# Load MNIST dataset

In [2]:
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 784) / 255.0
test_images = test_images.reshape(test_images.shape[0], 784) / 255.0
train_labels = jax.nn.one_hot(train_labels, 10)
test_labels = jax.nn.one_hot(test_labels, 10)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz




# Move training data to TPU

In [None]:
train_images = jax.device_put_sharded(tpu_device, train_images)
train_labels = jax.device_put_sharded(tpu_device, train_labels)

# Define neural network architecture

In [3]:
def net(params, x):
  w1, b1, w2, b2 = params
  hidden = jnp.dot(x, w1) + b1
  hidden = relu(hidden)
  logits = jnp.dot(hidden, w2) + b2
  return logits

# Define loss function

In [None]:
def loss(params, x, y):
    logits = net(params, x)
    return -jnp.mean(jnp.sum(y * log_softmax(logits), axis=1))

# Define accuracy metric

In [None]:
def accuracy(params, images, labels):
    preds = net(params, images)
    return jnp.mean(jnp.argmax(preds, axis=1) == jnp.argmax(labels, axis=1))

# Initialize network parameters

In [None]:
key = random.PRNGKey(0)
input_shape = (-1, 784)
hidden_shape = 256
output_shape = 10
w1 = random.normal(key, (784, hidden_shape))
b1 = jnp.zeros(hidden_shape)
w2 = random.normal(key, (hidden_shape, output_shape))
b2 = jnp.zeros(output_shape)
params = (w1, b1, w2, b2)

# Define optimizer

In [None]:
step_size = 0.001
opt_init, opt_update, get_params = optimizers.adam(step_size)
opt_state = opt_init(params)

# Define training step

In [None]:
@jit
def update(params, x, y, opt_state):
    grads = grad(loss)(params, x, y)
    return opt_update(0, grads, opt_state), get_params(opt_state)

# Train the network

In [None]:
num_epochs = 10
batch_size = 128
num_batches = train_images.shape[0] // batch_size

In [4]:
for epoch in range(num_epochs):
    for batch in range(num_batches):
        start_idx = batch * batch_size
        end_idx = (batch + 1) * batch_size
        batch_images = train_images[start_idx:end_idx]
        batch_labels = train_labels[start_idx:end_idx]
        opt_state, params = update(params, batch_images, batch_labels, opt_state)
    train_acc = accuracy(params, train_images, train_labels)
    test_acc = accuracy(params, test_images, test_labels)
    print("Epoch {}: train acc = {:.3f}, test acc = {:.3f}".format(epoch+1, train_acc, test_acc))

Epoch 1: train acc = 0.811, test acc = 0.809
Epoch 2: train acc = 0.853, test acc = 0.847
Epoch 3: train acc = 0.874, test acc = 0.867
Epoch 4: train acc = 0.888, test acc = 0.879
Epoch 5: train acc = 0.897, test acc = 0.888
Epoch 6: train acc = 0.905, test acc = 0.893
Epoch 7: train acc = 0.912, test acc = 0.898
Epoch 8: train acc = 0.917, test acc = 0.902
Epoch 9: train acc = 0.922, test acc = 0.906
Epoch 10: train acc = 0.927, test acc = 0.911
minjae > sihun = true 
 minjae < sihun = false 
 how typing korean??? i don't know.... my notebook's operation system 
