# Elegy(High-level API for deep learning in JAX & Flax)

Click the image below to read the post online.

<a target="_blank" href="https://www.machinelearningnuggets.com/elegy-flax-jax"><img src="https://digitalpress.fra1.cdn.digitaloceanspaces.com/mhujhsj/2022/07/logho-1.png" alt="Open in ML Nuggets"></a>

In [None]:
pip install -U elegy flax jax jaxlib

In [None]:
import os
# Obtain from https://www.kaggle.com/username/account
os.environ["KAGGLE_USERNAME"]="KAGGLE_USERNAME"
os.environ["KAGGLE_KEY"]="KAGGLE_KEY"

In [None]:
import kaggle

In [None]:
!kaggle datasets download lakshmi25npathi/imdb-dataset-of-50k-movie-reviews

In [None]:
import zipfile
with zipfile.ZipFile('imdb-dataset-of-50k-movie-reviews.zip', 'r') as zip_ref:
    zip_ref.extractall('imdb-dataset-of-50k-movie-reviews')

In [None]:
import numpy as np 
import pandas as pd 
from numpy import array
import tensorflow as tf
from sklearn.model_selection import train_test_split 
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt

In [None]:
df = pd.read_csv("imdb-dataset-of-50k-movie-reviews/IMDB Dataset.csv")

In [None]:
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
def remove_stop_words(review):
    review_minus_sw = []
    stop_words = stopwords.words('english')
    review = review.split()
    cleaned_review = [review_minus_sw.append(word) for word in review if word not in stop_words]            
    cleaned_review = ' '.join(review_minus_sw)
    return cleaned_review       

In [None]:
df['review'] = df['review'].apply(remove_stop_words)
labelencoder = LabelEncoder()
df = df.assign(sentiment = labelencoder.fit_transform(df["sentiment"]))

In [None]:
from sklearn.model_selection import train_test_split
df = df.drop_duplicates()
docs = df['review']
labels = array(df['sentiment'])
X_train, X_test , y_train, y_test = train_test_split(docs, labels , test_size = 0.20, random_state=0)

In [None]:
import tensorflow as tf
max_features = 10000  # Maximum vocab size.
batch_size = 128
max_len = 50 # Sequence length to pad the outputs to.
vectorize_layer = tf.keras.layers.TextVectorization(standardize='lower_and_strip_punctuation',max_tokens=max_features,output_mode='int',output_sequence_length=max_len)
vectorize_layer.adapt(X_train)

In [None]:
X_train_padded =  vectorize_layer(X_train)
X_test_padded =  vectorize_layer(X_test)

In [None]:
training_data = tf.data.Dataset.from_tensor_slices((X_train_padded, y_train))
validation_data = tf.data.Dataset.from_tensor_slices((X_test_padded, y_test))
training_data = training_data.batch(batch_size)
validation_data = validation_data.batch(batch_size)

In [None]:
import tensorflow_datasets as tfds
def get_train_batches():
  ds = training_data.prefetch(1)
  ds = ds.repeat(3)
  ds = ds.shuffle(3, reshuffle_each_iteration=True)
  # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
  return tfds.as_numpy(ds)

In [None]:
import jax

In [None]:
jax.devices()

In [None]:
from flax import linen as nn

class LSTMModel(nn.Module):
    def setup(self):
        self.embedding = nn.Embed(max_features, max_len)
        lstm_layer = nn.scan(nn.OptimizedLSTMCell,
                               variable_broadcast="params",
                               split_rngs={"params": False},
                               in_axes=1, 
                               out_axes=1,
                               length=max_len,
                               reverse=False)
        self.lstm1 = lstm_layer()
        self.dense1 = nn.Dense(256)
        self.lstm2 = lstm_layer()
        self.dense2 = nn.Dense(128)
        self.lstm3 = lstm_layer()
        self.dense3 = nn.Dense(64)
        self.dense4 = nn.Dense(2)
        
    @nn.remat    
    def __call__(self, x_batch):
        x = self.embedding(x_batch)
        
        carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=128)
        (carry, hidden), x = self.lstm1((carry, hidden), x)
        
        x = self.dense1(x)
        x = nn.relu(x)
        
        carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=64)
        (carry, hidden), x = self.lstm2((carry, hidden), x)
        
        x = self.dense2(x)
        x = nn.relu(x)
        
        carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=32)
        (carry, hidden), x = self.lstm3((carry, hidden), x)
        
       
        x = self.dense3(x)
        x = nn.relu(x)
        x = self.dense4(x[:, -1])
        return nn.log_softmax(x)

In [None]:
import jax.numpy as jnp
import jax
import elegy as eg

In [None]:
import optax

model = eg.Model(
    module=LSTMModel(),
    loss=[
        eg.losses.Crossentropy(),
        eg.regularizers.L2(l=1e-4),
    ],
    metrics=eg.metrics.Accuracy(),
    optimizer=optax.adam(1e-3),
)

In [None]:
model.summary(jnp.array(X_train_padded[:64]))

In [None]:
model = model.distributed()

In [None]:
callbacks = [ eg.callbacks.TensorBoard("logs"),
             eg.callbacks.ModelCheckpoint("models/high-level", save_best_only=True),
             eg.callbacks.EarlyStopping(monitor = 'val_loss',patience=5)
            ]
history = model.fit(
    training_data,
    epochs=100,
    validation_data=(validation_data),
    callbacks=callbacks,
)

In [None]:
model.evaluate(validation_data)

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir logs

In [None]:
import matplotlib.pyplot as plt

def plot_history(history):
    n_plots = len(history.history.keys()) // 2
    plt.figure(figsize=(14, 24))
    
    for i, key in enumerate(list(history.history.keys())[:n_plots]):
        metric = history.history[key]
        val_metric = history.history[f"val_{key}"]

        plt.subplot(n_plots, 1, i + 1)
        plt.plot(metric, label=f"Training {key}")
        plt.plot(val_metric, label=f"Validation {key}")
        plt.legend(loc="lower right")
        plt.ylabel(key)
        plt.title(f"Training and Validation {key}")
    plt.show()
    
plot_history(history)

In [None]:
(text, test_labels) = next(iter(validation_data))

y_pred = model.predict(jnp.array(text))

In [None]:
y_pred

In [None]:
# You can use can use `save` but `ModelCheckpoint already serialized the model
# model.save("model")

# current model reference
print("current model id:", id(model))

# load model from disk
model = eg.load("models/high-level")

# new model reference
print("new model id:    ", id(model))

# check that it works!
model.evaluate(validation_data)

## Where to go from here
Follow us on [LinkedIn](https://www.linkedin.com/company/mlnuggets), [Twitter](https://twitter.com/ml_nuggets), [GitHub](https://github.com/mlnuggets) and subscribe to our [blog](https://www.machinelearningnuggets.com/#/portal) so that you don't miss a new issue.