- [Text Classification Using Flax (JAX) Networks](https://coderzcolumn.com/tutorials/artificial-intelligence/text-classification-using-flax-jax-networks)

In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

%load_ext autoreload
%autoreload 2

# Configure Training

In [2]:
CONFIG = {
    "b1":0.9, 
    "b2":0.999, 
    "dataset":"BBC-News", # BBC-News, sst2
    "eps":1e-6,
    "learning-rate":2e-5,
    "model":"bert-base-cased",
    "model-directory":f"/home/rflagg/model/BBC-News",
    "number-of-epochs":5,
    "number-of-labels":5,# BBC-News: 5, sst2: 2
    "train-in-parallel":True,
    "per-device-batch-size":4,
    "seed":0,
    "weight-decay":1e-2,
    "text-key":"text",  # text for BBC-News; sentence for sst2
    "label-key":"label",
}

# Load Data

In [3]:
from datasets import load_dataset

directory = f"/home/rflagg/data/{CONFIG['dataset']}"

if CONFIG['dataset'] == "BBC-News":   
    dataset = load_dataset(
        'csv', 
        data_files={
            'train': f"{directory}/train-df.csv", 
             'validation': f"{directory}/test-df.csv"
        }
    )
else:
    dataset = load_dataset(
        'csv', 
        data_files={
            'train': f"{directory}/train-df.csv", 
            'test': f"{directory}/test-df.csv", 
            'validation': f"{directory}/validation-df.csv"
        }
    )
dataset

Using custom data configuration default-07cfe4aff386be8d
Reusing dataset csv (/home/rflagg/.cache/huggingface/datasets/csv/default-07cfe4aff386be8d/0.0.0/51cce309a08df9c4d82ffd9363bbe090bf173197fc01a71b034e8594995a1a58)


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['category', 'text', 'label'],
        num_rows: 2002
    })
    validation: Dataset({
        features: ['category', 'text', 'label'],
        num_rows: 223
    })
})

# Preprocess the data

In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(CONFIG['model'])

def preprocess_function(data, text_key=CONFIG["text-key"], label_key=CONFIG["label-key"]):
    texts = (data[text_key],)
    processed = tokenizer(*texts, padding="max_length", max_length=128, truncation=True)
    processed["labels"] = data[label_key]
    return processed

dataset_tokenized = dataset.map(
    preprocess_function, batched=True, remove_columns=dataset["train"].column_names
)
dataset_tokenized

train_ds = dataset_tokenized["train"]
validation_ds = dataset_tokenized["validation"]
if CONFIG['dataset'] == "sst2": test_ds = dataset_tokenized["test"]



  0%|          | 0/3 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

# Define Metric

In [5]:
import datasets
from sklearn.metrics import precision_recall_fscore_support

#from datasets import load_metric
#metric = load_metric('glue', "sst2")

class F1EtcMetric(datasets.Metric):
    def _info(self):
        return datasets.MetricInfo(
            description="Calculates precision, recall, f1 score, and support.",
            citation="TODO: _CITATION",
            inputs_description="_KWARGS_DESCRIPTION",
            features=datasets.Features({
                'predictions': datasets.Value('int64'),
                'references': datasets.Value('int64'),
            }),
            codebase_urls=[],
            reference_urls=[],
            format='numpy'
        )

    def _compute(self, predictions, references):
        precision, recall, fscore, support = precision_recall_fscore_support(references, predictions, average='weighted')
        return {
            "precision":precision,
            "recall":recall,
            "f1":fscore
        }
    
metric = F1EtcMetric()

# Fine-tune the model

- [Huggingface Evaluate](https://huggingface.co/docs/evaluate/index)

In [6]:
from datasets import load_metric
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from itertools import chain
import jax
import jax.numpy as jnp
import optax
from tqdm.notebook import tqdm
from transformers import FlaxAutoModelForSequenceClassification, AutoConfig

In [7]:
config = AutoConfig.from_pretrained(CONFIG['model'], num_labels=CONFIG["number-of-labels"])
model = FlaxAutoModelForSequenceClassification.from_pretrained(CONFIG['model'], config=config, seed=CONFIG['seed'])

Some weights of the model checkpoint at bert-base-cased were not used when initializing FlaxBertForSequenceClassification: {('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'dense', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias')}
- This IS expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxBertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased 

In [8]:
total_batch_size = CONFIG['per-device-batch-size']
if CONFIG['train-in-parallel']: total_batch_size *= jax.local_device_count()
print("The overall batch size (both for training and eval) is", total_batch_size)

num_train_steps = len(train_ds) // total_batch_size * CONFIG['number-of-epochs']
learning_rate_function = optax.linear_schedule(init_value=CONFIG['learning-rate'], end_value=0, transition_steps=num_train_steps)

The overall batch size (both for training and eval) is 32


## Defining the training state

In [9]:
import flax
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from typing import Callable

In [10]:
class TrainState(train_state.TrainState):
    logits_function: Callable = flax.struct.field(pytree_node=False)
    loss_function: Callable = flax.struct.field(pytree_node=False)
        
def decay_mask_fn(params):
    flat_params = flax.traverse_util.flatten_dict(params)
    flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
    return flax.traverse_util.unflatten_dict(flat_mask)

def adamw(weight_decay):
    return optax.adamw(
        learning_rate=learning_rate_function, b1=CONFIG['b1'], b2=CONFIG['b2'], eps=CONFIG['eps'], weight_decay=weight_decay, mask=decay_mask_fn
    )

def loss_function(logits, labels):
    xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=CONFIG['number-of-labels']))
    return jnp.mean(xentropy)
     
def eval_function(logits): return logits.argmax(-1)

In [11]:
state = TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=adamw(weight_decay=CONFIG['weight-decay']),
    logits_function=eval_function,
    loss_function=loss_function,
)

## Defining the training and evaluation step

In [12]:
def train_step(state, batch, dropout_rng, train_in_parallel=CONFIG['train-in-parallel']):
    targets = batch.pop("labels")
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

    def loss_function(params):
        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
        loss = state.loss_function(logits, targets)
        return loss

    grad_function = jax.value_and_grad(loss_function)
    loss, grad = grad_function(state.params)
    if train_in_parallel:
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)
        metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_function(state.step)}, axis_name="batch")
    else:
        new_state = state.apply_gradients(grads=grad)
        metrics = {"loss": loss, "learning_rate": learning_rate_function(state.step)}
        
    return new_state, metrics, new_dropout_rng

In [13]:
def eval_step(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.logits_function(logits)

## Defining the data collators

In [14]:
def train_data_loader(rng, dataset, batch_size, train_in_parallel=CONFIG['train-in-parallel']):
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size))
    for perm in perms:
        batch = dataset[perm]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        if train_in_parallel: batch = shard(batch)
        yield batch

In [15]:
def eval_data_loader(dataset, batch_size, train_in_parallel=CONFIG['train-in-parallel']):
    for i in range(len(dataset) // batch_size):
        batch = dataset[i * batch_size : (i + 1) * batch_size]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        if train_in_parallel: batch = shard(batch)
        yield batch

## Training

In [None]:
jit_train_step = jax.jit(train_step, donate_argnums=(0,))
jit_eval_step = jax.jit(eval_step)

In [None]:
rng = jax.random.PRNGKey(CONFIG['seed'])
rng, dropout_rng = jax.random.split(rng)

In [None]:
rng, input_rng = jax.random.split(rng)

for batch in train_data_loader(input_rng, train_ds, total_batch_size):
    state, train_metrics, dropout_rngs = jit_train_step(state, batch, dropout_rng)
    break
for batch in eval_data_loader(validation_ds, total_batch_size):
    labels = batch.pop("labels")
    predictions = jit_eval_step(state, batch)
    break

In [None]:
for batch in eval_data_loader(validation_ds, total_batch_size):
    labels = batch.pop("labels")
    predictions = jit_eval_step(state, batch)
    break


In [None]:
with tqdm(total=len(validation_ds) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
    for batch in eval_data_loader(validation_ds, total_batch_size):
        labels = batch.pop("labels")
        predictions = jit_eval_step(state, batch)
        metric.add_batch(predictions=predictions, references=labels)
        progress_bar_eval.update(1)


In [None]:
metric.compute()

In [None]:
eval_metric = metric.compute()

loss = round(train_metrics['loss'].item(), 4)
eval_score = round(list(eval_metric.values())[0], 4)
metric_name = list(eval_metric.keys())[0]

print(f"{i+1}/{CONFIG['number-of-epochs']} | Train loss: {loss} | Eval {metric_name}: {eval_score}")

In [None]:
eval_metric = metric.compute()
loss = round(train_metrics['loss'].item(), 4)
eval_score = round(list(eval_metric.values())[0], 4)
metric_name = list(eval_metric.keys())[0]

print(f"{i+1}/{CONFIG['number-of-epochs']} | Train loss: {loss} | Eval {metric_name}: {eval_score}")

In [None]:
import flax
flax.__version__

In [None]:
%%time
for i, epoch in enumerate(tqdm(range(1, CONFIG['number-of-epochs'] + 1), desc=f"Epoch ...", position=0, leave=True)):
    rng, input_rng = jax.random.split(rng)

    # train
    with tqdm(total=len(train_ds) // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
        for batch in train_data_loader(input_rng, train_ds, total_batch_size):
            state, train_metrics, dropout_rng = jit_train_step(state, batch, dropout_rng)
            progress_bar_train.update(1)

    # evaluate
    with tqdm(total=len(validation_ds) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
        for batch in eval_data_loader(validation_ds, total_batch_size):
            labels = batch.pop("labels")
            predictions = jit_eval_step(state, batch)
            metric.add_batch(predictions=predictions, references=labels)
            progress_bar_eval.update(1)

    eval_metric = metric.compute()

    loss = round(train_metrics['loss'].item(), 4)
    eval_score = round(list(eval_metric.values())[2], 4)
    metric_name = list(eval_metric.keys())[2]

    print(f"{i+1}/{CONFIG['number-of-epochs']} | Train loss: {loss} | Eval {metric_name}: {100 * eval_score:02f}")

In [None]:
tokenizer.save_pretrained(CONFIG['model-directory'])
model.save_pretrained(CONFIG['model-directory'], state.params)

### Verification

In [None]:
tokenizer = AutoTokenizer.from_pretrained(CONFIG['model-directory'])
config = AutoConfig.from_pretrained(CONFIG['model-directory'], num_labels=CONFIG["number-of-labels"])
model = FlaxAutoModelForSequenceClassification.from_pretrained(CONFIG['model-directory'], config=config, seed=CONFIG['seed'])

state = TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=adamw(weight_decay=CONFIG['weight-decay']),
    logits_function=eval_function,
    loss_function=loss_function,
)

In [None]:
with tqdm(total=len(validation_ds) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
    for batch in eval_data_loader(validation_ds, total_batch_size):
        labels = batch.pop("labels")
        predictions = jit_eval_step(state, batch)
        metric.add_batch(predictions=predictions, references=labels)
        progress_bar_eval.update(1)

test_metric = metric.compute()

test_score = round(list(test_metric.values())[0], 4)
metric_name = list(test_metric.keys())[0]

print(f"Test {metric_name}: {100 * test_score:0.2f}%")

## Parallel Training

In [16]:
parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
parallel_eval_step = jax.pmap(eval_step, axis_name="batch")
state = flax.jax_utils.replicate(state)

In [17]:
rng = jax.random.PRNGKey(CONFIG['seed'])
dropout_rngs = jax.random.split(rng, jax.local_device_count())

In [18]:
rng, input_rng = jax.random.split(rng)

for batch in train_data_loader(input_rng, train_ds, total_batch_size):
    state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
    break
for batch in eval_data_loader(validation_ds, total_batch_size):
    labels = batch.pop("labels")
    predictions = parallel_eval_step(state, batch)
    metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
    eval_metric = metric.compute()
    print(eval_metric)
    break

{'precision': 0.19254032258064516, 'recall': 0.21875, 'f1': 0.11289414414414414}


  _warn_prf(average, modifier, msg_start, len(result))


In [21]:
%%time
for i, epoch in enumerate(tqdm(range(1, CONFIG['number-of-epochs'] + 1), desc=f"Epoch ...", position=0, leave=True)):
    rng, input_rng = jax.random.split(rng)

    # train
    with tqdm(total=len(train_ds) // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
        for batch in train_data_loader(input_rng, train_ds, total_batch_size):
            state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
            progress_bar_train.update(1)

    # evaluate
    with tqdm(total=len(validation_ds) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
        for batch in eval_data_loader(validation_ds, total_batch_size):
            labels = batch.pop("labels")
            predictions = parallel_eval_step(state, batch)
            metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
            progress_bar_eval.update(1)

    eval_metric = metric.compute()

    loss = round(flax.jax_utils.unreplicate(train_metrics)['loss'].item(), 4)
    eval_score = round(list(eval_metric.values())[2], 4)
    metric_name = list(eval_metric.keys())[2]
    #eval_score = round(list(eval_metric.values())[0], 4)
    #metric_name = list(eval_metric.keys())[0]

    print(f"{i+1}/{CONFIG['number-of-epochs']} | Train loss: {loss:0.4f} | Eval {metric_name}: {100 * eval_score:0.2f}")

Epoch ...:   0%|          | 0/5 [00:00<?, ?it/s]

Training...:   0%|          | 0/62 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/6 [00:00<?, ?it/s]

1/5 | Train loss: 0.0051 | Eval f1: 95.80


Training...:   0%|          | 0/62 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/6 [00:00<?, ?it/s]

2/5 | Train loss: 0.0034 | Eval f1: 95.80


Training...:   0%|          | 0/62 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/6 [00:00<?, ?it/s]

3/5 | Train loss: 0.0058 | Eval f1: 95.80


Training...:   0%|          | 0/62 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/6 [00:00<?, ?it/s]

4/5 | Train loss: 0.0038 | Eval f1: 95.80


Training...:   0%|          | 0/62 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/6 [00:00<?, ?it/s]

5/5 | Train loss: 0.0040 | Eval f1: 95.80
CPU times: user 50.9 s, sys: 8.07 s, total: 58.9 s
Wall time: 37.1 s


In [25]:
tokenizer.save_pretrained(CONFIG['model-directory'])
model.save_pretrained(CONFIG['model-directory'], flax.jax_utils.unreplicate(state.params))

### Verification

In [28]:
tokenizer = AutoTokenizer.from_pretrained(CONFIG['model-directory'])
config = AutoConfig.from_pretrained(CONFIG['model-directory'], num_labels=CONFIG["number-of-labels"])
model = FlaxAutoModelForSequenceClassification.from_pretrained(CONFIG['model-directory'], config=config, seed=CONFIG['seed'])

state = TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=adamw(weight_decay=CONFIG['weight-decay']),
    logits_function=eval_function,
    loss_function=loss_function,
)
state = flax.jax_utils.replicate(state)

In [29]:
with tqdm(total=len(validation_ds) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval:
    for batch in eval_data_loader(validation_ds, total_batch_size):
        labels = batch.pop("labels")
        predictions = parallel_eval_step(state, batch)
        metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
        progress_bar_eval.update(1)

test_metric = metric.compute()

test_score = round(list(test_metric.values())[0], 4)
metric_name = list(test_metric.keys())[0]

print(f"Test {metric_name}: {100 * test_score:0.2f}%")

Evaluating...:   0%|          | 0/6 [00:00<?, ?it/s]

Test precision: 95.90%
