- [Text Classification Using Flax (JAX) Networks](https://coderzcolumn.com/tutorials/artificial-intelligence/text-classification-using-flax-jax-networks)

In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

%load_ext autoreload
%autoreload 2

# Configure Training

In [42]:
CONFIG = {
    "b1":0.9, 
    "b2":0.999, 
    "dataset":"sst2", # BBC-News, sst2
    "eps":1e-6,
    "learning-rate":2e-5,
    "model":"bert-base-cased",
    "model-directory":f"/home/rflagg/model/sst2",
    "number-of-epochs":3,
    "number-of-labels":2,
    "train-in-parallel":False,
    "per-device-batch-size":4,
    "seed":0,
    "weight-decay":1e-2,
    "text-key":"sentence",  # text for BBC-News; sentence for sst2
    "label-key":"label",
}

# Load Data

In [43]:
from datasets import load_dataset

directory = f"/home/rflagg/data/{CONFIG['dataset']}"

if CONFIG['dataset'] == "BBC-News":   
    dataset = load_dataset(
        'csv', 
        data_files={
            'train': f"{directory}/train-df.csv", 
             'validation': f"{directory}/test-df.csv"
        }
    )
else:
    dataset = load_dataset(
        'csv', 
        data_files={
            'train': f"{directory}/train-df.csv", 
            'test': f"{directory}/test-df.csv", 
            'validation': f"{directory}/validation-df.csv"
        }
    )
dataset



  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
})

# Preprocess the data

In [44]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(CONFIG['model'])

def preprocess_function(data, text_key=CONFIG["text-key"], label_key=CONFIG["label-key"]):
    texts = (data[text_key],)
    processed = tokenizer(*texts, padding="max_length", max_length=128, truncation=True)
    processed["labels"] = data[label_key]
    return processed

In [45]:
dataset_tokenized = dataset.map(
    preprocess_function, batched=True, remove_columns=dataset["train"].column_names
)
dataset_tokenized

  0%|          | 0/68 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 67349
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1821
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 872
    })
})

In [46]:
train_ds = dataset_tokenized["train"]
validation_ds = dataset_tokenized["validation"]
if CONFIG['dataset'] == "sst2": test_ds = dataset_tokenized["test"]

# Define Metric

In [8]:
from datasets import load_dataset, load_metric
metric = load_metric('glue', "sst2")

In [None]:
import datasets
from sklearn.metrics import precision_recall_fscore_support

class F1EtcMetric(datasets.Metric):
    def _info(self):
        return datasets.MetricInfo(
            description="Calculates precision, recall, f1 score, and support.",
            citation="TODO: _CITATION",
            inputs_description="_KWARGS_DESCRIPTION",
            features=datasets.Features({
                'predictions': datasets.Value('int64'),
                'references': datasets.Value('int64'),
            }),
            codebase_urls=[],
            reference_urls=[],
            format='numpy'
        )

    def _compute(self, predictions, references):
        precision, recall, fscore, support = precision_recall_fscore_support(references, predictions, average='weighted')
        return {
            "precision":precision,
            "recall":recall,
            "f1":fscore
        }
    
metric = F1EtcMetric()

# Fine-tune the model

- [Huggingface Evaluate](https://huggingface.co/docs/evaluate/index)

In [9]:
from datasets import load_metric
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from itertools import chain
import jax
import jax.numpy as jnp
import optax
from tqdm.notebook import tqdm
from transformers import FlaxAutoModelForSequenceClassification, AutoConfig

In [10]:
config = AutoConfig.from_pretrained(CONFIG['model'], num_labels=CONFIG["number-of-labels"])
model = FlaxAutoModelForSequenceClassification.from_pretrained(CONFIG['model'], config=config, seed=CONFIG['seed'])

Some weights of the model checkpoint at bert-base-cased were not used when initializing FlaxBertForSequenceClassification: {('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'bias')}
- This IS expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxBertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased 

In [11]:
total_batch_size = CONFIG['per-device-batch-size']
if CONFIG['train-in-parallel']: total_batch_size *= jax.local_device_count()
print("The overall batch size (both for training and eval) is", total_batch_size)

The overall batch size (both for training and eval) is 4


In [12]:
num_train_steps = len(train_ds) // total_batch_size * CONFIG['number-of-epochs']
learning_rate_function = optax.linear_schedule(init_value=CONFIG['learning-rate'], end_value=0, transition_steps=num_train_steps)

## Defining the training state

In [13]:
import flax
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from typing import Callable

In [14]:
class TrainState(train_state.TrainState):
    logits_function: Callable = flax.struct.field(pytree_node=False)
    loss_function: Callable = flax.struct.field(pytree_node=False)

In [15]:
def decay_mask_fn(params):
    flat_params = flax.traverse_util.flatten_dict(params)
    flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
    return flax.traverse_util.unflatten_dict(flat_mask)

In [16]:
def adamw(weight_decay):
    return optax.adamw(
        learning_rate=learning_rate_function, b1=CONFIG['b1'], b2=CONFIG['b2'], eps=CONFIG['eps'], weight_decay=weight_decay, mask=decay_mask_fn
    )

In [17]:
def loss_function(logits, labels):
    xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=CONFIG['number-of-labels']))
    return jnp.mean(xentropy)
     
def eval_function(logits): return logits.argmax(-1)

In [18]:
state = TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=adamw(weight_decay=CONFIG['weight-decay']),
    logits_function=eval_function,
    loss_function=loss_function,
)

## Defining the training and evaluation step

In [19]:
def train_step(state, batch, dropout_rng, train_in_parallel=CONFIG['train-in-parallel']):
    targets = batch.pop("labels")
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

    def loss_function(params):
        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
        loss = state.loss_function(logits, targets)
        return loss

    grad_function = jax.value_and_grad(loss_function)
    loss, grad = grad_function(state.params)
    if train_in_parallel:
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)
        metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_function(state.step)}, axis_name="batch")
    else:
        new_state = state.apply_gradients(grads=grad)
        metrics = {"loss": loss, "learning_rate": learning_rate_function(state.step)}
        
    return new_state, metrics, new_dropout_rng

In [20]:
def eval_step(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.logits_function(logits)

## Defining the data collators

In [21]:
def train_data_loader(rng, dataset, batch_size, train_in_parallel=CONFIG['train-in-parallel']):
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size))
    for perm in perms:
        batch = dataset[perm]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        if train_in_parallel: batch = shard(batch)
        yield batch

In [22]:
def eval_data_loader(dataset, batch_size, train_in_parallel=CONFIG['train-in-parallel']):
    for i in range(len(dataset) // batch_size):
        batch = dataset[i * batch_size : (i + 1) * batch_size]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        if train_in_parallel: batch = shard(batch)
        yield batch

## Training

In [23]:
jit_train_step = jax.jit(train_step, donate_argnums=(0,))
jit_eval_step = jax.jit(eval_step)

In [24]:
rng = jax.random.PRNGKey(CONFIG['seed'])
rng, dropout_rng = jax.random.split(rng)

In [25]:
rng, input_rng = jax.random.split(rng)

for batch in train_data_loader(input_rng, train_ds, total_batch_size):
    state, train_metrics, dropout_rngs = jit_train_step(state, batch, dropout_rng)
    break
for batch in eval_data_loader(validation_ds, total_batch_size):
    labels = batch.pop("labels")
    predictions = jit_eval_step(state, batch)
    break

In [None]:
for batch in eval_data_loader(validation_ds, total_batch_size):
    labels = batch.pop("labels")
    predictions = jit_eval_step(state, batch)
    break


In [None]:
with tqdm(total=len(validation_ds) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
    for batch in eval_data_loader(validation_ds, total_batch_size):
        labels = batch.pop("labels")
        predictions = jit_eval_step(state, batch)
        metric.add_batch(predictions=predictions, references=labels)
        progress_bar_eval.update(1)


In [None]:
metric.add_batch(predictions=predictions, references=labels)

In [None]:
eval_metric = metric.compute()

loss = round(train_metrics['loss'].item(), 4)
eval_score = round(list(eval_metric.values())[0], 4)
metric_name = list(eval_metric.keys())[0]

print(f"{i+1}/{CONFIG['number-of-epochs']} | Train loss: {loss} | Eval {metric_name}: {eval_score}")

In [None]:
eval_metric = metric.compute()
loss = round(train_metrics['loss'].item(), 4)
eval_score = round(list(eval_metric.values())[0], 4)
metric_name = list(eval_metric.keys())[0]

print(f"{i+1}/{CONFIG['number-of-epochs']} | Train loss: {loss} | Eval {metric_name}: {eval_score}")

In [29]:
import flax
flax.__version__

'0.5.2'

In [26]:
%%time
for i, epoch in enumerate(tqdm(range(1, CONFIG['number-of-epochs'] + 1), desc=f"Epoch ...", position=0, leave=True)):
    rng, input_rng = jax.random.split(rng)

    # train
    with tqdm(total=len(train_ds) // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
        for batch in train_data_loader(input_rng, train_ds, total_batch_size):
            state, train_metrics, dropout_rng = jit_train_step(state, batch, dropout_rng)
            progress_bar_train.update(1)

    # evaluate
    with tqdm(total=len(validation_ds) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
        for batch in eval_data_loader(validation_ds, total_batch_size):
            labels = batch.pop("labels")
            predictions = jit_eval_step(state, batch)
            metric.add_batch(predictions=predictions, references=labels)
            progress_bar_eval.update(1)

    eval_metric = metric.compute()

    loss = round(train_metrics['loss'].item(), 4)
    eval_score = round(list(eval_metric.values())[0], 4)
    metric_name = list(eval_metric.keys())[0]

    print(f"{i+1}/{CONFIG['number-of-epochs']} | Train loss: {loss} | Eval {metric_name}: {eval_score}")

Epoch ...:   0%|          | 0/3 [00:00<?, ?it/s]

Training...:   0%|          | 0/16837 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/218 [00:00<?, ?it/s]

1/3 | Train loss: 0.0077 | Eval accuracy: 0.9209


Training...:   0%|          | 0/16837 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/218 [00:00<?, ?it/s]

2/3 | Train loss: 0.0325 | Eval accuracy: 0.9163


Training...:   0%|          | 0/16837 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/218 [00:00<?, ?it/s]

3/3 | Train loss: 0.0009 | Eval accuracy: 0.9232
CPU times: user 14min 18s, sys: 2min 47s, total: 17min 5s
Wall time: 26min 26s


## Verification

In [33]:
tokenizer.save_pretrained(CONFIG['model-directory'])
model.save_pretrained(CONFIG['model-directory'], state.params)

('/home/rflagg/model/sst2/tokenizer_config.json',
 '/home/rflagg/model/sst2/special_tokens_map.json',
 '/home/rflagg/model/sst2/vocab.txt',
 '/home/rflagg/model/sst2/added_tokens.json',
 '/home/rflagg/model/sst2/tokenizer.json')

In [51]:
# evaluate on test data
metric = load_metric('glue', "sst2")
with tqdm(total=len(validation_ds) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
    for batch in eval_data_loader(validation_ds, total_batch_size):
        labels = batch.pop("labels")
        predictions = jit_eval_step(state, batch)
        metric.add_batch(predictions=predictions, references=labels)
        break

Evaluating...:   0%|          | 0/218 [00:00<?, ?it/s]

In [54]:
tokenizer = AutoTokenizer.from_pretrained(CONFIG['model-directory'])
config = AutoConfig.from_pretrained(CONFIG['model-directory'], num_labels=CONFIG["number-of-labels"])
model = FlaxAutoModelForSequenceClassification.from_pretrained(CONFIG['model-directory'], config=config, seed=CONFIG['seed'])

In [57]:
state = TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=adamw(weight_decay=CONFIG['weight-decay']),
    logits_function=eval_function,
    loss_function=loss_function,
)

In [58]:
# evaluate on test data
metric = load_metric('glue', "sst2")
with tqdm(total=len(validation_ds) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
    for batch in eval_data_loader(validation_ds, total_batch_size):
        labels = batch.pop("labels")
        predictions = jit_eval_step(state, batch)
        metric.add_batch(predictions=predictions, references=labels)
        progress_bar_eval.update(1)

test_metric = metric.compute()

test_score = round(list(test_metric.values())[0], 4)
metric_name = list(eval_metric.keys())[0]

print(f"Test {metric_name}: {test_score}")

Evaluating...:   0%|          | 0/218 [00:00<?, ?it/s]

Test accuracy: 0.9232


In [37]:
test_dataloader = DataLoader(test_set, batch_size=2)

from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# loaded_model.to(device)

predictions_total = []
labels_total = []
probabilities_total=[]
probability_total=[]
threshold=0.93
with torch.no_grad():
    loaded_model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    loaded_model.to(device)
    for batch in tqdm(test_dataloader):
        # get the inputs;
        labels = batch["labels"]
        del batch["labels"]

        # move everything to the GPU
        for k,v in batch.items(): batch[k] = batch[k].to(device)

        # forward pass
        outputs = loaded_model(**batch)
        logits = outputs.logits
        probabilities = logits.softmax(dim=-1).detach().cpu().numpy()
        probability= [max(x) for x in probabilities]
        predictions = logits.argmax(-1).tolist()
        predictions= [y if x>threshold else 0 for x,y in zip(probability, predictions)]
        probability_total.extend(probability)
        probabilities_total.extend(probabilities.tolist())
        predictions_total.extend(predictions)
        labels_total.extend(labels.tolist())

## Parallel Training

In [None]:
parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
parallel_eval_step = jax.pmap(eval_step, axis_name="batch")
state = flax.jax_utils.replicate(state)

In [None]:
rng = jax.random.PRNGKey(CONFIG['seed'])
dropout_rngs = jax.random.split(rng, jax.local_device_count())

In [None]:
rng, input_rng = jax.random.split(rng)

for batch in train_data_loader(input_rng, train_ds, total_batch_size):
    state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
    break
for batch in eval_data_loader(validation_ds, total_batch_size):
    labels = batch.pop("labels")
    predictions = parallel_eval_step(state, batch)
    metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
    eval_metric = metric.compute()
    print(eval_metric)
    break

In [None]:
%%time
for i, epoch in enumerate(tqdm(range(1, CONFIG['number-of-epochs'] + 1), desc=f"Epoch ...", position=0, leave=True)):
    rng, input_rng = jax.random.split(rng)

    # train
    with tqdm(total=len(train_ds) // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
        for batch in train_data_loader(input_rng, train_ds, total_batch_size):
            state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            progress_bar_train.update(1)

    # evaluate
    with tqdm(total=len(validation_ds) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
        for batch in eval_data_loader(validation_ds, total_batch_size):
            labels = batch.pop("labels")
            predictions = parallel_eval_step(state, batch)
            metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
            progress_bar_eval.update(1)

    eval_metric = metric.compute()

    loss = round(flax.jax_utils.unreplicate(train_metrics)['loss'].item(), 4)
    #eval_score = round(list(eval_metric.values())[2], 4)
    #metric_name = list(eval_metric.keys())[2]
    eval_score = round(list(eval_metric.values())[0], 4)
    metric_name = list(eval_metric.keys())[0]

    print(f"{i+1}/{CONFIG['number-of-epochs']} | Train loss: {loss} | Eval {metric_name}: {eval_score}")

In [None]:
import numpy as np
from sklearn import datasets
#import gc

all_categories = ['alt.atheism','comp.graphics','comp.os.ms-windows.misc','comp.sys.ibm.pc.hardware',
                  'comp.sys.mac.hardware','comp.windows.x', 'misc.forsale','rec.autos','rec.motorcycles',
                  'rec.sport.baseball','rec.sport.hockey','sci.crypt','sci.electronics','sci.med',
                  'sci.space','soc.religion.christian','talk.politics.guns','talk.politics.mideast',
                  'talk.politics.misc','talk.religion.misc']

selected_categories = ['comp.sys.mac.hardware','comp.windows.x','rec.motorcycles','sci.crypt','talk.politics.mideast']

X_train_text, Y_train = datasets.fetch_20newsgroups(subset="train", categories=selected_categories, return_X_y=True)
X_test_text , Y_test  = datasets.fetch_20newsgroups(subset="test", categories=selected_categories, return_X_y=True)

X_train_text = np.array(X_train_text)
X_test_text = np.array(X_test_text)

classes = np.unique(Y_train)
mapping = dict(zip(classes, selected_categories))

len(X_train_text), len(X_test_text), classes, mapping


In [None]:
import sklearn
from jax import numpy as jnp
from sklearn.feature_extraction.text import CountVectorizer

vectorizer = CountVectorizer(max_features=50000)

vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)

X_train = jnp.array(X_train.toarray(), dtype=jnp.float16)
X_test  = jnp.array(X_test.toarray(), dtype=jnp.float16)

X_train.shape, X_test.shape



In [None]:
import gc

gc.collect()


In [None]:
from flax import linen
from jax import random

class TextClassifier(linen.Module):
    def setup(self):
        self.linear1 = linen.Dense(features=128, name="DENSE1")
        self.linear2 = linen.Dense(features=64, name="DENSE2")
        self.linear3 = linen.Dense(len(classes), name="DENSE3")

    def __call__(self, inputs):
        x = linen.relu(self.linear1(inputs))
        x = linen.relu(self.linear2(x))
        logits = self.linear3(x)

        return logits #linen.softmax(x)


In [None]:
seed = jax.random.PRNGKey(0)

model = TextClassifier()
params = model.init(seed, X_train[:5])

for layer_params in params["params"].items():
    print("Layer Name : {}".format(layer_params[0]))
    weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
    print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))



In [None]:
preds = model.apply(params, X_train[:5])

preds.shape



In [None]:
def CrossEntropyLoss(weights, input_data, actual):
    logits_preds = model.apply(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    return optax.softmax_cross_entropy(logits=logits_preds, labels=one_hot_actual).sum()



In [None]:
from jax import value_and_grad
from tqdm import tqdm
from sklearn.metrics import accuracy_score

def TrainModelInBatches(X, Y, X_val, Y_val, epochs, weights, optimizer_state, batch_size=32):
    for i in range(1, epochs+1):
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices

        losses = [] ## Record loss of each batch
        for batch in tqdm(batches):
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLoss)(weights, X_batch,Y_batch)

            ## Update Weights
            updates, optimizer_state = optimizer.update(gradients, optimizer_state)
            weights = optax.apply_updates(weights, updates)

            losses.append(loss) ## Record Loss

        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))

        Y_val_preds = model.apply(weights, X_val)
        val_acc = accuracy_score(Y_val, jnp.argmax(Y_val_preds, axis=1))
        print("Validation  Accuracy : {:.3f}".format(val_acc))

    return weights



In [None]:
seed = random.PRNGKey(0)
batch_size=256
epochs=8
learning_rate = jnp.array(1/1e3)

model = TextClassifier()
weights = model.init(seed, X_train[:5])

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(weights)

final_weights = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
