### Setup

In [6]:
import torch
import jax

torch.cuda.is_available()


# JAX setup
JAX_SEED=42
print('jax device count:', jax.device_count())  # total number of accelerator devices in the cluster
print('jax local device count: ', jax.local_device_count())  # number of accelerator devices attached to this host

print(jax.devices())

jax device count: 1
jax local device count:  1
[CudaDevice(id=0)]


### Simple Speed Comparison: BERT vs JAX

In [2]:
from transformers import BertTokenizer, BertModel, FlaxBertModel
import jax
from jax import grad, jit
import jax.numpy as np
np.set_printoptions(linewidth=240)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
jax_model = FlaxBertModel.from_pretrained('bert-base-uncased')
pt_model = BertModel.from_pretrained('bert-base-uncased')

  from .autonotebook import tqdm as notebook_tqdm
Some weights of FlaxBertModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: {('pooler', 'dense', 'bias'), ('pooler', 'dense', 'kernel')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
def pt_forward():
    inputs = tokenizer("You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.", return_tensors="pt")
    outputs = pt_model(**inputs)
    return outputs.last_hidden_state

pt_forward()

tensor([[[-0.1914, -0.5446, -0.1544,  ..., -0.5098,  0.1434,  0.4844],
         [-0.4116, -0.5880, -0.7113,  ...,  0.3556,  0.7013, -0.3451],
         [ 0.1175, -0.4358, -0.1766,  ...,  0.0241,  0.0956, -0.1309],
         ...,
         [ 0.0185, -0.2742, -0.5359,  ..., -0.6504, -0.0117, -0.2525],
         [ 0.7980, -0.0621, -0.4312,  ...,  0.0909, -0.4904, -0.2666],
         [ 0.6253, -0.1266, -0.2116,  ...,  0.5051, -0.5687, -0.3665]]],
       grad_fn=<NativeLayerNormBackward0>)

In [4]:
%timeit pt_forward()

40.5 ms ± 765 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
def jax_forward():
    inputs = tokenizer("You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.", return_tensors='jax')
    outputs = jit(jax_model)(**inputs)
    return outputs.last_hidden_state

jax_forward()

Array([[[-0.19064319, -0.54382896, -0.15333818, ..., -0.5101217 ,  0.14285733,  0.48461798],
        [-0.41225404, -0.5879406 , -0.7097802 , ...,  0.35409293,  0.70039594, -0.34506193],
        [ 0.11897378, -0.43416384, -0.17663035, ...,  0.02296186,  0.09597686, -0.13082281],
        ...,
        [ 0.01874121, -0.27286923, -0.53406894, ..., -0.6514603 , -0.01144049, -0.2528037 ],
        [ 0.7958135 , -0.06071458, -0.4304184 , ...,  0.08984526, -0.48966262, -0.2661015 ],
        [ 0.62361646, -0.12573534, -0.21107566, ...,  0.5039953 , -0.5679369 , -0.36577556]]], dtype=float32)

In [6]:
%timeit jax_forward().block_until_ready()

4.12 ms ± 143 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Langid Data Load

In [42]:
from datasets import load_dataset
from difflib import get_close_matches

DATAPATH = '../data/language_detection.csv'

DATA_PERCENT_LIMIT = 100
TEST_SPLIT = 0.1
SEED = 42


split = f'train[:{DATA_PERCENT_LIMIT}%]' if DATA_PERCENT_LIMIT else 'train'
dataset = load_dataset("csv", split=split, data_files=DATAPATH, encoding='utf-8').shuffle(seed=SEED)
dataset = dataset.train_test_split(test_size=TEST_SPLIT, seed=SEED)

N_LABELS = len(set(dataset['train']['Language']))

print(dataset)
print(dataset['train'][:10]['Language'])

LANG2ID = {
    'English': 0,
    'Malayalam': 1,
    'Hindi': 2,
    'Tamil': 3,
    'Kannada': 4,
    'French': 5,
    'Spanish': 6,
    'Portuguese': 7,
    'Italian': 8,
    'Russian': 9,
    'Sweedish': 10,
    'Dutch': 11,
    'Arabic': 12,
    'Turkish': 13,
    'German': 14,
    'Danish': 15,
    'Greek': 16
    }

def lang_to_id(lang):
      return LANG2ID[get_close_matches(lang, LANG2ID.keys())[0]]


DatasetDict({
    train: Dataset({
        features: ['Text', 'Language'],
        num_rows: 9303
    })
    test: Dataset({
        features: ['Text', 'Language'],
        num_rows: 1034
    })
})
['English', 'German', 'Dutch', 'Tamil', 'Greek', 'Greek', 'French', 'Spanish', 'Russian', 'Malayalam']


### Tokenize

In [43]:
# BERT tokenize
from transformers import BertTokenizer

# tokenize
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased")

def tokenize_function(examples):
    batch = tokenizer(examples['Text'], padding="max_length", truncation=True)
    batch['labels'] = [lang_to_id(lang) for lang in examples['Language']]
    return batch

tokenized_datasets = dataset.map(tokenize_function, batched=True)
print(tokenized_datasets)

Map:   0%|          | 0/9303 [00:00<?, ? examples/s]

Map:   0%|          | 0/1034 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['Text', 'Language', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 9303
    })
    test: Dataset({
        features: ['Text', 'Language', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1034
    })
})


### Evaluation Settings

In [132]:
TRAIN_BATCH_SIZE = 4
EVAL_BATCH_SIZE = 1

### PT BERT Train

In [133]:
import os
import gc
import sys
from time import gmtime, strftime
import numpy as np

from transformers import BertForSequenceClassification, TrainingArguments, Trainer
import evaluate

TRAIN_STEPS_LIMIT = -1
N_EPOCHS = 1

OUTPUT_PATH = '../models/pt'

# Free memory
gc.collect()

# load pre-trained
pt_model = BertForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=N_LABELS)

# fine-tune
log_dir = os.path.join(OUTPUT_PATH, strftime("%Y%m%d-%H%M", gmtime()))
try:
    os.system(f'mkdir {log_dir}')
except:
    print('log dir exists, aborting')
    sys.exit(1)

training_args = TrainingArguments(output_dir=log_dir,
                                  label_names=['labels'],
                                  num_train_epochs=N_EPOCHS,
                                  max_steps = TRAIN_STEPS_LIMIT, #overrides num_train_epochs
                                  per_device_train_batch_size=TRAIN_BATCH_SIZE,
                                  per_device_eval_batch_size=EVAL_BATCH_SIZE,
                                  eval_strategy="steps",
                                  eval_steps=500)

metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

pt_trainer = Trainer(
    model=pt_model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['test'],
    compute_metrics=compute_metrics,
)

pt_train_start_time = time.time()
pt_trainer.train()
pt_train_end_time = time.time()
print("PT train time: %ss" % (pt_train_end_time - pt_train_start_time))    # PT train time: 9713.031286001205s

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Accuracy
200,No log,0.362411,0.87911
400,No log,0.324692,0.901354
600,0.663100,0.293299,0.90619
800,0.663100,0.26809,0.90619
1000,0.298300,0.217692,0.913926
1200,0.298300,0.28007,0.90619
1400,0.298300,0.255958,0.907157
1600,0.276100,0.275717,0.911025
1800,0.276100,0.296333,0.908124
2000,0.237700,0.225104,0.917795


PT train time: 9713.031286001205s


In [135]:
# PT BERT evaluate

CHECKPOINT = '../../models/langid/pt/20250318-0945/checkpoint-2326'

# load saved checkpoint
pt_checkpoint = BertForSequenceClassification.from_pretrained(CHECKPOINT, num_labels=N_LABELS)

pt_eval_trainer = Trainer(
    model=pt_checkpoint,
    args=training_args,
    eval_dataset=tokenized_datasets['test'],
    compute_metrics=compute_metrics,
)

pt_eval_start_time = time.time()
pt_eval_trainer.evaluate()
pt_eval_end_time = time.time()
print(f"PT eval time: {(pt_eval_end_time - pt_eval_start_time) / len(tokenized_datasets['test']) * 1000:0.2f}ms per iteration "
      f"({(pt_eval_end_time - pt_eval_start_time):0.2f}s / {len(tokenized_datasets['test'])} data points)") # PT eval time: 378.33ms per iteration (391.20s / 1034 data points)

PT eval time: 378.33ms per iteration (391.20s / 1034 data points)


### JAX BERT Train

In [91]:
# JAX BERT train
import os
import gc
import sys
from time import gmtime, strftime

import flax
import jax
import optax

from itertools import chain
from tqdm.notebook import tqdm
from typing import Callable

import jax.numpy as jnp

from flax import traverse_util
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.training import train_state

N_EPOCHS = 1
LEARNING_RATE = 2e-5

# Free memory
gc.collect()

# setup
num_train_steps = len(dataset['train']) // TRAIN_BATCH_SIZE * N_EPOCHS
learning_rate_function = optax.linear_schedule(init_value=LEARNING_RATE, end_value=0, transition_steps=num_train_steps)

class TrainState(train_state.TrainState):
    logits_function: Callable = flax.struct.field(pytree_node=False)
    loss_function: Callable = flax.struct.field(pytree_node=False)

def decay_mask_fn(params):
    flat_params = traverse_util.flatten_dict(params)
    flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
    return traverse_util.unflatten_dict(flat_mask)

def adamw(weight_decay):
    return optax.adamw(learning_rate=learning_rate_function, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay, mask=decay_mask_fn)

def loss_function(logits, labels):
  xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=N_LABELS))
  return jnp.mean(xentropy)
     
def eval_function(logits):
    return logits.argmax(-1)

In [92]:
from transformers import FlaxBertForSequenceClassification, BertConfig

# load pre-trained
config = BertConfig.from_pretrained('google-bert/bert-base-cased', num_labels=N_LABELS)
jax_model = FlaxBertForSequenceClassification.from_pretrained('google-bert/bert-base-cased', config=config, seed=JAX_SEED)

Some weights of the model checkpoint at google-bert/bert-base-cased were not used when initializing FlaxBertForSequenceClassification: {('cls', 'predictions', 'transform', 'dense', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'LayerNorm', 'scale')}
- This IS expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxBertForSequenceClassification were not initialized from the model checkpoint at goog

In [93]:
OUTPUT_PATH = '../models/jax'
log_dir = os.path.join(OUTPUT_PATH, strftime("%Y%m%d-%H%M", gmtime()))
try:
    os.system(f'mkdir {log_dir}')
except:
    print('log dir exists')

state = TrainState.create(
    apply_fn=jax_model.__call__,
    params=jax_model.params,
    tx=adamw(weight_decay=0.01),
    logits_function=eval_function,
    loss_function=loss_function,
)

In [94]:
def train_step(state, batch, dropout_rng):
    targets = batch.pop("labels")
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

    def loss_function(params):
        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
        loss = state.loss_function(logits, targets)
        return loss

    grad_function = jax.value_and_grad(loss_function)
    loss, grad = grad_function(state.params)
    grad = jax.lax.pmean(grad, "batch")
    new_state = state.apply_gradients(grads=grad)
    metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_function(state.step)}, axis_name="batch")
    return new_state, metrics, new_dropout_rng

parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))

In [95]:
def eval_step(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.logits_function(logits)

parallel_eval_step = jax.pmap(eval_step, axis_name="batch")

In [96]:
def train_data_loader(rng, dataset, batch_size):
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size))
    for perm in perms:
        batch = dataset[perm]
        del batch['Text']
        del batch['Language']
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)
        yield batch

def eval_data_loader(dataset, batch_size):
    for i in range(len(dataset) // batch_size):
        batch = dataset[i * batch_size : (i + 1) * batch_size]
        del batch['Text']
        del batch['Language']        
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)
        yield batch

In [97]:
state = flax.jax_utils.replicate(state)
num_labels = flax.jax_utils.replicate(N_LABELS)

In [98]:
rng = jax.random.PRNGKey(JAX_SEED)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

#for i, epoch in enumerate(tqdm(range(1, N_EPOCHS + 1), desc=f"Epoch ...", position=0, leave=True)):
rng, input_rng = jax.random.split(rng)

In [99]:
import time

# train
jax_train_start_time = time.time()
with tqdm(total=len(tokenized_datasets['train']) // TRAIN_BATCH_SIZE, desc="Training...", leave=False) as progress_bar_train:
  for batch in train_data_loader(input_rng, tokenized_datasets['train'], TRAIN_BATCH_SIZE):
    state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
jax_train_end_time = time.time()
print("Jax train time: %ss" % (jax_train_end_time - jax_train_start_time)) #Jax train time: 411.633828163147s

Training...:   0%|          | 0/2325 [00:00<?, ?it/s]

Jax train time: 417.2018005847931s


In [127]:
import itertools
import time
import numpy as np

from transformers import FlaxBertForSequenceClassification, BertConfig
import evaluate


CHECKPOINT = '../../langid/models/jax/20250318-0945/checkpoint-2000'

# load saved checkpoint
config = BertConfig.from_pretrained(CHECKPOINT, num_labels=N_LABELS)
jax_checkpoint = FlaxBertForSequenceClassification.from_pretrained(CHECKPOINT, config=config, seed=JAX_SEED)

state = TrainState.create(
    apply_fn=jax_checkpoint.__call__,
    params=jax_checkpoint.params,
    tx=adamw(weight_decay=0.01),
    logits_function=eval_function,
    loss_function=loss_function,
)

jax_acc_metric = evaluate.load("accuracy")

# evaluate
jax_eval_start_time = time.time()
with tqdm(total=len(tokenized_datasets['test'])/ EVAL_BATCH_SIZE, desc="Evaluating...", leave=False) as progress_bar_eval:
  for batch in eval_data_loader(tokenized_datasets['test'], EVAL_BATCH_SIZE):
      labels = batch.pop("labels")   
      predictions = parallel_eval_step(state, batch)    
      jax_acc_metric.add_batch(predictions=list(itertools.chain.from_iterable(predictions)), references=list(itertools.chain.from_iterable(labels)))
      progress_bar_eval.update(1)
jax_eval_end_time = time.time()
print(f"Jax eval time: {(jax_eval_end_time - jax_eval_start_time) / len(tokenized_datasets['test']) * 1000:0.2f}ms per iteration "
      f"({(jax_eval_end_time - jax_eval_start_time):0.2f}s / {len(tokenized_datasets['test'])} data points)")

jax_eval_acc_metric = jax_acc_metric.compute()

loss = round(flax.jax_utils.unreplicate(train_metrics)['loss'].item(), 3)
jax_eval_acc_score = round(list(eval_metric.values())[0], 3)
metric_name = list(eval_metric.keys())[0]

print(f"Eval {metric_name}: {jax_eval_acc_score}")

Evaluating...:   0%|          | 0/1034.0 [00:00<?, ?it/s]

Jax eval time: 19.69ms per iteration (20.36s / 1034 data points)
Eval accuracy: 0.918


### Fasttext train

In [140]:
import evaluate
import fasttext
from difflib import get_close_matches
from huggingface_hub import hf_hub_download

'''
LANG2ID = {
    '__label__eng': 0,
    'Malayalam': 1,
    'Hindi': 2,
    'Tamil': 3,
    'Kannada': 4,
    'French': 5,
    'Spanish': 6,
    'Portuguese': 7,
    'Italian': 8,
    'Russian': 9,
    'Sweedish': 10,
    'Dutch': 11,
    'Arabic': 12,
    'Turkish': 13,
    'German': 14,
    'Danish': 15,
    'Greek': 16
    }

def lang_to_id(lang):
      return LANG2ID[get_close_matches(lang, LANG2ID.keys())[0]]


'__label__eng_Latn'
af als am an ar arz as ast av az azb ba bar bcl be bg bh bn bo bpy br bs bxr ca cbk ce ceb ckb co cs cv cy da de diq dsb dty dv el eml en eo es et eu fa fi fr frr fy ga gd gl gn gom gu gv he hi hif hr hsb ht hu hy ia id ie ilo io is it ja jbo jv ka kk km kn ko krc ku kv kw ky la lb lez li lmo lo lrc lt lv mai mg mhr min mk ml mn mr mrj ms mt mwl my myv mzn nah nap nds ne new nl nn no oc or os pa pam pfl pl pms pnb ps pt qu rm ro ru rue sa sah sc scn sco sd sh si sk sl so sq sr su sv sw ta te tg th tk tl tr tt tyv ug uk ur uz vec vep vi vls vo wa war wuu xal xmf yi yo yue zh
'''

fasttext_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", 
                                      filename="model.bin", 
                                      cache_dir="../../models/langid/fasttext/cached")
fasttext_model = fasttext.load_model(fasttext_model_path)

fasttest_acc_metric = evaluate.load("accuracy")

fasttext_eval_start_time = time.time()
for data in tqdm(tokenized_datasets['test']):
    input = data['Text']
    prediction = fasttext_model.predict(input)
    fasttext_acc_metric.add_batch(predictions=[prediction], references=[lang_to_id(data['Language'])])
fasttext_eval_end_time = time.time()

fasttext_eval_acc_metric = fasttext_acc_metric.compute()
fasttext_eval_acc_score = round(list(fasttext_metric.values())[0], 3)

print(f"fasttext eval {metric_name}: {fasttext_eval_acc_score}")

model.bin:   0%|          | 0.00/1.18G [00:00<?, ?B/s]

  0%|          | 0/1034 [00:00<?, ?it/s]

ValueError: Unable to avoid copy while creating an array as requested.
If using `np.array(obj, copy=False)` replace it with `np.asarray(obj)` to allow a copy when needed (no behavior change in NumPy 1.x).
For more details, see https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword.

### Langid Run

In [None]:
# DO NOT RUN IN PARALLEL -- BATCHES OF 1!!

### Comparisons

import pandas as pd
import seaborn as sns

stats_list = {'pt': [pt_train_end_time - pt_train_start_time,
                     (pt_eval_end_time - pt_eval_start_time) / len(tokenized_datasets['test'],
                     pt_eval_acc_score,
                     pt_eval_roc_auc_score],
              'jax': [jax_train_end_time - jax_train_start_time,
                      (jax_eval_end_time - jax_eval_start_time) / len(tokenized_datasets['test'],
                      jax_eval_acc_score,
                      jax_eval_roc_auc_score]
              'fasttext': [fasttext_train_end_time - fasttext_train_start_time,
                    (fasttext_eval_end_time - fasttext_eval_start_time) / len(tokenized_datasets['test'],
                    fasttext_eval_acc_score,
                    fasttext_eval_roc_auc_score]                    
                    
                    'jax': ['a', 'b', 'c', 'd'], 'fasttext': [0, 0]}
stats_list_pd = pd.DataFrame.from_dict(data)

runtimes = {
    'train'= {
        'pt': pt_train_end_time - pt_train_start_time,
        'jax': jax_train_end_time - jax_train_start_time,
        'fasttext': 0
    },
    'inference' = {
        'pt': pt_eval_end_time - pt_evak_start_time,
        'jax': jax_eval_end_time - jax_eval_start_time,
        'fasttext': 0
    }    
}

sns.barplot(penguins, x="island", y="body_mass_g", hue="sex")