## How to load datasets in JAX with TensorFlow
Click the image below to read the post online.

<a target="_blank" href="https://www.machinelearningnuggets.com/how-to-load-dataset-in-jax-with-tensorflow"><img src="https://digitalpress.fra1.cdn.digitaloceanspaces.com/mhujhsj/2022/07/logo.png" alt="Open in ML Nuggets"></a>

In [1]:
import os
# Obtain from https://www.kaggle.com/username/account
os.environ["KAGGLE_USERNAME"]="YOUR_KAGGLE_USERNAME"
os.environ["KAGGLE_KEY"]="YOUR_KAGGLE_KEY"

In [2]:
import kaggle


In [None]:
!kaggle datasets download lakshmi25npathi/imdb-dataset-of-50k-movie-reviews

In [4]:
import zipfile
with zipfile.ZipFile('imdb-dataset-of-50k-movie-reviews.zip', 'r') as zip_ref:
    zip_ref.extractall('imdb-dataset-of-50k-movie-reviews')

In [5]:
import numpy as np 
import pandas as pd 
from numpy import array
import tensorflow as tf
from sklearn.model_selection import train_test_split 
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt

In [6]:
df = pd.read_csv("imdb-dataset-of-50k-movie-reviews/IMDB Dataset.csv")


In [7]:
df.head()


Unnamed: 0,review,sentiment
0,One of the other reviewers has mentioned that ...,positive
1,A wonderful little production. <br /><br />The...,positive
2,I thought this was a wonderful way to spend ti...,positive
3,Basically there's a family where a little boy ...,negative
4,"Petter Mattei's ""Love in the Time of Money"" is...",positive


In [8]:
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
def remove_stop_words(review):
    review_minus_sw = []
    stop_words = stopwords.words('english')
    review = review.split()
    cleaned_review = [review_minus_sw.append(word) for word in review if word not in stop_words]            
    cleaned_review = ' '.join(review_minus_sw)
    return cleaned_review       

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [9]:
df['review'] = df['review'].apply(remove_stop_words)
labelencoder = LabelEncoder()
df = df.assign(sentiment = labelencoder.fit_transform(df["sentiment"]))

In [10]:
df.head()

Unnamed: 0,review,sentiment
0,One reviewers mentioned watching 1 Oz episode ...,1
1,A wonderful little production. <br /><br />The...,1
2,I thought wonderful way spend time hot summer ...,1
3,Basically there's family little boy (Jake) thi...,0
4,"Petter Mattei's ""Love Time Money"" visually stu...",1


In [11]:
from sklearn.model_selection import train_test_split
df = df.drop_duplicates()
docs = df['review']
labels = array(df['sentiment'])
X_train, X_test , y_train, y_test = train_test_split(docs, labels , test_size = 0.20, random_state=0)

In [12]:
import tensorflow as tf
max_features = 5000  # Maximum vocab size.
batch_size = 32
max_len = 512 # Sequence length to pad the outputs to.
vectorize_layer = tf.keras.layers.TextVectorization(standardize='lower_and_strip_punctuation',max_tokens=max_features,output_mode='int',output_sequence_length=max_len)
vectorize_layer.adapt(X_train,batch_size=None)

In [13]:
X_train_padded =  vectorize_layer(X_train)
X_test_padded =  vectorize_layer(X_test)

In [14]:
X_train_padded

<tf.Tensor: shape=(39664, 512), dtype=int64, numpy=
array([[   1,    1,   48, ...,    0,    0,    0],
       [   2,   46, 1269, ...,    0,    0,    0],
       [  84,  784,   38, ...,    0,    0,    0],
       ...,
       [   2,  310,  328, ...,    0,    0,    0],
       [   2,  695, 1704, ...,    0,    0,    0],
       [   1,  241, 3061, ...,    0,    0,    0]])>

In [15]:
X_test_padded

<tf.Tensor: shape=(9917, 512), dtype=int64, numpy=
array([[1781,    2,    1, ...,    0,    0,    0],
       [  10,    1,  860, ...,    0,    0,    0],
       [ 108, 2906,    1, ...,    0,    0,    0],
       ...,
       [ 169,  104, 1132, ...,    0,    0,    0],
       [  10,  159,    1, ...,    0,    0,    0],
       [ 273,  285,    6, ...,    0,    0,    0]])>

In [16]:
training_data = tf.data.Dataset.from_tensor_slices((X_train_padded, y_train))
validation_data = tf.data.Dataset.from_tensor_slices((X_test_padded, y_test))
training_data = training_data.batch(batch_size)
validation_data = validation_data.batch(batch_size)

In [17]:
import tensorflow_datasets as tfds
def get_train_batches():
  ds = training_data.prefetch(1)
  # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
  return tfds.as_numpy(ds)

In [18]:
for text, labels in training_data:
  print(text.shape, labels.shape)
  break

(32, 512) (32,)


In [19]:
import jax
import jax.numpy as jnp

In [20]:
# pip install flax

In [21]:
import flax
from flax import linen as nn  

class Model(nn.Module):

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=2)(x)
    x = nn.log_softmax(x)
    return x

In [22]:
import optax
import jax.numpy as jnp
def cross_entropy_loss(*, logits, labels):
  labels_onehot = jax.nn.one_hot(labels, num_classes=2)
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

In [23]:
def compute_metrics(*, logits, labels):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  }
  return metrics

In [24]:
from flax.training import train_state

def create_train_state(rng, learning_rate, momentum):
  """Creates initial `TrainState`."""
  model = Model()
  params = model.init(rng, X_train_padded[0])['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=model.apply, params=params, tx=tx)

In [25]:
def compute_loss(params,text,labels):
    logits = Model().apply({'params': params}, text)
    loss = cross_entropy_loss(logits=logits, labels=labels)
    return loss, logits
    
    
@jax.jit
def train_step(state,text, labels):
  """Train for a single step."""
  (_, logits), grads = jax.value_and_grad(compute_loss, has_aux=True)(state.params,text,labels)
  state = state.apply_gradients(grads=grads) 
  metrics = compute_metrics(logits=logits, labels=labels) 
  return state, metrics

In [26]:
@jax.jit
def eval_step(state, text, labels):
    logits = Model().apply({'params': state.params}, text)
    return compute_metrics(logits=logits, labels=labels)

In [27]:
def train_one_epoch(state):
    """Train for 1 epoch on the training set."""
    batch_metrics = []
    for text, labels in get_train_batches():
        state, metrics = train_step(state, text, labels)
        batch_metrics.append(metrics)

    batch_metrics_np = jax.device_get(batch_metrics)  
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]
    }

    return state, epoch_metrics_np

In [28]:
def evaluate_model(state, text, test_lbls):
    """Evaluate on the validation set."""
    metrics = eval_step(state, text, test_lbls)
    metrics = jax.device_get(metrics) 
    metrics = jax.tree_map(lambda x: x.item(), metrics)  
    return metrics

In [None]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

learning_rate = 0.1
momentum = 0.9
seed = 0 

state = create_train_state(init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.

num_epochs = 30
(text, test_labels) = next(iter(validation_data))
text = jnp.array(text)
test_labels = jnp.array(test_labels)

state = create_train_state(jax.random.PRNGKey(seed), learning_rate, momentum)
training_loss = []
training_accuracy = []
testing_loss = []
testing_accuracy = []
for epoch in range(1, num_epochs + 1):
    train_state, train_metrics = train_one_epoch(state)
    training_loss.append(train_metrics['loss'])
    training_accuracy.append(train_metrics['accuracy'])
    print(f"Train epoch: {epoch}, loss: {train_metrics['loss']}, accuracy: {train_metrics['accuracy'] * 100}")

    test_metrics = evaluate_model(train_state, text, test_labels)
    testing_loss.append(test_metrics['loss'])
    testing_accuracy.append(test_metrics['accuracy'])
    print(f"Test epoch: {epoch}, loss: {test_metrics['loss']}, accuracy: {test_metrics['accuracy'] * 100}")


In [None]:
plt.plot(training_accuracy, label="Training")
plt.plot(testing_accuracy, label="Test")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()


plt.plot(training_loss, label="Training")
plt.plot(testing_loss, label="Test")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

## Where to go from here
Follow us on [LinkedIn](https://www.linkedin.com/company/mlnuggets), [Twitter](https://twitter.com/ml_nuggets), [GitHub](https://github.com/mlnuggets) and subscribe to our [blog](https://www.machinelearningnuggets.com/#/portal) so that you don't miss a new issue.