- [Text Classification Using Flax (JAX) Networks](https://coderzcolumn.com/tutorials/artificial-intelligence/text-classification-using-flax-jax-networks)

In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

%load_ext autoreload
%autoreload 2

In [13]:
import optax


In [2]:
import numpy as np
from sklearn import datasets
#import gc

all_categories = ['alt.atheism','comp.graphics','comp.os.ms-windows.misc','comp.sys.ibm.pc.hardware',
                  'comp.sys.mac.hardware','comp.windows.x', 'misc.forsale','rec.autos','rec.motorcycles',
                  'rec.sport.baseball','rec.sport.hockey','sci.crypt','sci.electronics','sci.med',
                  'sci.space','soc.religion.christian','talk.politics.guns','talk.politics.mideast',
                  'talk.politics.misc','talk.religion.misc']

selected_categories = ['comp.sys.mac.hardware','comp.windows.x','rec.motorcycles','sci.crypt','talk.politics.mideast']

X_train_text, Y_train = datasets.fetch_20newsgroups(subset="train", categories=selected_categories, return_X_y=True)
X_test_text , Y_test  = datasets.fetch_20newsgroups(subset="test", categories=selected_categories, return_X_y=True)

X_train_text = np.array(X_train_text)
X_test_text = np.array(X_test_text)

classes = np.unique(Y_train)
mapping = dict(zip(classes, selected_categories))

len(X_train_text), len(X_test_text), classes, mapping


(2928,
 1950,
 array([0, 1, 2, 3, 4]),
 {0: 'comp.sys.mac.hardware',
  1: 'comp.windows.x',
  2: 'rec.motorcycles',
  3: 'sci.crypt',
  4: 'talk.politics.mideast'})

In [3]:
import sklearn
from jax import numpy as jnp
from sklearn.feature_extraction.text import CountVectorizer

vectorizer = CountVectorizer(max_features=50000)

vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)

X_train = jnp.array(X_train.toarray(), dtype=jnp.float16)
X_test  = jnp.array(X_test.toarray(), dtype=jnp.float16)

X_train.shape, X_test.shape



tcmalloc: large alloc 1296752640 bytes == 0x55aacea2e000 @  0x7fe596d12680 0x7fe596d322ec 0x7fe58d1ed4a6 0x7fe58d1ed576 0x7fe58d22fec5 0x7fe58d2cccee 0x7fe58d2cd50a 0x7fe58d2cd73e 0x55aa794bce14 0x55aa79480ac9 0x7fe58d21cc24 0x55aa794bce39 0x55aa79476caf 0x55aa7951080a 0x55aa794d1663 0x55aa794d2354 0x55aa7943a755 0x55aa794d1663 0x55aa7957e45c 0x55aa794d245b 0x55aa795b564e 0x55aa794dbc51 0x55aa79438ae6 0x55aa794c6f1b 0x55aa795161be 0x55aa794c6f1b 0x55aa795161be 0x55aa794c6f1b 0x55aa794dc674 0x55aa7943a72f 0x55aa794d2284


((2928, 50000), (1950, 50000))

In [6]:
import gc

gc.collect()


462

In [7]:
from flax import linen
from jax import random

class TextClassifier(linen.Module):
    def setup(self):
        self.linear1 = linen.Dense(features=128, name="DENSE1")
        self.linear2 = linen.Dense(features=64, name="DENSE2")
        self.linear3 = linen.Dense(len(classes), name="DENSE3")

    def __call__(self, inputs):
        x = linen.relu(self.linear1(inputs))
        x = linen.relu(self.linear2(x))
        logits = self.linear3(x)

        return logits #linen.softmax(x)


In [8]:
seed = jax.random.PRNGKey(0)

model = TextClassifier()
params = model.init(seed, X_train[:5])

for layer_params in params["params"].items():
    print("Layer Name : {}".format(layer_params[0]))
    weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
    print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))



Layer Name : DENSE1
	Layer Weights : (50000, 128), Biases : (128,)
Layer Name : DENSE2
	Layer Weights : (128, 64), Biases : (64,)
Layer Name : DENSE3
	Layer Weights : (64, 5), Biases : (5,)


In [9]:
preds = model.apply(params, X_train[:5])

preds.shape



(5, 5)

In [10]:
def CrossEntropyLoss(weights, input_data, actual):
    logits_preds = model.apply(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    return optax.softmax_cross_entropy(logits=logits_preds, labels=one_hot_actual).sum()



In [11]:
from jax import value_and_grad
from tqdm import tqdm
from sklearn.metrics import accuracy_score

def TrainModelInBatches(X, Y, X_val, Y_val, epochs, weights, optimizer_state, batch_size=32):
    for i in range(1, epochs+1):
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices

        losses = [] ## Record loss of each batch
        for batch in tqdm(batches):
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLoss)(weights, X_batch,Y_batch)

            ## Update Weights
            updates, optimizer_state = optimizer.update(gradients, optimizer_state)
            weights = optax.apply_updates(weights, updates)

            losses.append(loss) ## Record Loss

        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))

        Y_val_preds = model.apply(weights, X_val)
        val_acc = accuracy_score(Y_val, jnp.argmax(Y_val_preds, axis=1))
        print("Validation  Accuracy : {:.3f}".format(val_acc))

    return weights



In [14]:
seed = random.PRNGKey(0)
batch_size=256
epochs=8
learning_rate = jnp.array(1/1e3)

model = TextClassifier()
weights = model.init(seed, X_train[:5])

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(weights)

final_weights = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:06<00:00,  1.98it/s]


CrossEntropyLoss : 279.461
Validation  Accuracy : 0.949


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:01<00:00, 10.00it/s]


CrossEntropyLoss : 62.105
Validation  Accuracy : 0.957


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:01<00:00,  9.27it/s]


CrossEntropyLoss : 13.602
Validation  Accuracy : 0.969


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:01<00:00,  9.79it/s]


CrossEntropyLoss : 4.436
Validation  Accuracy : 0.966


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:01<00:00,  9.86it/s]


CrossEntropyLoss : 1.980
Validation  Accuracy : 0.967


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:01<00:00,  9.89it/s]


CrossEntropyLoss : 1.137
Validation  Accuracy : 0.967


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:01<00:00,  9.99it/s]


CrossEntropyLoss : 0.764
Validation  Accuracy : 0.967


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:01<00:00,  9.90it/s]

CrossEntropyLoss : 0.560
Validation  Accuracy : 0.966



